Introduce SetJointsVisibilityCalculator
PiperOrigin-RevId: 570745171
This commit is contained in:
		
							parent
							
								
									c81624d7b2
								
							
						
					
					
						commit
						3b99f8d9dd
					
				|  | @ -1659,6 +1659,49 @@ cc_test( | |||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "set_joints_visibility_calculator", | ||||
|     srcs = ["set_joints_visibility_calculator.cc"], | ||||
|     hdrs = ["set_joints_visibility_calculator.h"], | ||||
|     deps = [ | ||||
|         ":set_joints_visibility_calculator_cc_proto", | ||||
|         "//mediapipe/framework:calculator_framework", | ||||
|         "//mediapipe/framework/api2:node", | ||||
|         "//mediapipe/framework/formats:body_rig_cc_proto", | ||||
|         "//mediapipe/framework/formats:landmark_cc_proto", | ||||
|         "//mediapipe/framework/port:ret_check", | ||||
|         "//mediapipe/framework/port:status", | ||||
|     ], | ||||
|     alwayslink = 1, | ||||
| ) | ||||
| 
 | ||||
| mediapipe_proto_library( | ||||
|     name = "set_joints_visibility_calculator_proto", | ||||
|     srcs = ["set_joints_visibility_calculator.proto"], | ||||
|     deps = [ | ||||
|         "//mediapipe/framework:calculator_options_proto", | ||||
|         "//mediapipe/framework:calculator_proto", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_test( | ||||
|     name = "set_joints_visibility_calculator_test", | ||||
|     srcs = ["set_joints_visibility_calculator_test.cc"], | ||||
|     deps = [ | ||||
|         ":set_joints_visibility_calculator", | ||||
|         "//mediapipe/framework:calculator_framework", | ||||
|         "//mediapipe/framework:calculator_runner", | ||||
|         "//mediapipe/framework:packet", | ||||
|         "//mediapipe/framework/formats:body_rig_cc_proto", | ||||
|         "//mediapipe/framework/formats:landmark_cc_proto", | ||||
|         "//mediapipe/framework/port:gtest_main", | ||||
|         "//mediapipe/framework/port:parse_text_proto", | ||||
|         "//mediapipe/framework/port:status_matchers", | ||||
|         "@com_google_absl//absl/strings", | ||||
|         "@com_google_absl//absl/types:optional", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|     name = "pass_through_or_empty_detection_vector_calculator", | ||||
|     srcs = ["pass_through_or_empty_detection_vector_calculator.cc"], | ||||
|  |  | |||
							
								
								
									
										108
									
								
								mediapipe/calculators/util/set_joints_visibility_calculator.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								mediapipe/calculators/util/set_joints_visibility_calculator.cc
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,108 @@ | |||
| // Copyright 2023 The MediaPipe Authors.
 | ||||
| //
 | ||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||
| // you may not use this file except in compliance with the License.
 | ||||
| // You may obtain a copy of the License at
 | ||||
| //
 | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| //
 | ||||
| // Unless required by applicable law or agreed to in writing, software
 | ||||
| // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||
| // See the License for the specific language governing permissions and
 | ||||
| // limitations under the License.
 | ||||
| 
 | ||||
| #include "mediapipe/calculators/util/set_joints_visibility_calculator.h" | ||||
| 
 | ||||
| #include <algorithm> | ||||
| #include <optional> | ||||
| #include <utility> | ||||
| 
 | ||||
| #include "mediapipe/calculators/util/set_joints_visibility_calculator.pb.h" | ||||
| #include "mediapipe/framework/api2/node.h" | ||||
| #include "mediapipe/framework/calculator_framework.h" | ||||
| #include "mediapipe/framework/formats/body_rig.pb.h" | ||||
| #include "mediapipe/framework/formats/landmark.pb.h" | ||||
| #include "mediapipe/framework/port/ret_check.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace api2 { | ||||
| 
 | ||||
| namespace {}  // namespace
 | ||||
| 
 | ||||
| class SetJointsVisibilityCalculatorImpl | ||||
|     : public NodeImpl<SetJointsVisibilityCalculator> { | ||||
|  public: | ||||
|   absl::Status Open(CalculatorContext* cc) override { | ||||
|     options_ = cc->Options<SetJointsVisibilityCalculatorOptions>(); | ||||
|     return absl::OkStatus(); | ||||
|   } | ||||
| 
 | ||||
|   absl::Status Process(CalculatorContext* cc) override { | ||||
|     // Skip if Joints are empty.
 | ||||
|     if (kInJoints(cc).IsEmpty()) { | ||||
|       return absl::OkStatus(); | ||||
|     } | ||||
| 
 | ||||
|     // Get joints.
 | ||||
|     const JointList& in_joints = kInJoints(cc).Get(); | ||||
|     RET_CHECK_EQ(in_joints.joint_size(), options_.mapping_size()) | ||||
|         << "Number of joints doesn't match number of mappings"; | ||||
| 
 | ||||
|     // Get landmarks.
 | ||||
|     RET_CHECK(!kInLandmarks(cc).IsEmpty()) << "Landmarks must be provided"; | ||||
|     const LandmarkList& in_landmarks = kInLandmarks(cc).Get(); | ||||
| 
 | ||||
|     // Set joints visibility.
 | ||||
|     JointList out_joints; | ||||
|     for (int i = 0; i < in_joints.joint_size(); ++i) { | ||||
|       // Initialize output joint.
 | ||||
|       Joint* out_joint = out_joints.add_joint(); | ||||
|       *out_joint = in_joints.joint(i); | ||||
| 
 | ||||
|       // Get visibility. But only if it exists in the source landmark(s).
 | ||||
|       std::optional<float> visibility; | ||||
|       auto& mapping = options_.mapping(i); | ||||
|       if (mapping.has_unchanged()) { | ||||
|         continue; | ||||
|       } else if (mapping.has_copy()) { | ||||
|         const int idx = mapping.copy().idx(); | ||||
|         RET_CHECK(idx >= 0 && idx < in_landmarks.landmark_size()) | ||||
|             << "Landmark index out of range"; | ||||
|         if (in_landmarks.landmark(idx).has_visibility()) { | ||||
|           visibility = in_landmarks.landmark(idx).visibility(); | ||||
|         } | ||||
|       } else if (mapping.has_highest()) { | ||||
|         RET_CHECK_GT(mapping.highest().idx_size(), 0) << "No indexes provided"; | ||||
|         for (int idx : mapping.highest().idx()) { | ||||
|           RET_CHECK(idx >= 0 && idx < in_landmarks.landmark_size()) | ||||
|               << "Landmark index out of range"; | ||||
|           if (in_landmarks.landmark(idx).has_visibility()) { | ||||
|             const float landmark_visibility = | ||||
|                 in_landmarks.landmark(idx).visibility(); | ||||
|             visibility = visibility.has_value() | ||||
|                              ? std::max(visibility.value(), landmark_visibility) | ||||
|                              : landmark_visibility; | ||||
|           } | ||||
|         } | ||||
|       } else { | ||||
|         RET_CHECK_FAIL() << "Unknown mapping"; | ||||
|       } | ||||
| 
 | ||||
|       // Set visibility. But only if it was possible to obtain it.
 | ||||
|       if (visibility.has_value()) { | ||||
|         out_joint->set_visibility(visibility.value()); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     kOutJoints(cc).Send(std::move(out_joints)); | ||||
|     return absl::OkStatus(); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   SetJointsVisibilityCalculatorOptions options_; | ||||
| }; | ||||
| MEDIAPIPE_NODE_IMPLEMENTATION(SetJointsVisibilityCalculatorImpl); | ||||
| 
 | ||||
| }  // namespace api2
 | ||||
| }  // namespace mediapipe
 | ||||
|  | @ -0,0 +1,68 @@ | |||
| // Copyright 2023 The MediaPipe Authors.
 | ||||
| //
 | ||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||
| // you may not use this file except in compliance with the License.
 | ||||
| // You may obtain a copy of the License at
 | ||||
| //
 | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| //
 | ||||
| // Unless required by applicable law or agreed to in writing, software
 | ||||
| // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||
| // See the License for the specific language governing permissions and
 | ||||
| // limitations under the License.
 | ||||
| 
 | ||||
| #ifndef MEDIAPIPE_CALCULATORS_UTIL_SET_JOINTS_VISIBILITY_CALCULATOR_H_ | ||||
| #define MEDIAPIPE_CALCULATORS_UTIL_SET_JOINTS_VISIBILITY_CALCULATOR_H_ | ||||
| 
 | ||||
| #include "mediapipe/framework/api2/node.h" | ||||
| #include "mediapipe/framework/calculator_framework.h" | ||||
| #include "mediapipe/framework/formats/body_rig.pb.h" | ||||
| #include "mediapipe/framework/formats/landmark.pb.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace api2 { | ||||
| 
 | ||||
| // A calculator set Joints visibility from Landmarks.
 | ||||
| //
 | ||||
| // Calculator allows to either copy visibility right from the landmark or
 | ||||
| // somehow combine visibilities of several landmarks.
 | ||||
| //
 | ||||
| // Input:
 | ||||
| //   JOINTS - JointList
 | ||||
| //     Joints to to update visibility.
 | ||||
| //   LANDMARKS - LandmarkList
 | ||||
| //     Landmarks to take visibility from.
 | ||||
| //
 | ||||
| // Output:
 | ||||
| //   JOINTS - JointList
 | ||||
| //     Joints with updated visibility.
 | ||||
| //
 | ||||
| // Example:
 | ||||
| //   node {
 | ||||
| //     calculator: "SetJointsVisibilityCalculator"
 | ||||
| //     input_stream: "JOINTS:joints"
 | ||||
| //     input_stream: "LANDMARKS:landmarks"
 | ||||
| //     output_stream: "JOINTS:joints_with_visibility"
 | ||||
| //     options: {
 | ||||
| //       [mediapipe.SetJointsVisibilityCalculatorOptions.ext] {
 | ||||
| //         mapping: [
 | ||||
| //           { copy: { idx: 0 } },
 | ||||
| //           { highest: { idx: [5, 6] } }
 | ||||
| //         ]
 | ||||
| //       }
 | ||||
| //     }
 | ||||
| //   }
 | ||||
| class SetJointsVisibilityCalculator : public NodeIntf { | ||||
|  public: | ||||
|   static constexpr Input<mediapipe::JointList> kInJoints{"JOINTS"}; | ||||
|   static constexpr Input<mediapipe::LandmarkList> kInLandmarks{"LANDMARKS"}; | ||||
|   static constexpr Output<mediapipe::JointList> kOutJoints{"JOINTS"}; | ||||
|   MEDIAPIPE_NODE_INTERFACE(SetJointsVisibilityCalculator, kInJoints, | ||||
|                            kInLandmarks, kOutJoints); | ||||
| }; | ||||
| 
 | ||||
| }  // namespace api2
 | ||||
| }  // namespace mediapipe
 | ||||
| 
 | ||||
| #endif  // MEDIAPIPE_CALCULATORS_UTIL_SET_JOINTS_VISIBILITY_CALCULATOR_H_
 | ||||
|  | @ -0,0 +1,55 @@ | |||
| // Copyright 2023 The MediaPipe Authors. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| syntax = "proto2"; | ||||
| 
 | ||||
| package mediapipe; | ||||
| 
 | ||||
| import "mediapipe/framework/calculator.proto"; | ||||
| 
 | ||||
| message SetJointsVisibilityCalculatorOptions { | ||||
|   extend CalculatorOptions { | ||||
|     optional SetJointsVisibilityCalculatorOptions ext = 406440186; | ||||
|   } | ||||
| 
 | ||||
|   // Mapping that tells where to take visibility for the joint. | ||||
|   message Mapping { | ||||
|     // Keep visibility unchanged. | ||||
|     message Unchanged {} | ||||
| 
 | ||||
|     // Copy visibility as is from the given landmark. | ||||
|     message Copy { | ||||
|       // Index of the landmark. | ||||
|       optional int32 idx = 1; | ||||
|     } | ||||
| 
 | ||||
|     // Take the highest visibility among the given landmarks. | ||||
|     message Highest { | ||||
|       // Indexes of landmarks to take the highest visibility value from. At | ||||
|       // least one index must be provided. | ||||
|       repeated int32 idx = 1 [packed = true]; | ||||
|     } | ||||
| 
 | ||||
|     oneof mapping { | ||||
|       Unchanged unchanged = 1; | ||||
|       Copy copy = 2; | ||||
|       Highest highest = 3; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Mapping that tells where to take visibility for each joint. | ||||
|   // Number of mappings must be equal to number of provided joints. Each mapping | ||||
|   // must contain exactly one rule for how to set the joint visibility. | ||||
|   repeated Mapping mapping = 1; | ||||
| } | ||||
|  | @ -0,0 +1,155 @@ | |||
| // Copyright 2023 The MediaPipe Authors.
 | ||||
| //
 | ||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||
| // you may not use this file except in compliance with the License.
 | ||||
| // You may obtain a copy of the License at
 | ||||
| //
 | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| //
 | ||||
| // Unless required by applicable law or agreed to in writing, software
 | ||||
| // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||
| // See the License for the specific language governing permissions and
 | ||||
| // limitations under the License.
 | ||||
| 
 | ||||
| #include <string> | ||||
| #include <utility> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include "absl/strings/substitute.h" | ||||
| #include "absl/types/optional.h" | ||||
| #include "mediapipe/framework/calculator_framework.h" | ||||
| #include "mediapipe/framework/calculator_runner.h" | ||||
| #include "mediapipe/framework/formats/body_rig.pb.h" | ||||
| #include "mediapipe/framework/formats/landmark.pb.h" | ||||
| #include "mediapipe/framework/packet.h" | ||||
| #include "mediapipe/framework/port/gmock.h" | ||||
| #include "mediapipe/framework/port/gtest.h" | ||||
| #include "mediapipe/framework/port/parse_text_proto.h" | ||||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace api2 { | ||||
| namespace { | ||||
| 
 | ||||
| using Node = ::mediapipe::CalculatorGraphConfig::Node; | ||||
| 
 | ||||
| struct SetJointsVisibilityTestCase { | ||||
|   std::string test_name; | ||||
|   std::string mapping; | ||||
|   std::vector<absl::optional<float>> in_joint_visibilities; | ||||
|   std::vector<absl::optional<float>> landmark_visibilities; | ||||
|   std::vector<absl::optional<float>> out_joint_visibilities; | ||||
| }; | ||||
| 
 | ||||
| using SetJointsVisibilityTest = | ||||
|     ::testing::TestWithParam<SetJointsVisibilityTestCase>; | ||||
| 
 | ||||
| TEST_P(SetJointsVisibilityTest, SetJointsVisibilityTest) { | ||||
|   const SetJointsVisibilityTestCase& tc = GetParam(); | ||||
| 
 | ||||
|   // Prepare graph.
 | ||||
|   mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(absl::Substitute( | ||||
|       R"( | ||||
|       calculator: "SetJointsVisibilityCalculator" | ||||
|       input_stream: "JOINTS:joints" | ||||
|       input_stream: "LANDMARKS:landmarks" | ||||
|       output_stream: "JOINTS:joints_with_visibility" | ||||
|       options: { | ||||
|         [mediapipe.SetJointsVisibilityCalculatorOptions.ext] { | ||||
|           mapping: [ | ||||
|             $0 | ||||
|           ] | ||||
|         } | ||||
|       } | ||||
|   )", | ||||
|       tc.mapping))); | ||||
| 
 | ||||
|   // Prepare joints.
 | ||||
|   JointList in_joints; | ||||
|   for (auto vis_opt : tc.in_joint_visibilities) { | ||||
|     Joint* joint = in_joints.add_joint(); | ||||
|     if (vis_opt) { | ||||
|       joint->set_visibility(vis_opt.value()); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Prepare landmarks.
 | ||||
|   LandmarkList landmarks; | ||||
|   for (auto vis_opt : tc.landmark_visibilities) { | ||||
|     Landmark* lmk = landmarks.add_landmark(); | ||||
|     if (vis_opt) { | ||||
|       lmk->set_visibility(vis_opt.value()); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Send joints and landmarks to the graph.
 | ||||
|   runner.MutableInputs()->Tag("JOINTS").packets.push_back( | ||||
|       MakePacket<JointList>(std::move(in_joints)).At(mediapipe::Timestamp(0))); | ||||
|   runner.MutableInputs() | ||||
|       ->Tag("LANDMARKS") | ||||
|       .packets.push_back(MakePacket<LandmarkList>(std::move(landmarks)) | ||||
|                              .At(mediapipe::Timestamp(0))); | ||||
| 
 | ||||
|   // Run the graph.
 | ||||
|   MP_ASSERT_OK(runner.Run()); | ||||
| 
 | ||||
|   const auto& output_packets = runner.Outputs().Tag("JOINTS").packets; | ||||
|   EXPECT_EQ(1, output_packets.size()); | ||||
| 
 | ||||
|   const auto& out_joints = output_packets[0].Get<JointList>(); | ||||
|   EXPECT_EQ(out_joints.joint_size(), tc.out_joint_visibilities.size()); | ||||
|   for (int i = 0; i < out_joints.joint_size(); ++i) { | ||||
|     const Joint& joint = out_joints.joint(i); | ||||
|     auto expected_vis_opt = tc.out_joint_visibilities[i]; | ||||
|     if (expected_vis_opt) { | ||||
|       EXPECT_NEAR(joint.visibility(), expected_vis_opt.value(), 1e-5); | ||||
|     } else { | ||||
|       EXPECT_FALSE(joint.has_visibility()); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| INSTANTIATE_TEST_SUITE_P( | ||||
|     SetJointsVisibilityTests, SetJointsVisibilityTest, | ||||
|     testing::ValuesIn<SetJointsVisibilityTestCase>({ | ||||
|         {"Empty_LandmarksAndJoints", "", {}, {}, {}}, | ||||
|         {"Empty_Joints", "", {}, {1, 2, 3}, {}}, | ||||
|         {"Empty_Landmarks", | ||||
|          "{ unchanged: {} }, { unchanged: {} }, { unchanged: {} }", | ||||
|          {1, 2, absl::nullopt}, | ||||
|          {}, | ||||
|          {1, 2, absl::nullopt}}, | ||||
| 
 | ||||
|         {"Mapping_Unchanged", "{ unchanged: {} }", {1}, {2}, {1}}, | ||||
|         {"Mapping_Unchanged_UnsetJointVisRemainsUnset", | ||||
|          "{ unchanged: {} }", | ||||
|          {absl::nullopt}, | ||||
|          {2}, | ||||
|          {absl::nullopt}}, | ||||
| 
 | ||||
|         {"Mapping_Copy", "{ copy: { idx: 0 } }", {1}, {2}, {2}}, | ||||
|         {"Mapping_Copy_UnsetLmkVisResultsIntoZeroJointVis", | ||||
|          "{ copy: { idx: 0 } }", | ||||
|          {absl::nullopt}, | ||||
|          {absl::nullopt}, | ||||
|          {0}}, | ||||
| 
 | ||||
|         {"Mapping_Highest", | ||||
|          "{ highest: { idx: [0, 1, 2] } }", | ||||
|          {absl::nullopt}, | ||||
|          {2, 4, 3}, | ||||
|          {4}}, | ||||
|         {"Mapping_Highest_UnsetLmkIsIgnored", | ||||
|          "{ highest: { idx: [0, 1, 2] } }", | ||||
|          {absl::nullopt}, | ||||
|          {-2, absl::nullopt, -3}, | ||||
|          {-2}}, | ||||
|     }), | ||||
|     [](const testing::TestParamInfo<SetJointsVisibilityTest::ParamType>& info) { | ||||
|       return info.param.test_name; | ||||
|     }); | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace api2
 | ||||
| }  // namespace mediapipe
 | ||||
|  | @ -248,7 +248,6 @@ rewrite_target_list = [ | |||
|     "segmenter_proto", | ||||
|     "sequence_shift_calculator_proto", | ||||
|     "set_alpha_calculator_proto", | ||||
|     "set_joints_visibility_calculator_proto", | ||||
|     "sharpen_calculator_gl_proto", | ||||
|     "simple_calculator_proto", | ||||
|     "single_shot_detector_gpu_cpu_proto", | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user