Merge branch 'master' into image-classification-python-impl
This commit is contained in:
		
						commit
						4204c8b8a9
					
				|  | @ -157,6 +157,13 @@ http_archive( | ||||||
|     urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"], |     urls = ["https://github.com/google/multichannel-audio-tools/archive/master.zip"], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | http_archive( | ||||||
|  |     name = "pffft", | ||||||
|  |     strip_prefix = "jpommier-pffft-7c3b5a7dc510", | ||||||
|  |     urls = ["https://bitbucket.org/jpommier/pffft/get/7c3b5a7dc510.zip"], | ||||||
|  |     build_file = "@//third_party:pffft.BUILD", | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| # sentencepiece | # sentencepiece | ||||||
| http_archive( | http_archive( | ||||||
|     name = "com_google_sentencepiece", |     name = "com_google_sentencepiece", | ||||||
|  |  | ||||||
|  | @ -217,7 +217,7 @@ A list of pose landmarks. Each landmark consists of the following: | ||||||
| 
 | 
 | ||||||
| *Fig 5. Example of MediaPipe Pose real-world 3D coordinates.* | | *Fig 5. Example of MediaPipe Pose real-world 3D coordinates.* | | ||||||
| :-----------------------------------------------------------: | | :-----------------------------------------------------------: | | ||||||
| <video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/mobile/pose_world_landmarks.mp4" type="video/mp4"></video> | | <video autoplay muted loop preload style="height: auto; width: 480px"><source src="https://mediapipe.dev/images/mobile/pose_world_landmarks.mp4" type="video/mp4"></video> | | ||||||
| 
 | 
 | ||||||
| Another list of pose landmarks in world coordinates. Each landmark consists of | Another list of pose landmarks in world coordinates. Each landmark consists of | ||||||
| the following: | the following: | ||||||
|  | @ -238,7 +238,7 @@ for usage details. | ||||||
| 
 | 
 | ||||||
| *Fig 6. Example of MediaPipe Pose segmentation mask.* | | *Fig 6. Example of MediaPipe Pose segmentation mask.* | | ||||||
| :---------------------------------------------------: | | :---------------------------------------------------: | | ||||||
| <video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/mobile/pose_segmentation.mp4" type="video/mp4"></video> | | <video autoplay muted loop preload style="height: auto; width: 480px"><source src="https://mediapipe.dev/images/mobile/pose_segmentation.mp4" type="video/mp4"></video> | | ||||||
| 
 | 
 | ||||||
| ### Python Solution API | ### Python Solution API | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -22,7 +22,7 @@ nav_order: 7 | ||||||
| 
 | 
 | ||||||
| *Fig 1. Example of MediaPipe Selfie Segmentation.* | | *Fig 1. Example of MediaPipe Selfie Segmentation.* | | ||||||
| :------------------------------------------------: | | :------------------------------------------------: | | ||||||
| <video autoplay muted loop preload style="height: auto; width: 480px"><source src="../images/selfie_segmentation_web.mp4" type="video/mp4"></video> | | <video autoplay muted loop preload style="height: auto; width: 480px"><source src="https://mediapipe.dev/images/selfie_segmentation_web.mp4" type="video/mp4"></video> | | ||||||
| 
 | 
 | ||||||
| MediaPipe Selfie Segmentation segments the prominent humans in the scene. It can | MediaPipe Selfie Segmentation segments the prominent humans in the scene. It can | ||||||
| run in real-time on both smartphones and laptops. The intended use cases include | run in real-time on both smartphones and laptops. The intended use cases include | ||||||
|  |  | ||||||
|  | @ -1294,8 +1294,8 @@ cc_library( | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":get_vector_item_calculator_cc_proto", |         ":get_vector_item_calculator_cc_proto", | ||||||
|         "//mediapipe/framework:calculator_framework", |         "//mediapipe/framework:calculator_framework", | ||||||
|         "//mediapipe/framework:packet", |  | ||||||
|         "//mediapipe/framework/api2:node", |         "//mediapipe/framework/api2:node", | ||||||
|  |         "//mediapipe/framework/api2:packet", | ||||||
|         "//mediapipe/framework/api2:port", |         "//mediapipe/framework/api2:port", | ||||||
|         "//mediapipe/framework/formats:classification_cc_proto", |         "//mediapipe/framework/formats:classification_cc_proto", | ||||||
|         "//mediapipe/framework/formats:landmark_cc_proto", |         "//mediapipe/framework/formats:landmark_cc_proto", | ||||||
|  | @ -1319,6 +1319,32 @@ cc_test( | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "vector_indices_calculator", | ||||||
|  |     srcs = ["vector_indices_calculator.cc"], | ||||||
|  |     hdrs = ["vector_indices_calculator.h"], | ||||||
|  |     visibility = ["//visibility:public"], | ||||||
|  |     deps = [ | ||||||
|  |         "//mediapipe/framework:calculator_framework", | ||||||
|  |         "//mediapipe/framework/api2:node", | ||||||
|  |         "//mediapipe/framework/formats:landmark_cc_proto", | ||||||
|  |         "//mediapipe/framework/port:status", | ||||||
|  |     ], | ||||||
|  |     alwayslink = 1, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | cc_test( | ||||||
|  |     name = "vector_indices_calculator_test", | ||||||
|  |     srcs = ["vector_indices_calculator_test.cc"], | ||||||
|  |     deps = [ | ||||||
|  |         ":vector_indices_calculator", | ||||||
|  |         "//mediapipe/framework:calculator_framework", | ||||||
|  |         "//mediapipe/framework:calculator_runner", | ||||||
|  |         "//mediapipe/framework/port:gtest_main", | ||||||
|  |         "//mediapipe/framework/port:parse_text_proto", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| cc_library( | cc_library( | ||||||
|     name = "vector_size_calculator", |     name = "vector_size_calculator", | ||||||
|     srcs = ["vector_size_calculator.cc"], |     srcs = ["vector_size_calculator.cc"], | ||||||
|  |  | ||||||
|  | @ -40,6 +40,9 @@ REGISTER_CALCULATOR(EndLoopNormalizedLandmarkListVectorCalculator); | ||||||
| typedef EndLoopCalculator<std::vector<bool>> EndLoopBooleanCalculator; | typedef EndLoopCalculator<std::vector<bool>> EndLoopBooleanCalculator; | ||||||
| REGISTER_CALCULATOR(EndLoopBooleanCalculator); | REGISTER_CALCULATOR(EndLoopBooleanCalculator); | ||||||
| 
 | 
 | ||||||
|  | typedef EndLoopCalculator<std::vector<float>> EndLoopFloatCalculator; | ||||||
|  | REGISTER_CALCULATOR(EndLoopFloatCalculator); | ||||||
|  | 
 | ||||||
| typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>> | typedef EndLoopCalculator<std::vector<::mediapipe::RenderData>> | ||||||
|     EndLoopRenderDataCalculator; |     EndLoopRenderDataCalculator; | ||||||
| REGISTER_CALCULATOR(EndLoopRenderDataCalculator); | REGISTER_CALCULATOR(EndLoopRenderDataCalculator); | ||||||
|  |  | ||||||
|  | @ -24,6 +24,10 @@ using GetLandmarkListVectorItemCalculator = | ||||||
|     GetVectorItemCalculator<mediapipe::LandmarkList>; |     GetVectorItemCalculator<mediapipe::LandmarkList>; | ||||||
| REGISTER_CALCULATOR(GetLandmarkListVectorItemCalculator); | REGISTER_CALCULATOR(GetLandmarkListVectorItemCalculator); | ||||||
| 
 | 
 | ||||||
|  | using GetNormalizedLandmarkListVectorItemCalculator = | ||||||
|  |     GetVectorItemCalculator<mediapipe::NormalizedLandmarkList>; | ||||||
|  | REGISTER_CALCULATOR(GetNormalizedLandmarkListVectorItemCalculator); | ||||||
|  | 
 | ||||||
| using GetClassificationListVectorItemCalculator = | using GetClassificationListVectorItemCalculator = | ||||||
|     GetVectorItemCalculator<mediapipe::ClassificationList>; |     GetVectorItemCalculator<mediapipe::ClassificationList>; | ||||||
| REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator); | REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator); | ||||||
|  |  | ||||||
|  | @ -19,6 +19,7 @@ | ||||||
| 
 | 
 | ||||||
| #include "mediapipe/calculators/core/get_vector_item_calculator.pb.h" | #include "mediapipe/calculators/core/get_vector_item_calculator.pb.h" | ||||||
| #include "mediapipe/framework/api2/node.h" | #include "mediapipe/framework/api2/node.h" | ||||||
|  | #include "mediapipe/framework/api2/packet.h" | ||||||
| #include "mediapipe/framework/api2/port.h" | #include "mediapipe/framework/api2/port.h" | ||||||
| #include "mediapipe/framework/calculator_framework.h" | #include "mediapipe/framework/calculator_framework.h" | ||||||
| #include "mediapipe/framework/port/ret_check.h" | #include "mediapipe/framework/port/ret_check.h" | ||||||
|  | @ -58,7 +59,7 @@ template <typename T> | ||||||
| class GetVectorItemCalculator : public Node { | class GetVectorItemCalculator : public Node { | ||||||
|  public: |  public: | ||||||
|   static constexpr Input<std::vector<T>> kIn{"VECTOR"}; |   static constexpr Input<std::vector<T>> kIn{"VECTOR"}; | ||||||
|   static constexpr Input<int>::Optional kIdx{"INDEX"}; |   static constexpr Input<OneOf<int, uint64_t>>::Optional kIdx{"INDEX"}; | ||||||
|   static constexpr Output<T> kOut{"ITEM"}; |   static constexpr Output<T> kOut{"ITEM"}; | ||||||
| 
 | 
 | ||||||
|   MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut); |   MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut); | ||||||
|  | @ -80,7 +81,9 @@ class GetVectorItemCalculator : public Node { | ||||||
| 
 | 
 | ||||||
|     int idx = 0; |     int idx = 0; | ||||||
|     if (kIdx(cc).IsConnected() && !kIdx(cc).IsEmpty()) { |     if (kIdx(cc).IsConnected() && !kIdx(cc).IsEmpty()) { | ||||||
|       idx = kIdx(cc).Get(); |       idx = kIdx(cc).Visit( | ||||||
|  |           [](uint64_t idx_uint64_t) { return static_cast<int>(idx_uint64_t); }, | ||||||
|  |           [](int idx_int) { return idx_int; }); | ||||||
|     } else if (options.has_item_index()) { |     } else if (options.has_item_index()) { | ||||||
|       idx = options.item_index(); |       idx = options.item_index(); | ||||||
|     } else { |     } else { | ||||||
|  |  | ||||||
|  | @ -227,4 +227,15 @@ TEST(TestGetIntVectorItemCalculatorTest, IndexOptionsTwoTimestamps) { | ||||||
|               testing::ElementsAre(TimestampValue(1), TimestampValue(2))); |               testing::ElementsAre(TimestampValue(1), TimestampValue(2))); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST(TestGetIntVectorItemCalculatorTest, IndexUint64) { | ||||||
|  |   CalculatorRunner runner = MakeRunnerWithStream(); | ||||||
|  |   const std::vector<int> inputs = {1, 2, 3}; | ||||||
|  |   const uint64_t index = 1; | ||||||
|  |   AddInputVector(runner, inputs, 1); | ||||||
|  |   AddInputIndex(runner, index, 1); | ||||||
|  |   MP_ASSERT_OK(runner.Run()); | ||||||
|  |   const std::vector<Packet>& outputs = runner.Outputs().Tag("ITEM").packets; | ||||||
|  |   EXPECT_THAT(outputs, testing::ElementsAre(IntPacket(inputs[index]))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace mediapipe
 | }  // namespace mediapipe
 | ||||||
|  |  | ||||||
							
								
								
									
										33
									
								
								mediapipe/calculators/core/vector_indices_calculator.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								mediapipe/calculators/core/vector_indices_calculator.cc
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,33 @@ | ||||||
|  | // Copyright 2022 The MediaPipe Authors.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/calculators/core/vector_indices_calculator.h" | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/framework/formats/landmark.pb.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace api2 { | ||||||
|  | 
 | ||||||
|  | using IntVectorIndicesCalculator = VectorIndicesCalculator<int>; | ||||||
|  | REGISTER_CALCULATOR(IntVectorIndicesCalculator); | ||||||
|  | 
 | ||||||
|  | using Uint64tVectorIndicesCalculator = VectorIndicesCalculator<uint64_t>; | ||||||
|  | REGISTER_CALCULATOR(Uint64tVectorIndicesCalculator); | ||||||
|  | 
 | ||||||
|  | using NormalizedLandmarkListVectorIndicesCalculator = | ||||||
|  |     VectorIndicesCalculator<mediapipe::NormalizedLandmarkList>; | ||||||
|  | REGISTER_CALCULATOR(NormalizedLandmarkListVectorIndicesCalculator); | ||||||
|  | 
 | ||||||
|  | }  // namespace api2
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
							
								
								
									
										65
									
								
								mediapipe/calculators/core/vector_indices_calculator.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								mediapipe/calculators/core/vector_indices_calculator.h
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,65 @@ | ||||||
|  | // Copyright 2022 The MediaPipe Authors.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #ifndef MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_ | ||||||
|  | #define MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_ | ||||||
|  | 
 | ||||||
|  | #include <optional> | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/framework/api2/node.h" | ||||||
|  | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | #include "mediapipe/framework/port/status.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace api2 { | ||||||
|  | // Calculator that takes a vector and constructs an index range vector based on
 | ||||||
|  | // the size of the input vector.
 | ||||||
|  | //
 | ||||||
|  | // Inputs:
 | ||||||
|  | //   VECTOR - std::vector<T>
 | ||||||
|  | //     Vector whose range of indices to return.
 | ||||||
|  | //
 | ||||||
|  | // Outputs:
 | ||||||
|  | //   INDICES - std::vector<int>
 | ||||||
|  | //     Indices vector of the input vector.
 | ||||||
|  | //
 | ||||||
|  | // Example config:
 | ||||||
|  | //  node {
 | ||||||
|  | //    calculator: "{SpecificType}VectorIndicesCalculator"
 | ||||||
|  | //    input_stream: "VECTOR:vector"
 | ||||||
|  | //    output_stream: "INDICES:indices"
 | ||||||
|  | //  }
 | ||||||
|  | //
 | ||||||
|  | template <typename T> | ||||||
|  | class VectorIndicesCalculator : public Node { | ||||||
|  |  public: | ||||||
|  |   static constexpr Input<std::vector<T>> kVector{"VECTOR"}; | ||||||
|  |   static constexpr Output<std::vector<int>> kRange{"INDICES"}; | ||||||
|  | 
 | ||||||
|  |   MEDIAPIPE_NODE_CONTRACT(kVector, kRange); | ||||||
|  | 
 | ||||||
|  |   absl::Status Process(CalculatorContext* cc) final { | ||||||
|  |     // Get the size of the input vector.
 | ||||||
|  |     const int vector_size = kVector(cc).Get().size(); | ||||||
|  |     std::vector<int> out_idxs(vector_size); | ||||||
|  |     std::iota(out_idxs.begin(), out_idxs.end(), 0); | ||||||
|  |     kRange(cc).Send(out_idxs); | ||||||
|  |     return absl::OkStatus(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace api2
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
|  | 
 | ||||||
|  | #endif  // MEDIAPIPE_CALCULATORS_CORE_VECTOR_INDICES_CALCULATOR_H_
 | ||||||
							
								
								
									
										87
									
								
								mediapipe/calculators/core/vector_indices_calculator_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								mediapipe/calculators/core/vector_indices_calculator_test.cc
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,87 @@ | ||||||
|  | // Copyright 2022 The MediaPipe Authors.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/calculators/core/vector_indices_calculator.h" | ||||||
|  | 
 | ||||||
|  | #include <memory> | ||||||
|  | #include <string> | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | #include "mediapipe/framework/calculator_runner.h" | ||||||
|  | #include "mediapipe/framework/port/gmock.h" | ||||||
|  | #include "mediapipe/framework/port/status_matchers.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | 
 | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | using ::testing::TestParamInfo; | ||||||
|  | using ::testing::TestWithParam; | ||||||
|  | using ::testing::Values; | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | void AddInputVector(CalculatorRunner& runner, const std::vector<T>& inputs, | ||||||
|  |                     int timestamp) { | ||||||
|  |   runner.MutableInputs()->Tag("VECTOR").packets.push_back( | ||||||
|  |       MakePacket<std::vector<T>>(inputs).At(Timestamp(timestamp))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | struct TestParams { | ||||||
|  |   const std::string test_name; | ||||||
|  |   const std::vector<T> inputs; | ||||||
|  |   const int timestamp; | ||||||
|  |   const std::vector<int> expected_indices; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | class IntVectorIndicesCalculatorTest | ||||||
|  |     : public testing::TestWithParam<TestParams<int>> {}; | ||||||
|  | 
 | ||||||
|  | TEST_P(IntVectorIndicesCalculatorTest, Succeeds) { | ||||||
|  |   CalculatorRunner runner = CalculatorRunner(R"( | ||||||
|  |     calculator: "IntVectorIndicesCalculator" | ||||||
|  |     input_stream: "VECTOR:vector_stream" | ||||||
|  |     output_stream: "INDICES:indices_stream" | ||||||
|  |   )"); | ||||||
|  |   const std::vector<int>& inputs = GetParam().inputs; | ||||||
|  |   std::vector<int> expected_indices(inputs.size()); | ||||||
|  |   AddInputVector(runner, inputs, GetParam().timestamp); | ||||||
|  |   MP_ASSERT_OK(runner.Run()); | ||||||
|  |   const std::vector<Packet>& outputs = runner.Outputs().Tag("INDICES").packets; | ||||||
|  |   EXPECT_EQ(1, outputs.size()); | ||||||
|  |   EXPECT_THAT(outputs[0].Get<std::vector<int>>(), | ||||||
|  |               testing::ElementsAreArray(GetParam().expected_indices)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | INSTANTIATE_TEST_SUITE_P( | ||||||
|  |     IntVectorIndicesCalculatorTest, IntVectorIndicesCalculatorTest, | ||||||
|  |     Values(TestParams<int>{ | ||||||
|  |                /* test_name= */ "IntVectorIndices", | ||||||
|  |                /* inputs= */ {1, 2, 3}, | ||||||
|  |                /* timestamp= */ 1, | ||||||
|  |                /* expected_indices= */ {0, 1, 2}, | ||||||
|  |            }, | ||||||
|  |            TestParams<int>{ | ||||||
|  |                /* test_name= */ "EmptyVector", | ||||||
|  |                /* inputs= */ {}, | ||||||
|  |                /* timestamp= */ 1, | ||||||
|  |                /* expected_indices= */ {}, | ||||||
|  |            }), | ||||||
|  |     [](const TestParamInfo<IntVectorIndicesCalculatorTest::ParamType>& info) { | ||||||
|  |       return info.param.test_name; | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
|  | @ -55,6 +55,14 @@ mediapipe_proto_library( | ||||||
| cc_library( | cc_library( | ||||||
|     name = "audio_to_tensor_calculator", |     name = "audio_to_tensor_calculator", | ||||||
|     srcs = ["audio_to_tensor_calculator.cc"], |     srcs = ["audio_to_tensor_calculator.cc"], | ||||||
|  |     copts = select({ | ||||||
|  |         # b/215212850 | ||||||
|  |         "//mediapipe:apple": [ | ||||||
|  |             "-x objective-c++", | ||||||
|  |             "-fobjc-arc", | ||||||
|  |         ], | ||||||
|  |         "//conditions:default": [], | ||||||
|  |     }), | ||||||
|     visibility = [ |     visibility = [ | ||||||
|         "//mediapipe/framework:mediapipe_internal", |         "//mediapipe/framework:mediapipe_internal", | ||||||
|     ], |     ], | ||||||
|  | @ -67,13 +75,16 @@ cc_library( | ||||||
|         "//mediapipe/framework/formats:matrix", |         "//mediapipe/framework/formats:matrix", | ||||||
|         "//mediapipe/framework/formats:tensor", |         "//mediapipe/framework/formats:tensor", | ||||||
|         "//mediapipe/framework/formats:time_series_header_cc_proto", |         "//mediapipe/framework/formats:time_series_header_cc_proto", | ||||||
|  |         "//mediapipe/framework/port:ret_check", | ||||||
|         "//mediapipe/util:time_series_util", |         "//mediapipe/util:time_series_util", | ||||||
|         "@com_google_absl//absl/memory", |         "@com_google_absl//absl/memory", | ||||||
|         "@com_google_absl//absl/status", |         "@com_google_absl//absl/status", | ||||||
|         "@com_google_absl//absl/status:statusor", |         "@com_google_absl//absl/status:statusor", | ||||||
|         "@com_google_absl//absl/strings:str_format", |         "@com_google_absl//absl/strings:str_format", | ||||||
|         "@com_google_audio_tools//audio/dsp:resampler_q", |         "@com_google_audio_tools//audio/dsp:resampler_q", | ||||||
|  |         "@com_google_audio_tools//audio/dsp:window_functions", | ||||||
|         "@org_tensorflow//tensorflow/lite/c:common", |         "@org_tensorflow//tensorflow/lite/c:common", | ||||||
|  |         "@pffft", | ||||||
|     ], |     ], | ||||||
|     alwayslink = 1, |     alwayslink = 1, | ||||||
| ) | ) | ||||||
|  | @ -83,6 +94,7 @@ cc_test( | ||||||
|     srcs = ["audio_to_tensor_calculator_test.cc"], |     srcs = ["audio_to_tensor_calculator_test.cc"], | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":audio_to_tensor_calculator", |         ":audio_to_tensor_calculator", | ||||||
|  |         ":audio_to_tensor_calculator_cc_proto", | ||||||
|         "//mediapipe/framework:calculator_cc_proto", |         "//mediapipe/framework:calculator_cc_proto", | ||||||
|         "//mediapipe/framework:calculator_framework", |         "//mediapipe/framework:calculator_framework", | ||||||
|         "//mediapipe/framework:timestamp", |         "//mediapipe/framework:timestamp", | ||||||
|  | @ -97,6 +109,58 @@ cc_test( | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | mediapipe_proto_library( | ||||||
|  |     name = "feedback_tensors_calculator_proto", | ||||||
|  |     srcs = ["feedback_tensors_calculator.proto"], | ||||||
|  |     visibility = [ | ||||||
|  |         "//mediapipe/framework:mediapipe_internal", | ||||||
|  |     ], | ||||||
|  |     deps = [ | ||||||
|  |         "//mediapipe/framework:calculator_options_proto", | ||||||
|  |         "//mediapipe/framework:calculator_proto", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "feedback_tensors_calculator", | ||||||
|  |     srcs = ["feedback_tensors_calculator.cc"], | ||||||
|  |     copts = select({ | ||||||
|  |         # b/215212850 | ||||||
|  |         "//mediapipe:apple": [ | ||||||
|  |             "-x objective-c++", | ||||||
|  |             "-fobjc-arc", | ||||||
|  |         ], | ||||||
|  |         "//conditions:default": [], | ||||||
|  |     }), | ||||||
|  |     visibility = [ | ||||||
|  |         "//mediapipe/framework:mediapipe_internal", | ||||||
|  |     ], | ||||||
|  |     deps = [ | ||||||
|  |         ":feedback_tensors_calculator_cc_proto", | ||||||
|  |         "//mediapipe/framework:calculator_framework", | ||||||
|  |         "//mediapipe/framework/api2:node", | ||||||
|  |         "//mediapipe/framework/formats:tensor", | ||||||
|  |         "@com_google_absl//absl/status", | ||||||
|  |     ], | ||||||
|  |     alwayslink = 1, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | cc_test( | ||||||
|  |     name = "feedback_tensors_calculator_test", | ||||||
|  |     srcs = ["feedback_tensors_calculator_test.cc"], | ||||||
|  |     deps = [ | ||||||
|  |         ":feedback_tensors_calculator", | ||||||
|  |         ":feedback_tensors_calculator_cc_proto", | ||||||
|  |         "//mediapipe/framework:calculator_cc_proto", | ||||||
|  |         "//mediapipe/framework:calculator_framework", | ||||||
|  |         "//mediapipe/framework:timestamp", | ||||||
|  |         "//mediapipe/framework/formats:tensor", | ||||||
|  |         "//mediapipe/framework/port:gtest_main", | ||||||
|  |         "//mediapipe/framework/port:parse_text_proto", | ||||||
|  |         "@org_tensorflow//tensorflow/lite/c:common", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| mediapipe_proto_library( | mediapipe_proto_library( | ||||||
|     name = "inference_calculator_proto", |     name = "inference_calculator_proto", | ||||||
|     srcs = ["inference_calculator.proto"], |     srcs = ["inference_calculator.proto"], | ||||||
|  | @ -346,6 +410,10 @@ cc_library( | ||||||
|     }), |     }), | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | # This target provides the InferenceCalculator and a default set of implementations tailored for the | ||||||
|  | # current build platforms. More implementations can be added as separate dependencies to a client; | ||||||
|  | # for clients that want a narrower set of implementations than the default should see the comment on | ||||||
|  | # inference_calculator_interface. | ||||||
| cc_library( | cc_library( | ||||||
|     name = "inference_calculator", |     name = "inference_calculator", | ||||||
|     visibility = ["//visibility:public"], |     visibility = ["//visibility:public"], | ||||||
|  |  | ||||||
|  | @ -12,9 +12,8 @@ | ||||||
| // See the License for the specific language governing permissions and
 | // See the License for the specific language governing permissions and
 | ||||||
| // limitations under the License.
 | // limitations under the License.
 | ||||||
| 
 | 
 | ||||||
| #include <math.h> |  | ||||||
| 
 |  | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
|  | #include <cmath> | ||||||
| #include <cstring> | #include <cstring> | ||||||
| #include <memory> | #include <memory> | ||||||
| #include <string> | #include <string> | ||||||
|  | @ -26,6 +25,7 @@ | ||||||
| #include "absl/status/statusor.h" | #include "absl/status/statusor.h" | ||||||
| #include "absl/strings/str_format.h" | #include "absl/strings/str_format.h" | ||||||
| #include "audio/dsp/resampler_q.h" | #include "audio/dsp/resampler_q.h" | ||||||
|  | #include "audio/dsp/window_functions.h" | ||||||
| #include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h" | #include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h" | ||||||
| #include "mediapipe/framework/api2/node.h" | #include "mediapipe/framework/api2/node.h" | ||||||
| #include "mediapipe/framework/api2/packet.h" | #include "mediapipe/framework/api2/packet.h" | ||||||
|  | @ -34,19 +34,60 @@ | ||||||
| #include "mediapipe/framework/formats/matrix.h" | #include "mediapipe/framework/formats/matrix.h" | ||||||
| #include "mediapipe/framework/formats/tensor.h" | #include "mediapipe/framework/formats/tensor.h" | ||||||
| #include "mediapipe/framework/formats/time_series_header.pb.h" | #include "mediapipe/framework/formats/time_series_header.pb.h" | ||||||
|  | #include "mediapipe/framework/port/ret_check.h" | ||||||
| #include "mediapipe/util/time_series_util.h" | #include "mediapipe/util/time_series_util.h" | ||||||
|  | #include "pffft.h" | ||||||
| 
 | 
 | ||||||
| namespace mediapipe { | namespace mediapipe { | ||||||
| namespace api2 { | namespace api2 { | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | using Options = ::mediapipe::AudioToTensorCalculatorOptions; | ||||||
|  | using FlushMode = Options::FlushMode; | ||||||
|  | 
 | ||||||
|  | std::vector<float> HannWindow(int window_size, bool sqrt_hann) { | ||||||
|  |   std::vector<float> hann_window(window_size); | ||||||
|  |   audio_dsp::HannWindow().GetPeriodicSamples(window_size, &hann_window); | ||||||
|  |   if (sqrt_hann) { | ||||||
|  |     absl::c_transform(hann_window, hann_window.begin(), | ||||||
|  |                       [](double x) { return std::sqrt(x); }); | ||||||
|  |   } | ||||||
|  |   return hann_window; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // PFFFT only supports transforms for inputs of length N of the form
 | ||||||
|  | // N = (2^a)*(3^b)*(5^c) where b >=0 and c >= 0 and a >= 5 for the real FFT.
 | ||||||
|  | bool IsValidFftSize(int size) { | ||||||
|  |   if (size <= 0) { | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |   constexpr int kFactors[] = {2, 3, 5}; | ||||||
|  |   int factorization[] = {0, 0, 0}; | ||||||
|  |   int n = static_cast<int>(size); | ||||||
|  |   for (int i = 0; i < 3; ++i) { | ||||||
|  |     while (n % kFactors[i] == 0) { | ||||||
|  |       n = n / kFactors[i]; | ||||||
|  |       ++factorization[i]; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   return factorization[0] >= 5 && n == 1; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| // Converts audio buffers into tensors, possibly with resampling, buffering
 | // Converts audio buffers into tensors, possibly with resampling, buffering
 | ||||||
| // and framing, according to specified inputs and options. All input audio
 | // and framing, according to specified inputs and options. All input audio
 | ||||||
| // buffers will be first resampled from the input sample rate to the target
 | // buffers will be first resampled from the input sample rate to the target
 | ||||||
| // sample rate if they are not equal. The resampled audio data (with the
 | // sample rate if they are not equal. The resampled audio data (with the
 | ||||||
| // buffered samples from the previous runs in the streaming mode) will be broken
 | // buffered samples from the previous runs in the streaming mode) will be broken
 | ||||||
| // into fixed-sized, possibly overlapping frames. Finally, all frames will be
 | // into fixed-sized, possibly overlapping frames. If the calculator is not asked
 | ||||||
| // converted to and outputted as MediaPipe Tensors. The last output tensor will
 | // to perform fft (the fft_size is not set in the calculator options), all
 | ||||||
| // be zero-padding if the remaining samples are insufficient.
 | // frames will be converted to and outputted as MediaPipe Tensors. The last
 | ||||||
|  | // output tensor will be zero-padding if the remaining samples are insufficient.
 | ||||||
|  | // Otherwise, when the fft_size is set and valid, the calculator will perform
 | ||||||
|  | // fft on the fixed-sized audio frames, the complex DFT results will be
 | ||||||
|  | // converted to and outputted as 2D MediaPipe float Tensors where the first
 | ||||||
|  | // rows are the DFT real parts and the second rows are the DFT imagery parts.
 | ||||||
| //
 | //
 | ||||||
| // This calculator assumes that the input timestamps refer to the first
 | // This calculator assumes that the input timestamps refer to the first
 | ||||||
| // sample in each Matrix. The output timestamps follow this same convention.
 | // sample in each Matrix. The output timestamps follow this same convention.
 | ||||||
|  | @ -86,11 +127,15 @@ namespace api2 { | ||||||
| // Outputs:
 | // Outputs:
 | ||||||
| //   TENSORS - std::vector<Tensor>
 | //   TENSORS - std::vector<Tensor>
 | ||||||
| //     Vector containing a single Tensor that represents a fix-sized audio
 | //     Vector containing a single Tensor that represents a fix-sized audio
 | ||||||
| //     frame.
 | //     frame or the complex DFT results.
 | ||||||
| //   TIMESTAMPS - std::vector<Timestamp> @Optional
 | //   TIMESTAMPS - std::vector<Timestamp> @Optional
 | ||||||
| //     Vector containing the output timestamps emitted by the current Process()
 | //     Vector containing the output timestamps emitted by the current Process()
 | ||||||
| //     invocation. In the non-streaming mode, the vector contains all of the
 | //     invocation. In the non-streaming mode, the vector contains all of the
 | ||||||
| //     output timestamps for an input audio buffer.
 | //     output timestamps for an input audio buffer.
 | ||||||
|  | //   DC_AND_NYQUIST - std::pair<float, float> @Optional.
 | ||||||
|  | //     A pair of dc component and nyquest component. Only can be connected when
 | ||||||
|  | //     the calculator performs fft (the fft_size is set in the calculator
 | ||||||
|  | //     options).
 | ||||||
| //
 | //
 | ||||||
| // Example:
 | // Example:
 | ||||||
| // node {
 | // node {
 | ||||||
|  | @ -116,12 +161,14 @@ class AudioToTensorCalculator : public Node { | ||||||
|   // such as sample rate.
 |   // such as sample rate.
 | ||||||
|   static constexpr Input<double>::Optional kAudioSampleRateIn{"SAMPLE_RATE"}; |   static constexpr Input<double>::Optional kAudioSampleRateIn{"SAMPLE_RATE"}; | ||||||
|   static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"}; |   static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"}; | ||||||
|  |   static constexpr Output<std::pair<float, float>>::Optional kDcAndNyquistOut{ | ||||||
|  |       "DC_AND_NYQUIST"}; | ||||||
|   // A vector of the output timestamps emitted by the current Process()
 |   // A vector of the output timestamps emitted by the current Process()
 | ||||||
|   // invocation. The packet timestamp is the last emitted timestamp.
 |   // invocation. The packet timestamp is the last emitted timestamp.
 | ||||||
|   static constexpr Output<std::vector<Timestamp>>::Optional kTimestampsOut{ |   static constexpr Output<std::vector<Timestamp>>::Optional kTimestampsOut{ | ||||||
|       "TIMESTAMPS"}; |       "TIMESTAMPS"}; | ||||||
|   MEDIAPIPE_NODE_CONTRACT(kAudioIn, kAudioSampleRateIn, kTensorsOut, |   MEDIAPIPE_NODE_CONTRACT(kAudioIn, kAudioSampleRateIn, kTensorsOut, | ||||||
|                           kTimestampsOut); |                           kDcAndNyquistOut, kTimestampsOut); | ||||||
| 
 | 
 | ||||||
|   static absl::Status UpdateContract(CalculatorContract* cc); |   static absl::Status UpdateContract(CalculatorContract* cc); | ||||||
|   absl::Status Open(CalculatorContext* cc); |   absl::Status Open(CalculatorContext* cc); | ||||||
|  | @ -138,6 +185,9 @@ class AudioToTensorCalculator : public Node { | ||||||
|   int frame_step_; |   int frame_step_; | ||||||
|   bool stream_mode_; |   bool stream_mode_; | ||||||
|   bool check_inconsistent_timestamps_; |   bool check_inconsistent_timestamps_; | ||||||
|  |   int padding_samples_before_; | ||||||
|  |   int padding_samples_after_; | ||||||
|  |   FlushMode flush_mode_; | ||||||
|   Timestamp initial_timestamp_ = Timestamp::Unstarted(); |   Timestamp initial_timestamp_ = Timestamp::Unstarted(); | ||||||
|   int64 cumulative_input_samples_ = 0; |   int64 cumulative_input_samples_ = 0; | ||||||
|   Timestamp next_output_timestamp_ = Timestamp::Unstarted(); |   Timestamp next_output_timestamp_ = Timestamp::Unstarted(); | ||||||
|  | @ -151,22 +201,33 @@ class AudioToTensorCalculator : public Node { | ||||||
|   Matrix sample_buffer_; |   Matrix sample_buffer_; | ||||||
|   int processed_buffer_cols_ = 0; |   int processed_buffer_cols_ = 0; | ||||||
| 
 | 
 | ||||||
|  |   // The internal state of the FFT library.
 | ||||||
|  |   PFFFT_Setup* fft_state_ = nullptr; | ||||||
|  |   int fft_size_ = 0; | ||||||
|  |   std::vector<float> fft_window_; | ||||||
|  |   std::vector<float, Eigen::aligned_allocator<float>> fft_input_buffer_; | ||||||
|  |   // pffft requires memory to work with to avoid using the stack.
 | ||||||
|  |   std::vector<float, Eigen::aligned_allocator<float>> fft_workplace_; | ||||||
|  |   std::vector<float, Eigen::aligned_allocator<float>> fft_output_; | ||||||
|  | 
 | ||||||
|   absl::Status ProcessStreamingData(CalculatorContext* cc, const Matrix& input); |   absl::Status ProcessStreamingData(CalculatorContext* cc, const Matrix& input); | ||||||
|   absl::Status ProcessNonStreamingData(CalculatorContext* cc, |   absl::Status ProcessNonStreamingData(CalculatorContext* cc, | ||||||
|                                        const Matrix& input); |                                        const Matrix& input); | ||||||
| 
 | 
 | ||||||
|   absl::Status SetupStreamingResampler(double input_sample_rate_); |   absl::Status SetupStreamingResampler(double input_sample_rate_); | ||||||
|   void AppendToSampleBuffer(Matrix buffer_to_append); |   void AppendToSampleBuffer(Matrix buffer_to_append); | ||||||
|  |   void AppendZerosToSampleBuffer(int num_samples); | ||||||
| 
 | 
 | ||||||
|   absl::StatusOr<std::vector<Tensor>> ConvertToTensor( |   absl::StatusOr<std::vector<Tensor>> ConvertToTensor( | ||||||
|       const Matrix& frame_to_convert); |       const Matrix& block, std::vector<int> tensor_dims); | ||||||
|   absl::Status OutputTensors(const Matrix& buffer, bool should_flush, |   absl::Status OutputTensor(const Matrix& block, Timestamp timestamp, | ||||||
|  |                             CalculatorContext* cc); | ||||||
|  |   absl::Status ProcessBuffer(const Matrix& buffer, bool should_flush, | ||||||
|                              CalculatorContext* cc); |                              CalculatorContext* cc); | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) { | absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) { | ||||||
|   const auto& options = |   const auto& options = cc->Options<Options>(); | ||||||
|       cc->Options<mediapipe::AudioToTensorCalculatorOptions>(); |  | ||||||
|   if (!options.has_num_channels() || !options.has_num_samples() || |   if (!options.has_num_channels() || !options.has_num_samples() || | ||||||
|       !options.has_target_sample_rate()) { |       !options.has_target_sample_rate()) { | ||||||
|     return absl::InvalidArgumentError( |     return absl::InvalidArgumentError( | ||||||
|  | @ -174,13 +235,21 @@ absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) { | ||||||
|         "`num_channels`, `num_samples`, and `target_sample_rate`."); |         "`num_channels`, `num_samples`, and `target_sample_rate`."); | ||||||
|   } |   } | ||||||
|   if (options.stream_mode()) { |   if (options.stream_mode()) { | ||||||
|     // Explicitly disables tiemstamp offset to disallow the timestamp bound
 |     // Explicitly disables timestamp offset to disallow the timestamp bound
 | ||||||
|     // from the input streams to be propagated to the output streams.
 |     // from the input streams to be propagated to the output streams.
 | ||||||
|     // In the streaming mode, the output timestamp bound is based on
 |     // In the streaming mode, the output timestamp bound is based on
 | ||||||
|     // next_output_timestamp_, which can be smaller than the current input
 |     // next_output_timestamp_, which can be smaller than the current input
 | ||||||
|     // timestamps.
 |     // timestamps.
 | ||||||
|     cc->SetTimestampOffset(TimestampDiff::Unset()); |     cc->SetTimestampOffset(TimestampDiff::Unset()); | ||||||
|   } |   } | ||||||
|  |   if (options.padding_samples_before() < 0 || | ||||||
|  |       options.padding_samples_after() < 0) { | ||||||
|  |     return absl::InvalidArgumentError("Negative zero padding unsupported"); | ||||||
|  |   } | ||||||
|  |   if (options.flush_mode() != Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX && | ||||||
|  |       options.flush_mode() != Options::PROCEED_AS_USUAL) { | ||||||
|  |     return absl::InvalidArgumentError("Unsupported flush mode"); | ||||||
|  |   } | ||||||
|   return absl::OkStatus(); |   return absl::OkStatus(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -202,6 +271,9 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { | ||||||
|     check_inconsistent_timestamps_ = options.check_inconsistent_timestamps(); |     check_inconsistent_timestamps_ = options.check_inconsistent_timestamps(); | ||||||
|     sample_buffer_.resize(num_channels_, Eigen::NoChange); |     sample_buffer_.resize(num_channels_, Eigen::NoChange); | ||||||
|   } |   } | ||||||
|  |   padding_samples_before_ = options.padding_samples_before(); | ||||||
|  |   padding_samples_after_ = options.padding_samples_after(); | ||||||
|  |   flush_mode_ = options.flush_mode(); | ||||||
| 
 | 
 | ||||||
|   RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ |   RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^ | ||||||
|             !kAudioIn(cc).Header().IsEmpty()) |             !kAudioIn(cc).Header().IsEmpty()) | ||||||
|  | @ -217,6 +289,25 @@ absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) { | ||||||
|       source_sample_rate_ = input_header.sample_rate(); |       source_sample_rate_ = input_header.sample_rate(); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |   AppendZerosToSampleBuffer(padding_samples_before_); | ||||||
|  |   if (options.has_fft_size()) { | ||||||
|  |     RET_CHECK(IsValidFftSize(options.fft_size())) | ||||||
|  |         << "FFT size must be of the form fft_size = (2^a)*(3^b)*(5^c) where b " | ||||||
|  |            ">=0 and c >= 0 and a >= 5, the requested fft size is " | ||||||
|  |         << options.fft_size(); | ||||||
|  |     RET_CHECK_EQ(1, num_channels_) | ||||||
|  |         << "Currently only support applying FFT on mono channel."; | ||||||
|  |     fft_size_ = options.fft_size(); | ||||||
|  |     fft_state_ = pffft_new_setup(fft_size_, PFFFT_REAL); | ||||||
|  |     fft_window_ = HannWindow(fft_size_, /* sqrt_hann = */ false); | ||||||
|  |     fft_input_buffer_.resize(fft_size_); | ||||||
|  |     fft_workplace_.resize(fft_size_); | ||||||
|  |     fft_output_.resize(fft_size_); | ||||||
|  |   } else { | ||||||
|  |     RET_CHECK(!kDcAndNyquistOut(cc).IsConnected()) | ||||||
|  |         << "The DC_AND_NYQUIST output stream can only be connected when the " | ||||||
|  |            "calculator outputs fft tensors"; | ||||||
|  |   } | ||||||
|   return absl::OkStatus(); |   return absl::OkStatus(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -262,7 +353,12 @@ absl::Status AudioToTensorCalculator::Close(CalculatorContext* cc) { | ||||||
|     resampler_->Flush(&resampled_buffer); |     resampler_->Flush(&resampled_buffer); | ||||||
|     AppendToSampleBuffer(std::move(resampled_buffer)); |     AppendToSampleBuffer(std::move(resampled_buffer)); | ||||||
|   } |   } | ||||||
|   return OutputTensors(sample_buffer_, /*should_flush=*/true, cc); |   AppendZerosToSampleBuffer(padding_samples_after_); | ||||||
|  |   MP_RETURN_IF_ERROR(ProcessBuffer(sample_buffer_, /*should_flush=*/true, cc)); | ||||||
|  |   if (fft_state_) { | ||||||
|  |     pffft_destroy_setup(fft_state_); | ||||||
|  |   } | ||||||
|  |   return absl::OkStatus(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| absl::Status AudioToTensorCalculator::ProcessStreamingData( | absl::Status AudioToTensorCalculator::ProcessStreamingData( | ||||||
|  | @ -303,7 +399,7 @@ absl::Status AudioToTensorCalculator::ProcessStreamingData( | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   MP_RETURN_IF_ERROR(OutputTensors(sample_buffer_, /*should_flush=*/false, cc)); |   MP_RETURN_IF_ERROR(ProcessBuffer(sample_buffer_, /*should_flush=*/false, cc)); | ||||||
|   // Removes the processed samples from the global sample buffer.
 |   // Removes the processed samples from the global sample buffer.
 | ||||||
|   sample_buffer_ = Matrix(sample_buffer_.rightCols(sample_buffer_.cols() - |   sample_buffer_ = Matrix(sample_buffer_.rightCols(sample_buffer_.cols() - | ||||||
|                                                    processed_buffer_cols_ - 1)); |                                                    processed_buffer_cols_ - 1)); | ||||||
|  | @ -323,9 +419,9 @@ absl::Status AudioToTensorCalculator::ProcessNonStreamingData( | ||||||
|         input_frame); |         input_frame); | ||||||
|     Eigen::Map<const Matrix> matrix_mapping(resampled.data(), num_channels_, |     Eigen::Map<const Matrix> matrix_mapping(resampled.data(), num_channels_, | ||||||
|                                             resampled.size() / num_channels_); |                                             resampled.size() / num_channels_); | ||||||
|     return OutputTensors(matrix_mapping, /*should_flush=*/true, cc); |     return ProcessBuffer(matrix_mapping, /*should_flush=*/true, cc); | ||||||
|   } |   } | ||||||
|   return OutputTensors(input_frame, /*should_flush=*/true, cc); |   return ProcessBuffer(input_frame, /*should_flush=*/true, cc); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| absl::Status AudioToTensorCalculator::SetupStreamingResampler( | absl::Status AudioToTensorCalculator::SetupStreamingResampler( | ||||||
|  | @ -344,6 +440,16 @@ absl::Status AudioToTensorCalculator::SetupStreamingResampler( | ||||||
|   return absl::OkStatus(); |   return absl::OkStatus(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | void AudioToTensorCalculator::AppendZerosToSampleBuffer(int num_samples) { | ||||||
|  |   CHECK_GE(num_samples, 0);  // Ensured by `UpdateContract`.
 | ||||||
|  |   if (num_samples == 0) { | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |   sample_buffer_.conservativeResize(Eigen::NoChange, | ||||||
|  |                                     sample_buffer_.cols() + num_samples); | ||||||
|  |   sample_buffer_.rightCols(num_samples).setZero(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) { | void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) { | ||||||
|   sample_buffer_.conservativeResize( |   sample_buffer_.conservativeResize( | ||||||
|       Eigen::NoChange, sample_buffer_.cols() + buffer_to_append.cols()); |       Eigen::NoChange, sample_buffer_.cols() + buffer_to_append.cols()); | ||||||
|  | @ -351,49 +457,89 @@ void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| absl::StatusOr<std::vector<Tensor>> AudioToTensorCalculator::ConvertToTensor( | absl::StatusOr<std::vector<Tensor>> AudioToTensorCalculator::ConvertToTensor( | ||||||
|     const Matrix& frame_to_convert) { |     const Matrix& block, std::vector<int> tensor_dims) { | ||||||
|   Tensor tensor(Tensor::ElementType::kFloat32, |   Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape(tensor_dims)); | ||||||
|                 Tensor::Shape({num_channels_, num_samples_})); |  | ||||||
|   auto buffer_view = tensor.GetCpuWriteView(); |   auto buffer_view = tensor.GetCpuWriteView(); | ||||||
|   if (frame_to_convert.size() < num_channels_ * num_samples_) { |   int total_size = 1; | ||||||
|  |   for (int dim : tensor_dims) { | ||||||
|  |     total_size *= dim; | ||||||
|  |   } | ||||||
|  |   if (block.size() < total_size) { | ||||||
|     std::memset(buffer_view.buffer<float>(), 0, tensor.bytes()); |     std::memset(buffer_view.buffer<float>(), 0, tensor.bytes()); | ||||||
|   } |   } | ||||||
|   std::memcpy(buffer_view.buffer<float>(), frame_to_convert.data(), |   std::memcpy(buffer_view.buffer<float>(), block.data(), | ||||||
|               frame_to_convert.size() * sizeof(float)); |               block.size() * sizeof(float)); | ||||||
|   std::vector<Tensor> tensor_vector; |   std::vector<Tensor> tensor_vector; | ||||||
|   tensor_vector.push_back(std::move(tensor)); |   tensor_vector.push_back(std::move(tensor)); | ||||||
|   return tensor_vector; |   return tensor_vector; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| absl::Status AudioToTensorCalculator::OutputTensors(const Matrix& buffer, | absl::Status AudioToTensorCalculator::OutputTensor(const Matrix& block, | ||||||
|  |                                                    Timestamp timestamp, | ||||||
|  |                                                    CalculatorContext* cc) { | ||||||
|  |   std::vector<Tensor> output_tensor; | ||||||
|  |   if (fft_state_) { | ||||||
|  |     Eigen::VectorXf time_series_data = | ||||||
|  |         Eigen::VectorXf::Map(block.data(), block.size()); | ||||||
|  |     //  Window on input audio prior to FFT.
 | ||||||
|  |     std::transform(time_series_data.begin(), time_series_data.end(), | ||||||
|  |                    fft_window_.begin(), fft_input_buffer_.begin(), | ||||||
|  |                    std::multiplies<float>()); | ||||||
|  |     pffft_transform_ordered(fft_state_, fft_input_buffer_.data(), | ||||||
|  |                             fft_output_.data(), fft_workplace_.data(), | ||||||
|  |                             PFFFT_FORWARD); | ||||||
|  |     if (kDcAndNyquistOut(cc).IsConnected()) { | ||||||
|  |       kDcAndNyquistOut(cc).Send(std::make_pair(fft_output_[0], fft_output_[1]), | ||||||
|  |                                 timestamp); | ||||||
|  |     } | ||||||
|  |     Matrix fft_output_matrix = | ||||||
|  |         Eigen::Map<const Matrix>(fft_output_.data() + 2, 1, fft_size_ - 2); | ||||||
|  |     fft_output_matrix.conservativeResize(Eigen::NoChange, fft_size_); | ||||||
|  |     // The last two elements are the DFT Nyquist values.
 | ||||||
|  |     fft_output_matrix(fft_size_ - 2) = fft_output_[1];  // Nyquist real part
 | ||||||
|  |     fft_output_matrix(fft_size_ - 1) = 0.0f;            // Nyquist imagery part
 | ||||||
|  |     ASSIGN_OR_RETURN(output_tensor, | ||||||
|  |                      ConvertToTensor(fft_output_matrix, {2, fft_size_ / 2})); | ||||||
|  |   } else { | ||||||
|  |     ASSIGN_OR_RETURN(output_tensor, | ||||||
|  |                      ConvertToTensor(block, {num_channels_, num_samples_})); | ||||||
|  |   } | ||||||
|  |   kTensorsOut(cc).Send(std::move(output_tensor), timestamp); | ||||||
|  |   return absl::OkStatus(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | absl::Status AudioToTensorCalculator::ProcessBuffer(const Matrix& buffer, | ||||||
|                                                     bool should_flush, |                                                     bool should_flush, | ||||||
|                                                     CalculatorContext* cc) { |                                                     CalculatorContext* cc) { | ||||||
|  |   const bool should_flush_at_timestamp_max = | ||||||
|  |       stream_mode_ && should_flush && | ||||||
|  |       flush_mode_ == Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX; | ||||||
|   int next_frame_first_col = 0; |   int next_frame_first_col = 0; | ||||||
|   std::vector<Timestamp> timestamps; |   std::vector<Timestamp> timestamps; | ||||||
|   while ((!stream_mode_ || !should_flush) && |   if (!should_flush_at_timestamp_max) { | ||||||
|          next_frame_first_col + num_samples_ <= buffer.cols()) { |     while (next_frame_first_col + num_samples_ <= buffer.cols()) { | ||||||
|     ASSIGN_OR_RETURN(auto output_tensor, ConvertToTensor(buffer.block( |       MP_RETURN_IF_ERROR(OutputTensor( | ||||||
|                                              0, next_frame_first_col, |           buffer.block(0, next_frame_first_col, num_channels_, num_samples_), | ||||||
|                                              num_channels_, num_samples_))); |           next_output_timestamp_, cc)); | ||||||
|     kTensorsOut(cc).Send(std::move(output_tensor), next_output_timestamp_); |  | ||||||
|       timestamps.push_back(next_output_timestamp_); |       timestamps.push_back(next_output_timestamp_); | ||||||
|       next_output_timestamp_ += round(frame_step_ / target_sample_rate_ * |       next_output_timestamp_ += round(frame_step_ / target_sample_rate_ * | ||||||
|                                       Timestamp::kTimestampUnitsPerSecond); |                                       Timestamp::kTimestampUnitsPerSecond); | ||||||
|       next_frame_first_col += frame_step_; |       next_frame_first_col += frame_step_; | ||||||
|     } |     } | ||||||
|  |   } | ||||||
|   if (should_flush && next_frame_first_col < buffer.cols()) { |   if (should_flush && next_frame_first_col < buffer.cols()) { | ||||||
|     ASSIGN_OR_RETURN(auto output_tensor, |  | ||||||
|                      ConvertToTensor(buffer.block( |  | ||||||
|                          0, next_frame_first_col, num_channels_, |  | ||||||
|                          std::min(num_samples_, |  | ||||||
|                                   (int)buffer.cols() - next_frame_first_col)))); |  | ||||||
|     // In the streaming mode, the flush happens in Close() and a packet at
 |     // In the streaming mode, the flush happens in Close() and a packet at
 | ||||||
|     // Timestamp::Max() will be emitted. In the non-streaming mode, each
 |     // Timestamp::Max() will be emitted. In the non-streaming mode, each
 | ||||||
|     // Process() invocation will process the entire buffer completely.
 |     // Process() invocation will process the entire buffer completely.
 | ||||||
|     Timestamp timestamp = |     Timestamp timestamp = should_flush_at_timestamp_max | ||||||
|         stream_mode_ ? Timestamp::Max() : next_output_timestamp_; |                               ? Timestamp::Max() | ||||||
|  |                               : next_output_timestamp_; | ||||||
|  |     MP_RETURN_IF_ERROR(OutputTensor( | ||||||
|  |         buffer.block( | ||||||
|  |             0, next_frame_first_col, num_channels_, | ||||||
|  |             std::min(num_samples_, (int)buffer.cols() - next_frame_first_col)), | ||||||
|  |         timestamp, cc)); | ||||||
|     timestamps.push_back(timestamp); |     timestamps.push_back(timestamp); | ||||||
|     kTensorsOut(cc).Send(std::move(output_tensor), timestamp); |  | ||||||
|   } |   } | ||||||
|   if (kTimestampsOut(cc).IsConnected()) { |   if (kTimestampsOut(cc).IsConnected()) { | ||||||
|     Timestamp timestamp = timestamps.back(); |     Timestamp timestamp = timestamps.back(); | ||||||
|  |  | ||||||
|  | @ -44,4 +44,28 @@ message AudioToTensorCalculatorOptions { | ||||||
|   // Set to false to disable checks for jitter in timestamp values. Useful with |   // Set to false to disable checks for jitter in timestamp values. Useful with | ||||||
|   // live audio input. |   // live audio input. | ||||||
|   optional bool check_inconsistent_timestamps = 6 [default = true]; |   optional bool check_inconsistent_timestamps = 6 [default = true]; | ||||||
|  | 
 | ||||||
|  |   // Size of the fft in number of bins. If set, the calculator outputs fft | ||||||
|  |   // tensors. | ||||||
|  |   optional int64 fft_size = 7; | ||||||
|  | 
 | ||||||
|  |   // The amount of padding samples to add before the audio after resampling. | ||||||
|  |   // Note that the timestamps shift. Currently, only zero padding is supported. | ||||||
|  |   optional int64 padding_samples_before = 8; | ||||||
|  | 
 | ||||||
|  |   // The amount of padding samples to add after the audio after resampling. | ||||||
|  |   // Currently, only zero padding is supported. | ||||||
|  |   optional int64 padding_samples_after = 9; | ||||||
|  | 
 | ||||||
|  |   // Determines the "flushing" behavior in stream mode. | ||||||
|  |   enum FlushMode { | ||||||
|  |     // Unspecified (causes an error). Won't be used because of the default. | ||||||
|  |     NONE = 0; | ||||||
|  |     // Emit a packet with the entire remainder at `Timestamp::Max`. | ||||||
|  |     ENTIRE_TAIL_AT_TIMESTAMP_MAX = 1; | ||||||
|  |     // Continue emitting framed packets with relevant timestamps. | ||||||
|  |     PROCEED_AS_USUAL = 2; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   optional FlushMode flush_mode = 10 [default = ENTIRE_TAIL_AT_TIMESTAMP_MAX]; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -12,13 +12,13 @@ | ||||||
| // See the License for the specific language governing permissions and
 | // See the License for the specific language governing permissions and
 | ||||||
| // limitations under the License.
 | // limitations under the License.
 | ||||||
| 
 | 
 | ||||||
| #include <cmath> |  | ||||||
| #include <memory> | #include <memory> | ||||||
| #include <string> | #include <string> | ||||||
| #include <vector> | #include <vector> | ||||||
| 
 | 
 | ||||||
| #include "absl/strings/substitute.h" | #include "absl/strings/substitute.h" | ||||||
| #include "audio/dsp/resampler_q.h" | #include "audio/dsp/resampler_q.h" | ||||||
|  | #include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h" | ||||||
| #include "mediapipe/framework/api2/packet.h" | #include "mediapipe/framework/api2/packet.h" | ||||||
| #include "mediapipe/framework/calculator.pb.h" | #include "mediapipe/framework/calculator.pb.h" | ||||||
| #include "mediapipe/framework/calculator_framework.h" | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | @ -32,6 +32,14 @@ | ||||||
| namespace mediapipe { | namespace mediapipe { | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
|  | using ::testing::Not; | ||||||
|  | using Options = ::mediapipe::AudioToTensorCalculatorOptions; | ||||||
|  | using FlushMode = Options::FlushMode; | ||||||
|  | 
 | ||||||
|  | int DivideRoundedUp(int dividend, int divisor) { | ||||||
|  |   return (dividend + divisor - 1) / divisor; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| std::unique_ptr<Matrix> CreateTestMatrix(int num_channels, int num_samples, | std::unique_ptr<Matrix> CreateTestMatrix(int num_channels, int num_samples, | ||||||
|                                          int timestamp) { |                                          int timestamp) { | ||||||
|   auto matrix = std::make_unique<Matrix>(num_channels, num_samples); |   auto matrix = std::make_unique<Matrix>(num_channels, num_samples); | ||||||
|  | @ -292,16 +300,17 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { | ||||||
|     num_iterations_ = num_iterations; |     num_iterations_ = num_iterations; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   int GetExpectedNumOfSamples() { |   int GetExpectedNumOfSamples() { return output_sample_buffer_->cols(); } | ||||||
|     Matrix* expected_matrix = |  | ||||||
|         resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get(); |  | ||||||
|     return expected_matrix->cols(); |  | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   void Run(int num_samples, int num_overlapping_samples, |   void Run(int num_samples, int num_overlapping_samples, | ||||||
|            double resampling_factor) { |            double resampling_factor, int padding_before = 0, | ||||||
|  |            int padding_after = 0, bool expect_init_error = false) { | ||||||
|     double input_sample_rate = 10000; |     double input_sample_rate = 10000; | ||||||
|     double target_sample_rate = input_sample_rate * resampling_factor; |     double target_sample_rate = input_sample_rate * resampling_factor; | ||||||
|  |     FlushMode flush_mode = (padding_before != 0 || padding_after != 0) | ||||||
|  |                                ? Options::PROCEED_AS_USUAL | ||||||
|  |                                : Options::ENTIRE_TAIL_AT_TIMESTAMP_MAX; | ||||||
|  | 
 | ||||||
|     auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>( |     auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>( | ||||||
|         absl::Substitute(R"( |         absl::Substitute(R"( | ||||||
|         input_stream: "audio" |         input_stream: "audio" | ||||||
|  | @ -319,16 +328,25 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { | ||||||
|               num_overlapping_samples: $1 |               num_overlapping_samples: $1 | ||||||
|               target_sample_rate: $2 |               target_sample_rate: $2 | ||||||
|               stream_mode:true |               stream_mode:true | ||||||
|  |               padding_samples_before: $3 | ||||||
|  |               padding_samples_after: $4 | ||||||
|  |               flush_mode: $5 | ||||||
|             } |             } | ||||||
|           } |           } | ||||||
|         } |         } | ||||||
|         )", |         )", | ||||||
|                          /*$0=*/num_samples, /*$1=*/num_overlapping_samples, |                          /*$0=*/num_samples, /*$1=*/num_overlapping_samples, | ||||||
|                          /*$2=*/target_sample_rate)); |                          /*$2=*/target_sample_rate, /*$3=*/padding_before, | ||||||
|  |                          /*$4=*/padding_after, /*$5=*/flush_mode)); | ||||||
|     tool::AddVectorSink("tensors", &graph_config, &tensors_packets_); |     tool::AddVectorSink("tensors", &graph_config, &tensors_packets_); | ||||||
| 
 | 
 | ||||||
|     // Run the graph.
 |     // Run the graph.
 | ||||||
|     MP_ASSERT_OK(graph_.Initialize(graph_config)); |     const absl::Status init_status = graph_.Initialize(graph_config); | ||||||
|  |     if (expect_init_error) { | ||||||
|  |       EXPECT_THAT(init_status, Not(IsOk())); | ||||||
|  |       return; | ||||||
|  |     } | ||||||
|  |     MP_ASSERT_OK(init_status); | ||||||
|     MP_ASSERT_OK(graph_.StartRun({})); |     MP_ASSERT_OK(graph_.StartRun({})); | ||||||
|     for (int i = 0; i < num_iterations_; ++i) { |     for (int i = 0; i < num_iterations_; ++i) { | ||||||
|       Timestamp input_timestamp(Timestamp::kTimestampUnitsPerSecond * i); |       Timestamp input_timestamp(Timestamp::kTimestampUnitsPerSecond * i); | ||||||
|  | @ -345,8 +363,18 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { | ||||||
|     } |     } | ||||||
|     MP_ASSERT_OK(graph_.CloseAllInputStreams()); |     MP_ASSERT_OK(graph_.CloseAllInputStreams()); | ||||||
|     MP_ASSERT_OK(graph_.WaitUntilIdle()); |     MP_ASSERT_OK(graph_.WaitUntilIdle()); | ||||||
|     if (resampling_factor != 1) { |     if (resampling_factor == 1) { | ||||||
|       resampled_buffer_ = ResampleBuffer(*sample_buffer_, resampling_factor); |       output_sample_buffer_ = std::make_unique<Matrix>(*sample_buffer_); | ||||||
|  |     } else { | ||||||
|  |       output_sample_buffer_ = | ||||||
|  |           ResampleBuffer(*sample_buffer_, resampling_factor); | ||||||
|  |     } | ||||||
|  |     if (padding_before != 0 || padding_after != 0) { | ||||||
|  |       Matrix padded = Matrix::Zero( | ||||||
|  |           2, padding_before + output_sample_buffer_->cols() + padding_after); | ||||||
|  |       padded.block(0, padding_before, 2, output_sample_buffer_->cols()) = | ||||||
|  |           *output_sample_buffer_; | ||||||
|  |       output_sample_buffer_->swap(padded); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -372,14 +400,12 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { | ||||||
|     auto buffer = output_tensor.GetCpuReadView().buffer<float>(); |     auto buffer = output_tensor.GetCpuReadView().buffer<float>(); | ||||||
|     int num_values = output_tensor.shape().num_elements(); |     int num_values = output_tensor.shape().num_elements(); | ||||||
|     std::vector<float> output_floats(buffer, buffer + num_values); |     std::vector<float> output_floats(buffer, buffer + num_values); | ||||||
|     Matrix* expected_matrix = |  | ||||||
|         resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get(); |  | ||||||
|     for (int i = 0; i < num_values; ++i) { |     for (int i = 0; i < num_values; ++i) { | ||||||
|       if (i + sample_offset >= expected_matrix->size()) { |       if (i + sample_offset >= output_sample_buffer_->size()) { | ||||||
|         EXPECT_FLOAT_EQ(output_floats[i], 0); |         EXPECT_FLOAT_EQ(output_floats[i], 0); | ||||||
|       } else { |       } else { | ||||||
|         EXPECT_NEAR(output_floats[i], |         EXPECT_NEAR(output_floats[i], | ||||||
|                     expected_matrix->coeff((i + sample_offset) % 2, |                     output_sample_buffer_->coeff((i + sample_offset) % 2, | ||||||
|                                                  (i + sample_offset) / 2), |                                                  (i + sample_offset) / 2), | ||||||
|                     0.001) |                     0.001) | ||||||
|             << "i=" << i << ", sample_offset=" << sample_offset |             << "i=" << i << ", sample_offset=" << sample_offset | ||||||
|  | @ -391,7 +417,8 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { | ||||||
| 
 | 
 | ||||||
|   // Fully close graph at end, otherwise calculator+tensors are destroyed
 |   // Fully close graph at end, otherwise calculator+tensors are destroyed
 | ||||||
|   // after calling WaitUntilDone().
 |   // after calling WaitUntilDone().
 | ||||||
|   void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); } |   absl::Status TryCloseGraph() { return graph_.WaitUntilDone(); } | ||||||
|  |   void CloseGraph() { MP_EXPECT_OK(TryCloseGraph()); } | ||||||
| 
 | 
 | ||||||
|  private: |  private: | ||||||
|   int input_buffer_num_samples_ = 10; |   int input_buffer_num_samples_ = 10; | ||||||
|  | @ -399,7 +426,7 @@ class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test { | ||||||
|   CalculatorGraph graph_; |   CalculatorGraph graph_; | ||||||
|   std::vector<Packet> tensors_packets_; |   std::vector<Packet> tensors_packets_; | ||||||
|   std::unique_ptr<Matrix> sample_buffer_; |   std::unique_ptr<Matrix> sample_buffer_; | ||||||
|   std::unique_ptr<Matrix> resampled_buffer_; |   std::unique_ptr<Matrix> output_sample_buffer_; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| TEST_F(AudioToTensorCalculatorStreamingModeTest, | TEST_F(AudioToTensorCalculatorStreamingModeTest, | ||||||
|  | @ -408,7 +435,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, | ||||||
|       /*resampling_factor=*/1.0f); |       /*resampling_factor=*/1.0f); | ||||||
|   CheckTensorsOutputPackets( |   CheckTensorsOutputPackets( | ||||||
|       /*sample_offset=*/10, |       /*sample_offset=*/10, | ||||||
|       /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 5), |       /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 5), | ||||||
|       /*timestamp_interval=*/500, |       /*timestamp_interval=*/500, | ||||||
|       /*output_last_at_close=*/false); |       /*output_last_at_close=*/false); | ||||||
|   CloseGraph(); |   CloseGraph(); | ||||||
|  | @ -419,7 +446,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputRemainingInCloseMethod) { | ||||||
|       /*resampling_factor=*/1.0f); |       /*resampling_factor=*/1.0f); | ||||||
|   CheckTensorsOutputPackets( |   CheckTensorsOutputPackets( | ||||||
|       /*sample_offset=*/12, |       /*sample_offset=*/12, | ||||||
|       /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 6), |       /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 6), | ||||||
|       /*timestamp_interval=*/600, |       /*timestamp_interval=*/600, | ||||||
|       /*output_last_at_close=*/true); |       /*output_last_at_close=*/true); | ||||||
|   CloseGraph(); |   CloseGraph(); | ||||||
|  | @ -431,7 +458,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputOverlappingFp32Tensors) { | ||||||
|       /*resampling_factor=*/1.0f); |       /*resampling_factor=*/1.0f); | ||||||
|   CheckTensorsOutputPackets( |   CheckTensorsOutputPackets( | ||||||
|       /*sample_offset=*/16, |       /*sample_offset=*/16, | ||||||
|       /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 8), |       /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 8), | ||||||
|       /*timestamp_interval=*/800, |       /*timestamp_interval=*/800, | ||||||
|       /*output_last_at_close=*/true); |       /*output_last_at_close=*/true); | ||||||
|   CloseGraph(); |   CloseGraph(); | ||||||
|  | @ -443,7 +470,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, Downsampling) { | ||||||
|       /*resampling_factor=*/0.5f); |       /*resampling_factor=*/0.5f); | ||||||
|   CheckTensorsOutputPackets( |   CheckTensorsOutputPackets( | ||||||
|       /*sample_offset=*/512, |       /*sample_offset=*/512, | ||||||
|       /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256), |       /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 256), | ||||||
|       /*timestamp_interval=*/51200, |       /*timestamp_interval=*/51200, | ||||||
|       /*output_last_at_close=*/true); |       /*output_last_at_close=*/true); | ||||||
|   CloseGraph(); |   CloseGraph(); | ||||||
|  | @ -455,7 +482,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, DownsamplingWithOverlapping) { | ||||||
|       /*resampling_factor=*/0.5f); |       /*resampling_factor=*/0.5f); | ||||||
|   CheckTensorsOutputPackets( |   CheckTensorsOutputPackets( | ||||||
|       /*sample_offset=*/384, |       /*sample_offset=*/384, | ||||||
|       /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192), |       /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192), | ||||||
|       /*timestamp_interval=*/38400, |       /*timestamp_interval=*/38400, | ||||||
|       /*output_last_at_close=*/true); |       /*output_last_at_close=*/true); | ||||||
|   CloseGraph(); |   CloseGraph(); | ||||||
|  | @ -467,7 +494,7 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, Upsampling) { | ||||||
|       /*resampling_factor=*/2.0f); |       /*resampling_factor=*/2.0f); | ||||||
|   CheckTensorsOutputPackets( |   CheckTensorsOutputPackets( | ||||||
|       /*sample_offset=*/512, |       /*sample_offset=*/512, | ||||||
|       /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256), |       /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 256), | ||||||
|       /*timestamp_interval=*/12800, |       /*timestamp_interval=*/12800, | ||||||
|       /*output_last_at_close=*/true); |       /*output_last_at_close=*/true); | ||||||
|   CloseGraph(); |   CloseGraph(); | ||||||
|  | @ -479,12 +506,33 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, UpsamplingWithOverlapping) { | ||||||
|       /*resampling_factor=*/2.0f); |       /*resampling_factor=*/2.0f); | ||||||
|   CheckTensorsOutputPackets( |   CheckTensorsOutputPackets( | ||||||
|       /*sample_offset=*/384, |       /*sample_offset=*/384, | ||||||
|       /*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192), |       /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192), | ||||||
|       /*timestamp_interval=*/9600, |       /*timestamp_interval=*/9600, | ||||||
|       /*output_last_at_close=*/true); |       /*output_last_at_close=*/true); | ||||||
|   CloseGraph(); |   CloseGraph(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(AudioToTensorCalculatorStreamingModeTest, | ||||||
|  |        UpsamplingWithOverlappingAndPadding) { | ||||||
|  |   SetInputBufferNumSamplesPerChannel(1024); | ||||||
|  |   Run(/*num_samples=*/256, /*num_overlapping_samples=*/64, | ||||||
|  |       /*resampling_factor=*/2.0f, /*padding_before=*/13, /*padding_after=*/999); | ||||||
|  |   CheckTensorsOutputPackets( | ||||||
|  |       /*sample_offset=*/384, | ||||||
|  |       /*num_packets=*/DivideRoundedUp(GetExpectedNumOfSamples(), 192), | ||||||
|  |       /*timestamp_interval=*/9600, | ||||||
|  |       /*output_last_at_close=*/false); | ||||||
|  |   CloseGraph(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST_F(AudioToTensorCalculatorStreamingModeTest, NegativePaddingUnsupported) { | ||||||
|  |   SetInputBufferNumSamplesPerChannel(1024); | ||||||
|  |   Run(/*num_samples=*/256, /*num_overlapping_samples=*/64, | ||||||
|  |       /*resampling_factor=*/2.0f, /*padding_before=*/13, /*padding_after=*/-3, | ||||||
|  |       /*expect_init_error=*/true); | ||||||
|  |   EXPECT_THAT(TryCloseGraph(), Not(IsOk())); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| TEST_F(AudioToTensorCalculatorStreamingModeTest, | TEST_F(AudioToTensorCalculatorStreamingModeTest, | ||||||
|        OnlyOutputInCloseIfNoSufficientSamples) { |        OnlyOutputInCloseIfNoSufficientSamples) { | ||||||
|   SetNumIterations(1); |   SetNumIterations(1); | ||||||
|  | @ -498,5 +546,122 @@ TEST_F(AudioToTensorCalculatorStreamingModeTest, | ||||||
|   CloseGraph(); |   CloseGraph(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | class AudioToTensorCalculatorFftTest : public ::testing::Test { | ||||||
|  |  protected: | ||||||
|  |   // Creates an audio matrix containing a single sample of 1.0 at a specified
 | ||||||
|  |   // offset.
 | ||||||
|  |   std::unique_ptr<Matrix> CreateImpulseSignalData(int64 num_samples, | ||||||
|  |                                                   int impulse_offset_idx) { | ||||||
|  |     Matrix impulse = Matrix::Zero(1, num_samples); | ||||||
|  |     impulse(0, impulse_offset_idx) = 1.0; | ||||||
|  |     return std::make_unique<Matrix>(std::move(impulse)); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   void ConfigGraph(int num_channels, int num_samples, | ||||||
|  |                    int num_overlapping_samples, double sample_rate, | ||||||
|  |                    int fft_size) { | ||||||
|  |     graph_config_ = ParseTextProtoOrDie<CalculatorGraphConfig>( | ||||||
|  |         absl::Substitute(R"( | ||||||
|  |         input_stream: "audio" | ||||||
|  |         input_stream: "sample_rate" | ||||||
|  |         output_stream: "tensors" | ||||||
|  |         output_stream: "dc_and_nyquist" | ||||||
|  |         node { | ||||||
|  |           calculator: "AudioToTensorCalculator" | ||||||
|  |           input_stream: "AUDIO:audio" | ||||||
|  |           input_stream: "SAMPLE_RATE:sample_rate" | ||||||
|  |           output_stream: "TENSORS:tensors" | ||||||
|  |           output_stream: "DC_AND_NYQUIST:dc_and_nyquist" | ||||||
|  |           options { | ||||||
|  |             [mediapipe.AudioToTensorCalculatorOptions.ext] { | ||||||
|  |               num_channels: $0 | ||||||
|  |               num_samples: $1 | ||||||
|  |               num_overlapping_samples: $2 | ||||||
|  |               target_sample_rate: $3 | ||||||
|  |               fft_size: $4 | ||||||
|  |             } | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |         )", | ||||||
|  |                          /*$0=*/num_channels, | ||||||
|  |                          /*$1=*/num_samples, | ||||||
|  |                          /*$2=*/num_overlapping_samples, | ||||||
|  |                          /*$3=*/sample_rate, /*$4=*/fft_size)); | ||||||
|  |     std::vector<Packet> tensors_packets; | ||||||
|  |     tool::AddVectorSink("tensors", &graph_config_, &tensors_packets_); | ||||||
|  |     std::vector<Packet> dc_and_nyquist_packets; | ||||||
|  |     tool::AddVectorSink("dc_and_nyquist", &graph_config_, | ||||||
|  |                         &dc_and_nyquist_packets_); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   void RunGraph(std::unique_ptr<Matrix> input_data, double sample_rate) { | ||||||
|  |     MP_ASSERT_OK(graph_.Initialize(graph_config_)); | ||||||
|  |     MP_ASSERT_OK(graph_.StartRun({})); | ||||||
|  |     MP_ASSERT_OK(graph_.AddPacketToInputStream( | ||||||
|  |         "sample_rate", MakePacket<double>(sample_rate).At(Timestamp(0)))); | ||||||
|  |     MP_ASSERT_OK(graph_.AddPacketToInputStream( | ||||||
|  |         "audio", MakePacket<Matrix>(*input_data).At(Timestamp(0)))); | ||||||
|  |     MP_ASSERT_OK(graph_.CloseAllInputStreams()); | ||||||
|  |     MP_ASSERT_OK(graph_.WaitUntilIdle()); | ||||||
|  |     ASSERT_EQ(tensors_packets_.size(), dc_and_nyquist_packets_.size()); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Fully close graph at end, otherwise calculator+tensors are destroyed
 | ||||||
|  |   // after calling WaitUntilDone().
 | ||||||
|  |   void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); } | ||||||
|  | 
 | ||||||
|  |   std::vector<Packet> tensors_packets_; | ||||||
|  |   std::vector<Packet> dc_and_nyquist_packets_; | ||||||
|  |   CalculatorGraphConfig graph_config_; | ||||||
|  |   CalculatorGraph graph_; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | TEST_F(AudioToTensorCalculatorFftTest, TestInvalidFftSize) { | ||||||
|  |   ConfigGraph(1, 320, 160, 16000, 103); | ||||||
|  |   MP_ASSERT_OK(graph_.Initialize(graph_config_)); | ||||||
|  |   MP_ASSERT_OK(graph_.StartRun({})); | ||||||
|  |   auto status = graph_.WaitUntilIdle(); | ||||||
|  |   EXPECT_EQ(status.code(), absl::StatusCode::kInternal); | ||||||
|  |   EXPECT_THAT(status.message(), | ||||||
|  |               ::testing::HasSubstr("FFT size must be of the form")); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST_F(AudioToTensorCalculatorFftTest, TestInvalidNumChannels) { | ||||||
|  |   ConfigGraph(3, 320, 160, 16000, 256); | ||||||
|  |   MP_ASSERT_OK(graph_.Initialize(graph_config_)); | ||||||
|  |   MP_ASSERT_OK(graph_.StartRun({})); | ||||||
|  |   auto status = graph_.WaitUntilIdle(); | ||||||
|  |   EXPECT_EQ(status.code(), absl::StatusCode::kInternal); | ||||||
|  |   EXPECT_THAT( | ||||||
|  |       status.message(), | ||||||
|  |       ::testing::HasSubstr("only support applying FFT on mono channel")); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST_F(AudioToTensorCalculatorFftTest, TestImpulseSignal) { | ||||||
|  |   constexpr double sample_rate = 16000; | ||||||
|  |   ConfigGraph(1, 320, 160, sample_rate, 320); | ||||||
|  |   RunGraph(CreateImpulseSignalData(320, 160), sample_rate); | ||||||
|  |   for (int i = 0; i < tensors_packets_.size(); ++i) { | ||||||
|  |     const auto& tensors = tensors_packets_[i].Get<std::vector<Tensor>>(); | ||||||
|  |     ASSERT_EQ(1, tensors.size()); | ||||||
|  |     const Tensor& output_tensor = | ||||||
|  |         tensors_packets_[0].Get<std::vector<Tensor>>()[0]; | ||||||
|  |     auto* buffer = output_tensor.GetCpuReadView().buffer<float>(); | ||||||
|  |     int num_values = output_tensor.shape().num_elements(); | ||||||
|  |     const std::vector<float> output_floats(buffer, buffer + num_values); | ||||||
|  |     // Impulse signal should have (approximately) const power across all
 | ||||||
|  |     // frequency bins.
 | ||||||
|  |     const auto& pair = | ||||||
|  |         dc_and_nyquist_packets_[i].Get<std::pair<float, float>>(); | ||||||
|  |     EXPECT_FLOAT_EQ(pair.first, 1.0f); | ||||||
|  |     EXPECT_FLOAT_EQ(pair.second, 1.0f); | ||||||
|  |     for (int j = 0; j < num_values / 2; ++j) { | ||||||
|  |       std::complex<float> cf(output_floats[j * 2], output_floats[j * 2 + 1]); | ||||||
|  |       EXPECT_FLOAT_EQ(std::norm(cf), 1.0f); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   CloseGraph(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| }  // namespace mediapipe
 | }  // namespace mediapipe
 | ||||||
|  |  | ||||||
							
								
								
									
										165
									
								
								mediapipe/calculators/tensor/feedback_tensors_calculator.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										165
									
								
								mediapipe/calculators/tensor/feedback_tensors_calculator.cc
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,165 @@ | ||||||
|  | // Copyright 2022 The MediaPipe Authors.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #include <algorithm> | ||||||
|  | #include <memory> | ||||||
|  | #include <utility> | ||||||
|  | 
 | ||||||
|  | #include "absl/status/status.h" | ||||||
|  | #include "mediapipe/calculators/tensor/feedback_tensors_calculator.pb.h" | ||||||
|  | #include "mediapipe/framework/api2/node.h" | ||||||
|  | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | #include "mediapipe/framework/formats/tensor.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace api2 { | ||||||
|  | 
 | ||||||
|  | namespace { | ||||||
|  | constexpr char kInputTensorsTag[] = "INPUT_TENSORS"; | ||||||
|  | constexpr char kFeedbackTensorsTag[] = "FEEDBACK_TENSORS"; | ||||||
|  | constexpr char kOutputTensorsTag[] = "TENSORS"; | ||||||
|  | 
 | ||||||
|  | using Tensors = std::vector<Tensor>; | ||||||
|  | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | // FeedbackTensorsCalculator groups the input and the feedback (typically
 | ||||||
|  | // recurrent neural network cell state output tensors from the previous run)
 | ||||||
|  | // tensor vectors as the input tensor vector for the next recurrent model cell
 | ||||||
|  | // inference. On the first step, the feedback tensor is filled with zeros to
 | ||||||
|  | // jumpstart the loop.
 | ||||||
|  | class FeedbackTensorsCalculator : public Node { | ||||||
|  |  public: | ||||||
|  |   static constexpr Input<Tensors> kFeedbackTensorsIn{kFeedbackTensorsTag}; | ||||||
|  |   static constexpr Input<Tensors> kInputTensorsIn{kInputTensorsTag}; | ||||||
|  |   static constexpr Output<Tensors> kTensorsOut{kOutputTensorsTag}; | ||||||
|  | 
 | ||||||
|  |   MEDIAPIPE_NODE_CONTRACT(kFeedbackTensorsIn, kInputTensorsIn, kTensorsOut); | ||||||
|  | 
 | ||||||
|  |   static absl::Status GetContract(CalculatorContract* cc) { | ||||||
|  |     cc->SetProcessTimestampBounds(true); | ||||||
|  |     return absl::OkStatus(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   absl::Status Open(CalculatorContext* cc) override { | ||||||
|  |     const auto& options = | ||||||
|  |         cc->Options<mediapipe::FeedbackTensorsCalculatorOptions>(); | ||||||
|  | 
 | ||||||
|  |     const auto& shape_dims = options.feedback_tensor_shape().dims(); | ||||||
|  |     feedback_tensor_shape_.dims.assign(shape_dims.begin(), shape_dims.end()); | ||||||
|  |     feedback_tensor_size_ = feedback_tensor_shape_.num_elements(); | ||||||
|  | 
 | ||||||
|  |     num_feedback_tensors_ = options.num_feedback_tensors(); | ||||||
|  | 
 | ||||||
|  |     feedback_tensors_location_ = options.location(); | ||||||
|  | 
 | ||||||
|  |     return absl::OkStatus(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   absl::Status Process(CalculatorContext* cc) override { | ||||||
|  |     if (feedback_tensors_location_ == | ||||||
|  |         mediapipe::FeedbackTensorsCalculatorOptions::NONE) { | ||||||
|  |       kTensorsOut(cc).Send(kInputTensorsIn(cc).packet().As<Tensors>()); | ||||||
|  |       return absl::OkStatus(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     std::vector<Tensor> outputs; | ||||||
|  |     switch (feedback_tensors_location_) { | ||||||
|  |       case mediapipe::FeedbackTensorsCalculatorOptions::PREPENDED: | ||||||
|  |         MP_RETURN_IF_ERROR(AddFeedbackTensors(cc, outputs)); | ||||||
|  |         MP_RETURN_IF_ERROR(AddInputTensors(cc, outputs)); | ||||||
|  |         break; | ||||||
|  |       case mediapipe::FeedbackTensorsCalculatorOptions::APPENDED: | ||||||
|  |         MP_RETURN_IF_ERROR(AddInputTensors(cc, outputs)); | ||||||
|  |         MP_RETURN_IF_ERROR(AddFeedbackTensors(cc, outputs)); | ||||||
|  |         break; | ||||||
|  |       default: | ||||||
|  |         return absl::InvalidArgumentError( | ||||||
|  |             "Unsupported feedback tensors location"); | ||||||
|  |     } | ||||||
|  |     kTensorsOut(cc).Send(std::move(outputs)); | ||||||
|  |     return absl::OkStatus(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   absl::Status AddInputTensors(CalculatorContext* cc, | ||||||
|  |                                std::vector<Tensor>& outputs) { | ||||||
|  |     absl::StatusOr<std::unique_ptr<std::vector<Tensor>>> input_tensors = | ||||||
|  |         cc->Inputs() | ||||||
|  |             .Tag(kInputTensorsTag) | ||||||
|  |             .Value() | ||||||
|  |             .Consume<std::vector<Tensor>>(); | ||||||
|  |     if (!input_tensors.ok()) { | ||||||
|  |       return absl::InternalError("The input tensors packet is not consumable"); | ||||||
|  |     } | ||||||
|  |     RET_CHECK(*input_tensors); | ||||||
|  |     std::vector<Tensor>& inputs = **input_tensors; | ||||||
|  |     outputs.insert(outputs.end(), std::make_move_iterator(inputs.begin()), | ||||||
|  |                    std::make_move_iterator(inputs.end())); | ||||||
|  |     return absl::OkStatus(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   absl::Status AddFeedbackTensors(CalculatorContext* cc, | ||||||
|  |                                   std::vector<Tensor>& outputs) { | ||||||
|  |     if (first_run_) { | ||||||
|  |       for (int index = 0; index < num_feedback_tensors_; ++index) { | ||||||
|  |         Tensor initial_feedback_tensor(Tensor::ElementType::kFloat32, | ||||||
|  |                                        feedback_tensor_shape_); | ||||||
|  |         float* data = initial_feedback_tensor.GetCpuWriteView().buffer<float>(); | ||||||
|  |         std::fill_n(data, feedback_tensor_size_, 0.0f); | ||||||
|  |         outputs.push_back(std::move(initial_feedback_tensor)); | ||||||
|  |       } | ||||||
|  |       first_run_ = false; | ||||||
|  |       return absl::OkStatus(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (num_feedback_tensors_ != kFeedbackTensorsIn(cc)->size()) { | ||||||
|  |       return absl::InvalidArgumentError( | ||||||
|  |           "The number of tensors fed back differs from the configuration"); | ||||||
|  |     } | ||||||
|  |     absl::StatusOr<std::unique_ptr<std::vector<Tensor>>> feedback_tensors = | ||||||
|  |         cc->Inputs() | ||||||
|  |             .Tag(kFeedbackTensorsTag) | ||||||
|  |             .Value() | ||||||
|  |             .Consume<std::vector<Tensor>>(); | ||||||
|  |     if (!feedback_tensors.ok()) { | ||||||
|  |       return absl::InternalError( | ||||||
|  |           "The feedback tensors packet is not consumable"); | ||||||
|  |     } | ||||||
|  |     RET_CHECK(*feedback_tensors); | ||||||
|  |     std::vector<Tensor>& feedbacks = **feedback_tensors; | ||||||
|  |     for (const auto& feedback : feedbacks) { | ||||||
|  |       if (feedback.shape().dims != feedback_tensor_shape_.dims) { | ||||||
|  |         return absl::InvalidArgumentError( | ||||||
|  |             "The shape of a tensor fed back differs from the configuration"); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     outputs.insert(outputs.end(), std::make_move_iterator(feedbacks.begin()), | ||||||
|  |                    std::make_move_iterator(feedbacks.end())); | ||||||
|  | 
 | ||||||
|  |     return absl::OkStatus(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   Tensor::Shape feedback_tensor_shape_; | ||||||
|  |   int num_feedback_tensors_ = 0; | ||||||
|  |   mediapipe::FeedbackTensorsCalculatorOptions::FeedbackTensorsLocation | ||||||
|  |       feedback_tensors_location_; | ||||||
|  | 
 | ||||||
|  |   int feedback_tensor_size_ = 0; | ||||||
|  |   bool first_run_ = true; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | MEDIAPIPE_REGISTER_NODE(FeedbackTensorsCalculator); | ||||||
|  | 
 | ||||||
|  | }  // namespace api2
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
|  | @ -0,0 +1,47 @@ | ||||||
|  | // Copyright 2022 The MediaPipe Authors. | ||||||
|  | // | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | // you may not use this file except in compliance with the License. | ||||||
|  | // You may obtain a copy of the License at | ||||||
|  | // | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | // | ||||||
|  | // Unless required by applicable law or agreed to in writing, software | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | // See the License for the specific language governing permissions and | ||||||
|  | // limitations under the License. | ||||||
|  | 
 | ||||||
|  | syntax = "proto2"; | ||||||
|  | 
 | ||||||
|  | package mediapipe; | ||||||
|  | 
 | ||||||
|  | import "mediapipe/framework/calculator.proto"; | ||||||
|  | 
 | ||||||
|  | message FeedbackTensorsCalculatorOptions { | ||||||
|  |   extend mediapipe.CalculatorOptions { | ||||||
|  |     optional FeedbackTensorsCalculatorOptions ext = 474496252; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Represents the dimensions of a tensor starting from the outermost size. | ||||||
|  |   message TensorShape { | ||||||
|  |     repeated int32 dims = 1 [packed = true]; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // The shape of the feedback tensors to add. | ||||||
|  |   optional TensorShape feedback_tensor_shape = 1; | ||||||
|  |   // The number of the feedback tensors to add. | ||||||
|  |   optional int32 num_feedback_tensors = 2 [default = 1]; | ||||||
|  | 
 | ||||||
|  |   enum FeedbackTensorsLocation { | ||||||
|  |     // The feedback tensors will not be added. | ||||||
|  |     NONE = 0; | ||||||
|  |     // The feedback tensors will be added before the input tensors. | ||||||
|  |     PREPENDED = 1; | ||||||
|  |     // The feedback tensors will be added after the input tensors. | ||||||
|  |     APPENDED = 2; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Determines the location of the feedback tensor(s) in the output vector. | ||||||
|  |   optional FeedbackTensorsLocation location = 3 [default = APPENDED]; | ||||||
|  | } | ||||||
							
								
								
									
										389
									
								
								mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										389
									
								
								mediapipe/calculators/tensor/feedback_tensors_calculator_test.cc
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,389 @@ | ||||||
|  | // Copyright 2022 The MediaPipe Authors.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #include <functional> | ||||||
|  | #include <initializer_list> | ||||||
|  | #include <memory> | ||||||
|  | #include <utility> | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/calculators/tensor/feedback_tensors_calculator.pb.h" | ||||||
|  | #include "mediapipe/framework/calculator.pb.h" | ||||||
|  | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | #include "mediapipe/framework/formats/tensor.h" | ||||||
|  | #include "mediapipe/framework/port/gmock.h" | ||||||
|  | #include "mediapipe/framework/port/gtest.h" | ||||||
|  | #include "mediapipe/framework/port/parse_text_proto.h" | ||||||
|  | #include "mediapipe/framework/port/status_matchers.h" | ||||||
|  | #include "mediapipe/framework/timestamp.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | using ::mediapipe::CalculatorGraphConfig; | ||||||
|  | using ::testing::ElementsAreArray; | ||||||
|  | using ::testing::Not; | ||||||
|  | using Tensors = std::vector<Tensor>; | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | struct TensorElementType { | ||||||
|  |   static constexpr Tensor::ElementType value = Tensor::ElementType::kNone; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | template <> | ||||||
|  | struct TensorElementType<float> { | ||||||
|  |   static constexpr Tensor::ElementType value = Tensor::ElementType::kFloat32; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | template <> | ||||||
|  | struct TensorElementType<std::int8_t> { | ||||||
|  |   static constexpr Tensor::ElementType value = Tensor::ElementType::kInt8; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | template <> | ||||||
|  | struct TensorElementType<std::uint8_t> { | ||||||
|  |   static constexpr Tensor::ElementType value = Tensor::ElementType::kUInt8; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | template <> | ||||||
|  | struct TensorElementType<std::int32_t> { | ||||||
|  |   static constexpr Tensor::ElementType value = Tensor::ElementType::kInt32; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | Tensor MakeTensor(std::initializer_list<int> shape, | ||||||
|  |                   std::initializer_list<T> values) { | ||||||
|  |   Tensor tensor(TensorElementType<T>::value, shape); | ||||||
|  |   CHECK_EQ(values.size(), tensor.shape().num_elements()) | ||||||
|  |       << "The size of `values` is incompatible with `shape`"; | ||||||
|  |   absl::c_copy(values, tensor.GetCpuWriteView().buffer<T>()); | ||||||
|  |   return tensor; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | template <typename T> | ||||||
|  | void ValidateTensor(const Tensor& tensor, | ||||||
|  |                     const std::vector<int>& expected_shape, | ||||||
|  |                     const std::vector<T>& expected_values) { | ||||||
|  |   ASSERT_EQ(tensor.element_type(), TensorElementType<T>::value); | ||||||
|  |   EXPECT_EQ(tensor.shape().dims, expected_shape); | ||||||
|  |   EXPECT_EQ(tensor.shape().num_elements(), expected_values.size()); | ||||||
|  | 
 | ||||||
|  |   auto* tensor_buffer = tensor.GetCpuReadView().buffer<T>(); | ||||||
|  |   const std::vector<T> tensor_values( | ||||||
|  |       tensor_buffer, tensor_buffer + tensor.shape().num_elements()); | ||||||
|  |   EXPECT_THAT(tensor_values, ElementsAreArray(expected_values)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(FeedbackTensorsCalculatorTest, AppendsFeedback) { | ||||||
|  |   auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( | ||||||
|  |     input_stream: "input" | ||||||
|  |     input_stream: "feedback" | ||||||
|  |     node { | ||||||
|  |       calculator: "FeedbackTensorsCalculator" | ||||||
|  |       input_stream: "INPUT_TENSORS:input" | ||||||
|  |       input_stream: "FEEDBACK_TENSORS:feedback" | ||||||
|  |       output_stream: "TENSORS:output" | ||||||
|  |       options: { | ||||||
|  |         [mediapipe.FeedbackTensorsCalculatorOptions.ext] { | ||||||
|  |           feedback_tensor_shape: { dims: 2 dims: 3 } | ||||||
|  |           location: APPENDED | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   )pb"); | ||||||
|  |   std::vector<Packet> output_packets; | ||||||
|  |   tool::AddVectorSink("output", &graph_config, &output_packets); | ||||||
|  | 
 | ||||||
|  |   CalculatorGraph graph; | ||||||
|  |   MP_ASSERT_OK(graph.Initialize(graph_config)); | ||||||
|  |   MP_ASSERT_OK(graph.StartRun({})); | ||||||
|  | 
 | ||||||
|  |   auto initial_input_tensors = std::make_unique<Tensors>(); | ||||||
|  |   initial_input_tensors->push_back( | ||||||
|  |       MakeTensor<std::int32_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "input", Adopt(initial_input_tensors.release()).At(Timestamp(1)))); | ||||||
|  |   // At the beginning, the loopback packet with the model feedback is missing.
 | ||||||
|  |   // The calculator has to assume it's all-zero with the shape from the options.
 | ||||||
|  | 
 | ||||||
|  |   auto later_input_tensors = std::make_unique<Tensors>(); | ||||||
|  |   later_input_tensors->push_back( | ||||||
|  |       MakeTensor<std::int32_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "input", Adopt(later_input_tensors.release()).At(Timestamp(2)))); | ||||||
|  |   auto later_feedback_tensors = std::make_unique<Tensors>(); | ||||||
|  |   later_feedback_tensors->push_back( | ||||||
|  |       MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2)))); | ||||||
|  | 
 | ||||||
|  |   MP_ASSERT_OK(graph.CloseAllInputStreams()) | ||||||
|  |       << "Couldn't close the graph inputs"; | ||||||
|  |   MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run"; | ||||||
|  | 
 | ||||||
|  |   ASSERT_EQ(output_packets.size(), 2); | ||||||
|  | 
 | ||||||
|  |   const Tensors& initial_combined_tensors = output_packets[0].Get<Tensors>(); | ||||||
|  |   ASSERT_EQ(initial_combined_tensors.size(), 2); | ||||||
|  |   ValidateTensor<std::int32_t>(initial_combined_tensors[0], | ||||||
|  |                                /*expected_shape=*/{2, 4}, | ||||||
|  |                                /*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8}); | ||||||
|  |   // The initial feedback is zero.
 | ||||||
|  |   ValidateTensor<float>(initial_combined_tensors[1], /*expected_shape=*/{2, 3}, | ||||||
|  |                         /*expected_values=*/{0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); | ||||||
|  | 
 | ||||||
|  |   const Tensors& later_combined_tensors = output_packets[1].Get<Tensors>(); | ||||||
|  |   ASSERT_EQ(later_combined_tensors.size(), 2); | ||||||
|  |   ValidateTensor<std::int32_t>(later_combined_tensors[0], | ||||||
|  |                                /*expected_shape=*/{2, 4}, | ||||||
|  |                                /*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1}); | ||||||
|  |   // Afterwards, the provided feedback is passed through.
 | ||||||
|  |   ValidateTensor<float>( | ||||||
|  |       later_combined_tensors[1], /*expected_shape=*/{2, 3}, | ||||||
|  |       /*expected_values=*/{-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(FeedbackTensorsCalculatorTest, PrependsFeedback) { | ||||||
|  |   auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( | ||||||
|  |     input_stream: "input" | ||||||
|  |     input_stream: "feedback" | ||||||
|  |     node { | ||||||
|  |       calculator: "FeedbackTensorsCalculator" | ||||||
|  |       input_stream: "INPUT_TENSORS:input" | ||||||
|  |       input_stream: "FEEDBACK_TENSORS:feedback" | ||||||
|  |       output_stream: "TENSORS:output" | ||||||
|  |       options: { | ||||||
|  |         [mediapipe.FeedbackTensorsCalculatorOptions.ext] { | ||||||
|  |           feedback_tensor_shape: { dims: 3 dims: 2 } | ||||||
|  |           location: PREPENDED | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   )pb"); | ||||||
|  |   std::vector<Packet> output_packets; | ||||||
|  |   tool::AddVectorSink("output", &graph_config, &output_packets); | ||||||
|  | 
 | ||||||
|  |   CalculatorGraph graph; | ||||||
|  |   MP_ASSERT_OK(graph.Initialize(graph_config)); | ||||||
|  |   MP_ASSERT_OK(graph.StartRun({})); | ||||||
|  | 
 | ||||||
|  |   auto initial_input_tensors = std::make_unique<Tensors>(); | ||||||
|  |   initial_input_tensors->push_back( | ||||||
|  |       MakeTensor<std::int8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "input", Adopt(initial_input_tensors.release()).At(Timestamp(1)))); | ||||||
|  |   // At the beginning, the loopback packet with the model feedback is missing.
 | ||||||
|  |   // The calculator has to assume it's all-zero with the shape from the options.
 | ||||||
|  | 
 | ||||||
|  |   auto later_input_tensors = std::make_unique<Tensors>(); | ||||||
|  |   later_input_tensors->push_back( | ||||||
|  |       MakeTensor<std::int8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "input", Adopt(later_input_tensors.release()).At(Timestamp(2)))); | ||||||
|  |   auto later_feedback_tensors = std::make_unique<Tensors>(); | ||||||
|  |   later_feedback_tensors->push_back( | ||||||
|  |       MakeTensor({3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2)))); | ||||||
|  | 
 | ||||||
|  |   MP_ASSERT_OK(graph.CloseAllInputStreams()) | ||||||
|  |       << "Couldn't close the graph inputs"; | ||||||
|  |   MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run"; | ||||||
|  | 
 | ||||||
|  |   ASSERT_EQ(output_packets.size(), 2); | ||||||
|  | 
 | ||||||
|  |   const Tensors& initial_combined_tensors = output_packets[0].Get<Tensors>(); | ||||||
|  |   ASSERT_EQ(initial_combined_tensors.size(), 2); | ||||||
|  |   // The initial feedback is zero.
 | ||||||
|  |   ValidateTensor<float>(initial_combined_tensors[0], /*expected_shape=*/{3, 2}, | ||||||
|  |                         /*expected_values=*/{0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); | ||||||
|  |   ValidateTensor<std::int8_t>(initial_combined_tensors[1], | ||||||
|  |                               /*expected_shape=*/{2, 4}, | ||||||
|  |                               /*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8}); | ||||||
|  | 
 | ||||||
|  |   const Tensors& later_combined_tensors = output_packets[1].Get<Tensors>(); | ||||||
|  |   ASSERT_EQ(later_combined_tensors.size(), 2); | ||||||
|  |   // Afterwards, the provided feedback is passed through.
 | ||||||
|  |   ValidateTensor<float>( | ||||||
|  |       later_combined_tensors[0], /*expected_shape=*/{3, 2}, | ||||||
|  |       /*expected_values=*/{-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}); | ||||||
|  |   ValidateTensor<std::int8_t>(later_combined_tensors[1], | ||||||
|  |                               /*expected_shape=*/{2, 4}, | ||||||
|  |                               /*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1}); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(FeedbackTensorsCalculatorTest, NoFeedback) { | ||||||
|  |   auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( | ||||||
|  |     input_stream: "input" | ||||||
|  |     input_stream: "feedback" | ||||||
|  |     node { | ||||||
|  |       calculator: "FeedbackTensorsCalculator" | ||||||
|  |       input_stream: "INPUT_TENSORS:input" | ||||||
|  |       input_stream: "FEEDBACK_TENSORS:feedback" | ||||||
|  |       output_stream: "TENSORS:output" | ||||||
|  |       options: { | ||||||
|  |         [mediapipe.FeedbackTensorsCalculatorOptions.ext] { | ||||||
|  |           feedback_tensor_shape: { dims: 3 dims: 4 } | ||||||
|  |           location: NONE | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   )pb"); | ||||||
|  |   std::vector<Packet> output_packets; | ||||||
|  |   tool::AddVectorSink("output", &graph_config, &output_packets); | ||||||
|  | 
 | ||||||
|  |   CalculatorGraph graph; | ||||||
|  |   MP_ASSERT_OK(graph.Initialize(graph_config)); | ||||||
|  |   MP_ASSERT_OK(graph.StartRun({})); | ||||||
|  | 
 | ||||||
|  |   auto initial_input_tensors = std::make_unique<Tensors>(); | ||||||
|  |   initial_input_tensors->push_back( | ||||||
|  |       MakeTensor<std::uint8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "input", Adopt(initial_input_tensors.release()).At(Timestamp(1)))); | ||||||
|  |   // At the beginning, the loopback packet with the model feedback is missing.
 | ||||||
|  | 
 | ||||||
|  |   auto later_input_tensors = std::make_unique<Tensors>(); | ||||||
|  |   later_input_tensors->push_back( | ||||||
|  |       MakeTensor<std::uint8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "input", Adopt(later_input_tensors.release()).At(Timestamp(2)))); | ||||||
|  |   // This feedback should be ignored due to `location: NONE`.
 | ||||||
|  |   auto later_feedback_tensors = std::make_unique<Tensors>(); | ||||||
|  |   later_feedback_tensors->push_back( | ||||||
|  |       MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2)))); | ||||||
|  | 
 | ||||||
|  |   MP_ASSERT_OK(graph.CloseAllInputStreams()) | ||||||
|  |       << "Couldn't close the graph inputs"; | ||||||
|  |   MP_ASSERT_OK(graph.WaitUntilDone()) << "Couldn't finalize the graph run"; | ||||||
|  | 
 | ||||||
|  |   ASSERT_EQ(output_packets.size(), 2); | ||||||
|  | 
 | ||||||
|  |   const Tensors& initial_combined_tensors = output_packets[0].Get<Tensors>(); | ||||||
|  |   ASSERT_EQ(initial_combined_tensors.size(), 1); | ||||||
|  |   ValidateTensor<std::uint8_t>(initial_combined_tensors[0], | ||||||
|  |                                /*expected_shape=*/{2, 4}, | ||||||
|  |                                /*expected_values=*/{1, 2, 3, 4, 5, 6, 7, 8}); | ||||||
|  |   // No feedback due to `location: NONE`.
 | ||||||
|  | 
 | ||||||
|  |   const Tensors& later_combined_tensors = output_packets[1].Get<Tensors>(); | ||||||
|  |   ASSERT_EQ(later_combined_tensors.size(), 1); | ||||||
|  |   ValidateTensor<std::uint8_t>(later_combined_tensors[0], | ||||||
|  |                                /*expected_shape=*/{2, 4}, | ||||||
|  |                                /*expected_values=*/{8, 7, 6, 5, 4, 3, 2, 1}); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(FeedbackTensorsCalculatorTest, ChecksTensorNumber) { | ||||||
|  |   auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( | ||||||
|  |     input_stream: "input" | ||||||
|  |     input_stream: "feedback" | ||||||
|  |     node { | ||||||
|  |       calculator: "FeedbackTensorsCalculator" | ||||||
|  |       input_stream: "INPUT_TENSORS:input" | ||||||
|  |       input_stream: "FEEDBACK_TENSORS:feedback" | ||||||
|  |       output_stream: "TENSORS:output" | ||||||
|  |       options: { | ||||||
|  |         [mediapipe.FeedbackTensorsCalculatorOptions.ext] { | ||||||
|  |           num_feedback_tensors: 2 | ||||||
|  |           feedback_tensor_shape: { dims: 2 dims: 3 } | ||||||
|  |           location: PREPENDED | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   )pb"); | ||||||
|  |   std::vector<Packet> output_packets; | ||||||
|  |   tool::AddVectorSink("output", &graph_config, &output_packets); | ||||||
|  | 
 | ||||||
|  |   CalculatorGraph graph; | ||||||
|  |   MP_ASSERT_OK(graph.Initialize(graph_config)); | ||||||
|  |   MP_ASSERT_OK(graph.StartRun({})); | ||||||
|  | 
 | ||||||
|  |   auto initial_input_tensors = std::make_unique<Tensors>(); | ||||||
|  |   initial_input_tensors->push_back( | ||||||
|  |       MakeTensor<std::uint8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "input", Adopt(initial_input_tensors.release()).At(Timestamp(1)))); | ||||||
|  |   // At the beginning, the loopback packet with the model feedback is missing.
 | ||||||
|  | 
 | ||||||
|  |   auto later_input_tensors = std::make_unique<Tensors>(); | ||||||
|  |   later_input_tensors->push_back( | ||||||
|  |       MakeTensor<std::uint8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "input", Adopt(later_input_tensors.release()).At(Timestamp(2)))); | ||||||
|  |   // This feedback should be ignored due to `location: NONE`.
 | ||||||
|  |   auto later_feedback_tensors = std::make_unique<Tensors>(); | ||||||
|  |   later_feedback_tensors->push_back( | ||||||
|  |       MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2)))); | ||||||
|  | 
 | ||||||
|  |   MP_ASSERT_OK(graph.CloseAllInputStreams()) | ||||||
|  |       << "Couldn't close the graph inputs"; | ||||||
|  |   EXPECT_THAT(graph.WaitUntilDone(), Not(IsOk())) | ||||||
|  |       << "Tensor number mismatch missed"; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(FeedbackTensorsCalculatorTest, ChecksShape) { | ||||||
|  |   auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( | ||||||
|  |     input_stream: "input" | ||||||
|  |     input_stream: "feedback" | ||||||
|  |     node { | ||||||
|  |       calculator: "FeedbackTensorsCalculator" | ||||||
|  |       input_stream: "INPUT_TENSORS:input" | ||||||
|  |       input_stream: "FEEDBACK_TENSORS:feedback" | ||||||
|  |       output_stream: "TENSORS:output" | ||||||
|  |       options: { | ||||||
|  |         [mediapipe.FeedbackTensorsCalculatorOptions.ext] { | ||||||
|  |           feedback_tensor_shape: { dims: 3 dims: 4 } | ||||||
|  |           location: APPENDED | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   )pb"); | ||||||
|  |   std::vector<Packet> output_packets; | ||||||
|  |   tool::AddVectorSink("output", &graph_config, &output_packets); | ||||||
|  | 
 | ||||||
|  |   CalculatorGraph graph; | ||||||
|  |   MP_ASSERT_OK(graph.Initialize(graph_config)); | ||||||
|  |   MP_ASSERT_OK(graph.StartRun({})); | ||||||
|  | 
 | ||||||
|  |   auto initial_input_tensors = std::make_unique<Tensors>(); | ||||||
|  |   initial_input_tensors->push_back( | ||||||
|  |       MakeTensor<std::uint8_t>({2, 4}, {1, 2, 3, 4, 5, 6, 7, 8})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "input", Adopt(initial_input_tensors.release()).At(Timestamp(1)))); | ||||||
|  |   // At the beginning, the loopback packet with the model feedback is missing.
 | ||||||
|  | 
 | ||||||
|  |   auto later_input_tensors = std::make_unique<Tensors>(); | ||||||
|  |   later_input_tensors->push_back( | ||||||
|  |       MakeTensor<std::uint8_t>({2, 4}, {8, 7, 6, 5, 4, 3, 2, 1})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "input", Adopt(later_input_tensors.release()).At(Timestamp(2)))); | ||||||
|  |   // This feedback should be ignored due to `location: NONE`.
 | ||||||
|  |   auto later_feedback_tensors = std::make_unique<Tensors>(); | ||||||
|  |   later_feedback_tensors->push_back( | ||||||
|  |       MakeTensor({2, 3}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f})); | ||||||
|  |   MP_ASSERT_OK(graph.AddPacketToInputStream( | ||||||
|  |       "feedback", Adopt(later_feedback_tensors.release()).At(Timestamp(2)))); | ||||||
|  | 
 | ||||||
|  |   MP_ASSERT_OK(graph.CloseAllInputStreams()) | ||||||
|  |       << "Couldn't close the graph inputs"; | ||||||
|  |   EXPECT_THAT(graph.WaitUntilDone(), Not(IsOk())) | ||||||
|  |       << "Tensor shape mismatch missed"; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
							
								
								
									
										74
									
								
								mediapipe/calculators/tensor/text_to_tensor_calculator.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								mediapipe/calculators/tensor/text_to_tensor_calculator.cc
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,74 @@ | ||||||
|  | // Copyright 2022 The MediaPipe Authors.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #include <cstring> | ||||||
|  | #include <string> | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "absl/status/status.h" | ||||||
|  | #include "absl/strings/string_view.h" | ||||||
|  | #include "mediapipe/framework/api2/node.h" | ||||||
|  | #include "mediapipe/framework/api2/port.h" | ||||||
|  | #include "mediapipe/framework/calculator_context.h" | ||||||
|  | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | #include "mediapipe/framework/formats/tensor.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace api2 { | ||||||
|  | 
 | ||||||
|  | // Trivially converts an input string into a Tensor that stores a copy of
 | ||||||
|  | // the string.
 | ||||||
|  | //
 | ||||||
|  | // Inputs:
 | ||||||
|  | //   TEXT - std::string
 | ||||||
|  | //
 | ||||||
|  | // Outputs:
 | ||||||
|  | //   TENSORS - std::vector<Tensor>
 | ||||||
|  | //     Vector containing a single Tensor storing a copy of the input string.
 | ||||||
|  | //     Note that the underlying buffer of the Tensor is not necessarily
 | ||||||
|  | //     null-terminated. It is the graph writer's responsibility to copy the
 | ||||||
|  | //     correct number of characters when copying from this Tensor's buffer.
 | ||||||
|  | //
 | ||||||
|  | // Example:
 | ||||||
|  | //   node {
 | ||||||
|  | //     calculator: "TextToTensorCalculator"
 | ||||||
|  | //     input_stream: "TEXT:text"
 | ||||||
|  | //     output_stream: "TENSORS:tensors"
 | ||||||
|  | //   }
 | ||||||
|  | class TextToTensorCalculator : public Node { | ||||||
|  |  public: | ||||||
|  |   static constexpr Input<std::string> kTextIn{"TEXT"}; | ||||||
|  |   static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"}; | ||||||
|  | 
 | ||||||
|  |   MEDIAPIPE_NODE_CONTRACT(kTextIn, kTensorsOut); | ||||||
|  | 
 | ||||||
|  |   absl::Status Process(CalculatorContext* cc) override; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | absl::Status TextToTensorCalculator::Process(CalculatorContext* cc) { | ||||||
|  |   absl::string_view text = kTextIn(cc).Get(); | ||||||
|  |   int input_len = static_cast<int>(text.length()); | ||||||
|  | 
 | ||||||
|  |   std::vector<Tensor> result; | ||||||
|  |   result.push_back({Tensor::ElementType::kChar, Tensor::Shape({input_len})}); | ||||||
|  |   std::memcpy(result[0].GetCpuWriteView().buffer<char>(), text.data(), | ||||||
|  |               input_len * sizeof(char)); | ||||||
|  |   kTensorsOut(cc).Send(std::move(result)); | ||||||
|  |   return absl::OkStatus(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | MEDIAPIPE_REGISTER_NODE(TextToTensorCalculator); | ||||||
|  | 
 | ||||||
|  | }  // namespace api2
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
|  | @ -0,0 +1,88 @@ | ||||||
|  | // Copyright 2022 The MediaPipe Authors.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #include <cstring> | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "absl/status/status.h" | ||||||
|  | #include "absl/status/statusor.h" | ||||||
|  | #include "absl/strings/string_view.h" | ||||||
|  | #include "absl/strings/substitute.h" | ||||||
|  | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | #include "mediapipe/framework/calculator_graph.h" | ||||||
|  | #include "mediapipe/framework/formats/tensor.h" | ||||||
|  | #include "mediapipe/framework/packet.h" | ||||||
|  | #include "mediapipe/framework/port/gmock.h" | ||||||
|  | #include "mediapipe/framework/port/gtest.h" | ||||||
|  | #include "mediapipe/framework/port/parse_text_proto.h" | ||||||
|  | #include "mediapipe/framework/port/status_matchers.h" | ||||||
|  | #include "mediapipe/framework/tool/options_map.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | using ::testing::StrEq; | ||||||
|  | 
 | ||||||
|  | absl::StatusOr<std::string> RunTextToTensorCalculator(absl::string_view text) { | ||||||
|  |   auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>( | ||||||
|  |       R"pb( | ||||||
|  |         input_stream: "text" | ||||||
|  |         output_stream: "tensors" | ||||||
|  |         node { | ||||||
|  |           calculator: "TextToTensorCalculator" | ||||||
|  |           input_stream: "TEXT:text" | ||||||
|  |           output_stream: "TENSORS:tensors" | ||||||
|  |         } | ||||||
|  |       )pb"); | ||||||
|  |   std::vector<Packet> output_packets; | ||||||
|  |   tool::AddVectorSink("tensors", &graph_config, &output_packets); | ||||||
|  | 
 | ||||||
|  |   // Run the graph.
 | ||||||
|  |   CalculatorGraph graph; | ||||||
|  |   MP_RETURN_IF_ERROR(graph.Initialize(graph_config)); | ||||||
|  |   MP_RETURN_IF_ERROR(graph.StartRun({})); | ||||||
|  |   MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( | ||||||
|  |       "text", MakePacket<std::string>(text).At(Timestamp(0)))); | ||||||
|  |   MP_RETURN_IF_ERROR(graph.WaitUntilIdle()); | ||||||
|  | 
 | ||||||
|  |   if (output_packets.size() != 1) { | ||||||
|  |     return absl::InvalidArgumentError(absl::Substitute( | ||||||
|  |         "output_packets has size $0, expected 1", output_packets.size())); | ||||||
|  |   } | ||||||
|  |   const std::vector<Tensor>& tensor_vec = | ||||||
|  |       output_packets[0].Get<std::vector<Tensor>>(); | ||||||
|  |   if (tensor_vec.size() != 1) { | ||||||
|  |     return absl::InvalidArgumentError(absl::Substitute( | ||||||
|  |         "tensor_vec has size $0, expected 1", tensor_vec.size())); | ||||||
|  |   } | ||||||
|  |   if (tensor_vec[0].element_type() != Tensor::ElementType::kChar) { | ||||||
|  |     return absl::InvalidArgumentError(absl::Substitute( | ||||||
|  |         "tensor has element type $0, expected $1", tensor_vec[0].element_type(), | ||||||
|  |         Tensor::ElementType::kChar)); | ||||||
|  |   } | ||||||
|  |   const char* buffer = tensor_vec[0].GetCpuReadView().buffer<char>(); | ||||||
|  |   return std::string(buffer, text.length()); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(TextToTensorCalculatorTest, FooBarBaz) { | ||||||
|  |   EXPECT_THAT(RunTextToTensorCalculator("Foo. Bar? Baz!"), | ||||||
|  |               IsOkAndHolds(StrEq("Foo. Bar? Baz!"))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(TextToTensorCalculatorTest, Empty) { | ||||||
|  |   EXPECT_THAT(RunTextToTensorCalculator(""), IsOkAndHolds(StrEq(""))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
|  | @ -231,7 +231,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, | ||||||
|   // Session must be set.
 |   // Session must be set.
 | ||||||
|   ASSERT_NE(session.session, nullptr); |   ASSERT_NE(session.session, nullptr); | ||||||
|   std::vector<tensorflow::DeviceAttributes> devices; |   std::vector<tensorflow::DeviceAttributes> devices; | ||||||
|   ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK()); |   ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::OkStatus()); | ||||||
|   EXPECT_THAT(devices.size(), 10); |   EXPECT_THAT(devices.size(), 10); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -220,7 +220,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, | ||||||
|   // Session must be set.
 |   // Session must be set.
 | ||||||
|   ASSERT_NE(session.session, nullptr); |   ASSERT_NE(session.session, nullptr); | ||||||
|   std::vector<tensorflow::DeviceAttributes> devices; |   std::vector<tensorflow::DeviceAttributes> devices; | ||||||
|   ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::Status::OK()); |   ASSERT_EQ(session.session->ListDevices(&devices), tensorflow::OkStatus()); | ||||||
|   EXPECT_THAT(devices.size(), 10); |   EXPECT_THAT(devices.size(), 10); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -135,6 +135,7 @@ filegroup( | ||||||
|     srcs = [ |     srcs = [ | ||||||
|         "testdata/anchor_golden_file_0.txt", |         "testdata/anchor_golden_file_0.txt", | ||||||
|         "testdata/anchor_golden_file_1.txt", |         "testdata/anchor_golden_file_1.txt", | ||||||
|  |         "testdata/anchor_golden_file_2.txt", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -13,6 +13,7 @@ | ||||||
| // limitations under the License.
 | // limitations under the License.
 | ||||||
| 
 | 
 | ||||||
| #include <cmath> | #include <cmath> | ||||||
|  | #include <utility> | ||||||
| #include <vector> | #include <vector> | ||||||
| 
 | 
 | ||||||
| #include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" | #include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" | ||||||
|  | @ -24,6 +25,19 @@ namespace mediapipe { | ||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
|  | struct MultiScaleAnchorInfo { | ||||||
|  |   int32 level; | ||||||
|  |   std::vector<float> aspect_ratios; | ||||||
|  |   std::vector<float> scales; | ||||||
|  |   std::pair<float, float> base_anchor_size; | ||||||
|  |   std::pair<float, float> anchor_stride; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | struct FeatureMapDim { | ||||||
|  |   int height; | ||||||
|  |   int width; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
| float CalculateScale(float min_scale, float max_scale, int stride_index, | float CalculateScale(float min_scale, float max_scale, int stride_index, | ||||||
|                      int num_strides) { |                      int num_strides) { | ||||||
|   if (num_strides == 1) { |   if (num_strides == 1) { | ||||||
|  | @ -34,6 +48,71 @@ float CalculateScale(float min_scale, float max_scale, int stride_index, | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | int GetNumLayers(const SsdAnchorsCalculatorOptions& options) { | ||||||
|  |   if (options.multiscale_anchor_generation()) { | ||||||
|  |     return (options.max_level() - options.min_level() + 1); | ||||||
|  |   } | ||||||
|  |   return options.num_layers(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | FeatureMapDim GetFeatureMapDimensions( | ||||||
|  |     const SsdAnchorsCalculatorOptions& options, int index) { | ||||||
|  |   FeatureMapDim feature_map_dims; | ||||||
|  |   if (options.feature_map_height_size()) { | ||||||
|  |     feature_map_dims.height = options.feature_map_height(index); | ||||||
|  |     feature_map_dims.width = options.feature_map_width(index); | ||||||
|  |   } else { | ||||||
|  |     const int stride = options.strides(index); | ||||||
|  |     feature_map_dims.height = | ||||||
|  |         std::ceil(1.0f * options.input_size_height() / stride); | ||||||
|  |     feature_map_dims.width = | ||||||
|  |         std::ceil(1.0f * options.input_size_width() / stride); | ||||||
|  |   } | ||||||
|  |   return feature_map_dims; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Although we have stride for both x and y, only one value is used for offset
 | ||||||
|  | // calculation. See
 | ||||||
|  | // tensorflow_models/object_detection/anchor_generators/multiscale_grid_anchor_generator.py;l=121
 | ||||||
|  | std::pair<float, float> GetMultiScaleAnchorOffset( | ||||||
|  |     const SsdAnchorsCalculatorOptions& options, const float stride, | ||||||
|  |     const int level) { | ||||||
|  |   std::pair<float, float> result(0., 0.); | ||||||
|  |   int denominator = std::pow(2, level); | ||||||
|  |   if (options.input_size_height() % denominator == 0 || | ||||||
|  |       options.input_size_height() == 1) { | ||||||
|  |     result.first = stride / 2.0; | ||||||
|  |   } | ||||||
|  |   if (options.input_size_width() % denominator == 0 || | ||||||
|  |       options.input_size_width() == 1) { | ||||||
|  |     result.second = stride / 2.0; | ||||||
|  |   } | ||||||
|  |   return result; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void NormalizeAnchor(const int input_height, const int input_width, | ||||||
|  |                      Anchor* anchor) { | ||||||
|  |   anchor->set_h(anchor->h() / (float)input_height); | ||||||
|  |   anchor->set_w(anchor->w() / (float)input_width); | ||||||
|  |   anchor->set_y_center(anchor->y_center() / (float)input_height); | ||||||
|  |   anchor->set_x_center(anchor->x_center() / (float)input_width); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | Anchor CalculateAnchorBox(const int y_center, const int x_center, | ||||||
|  |                           const float scale, const float aspect_ratio, | ||||||
|  |                           const std::pair<float, float> base_anchor_size, | ||||||
|  |                           // y-height first
 | ||||||
|  |                           const std::pair<float, float> anchor_stride, | ||||||
|  |                           const std::pair<float, float> anchor_offset) { | ||||||
|  |   Anchor result; | ||||||
|  |   float ratio_sqrt = std::sqrt(aspect_ratio); | ||||||
|  |   result.set_h(scale * base_anchor_size.first / ratio_sqrt); | ||||||
|  |   result.set_w(scale * ratio_sqrt * base_anchor_size.second); | ||||||
|  |   result.set_y_center(y_center * anchor_stride.first + anchor_offset.first); | ||||||
|  |   result.set_x_center(x_center * anchor_stride.second + anchor_offset.second); | ||||||
|  |   return result; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
| 
 | 
 | ||||||
| // Generate anchors for SSD object detection model.
 | // Generate anchors for SSD object detection model.
 | ||||||
|  | @ -95,9 +174,77 @@ class SsdAnchorsCalculator : public CalculatorBase { | ||||||
|  private: |  private: | ||||||
|   static absl::Status GenerateAnchors( |   static absl::Status GenerateAnchors( | ||||||
|       std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options); |       std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options); | ||||||
|  | 
 | ||||||
|  |   static absl::Status GenerateMultiScaleAnchors( | ||||||
|  |       std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options); | ||||||
| }; | }; | ||||||
| REGISTER_CALCULATOR(SsdAnchorsCalculator); | REGISTER_CALCULATOR(SsdAnchorsCalculator); | ||||||
| 
 | 
 | ||||||
|  | // Generates grid anchors on the fly corresponding to multiple CNN layers as
 | ||||||
|  | // described in:
 | ||||||
|  | // "Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002)
 | ||||||
|  | // T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar
 | ||||||
|  | absl::Status SsdAnchorsCalculator::GenerateMultiScaleAnchors( | ||||||
|  |     std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options) { | ||||||
|  |   std::vector<MultiScaleAnchorInfo> anchor_infos; | ||||||
|  |   for (int i = options.min_level(); i <= options.max_level(); ++i) { | ||||||
|  |     MultiScaleAnchorInfo current_anchor_info; | ||||||
|  |     // level
 | ||||||
|  |     current_anchor_info.level = i; | ||||||
|  |     // aspect_ratios
 | ||||||
|  |     for (const float aspect_ratio : options.aspect_ratios()) { | ||||||
|  |       current_anchor_info.aspect_ratios.push_back(aspect_ratio); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // scale
 | ||||||
|  |     for (int i = 0; i < options.scales_per_octave(); ++i) { | ||||||
|  |       current_anchor_info.scales.push_back( | ||||||
|  |           std::pow(2.0, (double)i / (double)options.scales_per_octave())); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // anchor stride
 | ||||||
|  |     float anchor_stride = std::pow(2.0, i); | ||||||
|  |     current_anchor_info.anchor_stride = | ||||||
|  |         std::make_pair(anchor_stride, anchor_stride); | ||||||
|  | 
 | ||||||
|  |     // base_anchor_size
 | ||||||
|  |     current_anchor_info.base_anchor_size = | ||||||
|  |         std::make_pair(anchor_stride * options.anchor_scale(), | ||||||
|  |                        anchor_stride * options.anchor_scale()); | ||||||
|  |     anchor_infos.push_back(current_anchor_info); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   for (unsigned int i = 0; i < anchor_infos.size(); ++i) { | ||||||
|  |     FeatureMapDim dimensions = GetFeatureMapDimensions(options, i); | ||||||
|  |     for (int y = 0; y < dimensions.height; ++y) { | ||||||
|  |       for (int x = 0; x < dimensions.width; ++x) { | ||||||
|  |         // loop over combination of scale and aspect ratio
 | ||||||
|  |         for (unsigned int j = 0; j < anchor_infos[i].aspect_ratios.size(); | ||||||
|  |              ++j) { | ||||||
|  |           for (unsigned int k = 0; k < anchor_infos[i].scales.size(); ++k) { | ||||||
|  |             Anchor anchor = CalculateAnchorBox( | ||||||
|  |                 /*y_center=*/y, /*x_center=*/x, anchor_infos[i].scales[k], | ||||||
|  |                 anchor_infos[i].aspect_ratios[j], | ||||||
|  |                 anchor_infos[i].base_anchor_size, | ||||||
|  |                 /*anchor_stride=*/anchor_infos[i].anchor_stride, | ||||||
|  |                 /*anchor_offset=*/ | ||||||
|  |                 GetMultiScaleAnchorOffset(options, | ||||||
|  |                                           anchor_infos[i].anchor_stride.first, | ||||||
|  |                                           anchor_infos[i].level)); | ||||||
|  |             if (options.normalize_coordinates()) { | ||||||
|  |               NormalizeAnchor(options.input_size_height(), | ||||||
|  |                               options.input_size_width(), &anchor); | ||||||
|  |             } | ||||||
|  |             anchors->push_back(anchor); | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   return absl::OkStatus(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| absl::Status SsdAnchorsCalculator::GenerateAnchors( | absl::Status SsdAnchorsCalculator::GenerateAnchors( | ||||||
|     std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options) { |     std::vector<Anchor>* anchors, const SsdAnchorsCalculatorOptions& options) { | ||||||
|   // Verify the options.
 |   // Verify the options.
 | ||||||
|  | @ -106,15 +253,21 @@ absl::Status SsdAnchorsCalculator::GenerateAnchors( | ||||||
|         "Both feature map shape and strides are missing. Must provide either " |         "Both feature map shape and strides are missing. Must provide either " | ||||||
|         "one."); |         "one."); | ||||||
|   } |   } | ||||||
|  |   const int kNumLayers = GetNumLayers(options); | ||||||
|  | 
 | ||||||
|   if (options.feature_map_height_size()) { |   if (options.feature_map_height_size()) { | ||||||
|     if (options.strides_size()) { |     if (options.strides_size()) { | ||||||
|       LOG(ERROR) << "Found feature map shapes. Strides will be ignored."; |       LOG(ERROR) << "Found feature map shapes. Strides will be ignored."; | ||||||
|     } |     } | ||||||
|     CHECK_EQ(options.feature_map_height_size(), options.num_layers()); |     CHECK_EQ(options.feature_map_height_size(), kNumLayers); | ||||||
|     CHECK_EQ(options.feature_map_height_size(), |     CHECK_EQ(options.feature_map_height_size(), | ||||||
|              options.feature_map_width_size()); |              options.feature_map_width_size()); | ||||||
|   } else { |   } else { | ||||||
|     CHECK_EQ(options.strides_size(), options.num_layers()); |     CHECK_EQ(options.strides_size(), kNumLayers); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   if (options.multiscale_anchor_generation()) { | ||||||
|  |     return GenerateMultiScaleAnchors(anchors, options); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   int layer_id = 0; |   int layer_id = 0; | ||||||
|  |  | ||||||
|  | @ -60,4 +60,30 @@ message SsdAnchorsCalculatorOptions { | ||||||
|   // This option can be used when the predicted anchor width and height are in |   // This option can be used when the predicted anchor width and height are in | ||||||
|   // pixels. |   // pixels. | ||||||
|   optional bool fixed_anchor_size = 14 [default = false]; |   optional bool fixed_anchor_size = 14 [default = false]; | ||||||
|  | 
 | ||||||
|  |   // Generates grid anchors on the fly corresponding to multiple CNN layers as | ||||||
|  |   // described in: | ||||||
|  |   // "Focal Loss for Dense Object Detection" (https://arxiv.org/abs/1708.02002) | ||||||
|  |   //  T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar | ||||||
|  |   optional bool multiscale_anchor_generation = 15 [default = false]; | ||||||
|  | 
 | ||||||
|  |   // minimum level in feature pyramid | ||||||
|  |   // for multiscale_anchor_generation only! | ||||||
|  |   optional int32 min_level = 16 [default = 3]; | ||||||
|  | 
 | ||||||
|  |   // maximum level in feature pyramid | ||||||
|  |   // for multiscale_anchor_generation only! | ||||||
|  |   optional int32 max_level = 17 [default = 7]; | ||||||
|  | 
 | ||||||
|  |   // Scale of anchor to feature stride | ||||||
|  |   // for multiscale_anchor_generation only! | ||||||
|  |   optional float anchor_scale = 18 [default = 4.0]; | ||||||
|  | 
 | ||||||
|  |   // Number of intermediate scale each scale octave | ||||||
|  |   // for multiscale_anchor_generation only! | ||||||
|  |   optional int32 scales_per_octave = 19 [default = 2]; | ||||||
|  | 
 | ||||||
|  |   // Whether to produce anchors in normalized coordinates. | ||||||
|  |   // for multiscale_anchor_generation only! | ||||||
|  |   optional bool normalize_coordinates = 20 [default = true]; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -33,9 +33,6 @@ std::string GetGoldenFilePath(const std::string& filename) { | ||||||
| 
 | 
 | ||||||
| void ParseAnchorsFromText(const std::string& text, | void ParseAnchorsFromText(const std::string& text, | ||||||
|                           std::vector<Anchor>* anchors) { |                           std::vector<Anchor>* anchors) { | ||||||
|   const std::string line_delimiter = "\n"; |  | ||||||
|   const std::string number_delimiter = ","; |  | ||||||
| 
 |  | ||||||
|   std::istringstream stream(text); |   std::istringstream stream(text); | ||||||
|   std::string line; |   std::string line; | ||||||
|   while (std::getline(stream, line)) { |   while (std::getline(stream, line)) { | ||||||
|  | @ -64,6 +61,8 @@ void CompareAnchors(const std::vector<Anchor>& anchors_0, | ||||||
|                 testing::FloatNear(anchor_1.x_center(), 1e-5)); |                 testing::FloatNear(anchor_1.x_center(), 1e-5)); | ||||||
|     EXPECT_THAT(anchor_0.y_center(), |     EXPECT_THAT(anchor_0.y_center(), | ||||||
|                 testing::FloatNear(anchor_1.y_center(), 1e-5)); |                 testing::FloatNear(anchor_1.y_center(), 1e-5)); | ||||||
|  |     EXPECT_THAT(anchor_0.h(), testing::FloatNear(anchor_1.h(), 1e-5)); | ||||||
|  |     EXPECT_THAT(anchor_0.w(), testing::FloatNear(anchor_1.w(), 1e-5)); | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -148,4 +147,40 @@ TEST(SsdAnchorCalculatorTest, MobileSSDConfig) { | ||||||
|   CompareAnchors(anchors, anchors_golden); |   CompareAnchors(anchors, anchors_golden); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST(SsdAnchorCalculatorTest, RetinaNetSSDConfig) { | ||||||
|  |   CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( | ||||||
|  |     calculator: "SsdAnchorsCalculator" | ||||||
|  |     output_side_packet: "anchors" | ||||||
|  |     options { | ||||||
|  |       [mediapipe.SsdAnchorsCalculatorOptions.ext] { | ||||||
|  |         input_size_height: 640 | ||||||
|  |         input_size_width: 640 | ||||||
|  |         strides: 64 | ||||||
|  |         strides: 128 | ||||||
|  |         aspect_ratios: 1.0 | ||||||
|  |         aspect_ratios: 2.0 | ||||||
|  |         aspect_ratios: 0.5 | ||||||
|  |         multiscale_anchor_generation: true | ||||||
|  |         min_level: 6 | ||||||
|  |         max_level: 7 | ||||||
|  |         anchor_scale: 3.0 | ||||||
|  |         scales_per_octave: 3 | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   )pb")); | ||||||
|  | 
 | ||||||
|  |   MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; | ||||||
|  |   const auto& anchors = | ||||||
|  |       runner.OutputSidePackets().Index(0).Get<std::vector<Anchor>>(); | ||||||
|  | 
 | ||||||
|  |   std::string anchors_string; | ||||||
|  |   MP_EXPECT_OK(mediapipe::file::GetContents( | ||||||
|  |       GetGoldenFilePath("anchor_golden_file_2.txt"), &anchors_string)); | ||||||
|  | 
 | ||||||
|  |   std::vector<Anchor> anchors_golden; | ||||||
|  |   ParseAnchorsFromText(anchors_string, &anchors_golden); | ||||||
|  | 
 | ||||||
|  |   CompareAnchors(anchors, anchors_golden); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace mediapipe
 | }  // namespace mediapipe
 | ||||||
|  |  | ||||||
							
								
								
									
										1125
									
								
								mediapipe/calculators/tflite/testdata/anchor_golden_file_2.txt
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										1125
									
								
								mediapipe/calculators/tflite/testdata/anchor_golden_file_2.txt
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							|  | @ -19,6 +19,7 @@ | ||||||
| #include "mediapipe/framework/calculator_framework.h" | #include "mediapipe/framework/calculator_framework.h" | ||||||
| #include "mediapipe/framework/packet.h" | #include "mediapipe/framework/packet.h" | ||||||
| #include "mediapipe/framework/port/ret_check.h" | #include "mediapipe/framework/port/ret_check.h" | ||||||
|  | #include "tensorflow/lite/allocation.h" | ||||||
| #include "tensorflow/lite/model.h" | #include "tensorflow/lite/model.h" | ||||||
| 
 | 
 | ||||||
| namespace mediapipe { | namespace mediapipe { | ||||||
|  | @ -32,6 +33,8 @@ namespace mediapipe { | ||||||
| //                it to the graph as input side packet or you can use some of
 | //                it to the graph as input side packet or you can use some of
 | ||||||
| //                calculators like LocalFileContentsCalculator to get model
 | //                calculators like LocalFileContentsCalculator to get model
 | ||||||
| //                blob and use it as input here.
 | //                blob and use it as input here.
 | ||||||
|  | //   MODEL_FD   - Tflite model file descriptor std::tuple<int, size_t, size_t>
 | ||||||
|  | //                containing (fd, offset, size).
 | ||||||
| //
 | //
 | ||||||
| // Output side packets:
 | // Output side packets:
 | ||||||
| //   MODEL - TfLite model. (std::unique_ptr<tflite::FlatBufferModel,
 | //   MODEL - TfLite model. (std::unique_ptr<tflite::FlatBufferModel,
 | ||||||
|  | @ -52,17 +55,42 @@ class TfLiteModelCalculator : public CalculatorBase { | ||||||
|                       std::function<void(tflite::FlatBufferModel*)>>; |                       std::function<void(tflite::FlatBufferModel*)>>; | ||||||
| 
 | 
 | ||||||
|   static absl::Status GetContract(CalculatorContract* cc) { |   static absl::Status GetContract(CalculatorContract* cc) { | ||||||
|  |     if (cc->InputSidePackets().HasTag("MODEL_BLOB")) { | ||||||
|       cc->InputSidePackets().Tag("MODEL_BLOB").Set<std::string>(); |       cc->InputSidePackets().Tag("MODEL_BLOB").Set<std::string>(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (cc->InputSidePackets().HasTag("MODEL_FD")) { | ||||||
|  |       cc->InputSidePackets() | ||||||
|  |           .Tag("MODEL_FD") | ||||||
|  |           .Set<std::tuple<int, size_t, size_t>>(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     cc->OutputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>(); |     cc->OutputSidePackets().Tag("MODEL").Set<TfLiteModelPtr>(); | ||||||
|     return absl::OkStatus(); |     return absl::OkStatus(); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   absl::Status Open(CalculatorContext* cc) override { |   absl::Status Open(CalculatorContext* cc) override { | ||||||
|     const Packet& model_packet = cc->InputSidePackets().Tag("MODEL_BLOB"); |     Packet model_packet; | ||||||
|  |     std::unique_ptr<tflite::FlatBufferModel> model; | ||||||
|  | 
 | ||||||
|  |     if (cc->InputSidePackets().HasTag("MODEL_BLOB")) { | ||||||
|  |       model_packet = cc->InputSidePackets().Tag("MODEL_BLOB"); | ||||||
|       const std::string& model_blob = model_packet.Get<std::string>(); |       const std::string& model_blob = model_packet.Get<std::string>(); | ||||||
|     std::unique_ptr<tflite::FlatBufferModel> model = |       model = tflite::FlatBufferModel::BuildFromBuffer(model_blob.data(), | ||||||
|         tflite::FlatBufferModel::BuildFromBuffer(model_blob.data(), |  | ||||||
|                                                        model_blob.size()); |                                                        model_blob.size()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if (cc->InputSidePackets().HasTag("MODEL_FD")) { | ||||||
|  |       model_packet = cc->InputSidePackets().Tag("MODEL_FD"); | ||||||
|  |       const auto& model_fd = | ||||||
|  |           model_packet.Get<std::tuple<int, size_t, size_t>>(); | ||||||
|  |       auto model_allocation = std::make_unique<tflite::MMAPAllocation>( | ||||||
|  |           std::get<0>(model_fd), std::get<1>(model_fd), std::get<2>(model_fd), | ||||||
|  |           tflite::DefaultErrorReporter()); | ||||||
|  |       model = tflite::FlatBufferModel::BuildFromAllocation( | ||||||
|  |           std::move(model_allocation), tflite::DefaultErrorReporter()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     RET_CHECK(model) << "Failed to load TfLite model from blob."; |     RET_CHECK(model) << "Failed to load TfLite model from blob."; | ||||||
| 
 | 
 | ||||||
|     cc->OutputSidePackets().Tag("MODEL").Set( |     cc->OutputSidePackets().Tag("MODEL").Set( | ||||||
|  |  | ||||||
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 9.3 KiB | 
|  | @ -90,6 +90,7 @@ | ||||||
|     { |     { | ||||||
|       "idiom" : "ipad", |       "idiom" : "ipad", | ||||||
|       "size" : "83.5x83.5", |       "size" : "83.5x83.5", | ||||||
|  |       "filename" : "83.5_c_Ipad_2x.png", | ||||||
|       "scale" : "2x" |       "scale" : "2x" | ||||||
|     }, |     }, | ||||||
|     { |     { | ||||||
|  |  | ||||||
|  | @ -21,7 +21,9 @@ cc_library( | ||||||
|         ":port", |         ":port", | ||||||
|         "//mediapipe/framework:calculator_base", |         "//mediapipe/framework:calculator_base", | ||||||
|         "//mediapipe/framework:calculator_contract", |         "//mediapipe/framework:calculator_contract", | ||||||
|  |         "@com_google_absl//absl/container:btree", | ||||||
|         "@com_google_absl//absl/container:flat_hash_map", |         "@com_google_absl//absl/container:flat_hash_map", | ||||||
|  |         "@com_google_absl//absl/strings", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -4,7 +4,9 @@ | ||||||
| #include <string> | #include <string> | ||||||
| #include <type_traits> | #include <type_traits> | ||||||
| 
 | 
 | ||||||
|  | #include "absl/container/btree_map.h" | ||||||
| #include "absl/container/flat_hash_map.h" | #include "absl/container/flat_hash_map.h" | ||||||
|  | #include "absl/strings/string_view.h" | ||||||
| #include "mediapipe/framework/api2/const_str.h" | #include "mediapipe/framework/api2/const_str.h" | ||||||
| #include "mediapipe/framework/api2/contract.h" | #include "mediapipe/framework/api2/contract.h" | ||||||
| #include "mediapipe/framework/api2/node.h" | #include "mediapipe/framework/api2/node.h" | ||||||
|  | @ -46,7 +48,7 @@ struct TagIndexLocation { | ||||||
| template <typename T> | template <typename T> | ||||||
| class TagIndexMap { | class TagIndexMap { | ||||||
|  public: |  public: | ||||||
|   std::vector<std::unique_ptr<T>>& operator[](const std::string& tag) { |   std::vector<std::unique_ptr<T>>& operator[](absl::string_view tag) { | ||||||
|     return map_[tag]; |     return map_[tag]; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -72,7 +74,7 @@ class TagIndexMap { | ||||||
| 
 | 
 | ||||||
|   // Note: entries are held by a unique_ptr to ensure pointers remain valid.
 |   // Note: entries are held by a unique_ptr to ensure pointers remain valid.
 | ||||||
|   // Should use absl::flat_hash_map but ordering keys for now.
 |   // Should use absl::flat_hash_map but ordering keys for now.
 | ||||||
|   std::map<std::string, std::vector<std::unique_ptr<T>>> map_; |   absl::btree_map<std::string, std::vector<std::unique_ptr<T>>> map_; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| class Graph; | class Graph; | ||||||
|  | @ -169,6 +171,16 @@ class SourceImpl { | ||||||
|     return AddTarget(dest); |     return AddTarget(dest); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   template <typename U> | ||||||
|  |   struct AllowCast | ||||||
|  |       : public std::integral_constant<bool, std::is_same_v<T, AnyType> && | ||||||
|  |                                                 !std::is_same_v<T, U>> {}; | ||||||
|  | 
 | ||||||
|  |   template <typename U, std::enable_if_t<AllowCast<U>{}, int> = 0> | ||||||
|  |   SourceImpl<IsSide, U> Cast() { | ||||||
|  |     return SourceImpl<IsSide, U>(base_); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  private: |  private: | ||||||
|   // Never null.
 |   // Never null.
 | ||||||
|   SourceBase* base_; |   SourceBase* base_; | ||||||
|  | @ -212,19 +224,19 @@ class NodeBase { | ||||||
|   // of its entries by index. However, for nodes without visible contracts we
 |   // of its entries by index. However, for nodes without visible contracts we
 | ||||||
|   // can't know whether a tag is indexable or not, so we would need the
 |   // can't know whether a tag is indexable or not, so we would need the
 | ||||||
|   // multi-port to also be usable as a port directly (representing index 0).
 |   // multi-port to also be usable as a port directly (representing index 0).
 | ||||||
|   MultiSource<> Out(const std::string& tag) { |   MultiSource<> Out(absl::string_view tag) { | ||||||
|     return MultiSource<>(&out_streams_[tag]); |     return MultiSource<>(&out_streams_[tag]); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   MultiDestination<> In(const std::string& tag) { |   MultiDestination<> In(absl::string_view tag) { | ||||||
|     return MultiDestination<>(&in_streams_[tag]); |     return MultiDestination<>(&in_streams_[tag]); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   MultiSideSource<> SideOut(const std::string& tag) { |   MultiSideSource<> SideOut(absl::string_view tag) { | ||||||
|     return MultiSideSource<>(&out_sides_[tag]); |     return MultiSideSource<>(&out_sides_[tag]); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   MultiSideDestination<> SideIn(const std::string& tag) { |   MultiSideDestination<> SideIn(absl::string_view tag) { | ||||||
|     return MultiSideDestination<>(&in_sides_[tag]); |     return MultiSideDestination<>(&in_sides_[tag]); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -359,11 +371,11 @@ class PacketGenerator { | ||||||
|  public: |  public: | ||||||
|   PacketGenerator(std::string type) : type_(std::move(type)) {} |   PacketGenerator(std::string type) : type_(std::move(type)) {} | ||||||
| 
 | 
 | ||||||
|   MultiSideSource<> SideOut(const std::string& tag) { |   MultiSideSource<> SideOut(absl::string_view tag) { | ||||||
|     return MultiSideSource<>(&out_sides_[tag]); |     return MultiSideSource<>(&out_sides_[tag]); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   MultiSideDestination<> SideIn(const std::string& tag) { |   MultiSideDestination<> SideIn(absl::string_view tag) { | ||||||
|     return MultiSideDestination<>(&in_sides_[tag]); |     return MultiSideDestination<>(&in_sides_[tag]); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -452,19 +464,19 @@ class Graph { | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   // Graph ports, non-typed.
 |   // Graph ports, non-typed.
 | ||||||
|   MultiSource<> In(const std::string& graph_input) { |   MultiSource<> In(absl::string_view graph_input) { | ||||||
|     return graph_boundary_.Out(graph_input); |     return graph_boundary_.Out(graph_input); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   MultiDestination<> Out(const std::string& graph_output) { |   MultiDestination<> Out(absl::string_view graph_output) { | ||||||
|     return graph_boundary_.In(graph_output); |     return graph_boundary_.In(graph_output); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   MultiSideSource<> SideIn(const std::string& graph_input) { |   MultiSideSource<> SideIn(absl::string_view graph_input) { | ||||||
|     return graph_boundary_.SideOut(graph_input); |     return graph_boundary_.SideOut(graph_input); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   MultiSideDestination<> SideOut(const std::string& graph_output) { |   MultiSideDestination<> SideOut(absl::string_view graph_output) { | ||||||
|     return graph_boundary_.SideIn(graph_output); |     return graph_boundary_.SideIn(graph_output); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -2,6 +2,7 @@ | ||||||
| 
 | 
 | ||||||
| #include <functional> | #include <functional> | ||||||
| 
 | 
 | ||||||
|  | #include "absl/strings/string_view.h" | ||||||
| #include "absl/strings/substitute.h" | #include "absl/strings/substitute.h" | ||||||
| #include "mediapipe/framework/api2/node.h" | #include "mediapipe/framework/api2/node.h" | ||||||
| #include "mediapipe/framework/api2/packet.h" | #include "mediapipe/framework/api2/packet.h" | ||||||
|  | @ -296,6 +297,32 @@ TEST(BuilderTest, EmptyTag) { | ||||||
|   EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); |   EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST(BuilderTest, StringLikeTags) { | ||||||
|  |   const char kA[] = "A"; | ||||||
|  |   const std::string kB = "B"; | ||||||
|  |   constexpr absl::string_view kC = "C"; | ||||||
|  | 
 | ||||||
|  |   builder::Graph graph; | ||||||
|  |   auto& foo = graph.AddNode("Foo"); | ||||||
|  |   graph.In(kA).SetName("a") >> foo.In(kA); | ||||||
|  |   graph.In(kB).SetName("b") >> foo.In(kB); | ||||||
|  |   foo.Out(kC).SetName("c") >> graph.Out(kC); | ||||||
|  | 
 | ||||||
|  |   CalculatorGraphConfig expected = | ||||||
|  |       mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( | ||||||
|  |         input_stream: "A:a" | ||||||
|  |         input_stream: "B:b" | ||||||
|  |         output_stream: "C:c" | ||||||
|  |         node { | ||||||
|  |           calculator: "Foo" | ||||||
|  |           input_stream: "A:a" | ||||||
|  |           input_stream: "B:b" | ||||||
|  |           output_stream: "C:c" | ||||||
|  |         } | ||||||
|  |       )pb"); | ||||||
|  |   EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| TEST(BuilderTest, GraphIndexes) { | TEST(BuilderTest, GraphIndexes) { | ||||||
|   builder::Graph graph; |   builder::Graph graph; | ||||||
|   auto& foo = graph.AddNode("Foo"); |   auto& foo = graph.AddNode("Foo"); | ||||||
|  | @ -326,52 +353,63 @@ TEST(BuilderTest, GraphIndexes) { | ||||||
| 
 | 
 | ||||||
| class AnyAndSameTypeCalculator : public NodeIntf { | class AnyAndSameTypeCalculator : public NodeIntf { | ||||||
|  public: |  public: | ||||||
|   static constexpr Input<AnyType> kAnyTypeInput{"INPUT"}; |   static constexpr Input<AnyType>::Optional kAnyTypeInput{"INPUT"}; | ||||||
|   static constexpr Output<AnyType> kAnyTypeOutput{"ANY_OUTPUT"}; |   static constexpr Output<AnyType>::Optional kAnyTypeOutput{"ANY_OUTPUT"}; | ||||||
|   static constexpr Output<SameType<kAnyTypeInput>> kSameTypeOutput{ |   static constexpr Output<SameType<kAnyTypeInput>>::Optional kSameTypeOutput{ | ||||||
|       "SAME_OUTPUT"}; |       "SAME_OUTPUT"}; | ||||||
|  |   static constexpr Output<SameType<kSameTypeOutput>> kRecursiveSameTypeOutput{ | ||||||
|  |       "RECURSIVE_SAME_OUTPUT"}; | ||||||
| 
 | 
 | ||||||
|   static constexpr Input<int> kIntInput{"INT_INPUT"}; |   static constexpr Input<int>::Optional kIntInput{"INT_INPUT"}; | ||||||
|   // `SameType` usage for this output is only for testing purposes.
 |   // `SameType` usage for this output is only for testing purposes.
 | ||||||
|   //
 |   //
 | ||||||
|   // `SameType` is designed to work with inputs of `AnyType` and, normally, you
 |   // `SameType` is designed to work with inputs of `AnyType` and, normally, you
 | ||||||
|   // would not use `Output<SameType<kIntInput>>` in a real calculator. You
 |   // would not use `Output<SameType<kIntInput>>` in a real calculator. You
 | ||||||
|   // should write `Output<int>` instead, since the type is known.
 |   // should write `Output<int>` instead, since the type is known.
 | ||||||
|   static constexpr Output<SameType<kIntInput>> kSameIntOutput{ |   static constexpr Output<SameType<kIntInput>>::Optional kSameIntOutput{ | ||||||
|       "SAME_INT_OUTPUT"}; |       "SAME_INT_OUTPUT"}; | ||||||
|  |   static constexpr Output<SameType<kSameIntOutput>> kRecursiveSameIntOutput{ | ||||||
|  |       "RECURSIVE_SAME_INT_OUTPUT"}; | ||||||
| 
 | 
 | ||||||
|   MEDIAPIPE_NODE_INTERFACE(AnyTypeCalculator, kAnyTypeInput, kAnyTypeOutput, |   MEDIAPIPE_NODE_INTERFACE(AnyAndSameTypeCalculator, kAnyTypeInput, | ||||||
|                            kSameTypeOutput); |                            kAnyTypeOutput, kSameTypeOutput); | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| TEST(BuilderTest, AnyAndSameTypeHandledProperly) { | TEST(BuilderTest, AnyAndSameTypeHandledProperly) { | ||||||
|   builder::Graph graph; |   builder::Graph graph; | ||||||
|   builder::Source<internal::Generic> any_input = |   builder::Source<AnyType> any_input = graph[Input<AnyType>{"GRAPH_ANY_INPUT"}]; | ||||||
|       graph[Input<AnyType>{"GRAPH_ANY_INPUT"}]; |  | ||||||
|   builder::Source<int> int_input = graph[Input<int>{"GRAPH_INT_INPUT"}]; |   builder::Source<int> int_input = graph[Input<int>{"GRAPH_INT_INPUT"}]; | ||||||
| 
 | 
 | ||||||
|   auto& node = graph.AddNode("AnyAndSameTypeCalculator"); |   auto& node = graph.AddNode("AnyAndSameTypeCalculator"); | ||||||
|   any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; |   any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; | ||||||
|   int_input >> node[AnyAndSameTypeCalculator::kIntInput]; |   int_input >> node[AnyAndSameTypeCalculator::kIntInput]; | ||||||
| 
 | 
 | ||||||
|   builder::Source<internal::Generic> any_type_output = |   builder::Source<AnyType> any_type_output = | ||||||
|       node[AnyAndSameTypeCalculator::kAnyTypeOutput]; |       node[AnyAndSameTypeCalculator::kAnyTypeOutput]; | ||||||
|   any_type_output.SetName("any_type_output"); |   any_type_output.SetName("any_type_output"); | ||||||
| 
 | 
 | ||||||
|   builder::Source<internal::Generic> same_type_output = |   builder::Source<AnyType> same_type_output = | ||||||
|       node[AnyAndSameTypeCalculator::kSameTypeOutput]; |       node[AnyAndSameTypeCalculator::kSameTypeOutput]; | ||||||
|   same_type_output.SetName("same_type_output"); |   same_type_output.SetName("same_type_output"); | ||||||
|   builder::Source<internal::Generic> same_int_output = |   builder::Source<AnyType> recursive_same_type_output = | ||||||
|  |       node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput]; | ||||||
|  |   recursive_same_type_output.SetName("recursive_same_type_output"); | ||||||
|  |   builder::Source<int> same_int_output = | ||||||
|       node[AnyAndSameTypeCalculator::kSameIntOutput]; |       node[AnyAndSameTypeCalculator::kSameIntOutput]; | ||||||
|   same_int_output.SetName("same_int_output"); |   same_int_output.SetName("same_int_output"); | ||||||
|  |   builder::Source<int> recursive_same_int_type_output = | ||||||
|  |       node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput]; | ||||||
|  |   recursive_same_int_type_output.SetName("recursive_same_int_type_output"); | ||||||
| 
 | 
 | ||||||
|   CalculatorGraphConfig expected = |   CalculatorGraphConfig expected = mediapipe::ParseTextProtoOrDie< | ||||||
|       mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( |       CalculatorGraphConfig>(R"pb( | ||||||
|     node { |     node { | ||||||
|       calculator: "AnyAndSameTypeCalculator" |       calculator: "AnyAndSameTypeCalculator" | ||||||
|       input_stream: "INPUT:__stream_0" |       input_stream: "INPUT:__stream_0" | ||||||
|       input_stream: "INT_INPUT:__stream_1" |       input_stream: "INT_INPUT:__stream_1" | ||||||
|       output_stream: "ANY_OUTPUT:any_type_output" |       output_stream: "ANY_OUTPUT:any_type_output" | ||||||
|  |       output_stream: "RECURSIVE_SAME_INT_OUTPUT:recursive_same_int_type_output" | ||||||
|  |       output_stream: "RECURSIVE_SAME_OUTPUT:recursive_same_type_output" | ||||||
|       output_stream: "SAME_INT_OUTPUT:same_int_output" |       output_stream: "SAME_INT_OUTPUT:same_int_output" | ||||||
|       output_stream: "SAME_OUTPUT:same_type_output" |       output_stream: "SAME_OUTPUT:same_type_output" | ||||||
|     } |     } | ||||||
|  | @ -381,6 +419,29 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) { | ||||||
|   EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); |   EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST(BuilderTest, AnyTypeCanBeCast) { | ||||||
|  |   builder::Graph graph; | ||||||
|  |   builder::Source<std::string> any_input = | ||||||
|  |       graph.In("GRAPH_ANY_INPUT").Cast<std::string>(); | ||||||
|  | 
 | ||||||
|  |   auto& node = graph.AddNode("AnyAndSameTypeCalculator"); | ||||||
|  |   any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput]; | ||||||
|  |   builder::Source<double> any_type_output = | ||||||
|  |       node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>(); | ||||||
|  |   any_type_output.SetName("any_type_output"); | ||||||
|  | 
 | ||||||
|  |   CalculatorGraphConfig expected = | ||||||
|  |       mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( | ||||||
|  |         node { | ||||||
|  |           calculator: "AnyAndSameTypeCalculator" | ||||||
|  |           input_stream: "INPUT:__stream_0" | ||||||
|  |           output_stream: "ANY_OUTPUT:any_type_output" | ||||||
|  |         } | ||||||
|  |         input_stream: "GRAPH_ANY_INPUT:__stream_0" | ||||||
|  |       )pb"); | ||||||
|  |   EXPECT_THAT(graph.GetConfig(), EqualsProto(expected)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace test
 | }  // namespace test
 | ||||||
| }  // namespace api2
 | }  // namespace api2
 | ||||||
| }  // namespace mediapipe
 | }  // namespace mediapipe
 | ||||||
|  |  | ||||||
|  | @ -27,9 +27,7 @@ using HolderBase = mediapipe::packet_internal::HolderBase; | ||||||
| template <typename T> | template <typename T> | ||||||
| class Packet; | class Packet; | ||||||
| 
 | 
 | ||||||
| struct DynamicType {}; | struct AnyType { | ||||||
| 
 |  | ||||||
| struct AnyType : public DynamicType { |  | ||||||
|   AnyType() = delete; |   AnyType() = delete; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -73,14 +73,12 @@ class SideOutputBase : public PortBase { | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| struct NoneType { | struct NoneType { | ||||||
|  private: |  | ||||||
|   NoneType() = delete; |   NoneType() = delete; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| template <auto& P> | template <auto& kP> | ||||||
| class SameType : public DynamicType { | struct SameType { | ||||||
|  public: |   static constexpr const decltype(kP)& kPort = kP; | ||||||
|   static constexpr const decltype(P)& kPort = P; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| class PacketTypeAccess; | class PacketTypeAccess; | ||||||
|  | @ -137,21 +135,28 @@ struct IsOneOf : std::false_type {}; | ||||||
| template <class... T> | template <class... T> | ||||||
| struct IsOneOf<OneOf<T...>> : std::true_type {}; | struct IsOneOf<OneOf<T...>> : std::true_type {}; | ||||||
| 
 | 
 | ||||||
| template <typename T, typename std::enable_if< | template <class T> | ||||||
|                           !std::is_base_of<DynamicType, T>{} && !IsOneOf<T>{}, | struct IsSameType : std::false_type {}; | ||||||
|  | 
 | ||||||
|  | template <class P, P& kP> | ||||||
|  | struct IsSameType<SameType<kP>> : std::true_type {}; | ||||||
|  | 
 | ||||||
|  | template <typename T, | ||||||
|  |           typename std::enable_if<!std::is_same<T, AnyType>{} && | ||||||
|  |                                       !IsOneOf<T>{} && !IsSameType<T>{}, | ||||||
|                                   int>::type = 0> |                                   int>::type = 0> | ||||||
| inline void SetType(CalculatorContract* cc, PacketType& pt) { | inline void SetType(CalculatorContract* cc, PacketType& pt) { | ||||||
|   pt.Set<T>(); |   pt.Set<T>(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <typename T, typename std::enable_if<std::is_base_of<DynamicType, T>{}, | template <typename T, typename std::enable_if<IsSameType<T>{}, int>::type = 0> | ||||||
|                                               int>::type = 0> |  | ||||||
| inline void SetType(CalculatorContract* cc, PacketType& pt) { | inline void SetType(CalculatorContract* cc, PacketType& pt) { | ||||||
|   pt.SetSameAs(&internal::GetCollection(cc, T::kPort).Tag(T::kPort.Tag())); |   pt.SetSameAs(&internal::GetCollection(cc, T::kPort).Tag(T::kPort.Tag())); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| template <> | template <typename T, | ||||||
| inline void SetType<AnyType>(CalculatorContract* cc, PacketType& pt) { |           typename std::enable_if<std::is_same<T, AnyType>{}, int>::type = 0> | ||||||
|  | inline void SetType(CalculatorContract* cc, PacketType& pt) { | ||||||
|   pt.SetAny(); |   pt.SetAny(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -289,15 +294,15 @@ struct SideBase<InputBase> { | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| // TODO: maybe return a PacketBase instead of a Packet<internal::Generic>?
 | // TODO: maybe return a PacketBase instead of a Packet<internal::Generic>?
 | ||||||
| template <typename T, class = void> | template <typename T, typename = void> | ||||||
| struct ActualPayloadType { | struct ActualPayloadType { | ||||||
|   using type = T; |   using type = T; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| template <typename T> | template <typename T> | ||||||
| struct ActualPayloadType< | struct ActualPayloadType<T, std::enable_if_t<IsSameType<T>{}, void>> { | ||||||
|     T, std::enable_if_t<std::is_base_of<DynamicType, T>{}, void>> { |   using type = typename ActualPayloadType< | ||||||
|   using type = internal::Generic; |       typename std::decay_t<decltype(T::kPort)>::value_t>::type; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace internal
 | }  // namespace internal
 | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| // Copyright 2019 The MediaPipe Authors.
 | // Copyright 2022 The MediaPipe Authors.
 | ||||||
| //
 | //
 | ||||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
| // you may not use this file except in compliance with the License.
 | // you may not use this file except in compliance with the License.
 | ||||||
|  | @ -85,15 +85,9 @@ cv::Mat MatView(const ImageFrame* image) { | ||||||
|   const size_t steps[] = {static_cast<size_t>(image->WidthStep()), |   const size_t steps[] = {static_cast<size_t>(image->WidthStep()), | ||||||
|                           static_cast<size_t>(image->ByteDepth())}; |                           static_cast<size_t>(image->ByteDepth())}; | ||||||
|   // Use ImageFrame to initialize in-place. ImageFrame still owns memory.
 |   // Use ImageFrame to initialize in-place. ImageFrame still owns memory.
 | ||||||
|   if (steps[0] == sizes[1] * image->NumberOfChannels() * image->ByteDepth()) { |   return cv::Mat(dims, sizes, type, const_cast<uint8_t*>(image->PixelData()), | ||||||
|     // Contiguous memory optimization. See b/78570764
 |  | ||||||
|     return cv::Mat(dims, sizes, type, const_cast<uint8*>(image->PixelData())); |  | ||||||
|   } else { |  | ||||||
|     // Custom width step.
 |  | ||||||
|     return cv::Mat(dims, sizes, type, const_cast<uint8*>(image->PixelData()), |  | ||||||
|                  steps); |                  steps); | ||||||
| } | } | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| }  // namespace formats
 | }  // namespace formats
 | ||||||
| }  // namespace mediapipe
 | }  // namespace mediapipe
 | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| // Copyright 2019 The MediaPipe Authors.
 | // Copyright 2022 The MediaPipe Authors.
 | ||||||
| //
 | //
 | ||||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
| // you may not use this file except in compliance with the License.
 | // you may not use this file except in compliance with the License.
 | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| // Copyright 2019 The MediaPipe Authors.
 | // Copyright 2022 The MediaPipe Authors.
 | ||||||
| //
 | //
 | ||||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
| // you may not use this file except in compliance with the License.
 | // you may not use this file except in compliance with the License.
 | ||||||
|  | @ -21,7 +21,6 @@ | ||||||
| #include "mediapipe/framework/port/logging.h" | #include "mediapipe/framework/port/logging.h" | ||||||
| 
 | 
 | ||||||
| namespace mediapipe { | namespace mediapipe { | ||||||
| 
 |  | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| // Set image_frame to a constant per-channel pix_value.
 | // Set image_frame to a constant per-channel pix_value.
 | ||||||
|  | @ -50,8 +49,8 @@ TEST(ImageFrameOpencvTest, ConvertToMat) { | ||||||
|   ImageFrame frame2(ImageFormat::GRAY8, i_width, i_height); |   ImageFrame frame2(ImageFormat::GRAY8, i_width, i_height); | ||||||
| 
 | 
 | ||||||
|   // Check adding constant images.
 |   // Check adding constant images.
 | ||||||
|   const uint8 frame1_val = 12; |   const uint8_t frame1_val = 12; | ||||||
|   const uint8 frame2_val = 34; |   const uint8_t frame2_val = 34; | ||||||
|   SetToColor<uint8>(&frame1_val, &frame1); |   SetToColor<uint8>(&frame1_val, &frame1); | ||||||
|   SetToColor<uint8>(&frame2_val, &frame2); |   SetToColor<uint8>(&frame2_val, &frame2); | ||||||
|   // Get Mat wrapper around ImageFrame memory (zero copy).
 |   // Get Mat wrapper around ImageFrame memory (zero copy).
 | ||||||
|  | @ -77,6 +76,37 @@ TEST(ImageFrameOpencvTest, ConvertToMat) { | ||||||
|   EXPECT_EQ(max_loc.y, i_height - 6); |   EXPECT_EQ(max_loc.y, i_height - 6); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST(ImageFrameOpencvTest, ConvertToIpl) { | ||||||
|  |   const int i_width = 123, i_height = 45; | ||||||
|  |   ImageFrame frame1(ImageFormat::GRAY8, i_width, i_height); | ||||||
|  |   ImageFrame frame2(ImageFormat::GRAY8, i_width, i_height); | ||||||
|  | 
 | ||||||
|  |   // Check adding constant images.
 | ||||||
|  |   const uint8_t frame1_val = 12; | ||||||
|  |   const uint8_t frame2_val = 34; | ||||||
|  |   SetToColor<uint8>(&frame1_val, &frame1); | ||||||
|  |   SetToColor<uint8>(&frame2_val, &frame2); | ||||||
|  |   const cv::Mat frame1_mat = formats::MatView(&frame1); | ||||||
|  |   const cv::Mat frame2_mat = formats::MatView(&frame2); | ||||||
|  |   const cv::Mat frame_sum = frame1_mat + frame2_mat; | ||||||
|  |   const auto frame_avg = static_cast<int>(cv::mean(frame_sum).val[0]); | ||||||
|  |   EXPECT_EQ(frame_avg, frame1_val + frame2_val); | ||||||
|  | 
 | ||||||
|  |   // Check setting min/max pixels.
 | ||||||
|  |   uint8* frame1_ptr = frame1.MutablePixelData(); | ||||||
|  |   frame1_ptr[(i_width - 5) + (i_height - 5) * frame1.WidthStep()] = 1; | ||||||
|  |   frame1_ptr[(i_width - 6) + (i_height - 6) * frame1.WidthStep()] = 100; | ||||||
|  |   double min, max; | ||||||
|  |   cv::Point min_loc, max_loc; | ||||||
|  |   cv::minMaxLoc(frame1_mat, &min, &max, &min_loc, &max_loc); | ||||||
|  |   EXPECT_EQ(min, 1); | ||||||
|  |   EXPECT_EQ(min_loc.x, i_width - 5); | ||||||
|  |   EXPECT_EQ(min_loc.y, i_height - 5); | ||||||
|  |   EXPECT_EQ(max, 100); | ||||||
|  |   EXPECT_EQ(max_loc.x, i_width - 6); | ||||||
|  |   EXPECT_EQ(max_loc.y, i_height - 6); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| TEST(ImageFrameOpencvTest, ImageFormats) { | TEST(ImageFrameOpencvTest, ImageFormats) { | ||||||
|   const int i_width = 123, i_height = 45; |   const int i_width = 123, i_height = 45; | ||||||
|   ImageFrame frame_g8(ImageFormat::GRAY8, i_width, i_height); |   ImageFrame frame_g8(ImageFormat::GRAY8, i_width, i_height); | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| // Copyright 2019 The MediaPipe Authors.
 | // Copyright 2022 The MediaPipe Authors.
 | ||||||
| //
 | //
 | ||||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
| // you may not use this file except in compliance with the License.
 | // you may not use this file except in compliance with the License.
 | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| // Copyright 2019-2020 The MediaPipe Authors.
 | // Copyright 2022 The MediaPipe Authors.
 | ||||||
| //
 | //
 | ||||||
| // Licensed under the Apache License, Version 2.0 (the "License");
 | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
| // you may not use this file except in compliance with the License.
 | // you may not use this file except in compliance with the License.
 | ||||||
|  |  | ||||||
|  | @ -37,26 +37,21 @@ namespace mediapipe { | ||||||
| bool IsPowerOfTwo(int v) { return (v & (v - 1)) == 0; } | bool IsPowerOfTwo(int v) { return (v & (v - 1)) == 0; } | ||||||
| 
 | 
 | ||||||
| int BhwcBatchFromShape(const Tensor::Shape& shape) { | int BhwcBatchFromShape(const Tensor::Shape& shape) { | ||||||
|   LOG_IF(FATAL, shape.dims.empty()) |   if (shape.dims.empty()) { | ||||||
|       << "Tensor::Shape must be non-empty to retrieve a named dimension"; |     return 1; | ||||||
|  |   } | ||||||
|   return shape.dims[0]; |   return shape.dims[0]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| int BhwcHeightFromShape(const Tensor::Shape& shape) { | int BhwcHeightFromShape(const Tensor::Shape& shape) { | ||||||
|   LOG_IF(FATAL, shape.dims.empty()) |  | ||||||
|       << "Tensor::Shape must be non-empty to retrieve a named dimension"; |  | ||||||
|   return shape.dims.size() < 4 ? 1 : shape.dims[shape.dims.size() - 3]; |   return shape.dims.size() < 4 ? 1 : shape.dims[shape.dims.size() - 3]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| int BhwcWidthFromShape(const Tensor::Shape& shape) { | int BhwcWidthFromShape(const Tensor::Shape& shape) { | ||||||
|   LOG_IF(FATAL, shape.dims.empty()) |  | ||||||
|       << "Tensor::Shape must be non-empty to retrieve a named dimension"; |  | ||||||
|   return shape.dims.size() < 3 ? 1 : shape.dims[shape.dims.size() - 2]; |   return shape.dims.size() < 3 ? 1 : shape.dims[shape.dims.size() - 2]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| int BhwcDepthFromShape(const Tensor::Shape& shape) { | int BhwcDepthFromShape(const Tensor::Shape& shape) { | ||||||
|   LOG_IF(FATAL, shape.dims.empty()) |  | ||||||
|       << "Tensor::Shape must be non-empty to retrieve a named dimension"; |  | ||||||
|   return shape.dims.size() < 2 ? 1 : shape.dims[shape.dims.size() - 1]; |   return shape.dims.size() < 2 ? 1 : shape.dims[shape.dims.size() - 1]; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -424,6 +419,11 @@ Tensor::Tensor(ElementType element_type, const Shape& shape, | ||||||
| 
 | 
 | ||||||
| #if MEDIAPIPE_METAL_ENABLED | #if MEDIAPIPE_METAL_ENABLED | ||||||
| void Tensor::Invalidate() { | void Tensor::Invalidate() { | ||||||
|  | #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 | ||||||
|  |   GLuint cleanup_gl_tex = GL_INVALID_INDEX; | ||||||
|  |   GLuint cleanup_gl_fb = GL_INVALID_INDEX; | ||||||
|  | #endif  // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
 | ||||||
|  |   { | ||||||
|     absl::MutexLock lock(&view_mutex_); |     absl::MutexLock lock(&view_mutex_); | ||||||
|     // If memory is allocated and not owned by the metal buffer.
 |     // If memory is allocated and not owned by the metal buffer.
 | ||||||
|     // TODO: Re-design cpu buffer memory management.
 |     // TODO: Re-design cpu buffer memory management.
 | ||||||
|  | @ -432,6 +432,23 @@ void Tensor::Invalidate() { | ||||||
|     } |     } | ||||||
|     metal_buffer_ = nil; |     metal_buffer_ = nil; | ||||||
|     cpu_buffer_ = nullptr; |     cpu_buffer_ = nullptr; | ||||||
|  | #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 | ||||||
|  |     // Don't need to wait for the resource to be deleted bacause if will be
 | ||||||
|  |     // released on last reference deletion inside the OpenGL driver.
 | ||||||
|  |     std::swap(cleanup_gl_tex, opengl_texture2d_); | ||||||
|  |     std::swap(cleanup_gl_fb, frame_buffer_); | ||||||
|  | #endif  // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
 | ||||||
|  |   } | ||||||
|  | #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 | ||||||
|  |   // Do not hold the view mutex while invoking GlContext::RunWithoutWaiting,
 | ||||||
|  |   // since that method may acquire the context's own lock.
 | ||||||
|  |   if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX) { | ||||||
|  |     gl_context_->RunWithoutWaiting([cleanup_gl_tex, cleanup_gl_fb]() { | ||||||
|  |       glDeleteTextures(1, &cleanup_gl_tex); | ||||||
|  |       glDeleteFramebuffers(1, &cleanup_gl_fb); | ||||||
|  |     }); | ||||||
|  |   } | ||||||
|  | #endif  // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #else | #else | ||||||
|  |  | ||||||
|  | @ -18,6 +18,7 @@ | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
| #include <functional> | #include <functional> | ||||||
| #include <initializer_list> | #include <initializer_list> | ||||||
|  | #include <numeric> | ||||||
| #include <tuple> | #include <tuple> | ||||||
| #include <type_traits> | #include <type_traits> | ||||||
| #include <utility> | #include <utility> | ||||||
|  | @ -89,15 +90,23 @@ class Tensor { | ||||||
| 
 | 
 | ||||||
|  public: |  public: | ||||||
|   // No resources are allocated here.
 |   // No resources are allocated here.
 | ||||||
|   enum class ElementType { kNone, kFloat16, kFloat32, kUInt8, kInt8, kInt32 }; |   enum class ElementType { | ||||||
|  |     kNone, | ||||||
|  |     kFloat16, | ||||||
|  |     kFloat32, | ||||||
|  |     kUInt8, | ||||||
|  |     kInt8, | ||||||
|  |     kInt32, | ||||||
|  |     // TODO: Update the inference runner to handle kTfLiteString.
 | ||||||
|  |     kChar | ||||||
|  |   }; | ||||||
|   struct Shape { |   struct Shape { | ||||||
|     Shape() = default; |     Shape() = default; | ||||||
|     Shape(std::initializer_list<int> dimensions) : dims(dimensions) {} |     Shape(std::initializer_list<int> dimensions) : dims(dimensions) {} | ||||||
|     Shape(const std::vector<int>& dimensions) : dims(dimensions) {} |     Shape(const std::vector<int>& dimensions) : dims(dimensions) {} | ||||||
|     int num_elements() const { |     int num_elements() const { | ||||||
|       int res = dims.empty() ? 0 : 1; |       return std::accumulate(dims.begin(), dims.end(), 1, | ||||||
|       std::for_each(dims.begin(), dims.end(), [&res](int i) { res *= i; }); |                              std::multiplies<int>()); | ||||||
|       return res; |  | ||||||
|     } |     } | ||||||
|     std::vector<int> dims; |     std::vector<int> dims; | ||||||
|   }; |   }; | ||||||
|  | @ -319,6 +328,8 @@ class Tensor { | ||||||
|         return 1; |         return 1; | ||||||
|       case ElementType::kInt32: |       case ElementType::kInt32: | ||||||
|         return sizeof(int32_t); |         return sizeof(int32_t); | ||||||
|  |       case ElementType::kChar: | ||||||
|  |         return sizeof(char); | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|   int bytes() const { return shape_.num_elements() * element_size(); } |   int bytes() const { return shape_.num_elements() * element_size(); } | ||||||
|  |  | ||||||
|  | @ -1,5 +1,8 @@ | ||||||
| #include "mediapipe/framework/formats/tensor.h" | #include "mediapipe/framework/formats/tensor.h" | ||||||
| 
 | 
 | ||||||
|  | #include <cstring> | ||||||
|  | #include <string> | ||||||
|  | 
 | ||||||
| #include "mediapipe/framework/port/gmock.h" | #include "mediapipe/framework/port/gmock.h" | ||||||
| #include "mediapipe/framework/port/gtest.h" | #include "mediapipe/framework/port/gtest.h" | ||||||
| #if !MEDIAPIPE_DISABLE_GPU | #if !MEDIAPIPE_DISABLE_GPU | ||||||
|  | @ -23,6 +26,9 @@ TEST(General, TestDataTypes) { | ||||||
| 
 | 
 | ||||||
|   Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape{4, 3, 2, 3}); |   Tensor t2(Tensor::ElementType::kFloat16, Tensor::Shape{4, 3, 2, 3}); | ||||||
|   EXPECT_EQ(t2.bytes(), t2.shape().num_elements() * 2); |   EXPECT_EQ(t2.bytes(), t2.shape().num_elements() * 2); | ||||||
|  | 
 | ||||||
|  |   Tensor t_char(Tensor::ElementType::kChar, Tensor::Shape{4}); | ||||||
|  |   EXPECT_EQ(t_char.bytes(), t_char.shape().num_elements() * sizeof(char)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| TEST(Cpu, TestMemoryAllocation) { | TEST(Cpu, TestMemoryAllocation) { | ||||||
|  |  | ||||||
|  | @ -15,7 +15,7 @@ def mediapipe_cc_test( | ||||||
|         platforms = ["linux", "android", "ios", "wasm"], |         platforms = ["linux", "android", "ios", "wasm"], | ||||||
|         exclude_platforms = None, |         exclude_platforms = None, | ||||||
|         # ios_unit_test arguments |         # ios_unit_test arguments | ||||||
|         ios_minimum_os_version = "9.0", |         ios_minimum_os_version = "11.0", | ||||||
|         # android_cc_test arguments |         # android_cc_test arguments | ||||||
|         open_gl_driver = None, |         open_gl_driver = None, | ||||||
|         emulator_mini_boot = True, |         emulator_mini_boot = True, | ||||||
|  |  | ||||||
|  | @ -108,6 +108,7 @@ cc_library( | ||||||
|         ":sharded_map", |         ":sharded_map", | ||||||
|         "//mediapipe/framework:calculator_cc_proto", |         "//mediapipe/framework:calculator_cc_proto", | ||||||
|         "//mediapipe/framework:calculator_profile_cc_proto", |         "//mediapipe/framework:calculator_profile_cc_proto", | ||||||
|  |         "//mediapipe/framework/port:file_helpers", | ||||||
|         "//mediapipe/framework/port:integral_types", |         "//mediapipe/framework/port:integral_types", | ||||||
|         "@com_google_absl//absl/memory", |         "@com_google_absl//absl/memory", | ||||||
|         "@com_google_absl//absl/types:optional", |         "@com_google_absl//absl/types:optional", | ||||||
|  |  | ||||||
|  | @ -22,6 +22,7 @@ | ||||||
| #include "absl/time/time.h" | #include "absl/time/time.h" | ||||||
| #include "mediapipe/framework/port/advanced_proto_lite_inc.h" | #include "mediapipe/framework/port/advanced_proto_lite_inc.h" | ||||||
| #include "mediapipe/framework/port/canonical_errors.h" | #include "mediapipe/framework/port/canonical_errors.h" | ||||||
|  | #include "mediapipe/framework/port/file_helpers.h" | ||||||
| #include "mediapipe/framework/port/logging.h" | #include "mediapipe/framework/port/logging.h" | ||||||
| #include "mediapipe/framework/port/proto_ns.h" | #include "mediapipe/framework/port/proto_ns.h" | ||||||
| #include "mediapipe/framework/port/re2.h" | #include "mediapipe/framework/port/re2.h" | ||||||
|  | @ -244,7 +245,16 @@ absl::Status GraphProfiler::Start(mediapipe::Executor* executor) { | ||||||
|       executor != nullptr) { |       executor != nullptr) { | ||||||
|     // Inform the user via logging the path to the trace logs.
 |     // Inform the user via logging the path to the trace logs.
 | ||||||
|     ASSIGN_OR_RETURN(std::string trace_log_path, GetTraceLogPath()); |     ASSIGN_OR_RETURN(std::string trace_log_path, GetTraceLogPath()); | ||||||
|  |     // Check that we can actually write to it.
 | ||||||
|  |     auto status = | ||||||
|  |         file::SetContents(absl::StrCat(trace_log_path, "trace_writing_check"), | ||||||
|  |                           "can write trace logs to this location"); | ||||||
|  |     if (status.ok()) { | ||||||
|       LOG(INFO) << "trace_log_path: " << trace_log_path; |       LOG(INFO) << "trace_log_path: " << trace_log_path; | ||||||
|  |     } else { | ||||||
|  |       LOG(ERROR) << "cannot write to trace_log_path: " << trace_log_path << ": " | ||||||
|  |                  << status; | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     is_running_ = true; |     is_running_ = true; | ||||||
|     executor->Schedule([this] { |     executor->Schedule([this] { | ||||||
|  |  | ||||||
|  | @ -5,7 +5,7 @@ | ||||||
| 
 | 
 | ||||||
| buffers: { | buffers: { | ||||||
|   size_kb: 150000 |   size_kb: 150000 | ||||||
|   fill_policy: DISCARD |   fill_policy: RING_BUFFER | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| data_sources: { | data_sources: { | ||||||
|  | @ -21,12 +21,17 @@ data_sources: { | ||||||
|       # - what is happening on each CPU at each moment |       # - what is happening on each CPU at each moment | ||||||
|       ftrace_events: "power/cpu_frequency" |       ftrace_events: "power/cpu_frequency" | ||||||
|       ftrace_events: "power/cpu_idle" |       ftrace_events: "power/cpu_idle" | ||||||
|  |       # TODO: CPU frequency does not show up without scheduling | ||||||
|       ftrace_events: "sched/sched_switch" |       ftrace_events: "sched/sched_switch" | ||||||
|       compact_sched { |       compact_sched { | ||||||
|         enabled: true |         enabled: true | ||||||
|       } |       } | ||||||
|  |       # GPU | ||||||
|  |       ftrace_events: "power/gpu_frequency" | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
| write_into_file: true | write_into_file: true | ||||||
| file_write_period_ms: 500 | file_write_period_ms: 500 | ||||||
|  | # b/243571696 Added to remove Perfetto timeouts when running benchmarks remotely. | ||||||
|  | duration_ms: 60000 | ||||||
|  |  | ||||||
|  | @ -821,6 +821,19 @@ cc_library( | ||||||
|     alwayslink = 1, |     alwayslink = 1, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | mediapipe_cc_test( | ||||||
|  |     name = "switch_demux_calculator_test", | ||||||
|  |     srcs = ["switch_demux_calculator_test.cc"], | ||||||
|  |     deps = [ | ||||||
|  |         ":container_util", | ||||||
|  |         ":switch_demux_calculator", | ||||||
|  |         "//mediapipe/framework:calculator_framework", | ||||||
|  |         "//mediapipe/framework/port:gtest_main", | ||||||
|  |         "//mediapipe/framework/port:logging", | ||||||
|  |         "@com_google_absl//absl/strings", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| cc_library( | cc_library( | ||||||
|     name = "switch_mux_calculator", |     name = "switch_mux_calculator", | ||||||
|     srcs = ["switch_mux_calculator.cc"], |     srcs = ["switch_mux_calculator.cc"], | ||||||
|  |  | ||||||
|  | @ -129,12 +129,12 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { | ||||||
|   // Relay side packets to all channels.
 |   // Relay side packets to all channels.
 | ||||||
|   // Note: This is necessary because Calculator::Open only proceeds when every
 |   // Note: This is necessary because Calculator::Open only proceeds when every
 | ||||||
|   // anticipated side-packet arrives.
 |   // anticipated side-packet arrives.
 | ||||||
|   int channel_count = tool::ChannelCount(cc->OutputSidePackets().TagMap()); |   int side_channel_count = tool::ChannelCount(cc->OutputSidePackets().TagMap()); | ||||||
|   for (const std::string& tag : ChannelTags(cc->OutputSidePackets().TagMap())) { |   for (const std::string& tag : ChannelTags(cc->OutputSidePackets().TagMap())) { | ||||||
|     int num_entries = cc->InputSidePackets().NumEntries(tag); |     int num_entries = cc->InputSidePackets().NumEntries(tag); | ||||||
|     for (int index = 0; index < num_entries; ++index) { |     for (int index = 0; index < num_entries; ++index) { | ||||||
|       Packet input = cc->InputSidePackets().Get(tag, index); |       Packet input = cc->InputSidePackets().Get(tag, index); | ||||||
|       for (int channel = 0; channel < channel_count; ++channel) { |       for (int channel = 0; channel < side_channel_count; ++channel) { | ||||||
|         std::string output_tag = tool::ChannelTag(tag, channel); |         std::string output_tag = tool::ChannelTag(tag, channel); | ||||||
|         auto output_id = cc->OutputSidePackets().GetId(output_tag, index); |         auto output_id = cc->OutputSidePackets().GetId(output_tag, index); | ||||||
|         if (output_id.IsValid()) { |         if (output_id.IsValid()) { | ||||||
|  | @ -143,6 +143,23 @@ absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  | 
 | ||||||
|  |   // Relay headers to all channels.
 | ||||||
|  |   int output_channel_count = tool::ChannelCount(cc->Outputs().TagMap()); | ||||||
|  |   for (const std::string& tag : ChannelTags(cc->Outputs().TagMap())) { | ||||||
|  |     int num_entries = cc->Inputs().NumEntries(tag); | ||||||
|  |     for (int index = 0; index < num_entries; ++index) { | ||||||
|  |       auto& input = cc->Inputs().Get(tag, index); | ||||||
|  |       if (input.Header().IsEmpty()) continue; | ||||||
|  |       for (int channel = 0; channel < output_channel_count; ++channel) { | ||||||
|  |         std::string output_tag = tool::ChannelTag(tag, channel); | ||||||
|  |         auto output_id = cc->Outputs().GetId(output_tag, index); | ||||||
|  |         if (output_id.IsValid()) { | ||||||
|  |           cc->Outputs().Get(output_tag, index).SetHeader(input.Header()); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|   return absl::OkStatus(); |   return absl::OkStatus(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										135
									
								
								mediapipe/framework/tool/switch_demux_calculator_test.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								mediapipe/framework/tool/switch_demux_calculator_test.cc
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,135 @@ | ||||||
|  | // Copyright 2022 The MediaPipe Authors.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #include <string> | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "absl/strings/str_cat.h" | ||||||
|  | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | #include "mediapipe/framework/port/gmock.h" | ||||||
|  | #include "mediapipe/framework/port/gtest.h" | ||||||
|  | #include "mediapipe/framework/port/logging.h" | ||||||
|  | #include "mediapipe/framework/tool/container_util.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | // Returns a CalculatorGraph to run a single calculator.
 | ||||||
|  | CalculatorGraph BuildCalculatorGraph(CalculatorGraphConfig::Node node_config) { | ||||||
|  |   CalculatorGraphConfig config; | ||||||
|  |   *config.add_node() = node_config; | ||||||
|  |   *config.mutable_input_stream() = node_config.input_stream(); | ||||||
|  |   *config.mutable_output_stream() = node_config.output_stream(); | ||||||
|  |   *config.mutable_input_side_packet() = node_config.input_side_packet(); | ||||||
|  |   *config.mutable_output_side_packet() = node_config.output_side_packet(); | ||||||
|  |   return CalculatorGraph(config); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Creates a string packet.
 | ||||||
|  | Packet pack(std::string data, int timestamp) { | ||||||
|  |   return MakePacket<std::string>(data).At(Timestamp(timestamp)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Creates an int packet.
 | ||||||
|  | Packet pack(int data, int timestamp) { | ||||||
|  |   return MakePacket<int>(data).At(Timestamp(timestamp)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Tests showing packet channel synchronization through SwitchDemuxCalculator.
 | ||||||
|  | class SwitchDemuxCalculatorTest : public ::testing::Test { | ||||||
|  |  protected: | ||||||
|  |   SwitchDemuxCalculatorTest() {} | ||||||
|  |   ~SwitchDemuxCalculatorTest() override {} | ||||||
|  |   void SetUp() override {} | ||||||
|  |   void TearDown() override {} | ||||||
|  | 
 | ||||||
|  |   // Defines a SwitchDemuxCalculator CalculatorGraphConfig::Node.
 | ||||||
|  |   CalculatorGraphConfig::Node BuildNodeConfig() { | ||||||
|  |     CalculatorGraphConfig::Node result; | ||||||
|  |     *result.mutable_calculator() = "SwitchDemuxCalculator"; | ||||||
|  |     *result.add_input_stream() = "SELECT:select"; | ||||||
|  |     for (int c = 0; c < 2; ++c) { | ||||||
|  |       *result.add_output_stream() = | ||||||
|  |           absl::StrCat(tool::ChannelTag("FRAME", c), ":frame_", c); | ||||||
|  |       *result.add_output_stream() = | ||||||
|  |           absl::StrCat(tool::ChannelTag("MASK", c), ":mask_", c); | ||||||
|  |     } | ||||||
|  |     *result.add_input_stream() = "FRAME:frame"; | ||||||
|  |     *result.add_input_stream() = "MASK:mask"; | ||||||
|  |     return result; | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | // Shows the SwitchMuxCalculator is available.
 | ||||||
|  | TEST_F(SwitchDemuxCalculatorTest, IsRegistered) { | ||||||
|  |   EXPECT_TRUE(CalculatorBaseRegistry::IsRegistered("SwitchDemuxCalculator")); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST_F(SwitchDemuxCalculatorTest, BasicDataFlow) { | ||||||
|  |   CalculatorGraphConfig::Node node_config = BuildNodeConfig(); | ||||||
|  |   CalculatorGraph graph = BuildCalculatorGraph(node_config); | ||||||
|  |   std::vector<Packet> output_frames0; | ||||||
|  |   EXPECT_TRUE(graph | ||||||
|  |                   .ObserveOutputStream("frame_0", | ||||||
|  |                                        [&](const Packet& p) { | ||||||
|  |                                          output_frames0.push_back(p); | ||||||
|  |                                          return absl::OkStatus(); | ||||||
|  |                                        }) | ||||||
|  |                   .ok()); | ||||||
|  |   std::vector<Packet> output_frames1; | ||||||
|  |   EXPECT_TRUE(graph | ||||||
|  |                   .ObserveOutputStream("frame_1", | ||||||
|  |                                        [&](const Packet& p) { | ||||||
|  |                                          output_frames1.push_back(p); | ||||||
|  |                                          return absl::OkStatus(); | ||||||
|  |                                        }) | ||||||
|  |                   .ok()); | ||||||
|  |   EXPECT_TRUE( | ||||||
|  |       graph.StartRun({}, {{"frame", MakePacket<std::string>("frame_header")}}) | ||||||
|  |           .ok()); | ||||||
|  | 
 | ||||||
|  |   // Finalize input for the "mask" input stream.
 | ||||||
|  |   EXPECT_TRUE(graph.CloseInputStream("mask").ok()); | ||||||
|  | 
 | ||||||
|  |   // Channel 0 is selected just before corresponding packets arrive.
 | ||||||
|  |   EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(0, 1)).ok()); | ||||||
|  |   EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(0, 10)).ok()); | ||||||
|  |   EXPECT_TRUE(graph.AddPacketToInputStream("frame", pack("p0_t10", 10)).ok()); | ||||||
|  |   EXPECT_TRUE(graph.WaitUntilIdle().ok()); | ||||||
|  |   EXPECT_EQ(output_frames0.size(), 1); | ||||||
|  |   EXPECT_EQ(output_frames1.size(), 0); | ||||||
|  |   EXPECT_EQ(output_frames0[0].Get<std::string>(), "p0_t10"); | ||||||
|  | 
 | ||||||
|  |   // Channel 1 is selected just before corresponding packets arrive.
 | ||||||
|  |   EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(1, 11)).ok()); | ||||||
|  |   EXPECT_TRUE(graph.AddPacketToInputStream("select", pack(1, 20)).ok()); | ||||||
|  |   EXPECT_TRUE(graph.AddPacketToInputStream("frame", pack("p1_t20", 20)).ok()); | ||||||
|  |   EXPECT_TRUE(graph.WaitUntilIdle().ok()); | ||||||
|  |   EXPECT_EQ(output_frames0.size(), 1); | ||||||
|  |   EXPECT_EQ(output_frames1.size(), 1); | ||||||
|  |   EXPECT_EQ(output_frames1[0].Get<std::string>(), "p1_t20"); | ||||||
|  | 
 | ||||||
|  |   EXPECT_EQ( | ||||||
|  |       graph.FindOutputStreamManager("frame_0")->Header().Get<std::string>(), | ||||||
|  |       "frame_header"); | ||||||
|  |   EXPECT_EQ( | ||||||
|  |       graph.FindOutputStreamManager("frame_1")->Header().Get<std::string>(), | ||||||
|  |       "frame_header"); | ||||||
|  | 
 | ||||||
|  |   EXPECT_TRUE(graph.CloseAllPacketSources().ok()); | ||||||
|  |   EXPECT_TRUE(graph.WaitUntilDone().ok()); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
|  | @ -271,6 +271,7 @@ cc_library( | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":gpu_buffer_format", |         ":gpu_buffer_format", | ||||||
|         ":gpu_buffer_storage", |         ":gpu_buffer_storage", | ||||||
|  |         "@com_google_absl//absl/strings", | ||||||
|         "//mediapipe/framework/formats:image_frame", |         "//mediapipe/framework/formats:image_frame", | ||||||
|         "//mediapipe/framework/port:logging", |         "//mediapipe/framework/port:logging", | ||||||
|         ":gpu_buffer_storage_image_frame", |         ":gpu_buffer_storage_image_frame", | ||||||
|  | @ -366,6 +367,23 @@ cc_library( | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "gpu_buffer_storage_ahwb", | ||||||
|  |     srcs = ["gpu_buffer_storage_ahwb.cc"], | ||||||
|  |     hdrs = ["gpu_buffer_storage_ahwb.h"], | ||||||
|  |     linkopts = select({ | ||||||
|  |         "//conditions:default": [], | ||||||
|  |         "//mediapipe:android": [ | ||||||
|  |             "-landroid", | ||||||
|  |         ], | ||||||
|  |     }), | ||||||
|  |     deps = [ | ||||||
|  |         ":gpu_buffer_format", | ||||||
|  |         ":gpu_buffer_storage", | ||||||
|  |         "@com_google_absl//absl/strings:str_format", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| mediapipe_proto_library( | mediapipe_proto_library( | ||||||
|     name = "gpu_origin_proto", |     name = "gpu_origin_proto", | ||||||
|     srcs = ["gpu_origin.proto"], |     srcs = ["gpu_origin.proto"], | ||||||
|  | @ -1087,3 +1105,19 @@ ios_unit_test( | ||||||
|     ], |     ], | ||||||
|     deps = [":gl_ios_test_lib"], |     deps = [":gl_ios_test_lib"], | ||||||
| ) | ) | ||||||
|  | 
 | ||||||
|  | mediapipe_cc_test( | ||||||
|  |     name = "gpu_buffer_storage_ahwb_test", | ||||||
|  |     size = "small", | ||||||
|  |     srcs = ["gpu_buffer_storage_ahwb_test.cc"], | ||||||
|  |     exclude_platforms = [ | ||||||
|  |         "ios", | ||||||
|  |         "wasm", | ||||||
|  |     ], | ||||||
|  |     requires_full_emulation = True, | ||||||
|  |     deps = [ | ||||||
|  |         ":gpu_buffer_format", | ||||||
|  |         ":gpu_buffer_storage_ahwb", | ||||||
|  |         "//mediapipe/framework/port:gtest_main", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | @ -620,7 +620,9 @@ class GlSyncWrapper { | ||||||
| #endif | #endif | ||||||
|     GLenum result = glClientWaitSync(sync_, flags, timeout); |     GLenum result = glClientWaitSync(sync_, flags, timeout); | ||||||
|     if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) { |     if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) { | ||||||
|       Clear(); |       // TODO: we could clear at this point so later calls are faster,
 | ||||||
|  |       // but we need to do so in a thread-safe way.
 | ||||||
|  |       // Clear();
 | ||||||
|     } |     } | ||||||
|     // TODO: do something if the wait fails?
 |     // TODO: do something if the wait fails?
 | ||||||
|   } |   } | ||||||
|  | @ -646,7 +648,9 @@ class GlSyncWrapper { | ||||||
| #endif | #endif | ||||||
|     GLenum result = glClientWaitSync(sync_, flags, 0); |     GLenum result = glClientWaitSync(sync_, flags, 0); | ||||||
|     if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) { |     if (result == GL_ALREADY_SIGNALED || result == GL_CONDITION_SATISFIED) { | ||||||
|       Clear(); |       // TODO: we could clear at this point so later calls are faster,
 | ||||||
|  |       // but we need to do so in a thread-safe way.
 | ||||||
|  |       // Clear();
 | ||||||
|       return true; |       return true; | ||||||
|     } |     } | ||||||
|     return false; |     return false; | ||||||
|  | @ -822,10 +826,17 @@ std::shared_ptr<GlSyncPoint> GlContext::CreateSyncToken() { | ||||||
|   return token; |   return token; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | bool GlContext::IsAnyContextCurrent() { | ||||||
|  |   ContextBinding ctx; | ||||||
|  |   GetCurrentContextBinding(&ctx); | ||||||
|  |   return ctx.context != kPlatformGlContextNone; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| std::shared_ptr<GlSyncPoint> | std::shared_ptr<GlSyncPoint> | ||||||
| GlContext::CreateSyncTokenForCurrentExternalContext( | GlContext::CreateSyncTokenForCurrentExternalContext( | ||||||
|     const std::shared_ptr<GlContext>& delegate_graph_context) { |     const std::shared_ptr<GlContext>& delegate_graph_context) { | ||||||
|   CHECK(delegate_graph_context); |   CHECK(delegate_graph_context); | ||||||
|  |   if (!IsAnyContextCurrent()) return nullptr; | ||||||
|   if (delegate_graph_context->ShouldUseFenceSync()) { |   if (delegate_graph_context->ShouldUseFenceSync()) { | ||||||
|     return std::shared_ptr<GlSyncPoint>( |     return std::shared_ptr<GlSyncPoint>( | ||||||
|         new GlExternalFenceSyncPoint(delegate_graph_context)); |         new GlExternalFenceSyncPoint(delegate_graph_context)); | ||||||
|  |  | ||||||
|  | @ -303,6 +303,10 @@ class GlContext : public std::enable_shared_from_this<GlContext> { | ||||||
|     return *static_cast<T*>(entry.get()); |     return *static_cast<T*>(entry.get()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   // Returns true if any GL context, including external contexts not managed by
 | ||||||
|  |   // the GlContext class, is current.
 | ||||||
|  |   static bool IsAnyContextCurrent(); | ||||||
|  | 
 | ||||||
|   // Creates a synchronization token for the current, non-GlContext-owned
 |   // Creates a synchronization token for the current, non-GlContext-owned
 | ||||||
|   // context. This can be passed to MediaPipe so it can synchronize with the
 |   // context. This can be passed to MediaPipe so it can synchronize with the
 | ||||||
|   // commands issued in the external context up to this point.
 |   // commands issued in the external context up to this point.
 | ||||||
|  |  | ||||||
|  | @ -145,9 +145,13 @@ bool GlTextureBuffer::CreateInternal(const void* data, int alignment) { | ||||||
|     CHECK_NE(name_, 0); |     CHECK_NE(name_, 0); | ||||||
|     GLuint name_to_delete = name_; |     GLuint name_to_delete = name_; | ||||||
|     context->RunWithoutWaiting([name_to_delete, sync_token]() { |     context->RunWithoutWaiting([name_to_delete, sync_token]() { | ||||||
|  |       if (sync_token) { | ||||||
|         // TODO: maybe we do not actually have to wait for the
 |         // TODO: maybe we do not actually have to wait for the
 | ||||||
|         // consumer sync here. Check docs.
 |         // consumer sync here. Check docs.
 | ||||||
|         sync_token->WaitOnGpu(); |         sync_token->WaitOnGpu(); | ||||||
|  |       } else { | ||||||
|  |         LOG_FIRST_N(WARNING, 5) << "unexpected null sync in deletion_callback"; | ||||||
|  |       } | ||||||
|       DLOG_IF(ERROR, !glIsTexture(name_to_delete)) |       DLOG_IF(ERROR, !glIsTexture(name_to_delete)) | ||||||
|           << "Deleting invalid texture id: " << name_to_delete; |           << "Deleting invalid texture id: " << name_to_delete; | ||||||
|       glDeleteTextures(1, &name_to_delete); |       glDeleteTextures(1, &name_to_delete); | ||||||
|  | @ -179,13 +183,19 @@ void GlTextureBuffer::Reuse() { | ||||||
| void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) { | void GlTextureBuffer::Updated(std::shared_ptr<GlSyncPoint> prod_token) { | ||||||
|   CHECK(!producer_sync_) |   CHECK(!producer_sync_) | ||||||
|       << "Updated existing texture which had not been marked for reuse!"; |       << "Updated existing texture which had not been marked for reuse!"; | ||||||
|  |   CHECK(prod_token); | ||||||
|   producer_sync_ = std::move(prod_token); |   producer_sync_ = std::move(prod_token); | ||||||
|   producer_context_ = producer_sync_->GetContext(); |   producer_context_ = producer_sync_->GetContext(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const { | void GlTextureBuffer::DidRead(std::shared_ptr<GlSyncPoint> cons_token) const { | ||||||
|   absl::MutexLock lock(&consumer_sync_mutex_); |   absl::MutexLock lock(&consumer_sync_mutex_); | ||||||
|  |   if (cons_token) { | ||||||
|     consumer_multi_sync_->Add(std::move(cons_token)); |     consumer_multi_sync_->Add(std::move(cons_token)); | ||||||
|  |   } else { | ||||||
|  |     // TODO: change to a CHECK.
 | ||||||
|  |     LOG_FIRST_N(WARNING, 5) << "unexpected null sync in DidRead"; | ||||||
|  |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| GlTextureBuffer::~GlTextureBuffer() { | GlTextureBuffer::~GlTextureBuffer() { | ||||||
|  |  | ||||||
|  | @ -2,6 +2,8 @@ | ||||||
| 
 | 
 | ||||||
| #include <memory> | #include <memory> | ||||||
| 
 | 
 | ||||||
|  | #include "absl/strings/str_cat.h" | ||||||
|  | #include "absl/strings/str_join.h" | ||||||
| #include "mediapipe/framework/port/logging.h" | #include "mediapipe/framework/port/logging.h" | ||||||
| 
 | 
 | ||||||
| #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER | #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER | ||||||
|  | @ -10,6 +12,23 @@ | ||||||
| 
 | 
 | ||||||
| namespace mediapipe { | namespace mediapipe { | ||||||
| 
 | 
 | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | struct StorageTypeFormatter { | ||||||
|  |   void operator()(std::string* out, | ||||||
|  |                   const std::shared_ptr<internal::GpuBufferStorage>& s) const { | ||||||
|  |     absl::StrAppend(out, s->storage_type().name()); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | std::string GpuBuffer::DebugString() const { | ||||||
|  |   return absl::StrCat("GpuBuffer[", | ||||||
|  |                       absl::StrJoin(storages_, ", ", StorageTypeFormatter()), | ||||||
|  |                       "]"); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| internal::GpuBufferStorage& GpuBuffer::GetStorageForView( | internal::GpuBufferStorage& GpuBuffer::GetStorageForView( | ||||||
|     TypeId view_provider_type, bool for_writing) const { |     TypeId view_provider_type, bool for_writing) const { | ||||||
|   const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr; |   const std::shared_ptr<internal::GpuBufferStorage>* chosen_storage = nullptr; | ||||||
|  | @ -52,7 +71,10 @@ internal::GpuBufferStorage& GpuBuffer::GetStorageForView( | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   CHECK(chosen_storage) << "no view provider found"; |   CHECK(chosen_storage) << "no view provider found for requested view " | ||||||
|  |                         << view_provider_type.name() << "; storages available: " | ||||||
|  |                         << absl::StrJoin(storages_, ", ", | ||||||
|  |                                          StorageTypeFormatter()); | ||||||
|   DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type)); |   DCHECK((*chosen_storage)->can_down_cast_to(view_provider_type)); | ||||||
|   return **chosen_storage; |   return **chosen_storage; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -129,6 +129,8 @@ class GpuBuffer { | ||||||
|     return nullptr; |     return nullptr; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   std::string DebugString() const; | ||||||
|  | 
 | ||||||
|  private: |  private: | ||||||
|   class PlaceholderGpuBufferStorage |   class PlaceholderGpuBufferStorage | ||||||
|       : public internal::GpuBufferStorageImpl<PlaceholderGpuBufferStorage> { |       : public internal::GpuBufferStorageImpl<PlaceholderGpuBufferStorage> { | ||||||
|  |  | ||||||
|  | @ -21,18 +21,29 @@ licenses(["notice"]) | ||||||
| 
 | 
 | ||||||
| package(default_visibility = ["//visibility:public"]) | package(default_visibility = ["//visibility:public"]) | ||||||
| 
 | 
 | ||||||
|  | mediapipe_simple_subgraph( | ||||||
|  |     name = "pose_landmarks_to_render_data", | ||||||
|  |     graph = "pose_landmarks_to_render_data.pbtxt", | ||||||
|  |     register_as = "PoseLandmarksToRenderData", | ||||||
|  |     deps = [ | ||||||
|  |         "//mediapipe/calculators/core:concatenate_vector_calculator", | ||||||
|  |         "//mediapipe/calculators/core:split_proto_list_calculator", | ||||||
|  |         "//mediapipe/calculators/util:landmarks_to_render_data_calculator", | ||||||
|  |         "//mediapipe/calculators/util:rect_to_render_scale_calculator", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| mediapipe_simple_subgraph( | mediapipe_simple_subgraph( | ||||||
|     name = "pose_renderer_gpu", |     name = "pose_renderer_gpu", | ||||||
|     graph = "pose_renderer_gpu.pbtxt", |     graph = "pose_renderer_gpu.pbtxt", | ||||||
|     register_as = "PoseRendererGpu", |     register_as = "PoseRendererGpu", | ||||||
|     deps = [ |     deps = [ | ||||||
|         "//mediapipe/calculators/core:split_proto_list_calculator", |         ":pose_landmarks_to_render_data", | ||||||
|  |         "//mediapipe/calculators/image:image_properties_calculator", | ||||||
|         "//mediapipe/calculators/image:recolor_calculator", |         "//mediapipe/calculators/image:recolor_calculator", | ||||||
|         "//mediapipe/calculators/util:annotation_overlay_calculator", |         "//mediapipe/calculators/util:annotation_overlay_calculator", | ||||||
|         "//mediapipe/calculators/util:detections_to_render_data_calculator", |         "//mediapipe/calculators/util:detections_to_render_data_calculator", | ||||||
|         "//mediapipe/calculators/util:landmarks_to_render_data_calculator", |  | ||||||
|         "//mediapipe/calculators/util:rect_to_render_data_calculator", |         "//mediapipe/calculators/util:rect_to_render_data_calculator", | ||||||
|         "//mediapipe/calculators/util:rect_to_render_scale_calculator", |  | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -41,12 +52,11 @@ mediapipe_simple_subgraph( | ||||||
|     graph = "pose_renderer_cpu.pbtxt", |     graph = "pose_renderer_cpu.pbtxt", | ||||||
|     register_as = "PoseRendererCpu", |     register_as = "PoseRendererCpu", | ||||||
|     deps = [ |     deps = [ | ||||||
|         "//mediapipe/calculators/core:split_proto_list_calculator", |         ":pose_landmarks_to_render_data", | ||||||
|  |         "//mediapipe/calculators/image:image_properties_calculator", | ||||||
|         "//mediapipe/calculators/image:recolor_calculator", |         "//mediapipe/calculators/image:recolor_calculator", | ||||||
|         "//mediapipe/calculators/util:annotation_overlay_calculator", |         "//mediapipe/calculators/util:annotation_overlay_calculator", | ||||||
|         "//mediapipe/calculators/util:detections_to_render_data_calculator", |         "//mediapipe/calculators/util:detections_to_render_data_calculator", | ||||||
|         "//mediapipe/calculators/util:landmarks_to_render_data_calculator", |  | ||||||
|         "//mediapipe/calculators/util:rect_to_render_data_calculator", |         "//mediapipe/calculators/util:rect_to_render_data_calculator", | ||||||
|         "//mediapipe/calculators/util:rect_to_render_scale_calculator", |  | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,236 @@ | ||||||
|  | # MediaPipe pose landmarks to render data subgraph. | ||||||
|  | 
 | ||||||
|  | type: "PoseLandmarksToRenderData" | ||||||
|  | 
 | ||||||
|  | # Pose landmarks. (NormalizedLandmarkList) | ||||||
|  | input_stream: "LANDMARKS:pose_landmarks" | ||||||
|  | # Region of interest calculated based on landmarks. (NormalizedRect) | ||||||
|  | input_stream: "ROI:roi" | ||||||
|  | # Image size. (pair<int, int>) | ||||||
|  | input_stream: "IMAGE_SIZE:image_size" | ||||||
|  | 
 | ||||||
|  | # The resulting render data. (vector<RenderData>) | ||||||
|  | output_stream: "RENDER_DATA:merged_render_data" | ||||||
|  | 
 | ||||||
|  | # Calculates rendering scale based on the pose roi. | ||||||
|  | node { | ||||||
|  |   calculator: "RectToRenderScaleCalculator" | ||||||
|  |   input_stream: "NORM_RECT:roi" | ||||||
|  |   input_stream: "IMAGE_SIZE:image_size" | ||||||
|  |   output_stream: "RENDER_SCALE:render_scale" | ||||||
|  |   node_options: { | ||||||
|  |     [type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] { | ||||||
|  |       multiplier: 0.0012 | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | node { | ||||||
|  |   calculator: "SplitNormalizedLandmarkListCalculator" | ||||||
|  |   input_stream: "pose_landmarks" | ||||||
|  |   output_stream: "visible_pose_landmarks" | ||||||
|  |   node_options: { | ||||||
|  |     [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { | ||||||
|  |       ranges: { begin: 0 end: 25 } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | # Converts landmarks to drawing primitives for annotation overlay. | ||||||
|  | node { | ||||||
|  |   calculator: "LandmarksToRenderDataCalculator" | ||||||
|  |   input_stream: "NORM_LANDMARKS:pose_landmarks" | ||||||
|  |   input_stream: "RENDER_SCALE:render_scale" | ||||||
|  |   output_stream: "RENDER_DATA:landmarks_render_data" | ||||||
|  |   node_options: { | ||||||
|  |     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { | ||||||
|  |       landmark_connections: 0 | ||||||
|  |       landmark_connections: 1 | ||||||
|  |       landmark_connections: 1 | ||||||
|  |       landmark_connections: 2 | ||||||
|  |       landmark_connections: 2 | ||||||
|  |       landmark_connections: 3 | ||||||
|  |       landmark_connections: 3 | ||||||
|  |       landmark_connections: 7 | ||||||
|  |       landmark_connections: 0 | ||||||
|  |       landmark_connections: 4 | ||||||
|  |       landmark_connections: 4 | ||||||
|  |       landmark_connections: 5 | ||||||
|  |       landmark_connections: 5 | ||||||
|  |       landmark_connections: 6 | ||||||
|  |       landmark_connections: 6 | ||||||
|  |       landmark_connections: 8 | ||||||
|  |       landmark_connections: 9 | ||||||
|  |       landmark_connections: 10 | ||||||
|  |       landmark_connections: 11 | ||||||
|  |       landmark_connections: 12 | ||||||
|  |       landmark_connections: 11 | ||||||
|  |       landmark_connections: 13 | ||||||
|  |       landmark_connections: 13 | ||||||
|  |       landmark_connections: 15 | ||||||
|  |       landmark_connections: 15 | ||||||
|  |       landmark_connections: 17 | ||||||
|  |       landmark_connections: 15 | ||||||
|  |       landmark_connections: 19 | ||||||
|  |       landmark_connections: 15 | ||||||
|  |       landmark_connections: 21 | ||||||
|  |       landmark_connections: 17 | ||||||
|  |       landmark_connections: 19 | ||||||
|  |       landmark_connections: 12 | ||||||
|  |       landmark_connections: 14 | ||||||
|  |       landmark_connections: 14 | ||||||
|  |       landmark_connections: 16 | ||||||
|  |       landmark_connections: 16 | ||||||
|  |       landmark_connections: 18 | ||||||
|  |       landmark_connections: 16 | ||||||
|  |       landmark_connections: 20 | ||||||
|  |       landmark_connections: 16 | ||||||
|  |       landmark_connections: 22 | ||||||
|  |       landmark_connections: 18 | ||||||
|  |       landmark_connections: 20 | ||||||
|  |       landmark_connections: 11 | ||||||
|  |       landmark_connections: 23 | ||||||
|  |       landmark_connections: 12 | ||||||
|  |       landmark_connections: 24 | ||||||
|  |       landmark_connections: 23 | ||||||
|  |       landmark_connections: 24 | ||||||
|  |       landmark_connections: 23 | ||||||
|  |       landmark_connections: 25 | ||||||
|  |       landmark_connections: 24 | ||||||
|  |       landmark_connections: 26 | ||||||
|  |       landmark_connections: 25 | ||||||
|  |       landmark_connections: 27 | ||||||
|  |       landmark_connections: 26 | ||||||
|  |       landmark_connections: 28 | ||||||
|  |       landmark_connections: 27 | ||||||
|  |       landmark_connections: 29 | ||||||
|  |       landmark_connections: 28 | ||||||
|  |       landmark_connections: 30 | ||||||
|  |       landmark_connections: 29 | ||||||
|  |       landmark_connections: 31 | ||||||
|  |       landmark_connections: 30 | ||||||
|  |       landmark_connections: 32 | ||||||
|  |       landmark_connections: 27 | ||||||
|  |       landmark_connections: 31 | ||||||
|  |       landmark_connections: 28 | ||||||
|  |       landmark_connections: 32 | ||||||
|  | 
 | ||||||
|  |       landmark_color { r: 255 g: 255 b: 255 } | ||||||
|  |       connection_color { r: 255 g: 255 b: 255 } | ||||||
|  |       thickness: 3.0 | ||||||
|  |       visualize_landmark_depth: false | ||||||
|  |       utilize_visibility: true | ||||||
|  |       visibility_threshold: 0.5 | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | # Take left pose landmarks. | ||||||
|  | node { | ||||||
|  |   calculator: "SplitNormalizedLandmarkListCalculator" | ||||||
|  |   input_stream: "pose_landmarks" | ||||||
|  |   output_stream: "landmarks_left_side" | ||||||
|  |   node_options: { | ||||||
|  |     [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { | ||||||
|  |       ranges: { begin: 1 end: 4 } | ||||||
|  |       ranges: { begin: 7 end: 8 } | ||||||
|  |       ranges: { begin: 9 end: 10 } | ||||||
|  |       ranges: { begin: 11 end: 12 } | ||||||
|  |       ranges: { begin: 13 end: 14 } | ||||||
|  |       ranges: { begin: 15 end: 16 } | ||||||
|  |       ranges: { begin: 17 end: 18 } | ||||||
|  |       ranges: { begin: 19 end: 20 } | ||||||
|  |       ranges: { begin: 21 end: 22 } | ||||||
|  |       ranges: { begin: 23 end: 24 } | ||||||
|  | 
 | ||||||
|  |       combine_outputs: true | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | # Take right pose landmarks. | ||||||
|  | node { | ||||||
|  |   calculator: "SplitNormalizedLandmarkListCalculator" | ||||||
|  |   input_stream: "pose_landmarks" | ||||||
|  |   output_stream: "landmarks_right_side" | ||||||
|  |   node_options: { | ||||||
|  |     [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { | ||||||
|  |       ranges: { begin: 4 end: 7 } | ||||||
|  |       ranges: { begin: 8 end: 9 } | ||||||
|  |       ranges: { begin: 10 end: 11 } | ||||||
|  |       ranges: { begin: 12 end: 13 } | ||||||
|  |       ranges: { begin: 14 end: 15 } | ||||||
|  |       ranges: { begin: 16 end: 17 } | ||||||
|  |       ranges: { begin: 18 end: 19 } | ||||||
|  |       ranges: { begin: 20 end: 21 } | ||||||
|  |       ranges: { begin: 22 end: 23 } | ||||||
|  |       ranges: { begin: 24 end: 25 } | ||||||
|  | 
 | ||||||
|  |       combine_outputs: true | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | # Render pose joints as big white circles. | ||||||
|  | node { | ||||||
|  |   calculator: "LandmarksToRenderDataCalculator" | ||||||
|  |   input_stream: "NORM_LANDMARKS:visible_pose_landmarks" | ||||||
|  |   input_stream: "RENDER_SCALE:render_scale" | ||||||
|  |   output_stream: "RENDER_DATA:landmarks_background_joints_render_data" | ||||||
|  |   node_options: { | ||||||
|  |     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { | ||||||
|  |       landmark_color { r: 255 g: 255 b: 255 } | ||||||
|  |       connection_color { r: 255 g: 255 b: 255 } | ||||||
|  |       thickness: 5.0 | ||||||
|  |       visualize_landmark_depth: false | ||||||
|  |       utilize_visibility: true | ||||||
|  |       visibility_threshold: 0.5 | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | # Render pose left side joints as orange circles (inside white ones). | ||||||
|  | node { | ||||||
|  |   calculator: "LandmarksToRenderDataCalculator" | ||||||
|  |   input_stream: "NORM_LANDMARKS:landmarks_left_side" | ||||||
|  |   input_stream: "RENDER_SCALE:render_scale" | ||||||
|  |   output_stream: "RENDER_DATA:landmarks_left_joints_render_data" | ||||||
|  |   node_options: { | ||||||
|  |     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { | ||||||
|  |       landmark_color { r: 255 g: 138 b: 0 } | ||||||
|  |       connection_color { r: 255 g: 138 b: 0 } | ||||||
|  |       thickness: 3.0 | ||||||
|  |       visualize_landmark_depth: false | ||||||
|  |       utilize_visibility: true | ||||||
|  |       visibility_threshold: 0.5 | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | # Render pose right side joints as cyan circles (inside white ones). | ||||||
|  | node { | ||||||
|  |   calculator: "LandmarksToRenderDataCalculator" | ||||||
|  |   input_stream: "NORM_LANDMARKS:landmarks_right_side" | ||||||
|  |   input_stream: "RENDER_SCALE:render_scale" | ||||||
|  |   output_stream: "RENDER_DATA:landmarks_right_joints_render_data" | ||||||
|  |   node_options: { | ||||||
|  |     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { | ||||||
|  |       landmark_color { r: 0 g: 217 b: 231 } | ||||||
|  |       connection_color { r: 0 g: 217 b: 231 } | ||||||
|  |       thickness: 3.0 | ||||||
|  |       visualize_landmark_depth: false | ||||||
|  |       utilize_visibility: true | ||||||
|  |       visibility_threshold: 0.5 | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | # Merges annotations into one result. | ||||||
|  | node { | ||||||
|  |   calculator: "ConcatenateRenderDataVectorCalculator" | ||||||
|  |   input_stream: "landmarks_render_data" | ||||||
|  |   input_stream: "landmarks_background_joints_render_data" | ||||||
|  |   input_stream: "landmarks_left_joints_render_data" | ||||||
|  |   input_stream: "landmarks_right_joints_render_data" | ||||||
|  |   output_stream: "merged_render_data" | ||||||
|  | } | ||||||
|  | @ -22,19 +22,6 @@ node { | ||||||
|   output_stream: "SIZE:image_size" |   output_stream: "SIZE:image_size" | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| # Calculates rendering scale based on the pose roi. |  | ||||||
| node { |  | ||||||
|   calculator: "RectToRenderScaleCalculator" |  | ||||||
|   input_stream: "NORM_RECT:roi" |  | ||||||
|   input_stream: "IMAGE_SIZE:image_size" |  | ||||||
|   output_stream: "RENDER_SCALE:render_scale" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] { |  | ||||||
|       multiplier: 0.0012 |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Converts detections to drawing primitives for annotation overlay. | # Converts detections to drawing primitives for annotation overlay. | ||||||
| node { | node { | ||||||
|   calculator: "DetectionsToRenderDataCalculator" |   calculator: "DetectionsToRenderDataCalculator" | ||||||
|  | @ -48,204 +35,13 @@ node { | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | # Computes render data for landmarks. | ||||||
| node { | node { | ||||||
|   calculator: "SplitNormalizedLandmarkListCalculator" |   calculator: "PoseLandmarksToRenderData" | ||||||
|   input_stream: "pose_landmarks" |   input_stream: "LANDMARKS:pose_landmarks" | ||||||
|   output_stream: "visible_pose_landmarks" |   input_stream: "ROI:roi" | ||||||
|   node_options: { |   input_stream: "IMAGE_SIZE:image_size" | ||||||
|     [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { |  | ||||||
|       ranges: { begin: 0 end: 25 } |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Converts landmarks to drawing primitives for annotation overlay. |  | ||||||
| node { |  | ||||||
|   calculator: "LandmarksToRenderDataCalculator" |  | ||||||
|   input_stream: "NORM_LANDMARKS:pose_landmarks" |  | ||||||
|   input_stream: "RENDER_SCALE:render_scale" |  | ||||||
|   output_stream: "RENDER_DATA:landmarks_render_data" |   output_stream: "RENDER_DATA:landmarks_render_data" | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { |  | ||||||
|       landmark_connections: 0 |  | ||||||
|       landmark_connections: 1 |  | ||||||
|       landmark_connections: 1 |  | ||||||
|       landmark_connections: 2 |  | ||||||
|       landmark_connections: 2 |  | ||||||
|       landmark_connections: 3 |  | ||||||
|       landmark_connections: 3 |  | ||||||
|       landmark_connections: 7 |  | ||||||
|       landmark_connections: 0 |  | ||||||
|       landmark_connections: 4 |  | ||||||
|       landmark_connections: 4 |  | ||||||
|       landmark_connections: 5 |  | ||||||
|       landmark_connections: 5 |  | ||||||
|       landmark_connections: 6 |  | ||||||
|       landmark_connections: 6 |  | ||||||
|       landmark_connections: 8 |  | ||||||
|       landmark_connections: 9 |  | ||||||
|       landmark_connections: 10 |  | ||||||
|       landmark_connections: 11 |  | ||||||
|       landmark_connections: 12 |  | ||||||
|       landmark_connections: 11 |  | ||||||
|       landmark_connections: 13 |  | ||||||
|       landmark_connections: 13 |  | ||||||
|       landmark_connections: 15 |  | ||||||
|       landmark_connections: 15 |  | ||||||
|       landmark_connections: 17 |  | ||||||
|       landmark_connections: 15 |  | ||||||
|       landmark_connections: 19 |  | ||||||
|       landmark_connections: 15 |  | ||||||
|       landmark_connections: 21 |  | ||||||
|       landmark_connections: 17 |  | ||||||
|       landmark_connections: 19 |  | ||||||
|       landmark_connections: 12 |  | ||||||
|       landmark_connections: 14 |  | ||||||
|       landmark_connections: 14 |  | ||||||
|       landmark_connections: 16 |  | ||||||
|       landmark_connections: 16 |  | ||||||
|       landmark_connections: 18 |  | ||||||
|       landmark_connections: 16 |  | ||||||
|       landmark_connections: 20 |  | ||||||
|       landmark_connections: 16 |  | ||||||
|       landmark_connections: 22 |  | ||||||
|       landmark_connections: 18 |  | ||||||
|       landmark_connections: 20 |  | ||||||
|       landmark_connections: 11 |  | ||||||
|       landmark_connections: 23 |  | ||||||
|       landmark_connections: 12 |  | ||||||
|       landmark_connections: 24 |  | ||||||
|       landmark_connections: 23 |  | ||||||
|       landmark_connections: 24 |  | ||||||
|       landmark_connections: 23 |  | ||||||
|       landmark_connections: 25 |  | ||||||
|       landmark_connections: 24 |  | ||||||
|       landmark_connections: 26 |  | ||||||
|       landmark_connections: 25 |  | ||||||
|       landmark_connections: 27 |  | ||||||
|       landmark_connections: 26 |  | ||||||
|       landmark_connections: 28 |  | ||||||
|       landmark_connections: 27 |  | ||||||
|       landmark_connections: 29 |  | ||||||
|       landmark_connections: 28 |  | ||||||
|       landmark_connections: 30 |  | ||||||
|       landmark_connections: 29 |  | ||||||
|       landmark_connections: 31 |  | ||||||
|       landmark_connections: 30 |  | ||||||
|       landmark_connections: 32 |  | ||||||
|       landmark_connections: 27 |  | ||||||
|       landmark_connections: 31 |  | ||||||
|       landmark_connections: 28 |  | ||||||
|       landmark_connections: 32 |  | ||||||
| 
 |  | ||||||
|       landmark_color { r: 255 g: 255 b: 255 } |  | ||||||
|       connection_color { r: 255 g: 255 b: 255 } |  | ||||||
|       thickness: 3.0 |  | ||||||
|       visualize_landmark_depth: false |  | ||||||
|       utilize_visibility: true |  | ||||||
|       visibility_threshold: 0.5 |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Take left pose landmarks. |  | ||||||
| node { |  | ||||||
|   calculator: "SplitNormalizedLandmarkListCalculator" |  | ||||||
|   input_stream: "pose_landmarks" |  | ||||||
|   output_stream: "landmarks_left_side" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { |  | ||||||
|       ranges: { begin: 1 end: 4 } |  | ||||||
|       ranges: { begin: 7 end: 8 } |  | ||||||
|       ranges: { begin: 9 end: 10 } |  | ||||||
|       ranges: { begin: 11 end: 12 } |  | ||||||
|       ranges: { begin: 13 end: 14 } |  | ||||||
|       ranges: { begin: 15 end: 16 } |  | ||||||
|       ranges: { begin: 17 end: 18 } |  | ||||||
|       ranges: { begin: 19 end: 20 } |  | ||||||
|       ranges: { begin: 21 end: 22 } |  | ||||||
|       ranges: { begin: 23 end: 24 } |  | ||||||
| 
 |  | ||||||
|       combine_outputs: true |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Take right pose landmarks. |  | ||||||
| node { |  | ||||||
|   calculator: "SplitNormalizedLandmarkListCalculator" |  | ||||||
|   input_stream: "pose_landmarks" |  | ||||||
|   output_stream: "landmarks_right_side" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { |  | ||||||
|       ranges: { begin: 4 end: 7 } |  | ||||||
|       ranges: { begin: 8 end: 9 } |  | ||||||
|       ranges: { begin: 10 end: 11 } |  | ||||||
|       ranges: { begin: 12 end: 13 } |  | ||||||
|       ranges: { begin: 14 end: 15 } |  | ||||||
|       ranges: { begin: 16 end: 17 } |  | ||||||
|       ranges: { begin: 18 end: 19 } |  | ||||||
|       ranges: { begin: 20 end: 21 } |  | ||||||
|       ranges: { begin: 22 end: 23 } |  | ||||||
|       ranges: { begin: 24 end: 25 } |  | ||||||
| 
 |  | ||||||
|       combine_outputs: true |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Render pose joints as big white circles. |  | ||||||
| node { |  | ||||||
|   calculator: "LandmarksToRenderDataCalculator" |  | ||||||
|   input_stream: "NORM_LANDMARKS:visible_pose_landmarks" |  | ||||||
|   input_stream: "RENDER_SCALE:render_scale" |  | ||||||
|   output_stream: "RENDER_DATA:landmarks_background_joints_render_data" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { |  | ||||||
|       landmark_color { r: 255 g: 255 b: 255 } |  | ||||||
|       connection_color { r: 255 g: 255 b: 255 } |  | ||||||
|       thickness: 5.0 |  | ||||||
|       visualize_landmark_depth: false |  | ||||||
|       utilize_visibility: true |  | ||||||
|       visibility_threshold: 0.5 |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Render pose left side joints as orange circles (inside white ones). |  | ||||||
| node { |  | ||||||
|   calculator: "LandmarksToRenderDataCalculator" |  | ||||||
|   input_stream: "NORM_LANDMARKS:landmarks_left_side" |  | ||||||
|   input_stream: "RENDER_SCALE:render_scale" |  | ||||||
|   output_stream: "RENDER_DATA:landmarks_left_joints_render_data" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { |  | ||||||
|       landmark_color { r: 255 g: 138 b: 0 } |  | ||||||
|       connection_color { r: 255 g: 138 b: 0 } |  | ||||||
|       thickness: 3.0 |  | ||||||
|       visualize_landmark_depth: false |  | ||||||
|       utilize_visibility: true |  | ||||||
|       visibility_threshold: 0.5 |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Render pose right side joints as cyan circles (inside white ones). |  | ||||||
| node { |  | ||||||
|   calculator: "LandmarksToRenderDataCalculator" |  | ||||||
|   input_stream: "NORM_LANDMARKS:landmarks_right_side" |  | ||||||
|   input_stream: "RENDER_SCALE:render_scale" |  | ||||||
|   output_stream: "RENDER_DATA:landmarks_right_joints_render_data" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { |  | ||||||
|       landmark_color { r: 0 g: 217 b: 231 } |  | ||||||
|       connection_color { r: 0 g: 217 b: 231 } |  | ||||||
|       thickness: 3.0 |  | ||||||
|       visualize_landmark_depth: false |  | ||||||
|       utilize_visibility: true |  | ||||||
|       visibility_threshold: 0.5 |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| # Converts normalized rects to drawing primitives for annotation overlay. | # Converts normalized rects to drawing primitives for annotation overlay. | ||||||
|  | @ -283,10 +79,7 @@ node { | ||||||
|   calculator: "AnnotationOverlayCalculator" |   calculator: "AnnotationOverlayCalculator" | ||||||
|   input_stream: "IMAGE:segmented_image" |   input_stream: "IMAGE:segmented_image" | ||||||
|   input_stream: "detection_render_data" |   input_stream: "detection_render_data" | ||||||
|   input_stream: "landmarks_render_data" |   input_stream: "VECTOR:landmarks_render_data" | ||||||
|   input_stream: "landmarks_background_joints_render_data" |  | ||||||
|   input_stream: "landmarks_left_joints_render_data" |  | ||||||
|   input_stream: "landmarks_right_joints_render_data" |  | ||||||
|   input_stream: "roi_render_data" |   input_stream: "roi_render_data" | ||||||
|   output_stream: "IMAGE:output_image" |   output_stream: "IMAGE:output_image" | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -22,19 +22,6 @@ node { | ||||||
|   output_stream: "SIZE:image_size" |   output_stream: "SIZE:image_size" | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| # Calculates rendering scale based on the pose roi. |  | ||||||
| node { |  | ||||||
|   calculator: "RectToRenderScaleCalculator" |  | ||||||
|   input_stream: "NORM_RECT:roi" |  | ||||||
|   input_stream: "IMAGE_SIZE:image_size" |  | ||||||
|   output_stream: "RENDER_SCALE:render_scale" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.RectToRenderScaleCalculatorOptions] { |  | ||||||
|       multiplier: 0.0012 |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Converts detections to drawing primitives for annotation overlay. | # Converts detections to drawing primitives for annotation overlay. | ||||||
| node { | node { | ||||||
|   calculator: "DetectionsToRenderDataCalculator" |   calculator: "DetectionsToRenderDataCalculator" | ||||||
|  | @ -48,204 +35,13 @@ node { | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | # Computes render data for landmarks. | ||||||
| node { | node { | ||||||
|   calculator: "SplitNormalizedLandmarkListCalculator" |   calculator: "PoseLandmarksToRenderData" | ||||||
|   input_stream: "pose_landmarks" |   input_stream: "LANDMARKS:pose_landmarks" | ||||||
|   output_stream: "visible_pose_landmarks" |   input_stream: "ROI:roi" | ||||||
|   node_options: { |   input_stream: "IMAGE_SIZE:image_size" | ||||||
|     [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { |  | ||||||
|       ranges: { begin: 0 end: 25 } |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Converts landmarks to drawing primitives for annotation overlay. |  | ||||||
| node { |  | ||||||
|   calculator: "LandmarksToRenderDataCalculator" |  | ||||||
|   input_stream: "NORM_LANDMARKS:pose_landmarks" |  | ||||||
|   input_stream: "RENDER_SCALE:render_scale" |  | ||||||
|   output_stream: "RENDER_DATA:landmarks_render_data" |   output_stream: "RENDER_DATA:landmarks_render_data" | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { |  | ||||||
|       landmark_connections: 0 |  | ||||||
|       landmark_connections: 1 |  | ||||||
|       landmark_connections: 1 |  | ||||||
|       landmark_connections: 2 |  | ||||||
|       landmark_connections: 2 |  | ||||||
|       landmark_connections: 3 |  | ||||||
|       landmark_connections: 3 |  | ||||||
|       landmark_connections: 7 |  | ||||||
|       landmark_connections: 0 |  | ||||||
|       landmark_connections: 4 |  | ||||||
|       landmark_connections: 4 |  | ||||||
|       landmark_connections: 5 |  | ||||||
|       landmark_connections: 5 |  | ||||||
|       landmark_connections: 6 |  | ||||||
|       landmark_connections: 6 |  | ||||||
|       landmark_connections: 8 |  | ||||||
|       landmark_connections: 9 |  | ||||||
|       landmark_connections: 10 |  | ||||||
|       landmark_connections: 11 |  | ||||||
|       landmark_connections: 12 |  | ||||||
|       landmark_connections: 11 |  | ||||||
|       landmark_connections: 13 |  | ||||||
|       landmark_connections: 13 |  | ||||||
|       landmark_connections: 15 |  | ||||||
|       landmark_connections: 15 |  | ||||||
|       landmark_connections: 17 |  | ||||||
|       landmark_connections: 15 |  | ||||||
|       landmark_connections: 19 |  | ||||||
|       landmark_connections: 15 |  | ||||||
|       landmark_connections: 21 |  | ||||||
|       landmark_connections: 17 |  | ||||||
|       landmark_connections: 19 |  | ||||||
|       landmark_connections: 12 |  | ||||||
|       landmark_connections: 14 |  | ||||||
|       landmark_connections: 14 |  | ||||||
|       landmark_connections: 16 |  | ||||||
|       landmark_connections: 16 |  | ||||||
|       landmark_connections: 18 |  | ||||||
|       landmark_connections: 16 |  | ||||||
|       landmark_connections: 20 |  | ||||||
|       landmark_connections: 16 |  | ||||||
|       landmark_connections: 22 |  | ||||||
|       landmark_connections: 18 |  | ||||||
|       landmark_connections: 20 |  | ||||||
|       landmark_connections: 11 |  | ||||||
|       landmark_connections: 23 |  | ||||||
|       landmark_connections: 12 |  | ||||||
|       landmark_connections: 24 |  | ||||||
|       landmark_connections: 23 |  | ||||||
|       landmark_connections: 24 |  | ||||||
|       landmark_connections: 23 |  | ||||||
|       landmark_connections: 25 |  | ||||||
|       landmark_connections: 24 |  | ||||||
|       landmark_connections: 26 |  | ||||||
|       landmark_connections: 25 |  | ||||||
|       landmark_connections: 27 |  | ||||||
|       landmark_connections: 26 |  | ||||||
|       landmark_connections: 28 |  | ||||||
|       landmark_connections: 27 |  | ||||||
|       landmark_connections: 29 |  | ||||||
|       landmark_connections: 28 |  | ||||||
|       landmark_connections: 30 |  | ||||||
|       landmark_connections: 29 |  | ||||||
|       landmark_connections: 31 |  | ||||||
|       landmark_connections: 30 |  | ||||||
|       landmark_connections: 32 |  | ||||||
|       landmark_connections: 27 |  | ||||||
|       landmark_connections: 31 |  | ||||||
|       landmark_connections: 28 |  | ||||||
|       landmark_connections: 32 |  | ||||||
| 
 |  | ||||||
|       landmark_color { r: 255 g: 255 b: 255 } |  | ||||||
|       connection_color { r: 255 g: 255 b: 255 } |  | ||||||
|       thickness: 3.0 |  | ||||||
|       visualize_landmark_depth: false |  | ||||||
|       utilize_visibility: true |  | ||||||
|       visibility_threshold: 0.5 |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Take left pose landmarks. |  | ||||||
| node { |  | ||||||
|   calculator: "SplitNormalizedLandmarkListCalculator" |  | ||||||
|   input_stream: "pose_landmarks" |  | ||||||
|   output_stream: "landmarks_left_side" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { |  | ||||||
|       ranges: { begin: 1 end: 4 } |  | ||||||
|       ranges: { begin: 7 end: 8 } |  | ||||||
|       ranges: { begin: 9 end: 10 } |  | ||||||
|       ranges: { begin: 11 end: 12 } |  | ||||||
|       ranges: { begin: 13 end: 14 } |  | ||||||
|       ranges: { begin: 15 end: 16 } |  | ||||||
|       ranges: { begin: 17 end: 18 } |  | ||||||
|       ranges: { begin: 19 end: 20 } |  | ||||||
|       ranges: { begin: 21 end: 22 } |  | ||||||
|       ranges: { begin: 23 end: 24 } |  | ||||||
| 
 |  | ||||||
|       combine_outputs: true |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Take right pose landmarks. |  | ||||||
| node { |  | ||||||
|   calculator: "SplitNormalizedLandmarkListCalculator" |  | ||||||
|   input_stream: "pose_landmarks" |  | ||||||
|   output_stream: "landmarks_right_side" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { |  | ||||||
|       ranges: { begin: 4 end: 7 } |  | ||||||
|       ranges: { begin: 8 end: 9 } |  | ||||||
|       ranges: { begin: 10 end: 11 } |  | ||||||
|       ranges: { begin: 12 end: 13 } |  | ||||||
|       ranges: { begin: 14 end: 15 } |  | ||||||
|       ranges: { begin: 16 end: 17 } |  | ||||||
|       ranges: { begin: 18 end: 19 } |  | ||||||
|       ranges: { begin: 20 end: 21 } |  | ||||||
|       ranges: { begin: 22 end: 23 } |  | ||||||
|       ranges: { begin: 24 end: 25 } |  | ||||||
| 
 |  | ||||||
|       combine_outputs: true |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Render pose joints as big white circles. |  | ||||||
| node { |  | ||||||
|   calculator: "LandmarksToRenderDataCalculator" |  | ||||||
|   input_stream: "NORM_LANDMARKS:visible_pose_landmarks" |  | ||||||
|   input_stream: "RENDER_SCALE:render_scale" |  | ||||||
|   output_stream: "RENDER_DATA:landmarks_background_joints_render_data" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { |  | ||||||
|       landmark_color { r: 255 g: 255 b: 255 } |  | ||||||
|       connection_color { r: 255 g: 255 b: 255 } |  | ||||||
|       thickness: 5.0 |  | ||||||
|       visualize_landmark_depth: false |  | ||||||
|       utilize_visibility: true |  | ||||||
|       visibility_threshold: 0.5 |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Render pose left side joints as orange circles (inside white ones). |  | ||||||
| node { |  | ||||||
|   calculator: "LandmarksToRenderDataCalculator" |  | ||||||
|   input_stream: "NORM_LANDMARKS:landmarks_left_side" |  | ||||||
|   input_stream: "RENDER_SCALE:render_scale" |  | ||||||
|   output_stream: "RENDER_DATA:landmarks_left_joints_render_data" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { |  | ||||||
|       landmark_color { r: 255 g: 138 b: 0 } |  | ||||||
|       connection_color { r: 255 g: 138 b: 0 } |  | ||||||
|       thickness: 3.0 |  | ||||||
|       visualize_landmark_depth: false |  | ||||||
|       utilize_visibility: true |  | ||||||
|       visibility_threshold: 0.5 |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| # Render pose right side joints as cyan circles (inside white ones). |  | ||||||
| node { |  | ||||||
|   calculator: "LandmarksToRenderDataCalculator" |  | ||||||
|   input_stream: "NORM_LANDMARKS:landmarks_right_side" |  | ||||||
|   input_stream: "RENDER_SCALE:render_scale" |  | ||||||
|   output_stream: "RENDER_DATA:landmarks_right_joints_render_data" |  | ||||||
|   node_options: { |  | ||||||
|     [type.googleapis.com/mediapipe.LandmarksToRenderDataCalculatorOptions] { |  | ||||||
|       landmark_color { r: 0 g: 217 b: 231 } |  | ||||||
|       connection_color { r: 0 g: 217 b: 231 } |  | ||||||
|       thickness: 3.0 |  | ||||||
|       visualize_landmark_depth: false |  | ||||||
|       utilize_visibility: true |  | ||||||
|       visibility_threshold: 0.5 |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| # Converts normalized rects to drawing primitives for annotation overlay. | # Converts normalized rects to drawing primitives for annotation overlay. | ||||||
|  | @ -283,10 +79,7 @@ node { | ||||||
|   calculator: "AnnotationOverlayCalculator" |   calculator: "AnnotationOverlayCalculator" | ||||||
|   input_stream: "IMAGE_GPU:segmented_image" |   input_stream: "IMAGE_GPU:segmented_image" | ||||||
|   input_stream: "detection_render_data" |   input_stream: "detection_render_data" | ||||||
|   input_stream: "landmarks_render_data" |   input_stream: "VECTOR:landmarks_render_data" | ||||||
|   input_stream: "landmarks_background_joints_render_data" |  | ||||||
|   input_stream: "landmarks_left_joints_render_data" |  | ||||||
|   input_stream: "landmarks_right_joints_render_data" |  | ||||||
|   input_stream: "roi_render_data" |   input_stream: "roi_render_data" | ||||||
|   output_stream: "IMAGE_GPU:output_image" |   output_stream: "IMAGE_GPU:output_image" | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -174,6 +174,14 @@ public class ExternalTextureConverter implements TextureFrameProducer { | ||||||
|     thread.setRotation(rotation); |     thread.setRotation(rotation); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   /** | ||||||
|  |    * Sets whether the timestamps of each frame should be adjusted to be always monotonically | ||||||
|  |    * increasing. The default behavior is that this is {@code true}. | ||||||
|  |    */ | ||||||
|  |   public void setShouldAdjustTimestamps(boolean shouldAdjustTimestamps) { | ||||||
|  |     thread.setShouldAdjustTimestamps(shouldAdjustTimestamps); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   /** |   /** | ||||||
|    * Sets an offset that can be used to adjust the timestamps on the camera frames, for example to |    * Sets an offset that can be used to adjust the timestamps on the camera frames, for example to | ||||||
|    * conform to a preferred time-base or to account for a known device latency. The offset is added |    * conform to a preferred time-base or to account for a known device latency. The offset is added | ||||||
|  | @ -298,6 +306,7 @@ public class ExternalTextureConverter implements TextureFrameProducer { | ||||||
|     private int bufferPoolMaxSize; |     private int bufferPoolMaxSize; | ||||||
| 
 | 
 | ||||||
|     private ExternalTextureRenderer renderer = null; |     private ExternalTextureRenderer renderer = null; | ||||||
|  |     private boolean shouldAdjustTimestamps = true; | ||||||
|     private long nextFrameTimestampOffset = 0; |     private long nextFrameTimestampOffset = 0; | ||||||
|     private long timestampOffsetNanos = 0; |     private long timestampOffsetNanos = 0; | ||||||
|     private long previousTimestamp = 0; |     private long previousTimestamp = 0; | ||||||
|  | @ -433,6 +442,10 @@ public class ExternalTextureConverter implements TextureFrameProducer { | ||||||
|       super.releaseGl(); // This releases the EGL context, so must do it after any GL calls. |       super.releaseGl(); // This releases the EGL context, so must do it after any GL calls. | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     public void setShouldAdjustTimestamps(boolean shouldAdjustTimestamps) { | ||||||
|  |       this.shouldAdjustTimestamps = shouldAdjustTimestamps; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     public void setTimestampOffsetNanos(long offsetInNanos) { |     public void setTimestampOffsetNanos(long offsetInNanos) { | ||||||
|       timestampOffsetNanos = offsetInNanos; |       timestampOffsetNanos = offsetInNanos; | ||||||
|     } |     } | ||||||
|  | @ -565,7 +578,8 @@ public class ExternalTextureConverter implements TextureFrameProducer { | ||||||
|       // |nextFrameTimestampOffset| to ensure that timestamps increase monotonically.) |       // |nextFrameTimestampOffset| to ensure that timestamps increase monotonically.) | ||||||
|       long textureTimestamp = |       long textureTimestamp = | ||||||
|           (surfaceTexture.getTimestamp() + timestampOffsetNanos) / NANOS_PER_MICRO; |           (surfaceTexture.getTimestamp() + timestampOffsetNanos) / NANOS_PER_MICRO; | ||||||
|       if (previousTimestampValid |       if (shouldAdjustTimestamps | ||||||
|  |           && previousTimestampValid | ||||||
|           && textureTimestamp + nextFrameTimestampOffset <= previousTimestamp) { |           && textureTimestamp + nextFrameTimestampOffset <= previousTimestamp) { | ||||||
|         nextFrameTimestampOffset = previousTimestamp + 1 - textureTimestamp; |         nextFrameTimestampOffset = previousTimestamp + 1 - textureTimestamp; | ||||||
|       } |       } | ||||||
|  |  | ||||||
|  | @ -15,6 +15,10 @@ | ||||||
| package com.google.mediapipe.framework; | package com.google.mediapipe.framework; | ||||||
| 
 | 
 | ||||||
| import android.graphics.Bitmap; | import android.graphics.Bitmap; | ||||||
|  | import com.google.mediapipe.framework.image.BitmapExtractor; | ||||||
|  | import com.google.mediapipe.framework.image.ByteBufferExtractor; | ||||||
|  | import com.google.mediapipe.framework.image.Image; | ||||||
|  | import com.google.mediapipe.framework.image.ImageProperties; | ||||||
| import java.nio.ByteBuffer; | import java.nio.ByteBuffer; | ||||||
| 
 | 
 | ||||||
| // TODO: use Preconditions in this file. | // TODO: use Preconditions in this file. | ||||||
|  | @ -55,6 +59,50 @@ public class AndroidPacketCreator extends PacketCreator { | ||||||
|     return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); |     return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   /** | ||||||
|  |    * Creates an Image packet from an {@link Image}. | ||||||
|  |    * | ||||||
|  |    * <p>The ImageContainerType must be IMAGE_CONTAINER_BYTEBUFFER or IMAGE_CONTAINER_BITMAP. | ||||||
|  |    */ | ||||||
|  |   public Packet createImage(Image image) { | ||||||
|  |     // TODO: Choose the best storage from multiple containers. | ||||||
|  |     ImageProperties properties = image.getContainedImageProperties().get(0); | ||||||
|  |     if (properties.getStorageType() == Image.STORAGE_TYPE_BYTEBUFFER) { | ||||||
|  |       ByteBuffer buffer = ByteBufferExtractor.extract(image); | ||||||
|  |       int numChannels = 0; | ||||||
|  |       switch (properties.getImageFormat()) { | ||||||
|  |         case Image.IMAGE_FORMAT_RGBA: | ||||||
|  |           numChannels = 4; | ||||||
|  |           break; | ||||||
|  |         case Image.IMAGE_FORMAT_RGB: | ||||||
|  |           numChannels = 3; | ||||||
|  |           break; | ||||||
|  |         case Image.IMAGE_FORMAT_ALPHA: | ||||||
|  |           numChannels = 1; | ||||||
|  |           break; | ||||||
|  |         default: // fall out | ||||||
|  |       } | ||||||
|  |       if (numChannels == 0) { | ||||||
|  |         throw new UnsupportedOperationException( | ||||||
|  |             "Unsupported MediaPipe Image image format: " + properties.getImageFormat()); | ||||||
|  |       } | ||||||
|  |       int width = image.getWidth(); | ||||||
|  |       int height = image.getHeight(); | ||||||
|  |       return createImage(buffer, width, height, numChannels); | ||||||
|  |     } | ||||||
|  |     if (properties.getImageFormat() == Image.STORAGE_TYPE_BITMAP) { | ||||||
|  |       Bitmap bitmap = BitmapExtractor.extract(image); | ||||||
|  |       if (bitmap.getConfig() != Bitmap.Config.ARGB_8888) { | ||||||
|  |         throw new UnsupportedOperationException("bitmap must use ARGB_8888 config."); | ||||||
|  |       } | ||||||
|  |       return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Unsupported type. | ||||||
|  |     throw new UnsupportedOperationException( | ||||||
|  |         "Unsupported Image container type: " + properties.getImageFormat()); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   /** |   /** | ||||||
|    * Returns the native handle of a new internal::PacketWithContext object on success. Returns 0 on |    * Returns the native handle of a new internal::PacketWithContext object on success. Returns 0 on | ||||||
|    * failure. |    * failure. | ||||||
|  |  | ||||||
|  | @ -57,6 +57,7 @@ android_library( | ||||||
|     ], |     ], | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":android_core", |         ":android_core", | ||||||
|  |         "//mediapipe/java/com/google/mediapipe/framework/image", | ||||||
|         "//third_party:androidx_annotation", |         "//third_party:androidx_annotation", | ||||||
|         "//third_party:androidx_legacy_support_v4", |         "//third_party:androidx_legacy_support_v4", | ||||||
|         "@maven//:com_google_code_findbugs_jsr305", |         "@maven//:com_google_code_findbugs_jsr305", | ||||||
|  | @ -75,6 +76,7 @@ android_library( | ||||||
|     srcs = glob( |     srcs = glob( | ||||||
|         ["**/*.java"], |         ["**/*.java"], | ||||||
|         exclude = [ |         exclude = [ | ||||||
|  |             "image/**", | ||||||
|             "Android*", |             "Android*", | ||||||
|             "AssetCache.java", |             "AssetCache.java", | ||||||
|             "AssetCacheDbHelper.java", |             "AssetCacheDbHelper.java", | ||||||
|  |  | ||||||
|  | @ -0,0 +1,6 @@ | ||||||
|  | <?xml version="1.0" encoding="utf-8"?> | ||||||
|  | <manifest xmlns:android="http://schemas.android.com/apk/res/android" | ||||||
|  |     package="com.google.mediapipe.framework.image"> | ||||||
|  |   <uses-sdk android:minSdkVersion="16" /> | ||||||
|  |   <application /> | ||||||
|  | </manifest> | ||||||
							
								
								
									
										32
									
								
								mediapipe/java/com/google/mediapipe/framework/image/BUILD
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								mediapipe/java/com/google/mediapipe/framework/image/BUILD
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,32 @@ | ||||||
|  | # Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | #      http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | 
 | ||||||
|  | load("@build_bazel_rules_android//android:rules.bzl", "android_library") | ||||||
|  | 
 | ||||||
|  | licenses(["notice"]) | ||||||
|  | 
 | ||||||
|  | android_library( | ||||||
|  |     name = "image", | ||||||
|  |     srcs = glob(["*.java"]), | ||||||
|  |     manifest = "AndroidManifest.xml", | ||||||
|  |     visibility = [ | ||||||
|  |         "//mediapipe:__subpackages__", | ||||||
|  |     ], | ||||||
|  |     deps = [ | ||||||
|  |         "//third_party:androidx_legacy_support_v4", | ||||||
|  |         "//third_party:autovalue", | ||||||
|  |         "@maven//:androidx_annotation_annotation", | ||||||
|  |         "@maven//:com_google_guava_guava", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | @ -0,0 +1,49 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | import android.graphics.Bitmap; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * Utility for extracting {@link android.graphics.Bitmap} from {@link Image}. | ||||||
|  |  * | ||||||
|  |  * <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BITMAP}, otherwise | ||||||
|  |  * {@link IllegalArgumentException} will be thrown. | ||||||
|  |  */ | ||||||
|  | public final class BitmapExtractor { | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Extracts a {@link android.graphics.Bitmap} from an {@link Image}. | ||||||
|  |    * | ||||||
|  |    * @param image the image to extract {@link android.graphics.Bitmap} from. | ||||||
|  |    * @return the {@link android.graphics.Bitmap} stored in {@link Image} | ||||||
|  |    * @throws IllegalArgumentException when the extraction requires unsupported format or data type | ||||||
|  |    *     conversions. | ||||||
|  |    */ | ||||||
|  |   public static Bitmap extract(Image image) { | ||||||
|  |     ImageContainer imageContainer = image.getContainer(Image.STORAGE_TYPE_BITMAP); | ||||||
|  |     if (imageContainer != null) { | ||||||
|  |       return ((BitmapImageContainer) imageContainer).getBitmap(); | ||||||
|  |     } else { | ||||||
|  |       // TODO: Support ByteBuffer -> Bitmap conversion. | ||||||
|  |       throw new IllegalArgumentException( | ||||||
|  |           "Extracting Bitmap from an Image created by objects other than Bitmap is not" | ||||||
|  |               + " supported"); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private BitmapExtractor() {} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,72 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | import android.content.Context; | ||||||
|  | import android.graphics.Bitmap; | ||||||
|  | import android.net.Uri; | ||||||
|  | import android.provider.MediaStore; | ||||||
|  | import java.io.IOException; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * Builds {@link Image} from {@link android.graphics.Bitmap}. | ||||||
|  |  * | ||||||
|  |  * <p>You can pass in either mutable or immutable {@link android.graphics.Bitmap}. However once | ||||||
|  |  * {@link android.graphics.Bitmap} is passed in, to keep data integrity you shouldn't modify content | ||||||
|  |  * in it. | ||||||
|  |  * | ||||||
|  |  * <p>Use {@link BitmapExtractor} to get {@link android.graphics.Bitmap} you passed in. | ||||||
|  |  */ | ||||||
|  | public class BitmapImageBuilder { | ||||||
|  | 
 | ||||||
|  |   // Mandatory fields. | ||||||
|  |   private final Bitmap bitmap; | ||||||
|  | 
 | ||||||
|  |   // Optional fields. | ||||||
|  |   private long timestamp; | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Creates the builder with a mandatory {@link android.graphics.Bitmap}. | ||||||
|  |    * | ||||||
|  |    * @param bitmap image data object. | ||||||
|  |    */ | ||||||
|  |   public BitmapImageBuilder(Bitmap bitmap) { | ||||||
|  |     this.bitmap = bitmap; | ||||||
|  |     timestamp = 0; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Creates the builder to build {@link Image} from a file. | ||||||
|  |    * | ||||||
|  |    * @param context the application context. | ||||||
|  |    * @param uri the path to the resource file. | ||||||
|  |    */ | ||||||
|  |   public BitmapImageBuilder(Context context, Uri uri) throws IOException { | ||||||
|  |     this(MediaStore.Images.Media.getBitmap(context.getContentResolver(), uri)); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Sets value for {@link Image#getTimestamp()}. */ | ||||||
|  |   BitmapImageBuilder setTimestamp(long timestamp) { | ||||||
|  |     this.timestamp = timestamp; | ||||||
|  |     return this; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Builds an {@link Image} instance. */ | ||||||
|  |   public Image build() { | ||||||
|  |     return new Image( | ||||||
|  |         new BitmapImageContainer(bitmap), timestamp, bitmap.getWidth(), bitmap.getHeight()); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,60 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | import android.graphics.Bitmap; | ||||||
|  | import com.google.mediapipe.framework.image.Image.ImageFormat; | ||||||
|  | 
 | ||||||
|  | class BitmapImageContainer implements ImageContainer { | ||||||
|  | 
 | ||||||
|  |   private final Bitmap bitmap; | ||||||
|  |   private final ImageProperties properties; | ||||||
|  | 
 | ||||||
|  |   public BitmapImageContainer(Bitmap bitmap) { | ||||||
|  |     this.bitmap = bitmap; | ||||||
|  |     this.properties = | ||||||
|  |         ImageProperties.builder() | ||||||
|  |             .setImageFormat(convertFormatCode(bitmap.getConfig())) | ||||||
|  |             .setStorageType(Image.STORAGE_TYPE_BITMAP) | ||||||
|  |             .build(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   public Bitmap getBitmap() { | ||||||
|  |     return bitmap; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @Override | ||||||
|  |   public ImageProperties getImageProperties() { | ||||||
|  |     return properties; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @Override | ||||||
|  |   public void close() { | ||||||
|  |     bitmap.recycle(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @ImageFormat | ||||||
|  |   static int convertFormatCode(Bitmap.Config config) { | ||||||
|  |     switch (config) { | ||||||
|  |       case ALPHA_8: | ||||||
|  |         return Image.IMAGE_FORMAT_ALPHA; | ||||||
|  |       case ARGB_8888: | ||||||
|  |         return Image.IMAGE_FORMAT_RGBA; | ||||||
|  |       default: | ||||||
|  |         return Image.IMAGE_FORMAT_UNKNOWN; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,254 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | import android.annotation.SuppressLint; | ||||||
|  | import android.graphics.Bitmap; | ||||||
|  | import android.graphics.Bitmap.Config; | ||||||
|  | import android.os.Build.VERSION; | ||||||
|  | import android.os.Build.VERSION_CODES; | ||||||
|  | import com.google.auto.value.AutoValue; | ||||||
|  | import com.google.mediapipe.framework.image.Image.ImageFormat; | ||||||
|  | import java.nio.ByteBuffer; | ||||||
|  | import java.nio.ByteOrder; | ||||||
|  | import java.util.Locale; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * Utility for extracting {@link ByteBuffer} from {@link Image}. | ||||||
|  |  * | ||||||
|  |  * <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_BYTEBUFFER}, otherwise | ||||||
|  |  * {@link IllegalArgumentException} will be thrown. | ||||||
|  |  */ | ||||||
|  | public class ByteBufferExtractor { | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Extracts a {@link ByteBuffer} from an {@link Image}. | ||||||
|  |    * | ||||||
|  |    * <p>The returned {@link ByteBuffer} is a read-only view, with the first available {@link | ||||||
|  |    * ImageProperties} whose storage type is {@code Image.STORAGE_TYPE_BYTEBUFFER}. | ||||||
|  |    * | ||||||
|  |    * @see Image#getContainedImageProperties() | ||||||
|  |    * @return A read-only {@link ByteBuffer}. | ||||||
|  |    * @throws IllegalArgumentException when the image doesn't contain a {@link ByteBuffer} storage. | ||||||
|  |    */ | ||||||
|  |   @SuppressLint("SwitchIntDef") | ||||||
|  |   public static ByteBuffer extract(Image image) { | ||||||
|  |     ImageContainer container = image.getContainer(); | ||||||
|  |     switch (container.getImageProperties().getStorageType()) { | ||||||
|  |       case Image.STORAGE_TYPE_BYTEBUFFER: | ||||||
|  |         ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; | ||||||
|  |         return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); | ||||||
|  |       default: | ||||||
|  |         throw new IllegalArgumentException( | ||||||
|  |             "Extract ByteBuffer from an Image created by objects other than Bytebuffer is not" | ||||||
|  |                 + " supported"); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Extracts a readonly {@link ByteBuffer} in given {@code targetFormat} from an {@link Image}. | ||||||
|  |    * | ||||||
|  |    * <p>Format conversion spec: | ||||||
|  |    * | ||||||
|  |    * <ul> | ||||||
|  |    *   <li>When extracting RGB images to RGBA format, A channel will always set to 255. | ||||||
|  |    *   <li>When extracting RGBA images to RGB format, A channel will be dropped. | ||||||
|  |    * </ul> | ||||||
|  |    * | ||||||
|  |    * @param image the image to extract buffer from. | ||||||
|  |    * @param targetFormat the image format of the result bytebuffer. | ||||||
|  |    * @return the readonly {@link ByteBuffer} stored in {@link Image} | ||||||
|  |    * @throws IllegalArgumentException when the extraction requires unsupported format or data type | ||||||
|  |    *     conversions. | ||||||
|  |    */ | ||||||
|  |   static ByteBuffer extract(Image image, @ImageFormat int targetFormat) { | ||||||
|  |     ImageContainer container; | ||||||
|  |     ImageProperties byteBufferProperties = | ||||||
|  |         ImageProperties.builder() | ||||||
|  |             .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) | ||||||
|  |             .setImageFormat(targetFormat) | ||||||
|  |             .build(); | ||||||
|  |     if ((container = image.getContainer(byteBufferProperties)) != null) { | ||||||
|  |       ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; | ||||||
|  |       return byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(); | ||||||
|  |     } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { | ||||||
|  |       ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; | ||||||
|  |       @ImageFormat int sourceFormat = byteBufferImageContainer.getImageFormat(); | ||||||
|  |       return convertByteBuffer(byteBufferImageContainer.getByteBuffer(), sourceFormat, targetFormat) | ||||||
|  |           .asReadOnlyBuffer(); | ||||||
|  |     } else if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { | ||||||
|  |       BitmapImageContainer bitmapImageContainer = (BitmapImageContainer) container; | ||||||
|  |       ByteBuffer byteBuffer = | ||||||
|  |           extractByteBufferFromBitmap(bitmapImageContainer.getBitmap(), targetFormat) | ||||||
|  |               .asReadOnlyBuffer(); | ||||||
|  |       boolean unused = image.addContainer(new ByteBufferImageContainer(byteBuffer, targetFormat)); | ||||||
|  |       return byteBuffer; | ||||||
|  |     } else { | ||||||
|  |       throw new IllegalArgumentException( | ||||||
|  |           "Extracting ByteBuffer from an Image created by objects other than Bitmap or" | ||||||
|  |               + " Bytebuffer is not supported"); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** A wrapper for a {@link ByteBuffer} and its {@link ImageFormat}. */ | ||||||
|  |   @AutoValue | ||||||
|  |   abstract static class Result { | ||||||
|  |     /** Gets the {@link ByteBuffer} in the result of {@link ByteBufferExtractor#extract(Image)}. */ | ||||||
|  |     public abstract ByteBuffer buffer(); | ||||||
|  | 
 | ||||||
|  |     /** Gets the {@link ImageFormat} in the result of {@link ByteBufferExtractor#extract(Image)}. */ | ||||||
|  |     @ImageFormat | ||||||
|  |     public abstract int format(); | ||||||
|  | 
 | ||||||
|  |     static Result create(ByteBuffer buffer, @ImageFormat int imageFormat) { | ||||||
|  |       return new AutoValue_ByteBufferExtractor_Result(buffer, imageFormat); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Extracts a {@link ByteBuffer} in any available {@code imageFormat} from an {@link Image}. | ||||||
|  |    * | ||||||
|  |    * <p>It will make the best effort to return an already existed {@link ByteBuffer} to avoid copy. | ||||||
|  |    * | ||||||
|  |    * @return the readonly {@link ByteBuffer} stored in {@link Image} | ||||||
|  |    * @throws IllegalArgumentException when {@code image} doesn't contain {@link ByteBuffer} with | ||||||
|  |    *     given {@code imageFormat} | ||||||
|  |    */ | ||||||
|  |   static Result extractInRecommendedFormat(Image image) { | ||||||
|  |     ImageContainer container; | ||||||
|  |     if ((container = image.getContainer(Image.STORAGE_TYPE_BITMAP)) != null) { | ||||||
|  |       Bitmap bitmap = ((BitmapImageContainer) container).getBitmap(); | ||||||
|  |       @ImageFormat int format = adviseImageFormat(bitmap); | ||||||
|  |       Result result = | ||||||
|  |           Result.create(extractByteBufferFromBitmap(bitmap, format).asReadOnlyBuffer(), format); | ||||||
|  | 
 | ||||||
|  |       boolean unused = | ||||||
|  |           image.addContainer(new ByteBufferImageContainer(result.buffer(), result.format())); | ||||||
|  |       return result; | ||||||
|  |     } else if ((container = image.getContainer(Image.STORAGE_TYPE_BYTEBUFFER)) != null) { | ||||||
|  |       ByteBufferImageContainer byteBufferImageContainer = (ByteBufferImageContainer) container; | ||||||
|  |       return Result.create( | ||||||
|  |           byteBufferImageContainer.getByteBuffer().asReadOnlyBuffer(), | ||||||
|  |           byteBufferImageContainer.getImageFormat()); | ||||||
|  |     } else { | ||||||
|  |       throw new IllegalArgumentException( | ||||||
|  |           "Extract ByteBuffer from an Image created by objects other than Bitmap or Bytebuffer" | ||||||
|  |               + " is not supported"); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @ImageFormat | ||||||
|  |   private static int adviseImageFormat(Bitmap bitmap) { | ||||||
|  |     if (bitmap.getConfig() == Config.ARGB_8888) { | ||||||
|  |       return Image.IMAGE_FORMAT_RGBA; | ||||||
|  |     } else { | ||||||
|  |       throw new IllegalArgumentException( | ||||||
|  |           String.format( | ||||||
|  |               "Extracting ByteBuffer from an Image created by a Bitmap in config %s is not" | ||||||
|  |                   + " supported", | ||||||
|  |               bitmap.getConfig())); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private static ByteBuffer extractByteBufferFromBitmap( | ||||||
|  |       Bitmap bitmap, @ImageFormat int imageFormat) { | ||||||
|  |     if (VERSION.SDK_INT >= VERSION_CODES.JELLY_BEAN_MR1 && bitmap.isPremultiplied()) { | ||||||
|  |       throw new IllegalArgumentException( | ||||||
|  |           "Extracting ByteBuffer from an Image created by a premultiplied Bitmap is not" | ||||||
|  |               + " supported"); | ||||||
|  |     } | ||||||
|  |     if (bitmap.getConfig() == Config.ARGB_8888) { | ||||||
|  |       if (imageFormat == Image.IMAGE_FORMAT_RGBA) { | ||||||
|  |         ByteBuffer buffer = ByteBuffer.allocateDirect(bitmap.getByteCount()); | ||||||
|  |         bitmap.copyPixelsToBuffer(buffer); | ||||||
|  |         buffer.rewind(); | ||||||
|  |         return buffer; | ||||||
|  |       } else if (imageFormat == Image.IMAGE_FORMAT_RGB) { | ||||||
|  |         // TODO: Try Use RGBA buffer to create RGB buffer which might be faster. | ||||||
|  |         int w = bitmap.getWidth(); | ||||||
|  |         int h = bitmap.getHeight(); | ||||||
|  |         int[] pixels = new int[w * h]; | ||||||
|  |         bitmap.getPixels(pixels, 0, w, 0, 0, w, h); | ||||||
|  |         ByteBuffer buffer = ByteBuffer.allocateDirect(w * h * 3); | ||||||
|  |         buffer.order(ByteOrder.nativeOrder()); | ||||||
|  |         for (int pixel : pixels) { | ||||||
|  |           // getPixels returns Color in ARGB rather than copyPixelsToBuffer which returns RGBA | ||||||
|  |           buffer.put((byte) ((pixel >> 16) & 0xff)); | ||||||
|  |           buffer.put((byte) ((pixel >> 8) & 0xff)); | ||||||
|  |           buffer.put((byte) (pixel & 0xff)); | ||||||
|  |         } | ||||||
|  |         buffer.rewind(); | ||||||
|  |         return buffer; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     throw new IllegalArgumentException( | ||||||
|  |         String.format( | ||||||
|  |             "Extracting ByteBuffer from an Image created by Bitmap and convert from %s to format" | ||||||
|  |                 + " %d is not supported", | ||||||
|  |             bitmap.getConfig(), imageFormat)); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private static ByteBuffer convertByteBuffer( | ||||||
|  |       ByteBuffer source, @ImageFormat int sourceFormat, @ImageFormat int targetFormat) { | ||||||
|  |     if (sourceFormat == Image.IMAGE_FORMAT_RGB && targetFormat == Image.IMAGE_FORMAT_RGBA) { | ||||||
|  |       ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 3 * 4); | ||||||
|  |       // Extend the buffer when the target is longer than the source. Use two cursors and sweep the | ||||||
|  |       // array reversely to convert in-place. | ||||||
|  |       byte[] array = new byte[target.capacity()]; | ||||||
|  |       source.get(array, 0, source.capacity()); | ||||||
|  |       source.rewind(); | ||||||
|  |       int rgbCursor = source.capacity(); | ||||||
|  |       int rgbaCursor = target.capacity(); | ||||||
|  |       while (rgbCursor != rgbaCursor) { | ||||||
|  |         array[--rgbaCursor] = (byte) 0xff; // A | ||||||
|  |         array[--rgbaCursor] = array[--rgbCursor]; // B | ||||||
|  |         array[--rgbaCursor] = array[--rgbCursor]; // G | ||||||
|  |         array[--rgbaCursor] = array[--rgbCursor]; // R | ||||||
|  |       } | ||||||
|  |       target.put(array, 0, target.capacity()); | ||||||
|  |       target.rewind(); | ||||||
|  |       return target; | ||||||
|  |     } else if (sourceFormat == Image.IMAGE_FORMAT_RGBA && targetFormat == Image.IMAGE_FORMAT_RGB) { | ||||||
|  |       ByteBuffer target = ByteBuffer.allocateDirect(source.capacity() / 4 * 3); | ||||||
|  |       // Shrink the buffer when the target is shorter than the source. Use two cursors and sweep the | ||||||
|  |       // array to convert in-place. | ||||||
|  |       byte[] array = new byte[source.capacity()]; | ||||||
|  |       source.get(array, 0, source.capacity()); | ||||||
|  |       source.rewind(); | ||||||
|  |       int rgbaCursor = 0; | ||||||
|  |       int rgbCursor = 0; | ||||||
|  |       while (rgbaCursor < array.length) { | ||||||
|  |         array[rgbCursor++] = array[rgbaCursor++]; // R | ||||||
|  |         array[rgbCursor++] = array[rgbaCursor++]; // G | ||||||
|  |         array[rgbCursor++] = array[rgbaCursor++]; // B | ||||||
|  |         rgbaCursor++; | ||||||
|  |       } | ||||||
|  |       target.put(array, 0, target.capacity()); | ||||||
|  |       target.rewind(); | ||||||
|  |       return target; | ||||||
|  |     } else { | ||||||
|  |       throw new IllegalArgumentException( | ||||||
|  |           String.format( | ||||||
|  |               Locale.ENGLISH, | ||||||
|  |               "Convert bytebuffer image format from %d to %d is not supported", | ||||||
|  |               sourceFormat, | ||||||
|  |               targetFormat)); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // ByteBuffer is not able to be instantiated. | ||||||
|  |   private ByteBufferExtractor() {} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,71 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | import com.google.mediapipe.framework.image.Image.ImageFormat; | ||||||
|  | import java.nio.ByteBuffer; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * Builds a {@link Image} from a {@link ByteBuffer}. | ||||||
|  |  * | ||||||
|  |  * <p>You can pass in either mutable or immutable {@link ByteBuffer}. However once {@link | ||||||
|  |  * ByteBuffer} is passed in, to keep data integrity you shouldn't modify content in it. | ||||||
|  |  * | ||||||
|  |  * <p>Use {@link ByteBufferExtractor} to get {@link ByteBuffer} you passed in. | ||||||
|  |  */ | ||||||
|  | public class ByteBufferImageBuilder { | ||||||
|  | 
 | ||||||
|  |   // Mandatory fields. | ||||||
|  |   private final ByteBuffer buffer; | ||||||
|  |   private final int width; | ||||||
|  |   private final int height; | ||||||
|  |   @ImageFormat private final int imageFormat; | ||||||
|  | 
 | ||||||
|  |   // Optional fields. | ||||||
|  |   private long timestamp; | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Creates the builder with mandatory {@link ByteBuffer} and the represented image. | ||||||
|  |    * | ||||||
|  |    * <p>We will validate the size of the {@code byteBuffer} with given {@code width}, {@code height} | ||||||
|  |    * and {@code imageFormat}. | ||||||
|  |    * | ||||||
|  |    * @param byteBuffer image data object. | ||||||
|  |    * @param width the width of the represented image. | ||||||
|  |    * @param height the height of the represented image. | ||||||
|  |    * @param imageFormat how the data encode the image. | ||||||
|  |    */ | ||||||
|  |   public ByteBufferImageBuilder( | ||||||
|  |       ByteBuffer byteBuffer, int width, int height, @ImageFormat int imageFormat) { | ||||||
|  |     this.buffer = byteBuffer; | ||||||
|  |     this.width = width; | ||||||
|  |     this.height = height; | ||||||
|  |     this.imageFormat = imageFormat; | ||||||
|  |     // TODO: Validate bytebuffer size with width, height and image format | ||||||
|  |     this.timestamp = 0; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Sets value for {@link Image#getTimestamp()}. */ | ||||||
|  |   ByteBufferImageBuilder setTimestamp(long timestamp) { | ||||||
|  |     this.timestamp = timestamp; | ||||||
|  |     return this; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Builds an {@link Image} instance. */ | ||||||
|  |   public Image build() { | ||||||
|  |     return new Image(new ByteBufferImageContainer(buffer, imageFormat), timestamp, width, height); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,58 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | import com.google.mediapipe.framework.image.Image.ImageFormat; | ||||||
|  | import java.nio.ByteBuffer; | ||||||
|  | 
 | ||||||
|  | class ByteBufferImageContainer implements ImageContainer { | ||||||
|  | 
 | ||||||
|  |   private final ByteBuffer buffer; | ||||||
|  |   private final ImageProperties properties; | ||||||
|  | 
 | ||||||
|  |   public ByteBufferImageContainer( | ||||||
|  |       ByteBuffer buffer, | ||||||
|  |       @ImageFormat int imageFormat) { | ||||||
|  |     this.buffer = buffer; | ||||||
|  |     this.properties = | ||||||
|  |         ImageProperties.builder() | ||||||
|  |             .setStorageType(Image.STORAGE_TYPE_BYTEBUFFER) | ||||||
|  |             .setImageFormat(imageFormat) | ||||||
|  |             .build(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   public ByteBuffer getByteBuffer() { | ||||||
|  |     return buffer; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @Override | ||||||
|  |   public ImageProperties getImageProperties() { | ||||||
|  |     return properties; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Returns the image format. | ||||||
|  |    */ | ||||||
|  |   @ImageFormat | ||||||
|  |   public int getImageFormat() { | ||||||
|  |     return properties.getImageFormat(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @Override | ||||||
|  |   public void close() { | ||||||
|  |     // No op for ByteBuffer. | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										241
									
								
								mediapipe/java/com/google/mediapipe/framework/image/Image.java
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										241
									
								
								mediapipe/java/com/google/mediapipe/framework/image/Image.java
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,241 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | import androidx.annotation.IntDef; | ||||||
|  | import androidx.annotation.Nullable; | ||||||
|  | import java.io.Closeable; | ||||||
|  | import java.lang.annotation.Retention; | ||||||
|  | import java.lang.annotation.RetentionPolicy; | ||||||
|  | import java.util.Collections; | ||||||
|  | import java.util.HashMap; | ||||||
|  | import java.util.List; | ||||||
|  | import java.util.Map; | ||||||
|  | import java.util.Map.Entry; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * The wrapper class for image objects. | ||||||
|  |  * | ||||||
|  |  * <p>{@link Image} is designed to be an immutable image container, which could be shared | ||||||
|  |  * cross-platforms. | ||||||
|  |  * | ||||||
|  |  * <p>To construct an {@link Image}, use the provided builders: | ||||||
|  |  * | ||||||
|  |  * <ul> | ||||||
|  |  *   <li>{@link ByteBufferImageBuilder} | ||||||
|  |  *   <li>{@link BitmapImageBuilder} | ||||||
|  |  *   <li>{@link MediaImageBuilder} | ||||||
|  |  * </ul> | ||||||
|  |  * | ||||||
|  |  * <p>{@link Image} uses reference counting to maintain internal storage. When it is created the | ||||||
|  |  * reference count is 1. Developer can call {@link #close()} to reduce reference count to release | ||||||
|  |  * internal storage earlier, otherwise Java garbage collection will release the storage eventually. | ||||||
|  |  * | ||||||
|  |  * <p>To extract concrete image, first check {@link StorageType} and then use the provided | ||||||
|  |  * extractors: | ||||||
|  |  * | ||||||
|  |  * <ul> | ||||||
|  |  *   <li>{@link ByteBufferExtractor} | ||||||
|  |  *   <li>{@link BitmapExtractor} | ||||||
|  |  *   <li>{@link MediaImageExtractor} | ||||||
|  |  * </ul> | ||||||
|  |  */ | ||||||
|  | public class Image implements Closeable { | ||||||
|  | 
 | ||||||
|  |   /** Specifies the image format of an image. */ | ||||||
|  |   @IntDef({ | ||||||
|  |     IMAGE_FORMAT_UNKNOWN, | ||||||
|  |     IMAGE_FORMAT_RGBA, | ||||||
|  |     IMAGE_FORMAT_RGB, | ||||||
|  |     IMAGE_FORMAT_NV12, | ||||||
|  |     IMAGE_FORMAT_NV21, | ||||||
|  |     IMAGE_FORMAT_YV12, | ||||||
|  |     IMAGE_FORMAT_YV21, | ||||||
|  |     IMAGE_FORMAT_YUV_420_888, | ||||||
|  |     IMAGE_FORMAT_ALPHA, | ||||||
|  |     IMAGE_FORMAT_JPEG, | ||||||
|  |   }) | ||||||
|  |   @Retention(RetentionPolicy.SOURCE) | ||||||
|  |   public @interface ImageFormat {} | ||||||
|  | 
 | ||||||
|  |   public static final int IMAGE_FORMAT_UNKNOWN = 0; | ||||||
|  |   public static final int IMAGE_FORMAT_RGBA = 1; | ||||||
|  |   public static final int IMAGE_FORMAT_RGB = 2; | ||||||
|  |   public static final int IMAGE_FORMAT_NV12 = 3; | ||||||
|  |   public static final int IMAGE_FORMAT_NV21 = 4; | ||||||
|  |   public static final int IMAGE_FORMAT_YV12 = 5; | ||||||
|  |   public static final int IMAGE_FORMAT_YV21 = 6; | ||||||
|  |   public static final int IMAGE_FORMAT_YUV_420_888 = 7; | ||||||
|  |   public static final int IMAGE_FORMAT_ALPHA = 8; | ||||||
|  |   public static final int IMAGE_FORMAT_JPEG = 9; | ||||||
|  | 
 | ||||||
|  |   /** Specifies the image container type. Would be useful for choosing extractors. */ | ||||||
|  |   @IntDef({ | ||||||
|  |     STORAGE_TYPE_BITMAP, | ||||||
|  |     STORAGE_TYPE_BYTEBUFFER, | ||||||
|  |     STORAGE_TYPE_MEDIA_IMAGE, | ||||||
|  |     STORAGE_TYPE_IMAGE_PROXY, | ||||||
|  |   }) | ||||||
|  |   @Retention(RetentionPolicy.SOURCE) | ||||||
|  |   public @interface StorageType {} | ||||||
|  | 
 | ||||||
|  |   public static final int STORAGE_TYPE_BITMAP = 1; | ||||||
|  |   public static final int STORAGE_TYPE_BYTEBUFFER = 2; | ||||||
|  |   public static final int STORAGE_TYPE_MEDIA_IMAGE = 3; | ||||||
|  |   public static final int STORAGE_TYPE_IMAGE_PROXY = 4; | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Returns a list of supported image properties for this {@link Image}. | ||||||
|  |    * | ||||||
|  |    * <p>Currently {@link Image} only support single storage type so the size of return list will | ||||||
|  |    * always be 1. | ||||||
|  |    * | ||||||
|  |    * @see ImageProperties | ||||||
|  |    */ | ||||||
|  |   public List<ImageProperties> getContainedImageProperties() { | ||||||
|  |     return Collections.singletonList(getContainer().getImageProperties()); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Returns the timestamp attached to the image. */ | ||||||
|  |   long getTimestamp() { | ||||||
|  |      return timestamp; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Returns the width of the image. */ | ||||||
|  |   public int getWidth() { | ||||||
|  |     return width; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Returns the height of the image. */ | ||||||
|  |   public int getHeight() { | ||||||
|  |     return height; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Acquires a reference on this {@link Image}. This will increase the reference count by 1. */ | ||||||
|  |   private synchronized void acquire() { | ||||||
|  |     referenceCount += 1; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Removes a reference that was previously acquired or init. | ||||||
|  |    * | ||||||
|  |    * <p>When {@link Image} is created, it has 1 reference count. | ||||||
|  |    * | ||||||
|  |    * <p>When the reference count becomes 0, it will release the resource under the hood. | ||||||
|  |    */ | ||||||
|  |   @Override | ||||||
|  |   // TODO: Create an internal flag to indicate image is closed, or use referenceCount | ||||||
|  |   public synchronized void close() { | ||||||
|  |     referenceCount -= 1; | ||||||
|  |     if (referenceCount == 0) { | ||||||
|  |       for (ImageContainer imageContainer : containerMap.values()) { | ||||||
|  |         imageContainer.close(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Advanced API access for {@link Image}. */ | ||||||
|  |   static final class Internal { | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Acquires a reference on this {@link Image}. This will increase the reference count by 1. | ||||||
|  |      * | ||||||
|  |      * <p>This method is more useful for image consumer to acquire a reference so image resource | ||||||
|  |      * will not be closed accidentally. As image creator, normal developer doesn't need to call this | ||||||
|  |      * method. | ||||||
|  |      * | ||||||
|  |      * <p>The reference count is 1 when {@link Image} is created. Developer can call {@link | ||||||
|  |      * #close()} to indicate it doesn't need this {@link Image} anymore. | ||||||
|  |      * | ||||||
|  |      * @see #close() | ||||||
|  |      */ | ||||||
|  |     void acquire() { | ||||||
|  |       image.acquire(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private final Image image; | ||||||
|  | 
 | ||||||
|  |     // Only Image creates the internal helper. | ||||||
|  |     private Internal(Image image) { | ||||||
|  |       this.image = image; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Gets {@link Internal} object which contains internal APIs. */ | ||||||
|  |   Internal getInternal() { | ||||||
|  |     return new Internal(this); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   private final Map<ImageProperties, ImageContainer> containerMap; | ||||||
|  |   private final long timestamp; | ||||||
|  |   private final int width; | ||||||
|  |   private final int height; | ||||||
|  | 
 | ||||||
|  |   private int referenceCount; | ||||||
|  | 
 | ||||||
|  |   /** Constructs an {@link Image} with a built container. */ | ||||||
|  |   Image(ImageContainer container, long timestamp, int width, int height) { | ||||||
|  |     this.containerMap = new HashMap<>(); | ||||||
|  |     containerMap.put(container.getImageProperties(), container); | ||||||
|  |     this.timestamp = timestamp; | ||||||
|  |     this.width = width; | ||||||
|  |     this.height = height; | ||||||
|  |     this.referenceCount = 1; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Gets one available container. | ||||||
|  |    * | ||||||
|  |    * @return the current container. | ||||||
|  |    */ | ||||||
|  |   ImageContainer getContainer() { | ||||||
|  |     // According to the design, in the future we will support multiple containers in one image. | ||||||
|  |     // Currently just return the original container. | ||||||
|  |     // TODO: Cache multiple containers in Image. | ||||||
|  |     return containerMap.values().iterator().next(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Gets container from required {@code storageType}. Returns {@code null} if not existed. | ||||||
|  |    * | ||||||
|  |    * <p>If there are multiple containers with required {@code storageType}, returns the first one. | ||||||
|  |    */ | ||||||
|  |   @Nullable | ||||||
|  |   ImageContainer getContainer(@StorageType int storageType) { | ||||||
|  |     for (Entry<ImageProperties, ImageContainer> entry : containerMap.entrySet()) { | ||||||
|  |       if (entry.getKey().getStorageType() == storageType) { | ||||||
|  |         return entry.getValue(); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     return null; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Gets container from required {@code imageProperties}. Returns {@code null} if non existed. */ | ||||||
|  |   @Nullable | ||||||
|  |   ImageContainer getContainer(ImageProperties imageProperties) { | ||||||
|  |     return containerMap.get(imageProperties); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Adds a new container if it doesn't exist. Returns {@code true} if it succeeds. */ | ||||||
|  |   boolean addContainer(ImageContainer container) { | ||||||
|  |     ImageProperties imageProperties = container.getImageProperties(); | ||||||
|  |     if (containerMap.containsKey(imageProperties)) { | ||||||
|  |       return false; | ||||||
|  |     } | ||||||
|  |     containerMap.put(imageProperties, container); | ||||||
|  |     return true; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,27 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | /** Lightweight abstraction for an object that can receive {@link Image} */ | ||||||
|  | public interface ImageConsumer { | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Called when an {@link Image} is available. | ||||||
|  |    * | ||||||
|  |    * <p>The argument is only guaranteed to be available until this method returns. if you need to | ||||||
|  |    * extend its life time, acquire it, then release it when done. | ||||||
|  |    */ | ||||||
|  |   void onNewImage(Image image); | ||||||
|  | } | ||||||
|  | @ -0,0 +1,25 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | /** Manages internal image data storage. The interface is package-private. */ | ||||||
|  | interface ImageContainer { | ||||||
|  |   /** Returns the properties of the contained image. */ | ||||||
|  |   ImageProperties getImageProperties(); | ||||||
|  | 
 | ||||||
|  |   /** Close the image container and releases the image resource inside. */ | ||||||
|  |   void close(); | ||||||
|  | } | ||||||
|  | @ -0,0 +1,22 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | /** Lightweight abstraction for an object that produce {@link Image} */ | ||||||
|  | public interface ImageProducer { | ||||||
|  | 
 | ||||||
|  |   /** Sets the consumer that receives the {@link Image}. */ | ||||||
|  |   void setImageConsumer(ImageConsumer imageConsumer); | ||||||
|  | } | ||||||
|  | @ -0,0 +1,80 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | import com.google.auto.value.AutoValue; | ||||||
|  | import com.google.auto.value.extension.memoized.Memoized; | ||||||
|  | import com.google.mediapipe.framework.image.Image.ImageFormat; | ||||||
|  | import com.google.mediapipe.framework.image.Image.StorageType; | ||||||
|  | 
 | ||||||
|  | /** Groups a set of properties to describe how an image is stored. */ | ||||||
|  | @AutoValue | ||||||
|  | public abstract class ImageProperties { | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Gets the pixel format of the image. | ||||||
|  |    * | ||||||
|  |    * @see Image.ImageFormat | ||||||
|  |    */ | ||||||
|  |   @ImageFormat | ||||||
|  |   public abstract int getImageFormat(); | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Gets the storage type of the image. | ||||||
|  |    * | ||||||
|  |    * @see Image.StorageType | ||||||
|  |    */ | ||||||
|  |   @StorageType | ||||||
|  |   public abstract int getStorageType(); | ||||||
|  | 
 | ||||||
|  |   @Memoized | ||||||
|  |   @Override | ||||||
|  |   public abstract int hashCode(); | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Creates a builder of {@link ImageProperties}. | ||||||
|  |    * | ||||||
|  |    * @see ImageProperties.Builder | ||||||
|  |    */ | ||||||
|  |   static Builder builder() { | ||||||
|  |     return new AutoValue_ImageProperties.Builder(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Builds a {@link ImageProperties}. */ | ||||||
|  |   @AutoValue.Builder | ||||||
|  |   abstract static class Builder { | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Sets the {@link Image.ImageFormat}. | ||||||
|  |      * | ||||||
|  |      * @see ImageProperties#getImageFormat | ||||||
|  |      */ | ||||||
|  |     abstract Builder setImageFormat(@ImageFormat int value); | ||||||
|  | 
 | ||||||
|  |     /** | ||||||
|  |      * Sets the {@link Image.StorageType}. | ||||||
|  |      * | ||||||
|  |      * @see ImageProperties#getStorageType | ||||||
|  |      */ | ||||||
|  |     abstract Builder setStorageType(@StorageType int value); | ||||||
|  | 
 | ||||||
|  |     /** Builds the {@link ImageProperties}. */ | ||||||
|  |     abstract ImageProperties build(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // Hide the constructor. | ||||||
|  |   ImageProperties() {} | ||||||
|  | } | ||||||
|  | @ -0,0 +1,62 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | import android.os.Build.VERSION_CODES; | ||||||
|  | import androidx.annotation.RequiresApi; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * Builds {@link Image} from {@link android.media.Image}. | ||||||
|  |  * | ||||||
|  |  * <p>Once {@link android.media.Image} is passed in, to keep data integrity you shouldn't modify | ||||||
|  |  * content in it. | ||||||
|  |  * | ||||||
|  |  * <p>Use {@link MediaImageExtractor} to get {@link android.media.Image} you passed in. | ||||||
|  |  */ | ||||||
|  | @RequiresApi(VERSION_CODES.KITKAT) | ||||||
|  | public class MediaImageBuilder { | ||||||
|  | 
 | ||||||
|  |   // Mandatory fields. | ||||||
|  |   private final android.media.Image mediaImage; | ||||||
|  | 
 | ||||||
|  |   // Optional fields. | ||||||
|  |   private long timestamp; | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Creates the builder with a mandatory {@link android.media.Image}. | ||||||
|  |    * | ||||||
|  |    * @param mediaImage image data object. | ||||||
|  |    */ | ||||||
|  |   public MediaImageBuilder(android.media.Image mediaImage) { | ||||||
|  |     this.mediaImage = mediaImage; | ||||||
|  |     this.timestamp = 0; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Sets value for {@link Image#getTimestamp()}. */ | ||||||
|  |   MediaImageBuilder setTimestamp(long timestamp) { | ||||||
|  |     this.timestamp = timestamp; | ||||||
|  |     return this; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /** Builds an {@link Image} instance. */ | ||||||
|  |   public Image build() { | ||||||
|  |     return new Image( | ||||||
|  |         new MediaImageContainer(mediaImage), | ||||||
|  |         timestamp, | ||||||
|  |         mediaImage.getWidth(), | ||||||
|  |         mediaImage.getHeight()); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,73 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | import android.os.Build; | ||||||
|  | import android.os.Build.VERSION; | ||||||
|  | import android.os.Build.VERSION_CODES; | ||||||
|  | import androidx.annotation.RequiresApi; | ||||||
|  | import com.google.mediapipe.framework.image.Image.ImageFormat; | ||||||
|  | 
 | ||||||
|  | @RequiresApi(VERSION_CODES.KITKAT) | ||||||
|  | class MediaImageContainer implements ImageContainer { | ||||||
|  | 
 | ||||||
|  |   private final android.media.Image mediaImage; | ||||||
|  |   private final ImageProperties properties; | ||||||
|  | 
 | ||||||
|  |   public MediaImageContainer(android.media.Image mediaImage) { | ||||||
|  |     this.mediaImage = mediaImage; | ||||||
|  |     this.properties = | ||||||
|  |         ImageProperties.builder() | ||||||
|  |             .setStorageType(Image.STORAGE_TYPE_MEDIA_IMAGE) | ||||||
|  |             .setImageFormat(convertFormatCode(mediaImage.getFormat())) | ||||||
|  |             .build(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   public android.media.Image getImage() { | ||||||
|  |     return mediaImage; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @Override | ||||||
|  |   public ImageProperties getImageProperties() { | ||||||
|  |     return properties; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @Override | ||||||
|  |   public void close() { | ||||||
|  |     mediaImage.close(); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   @ImageFormat | ||||||
|  |   static int convertFormatCode(int graphicsFormat) { | ||||||
|  |     // We only cover the format mentioned in | ||||||
|  |     // https://developer.android.com/reference/android/media/Image#getFormat() | ||||||
|  |     if (VERSION.SDK_INT >= Build.VERSION_CODES.M) { | ||||||
|  |       if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGBA_8888) { | ||||||
|  |         return Image.IMAGE_FORMAT_RGBA; | ||||||
|  |       } else if (graphicsFormat == android.graphics.ImageFormat.FLEX_RGB_888) { | ||||||
|  |         return Image.IMAGE_FORMAT_RGB; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     switch (graphicsFormat) { | ||||||
|  |       case android.graphics.ImageFormat.JPEG: | ||||||
|  |         return Image.IMAGE_FORMAT_JPEG; | ||||||
|  |       case android.graphics.ImageFormat.YUV_420_888: | ||||||
|  |         return Image.IMAGE_FORMAT_YUV_420_888; | ||||||
|  |       default: | ||||||
|  |         return Image.IMAGE_FORMAT_UNKNOWN; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,49 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | package com.google.mediapipe.framework.image; | ||||||
|  | 
 | ||||||
|  | import android.os.Build.VERSION_CODES; | ||||||
|  | import androidx.annotation.RequiresApi; | ||||||
|  | 
 | ||||||
|  | /** | ||||||
|  |  * Utility for extracting {@link android.media.Image} from {@link Image}. | ||||||
|  |  * | ||||||
|  |  * <p>Currently it only supports {@link Image} with {@link Image#STORAGE_TYPE_MEDIA_IMAGE}, | ||||||
|  |  * otherwise {@link IllegalArgumentException} will be thrown. | ||||||
|  |  */ | ||||||
|  | @RequiresApi(VERSION_CODES.KITKAT) | ||||||
|  | public class MediaImageExtractor { | ||||||
|  | 
 | ||||||
|  |   private MediaImageExtractor() {} | ||||||
|  | 
 | ||||||
|  |   /** | ||||||
|  |    * Extracts a {@link android.media.Image} from an {@link Image}. Currently it only works for | ||||||
|  |    * {@link Image} that built from {@link MediaImageBuilder}. | ||||||
|  |    * | ||||||
|  |    * @param image the image to extract {@link android.media.Image} from. | ||||||
|  |    * @return {@link android.media.Image} that stored in {@link Image}. | ||||||
|  |    * @throws IllegalArgumentException if the extraction failed. | ||||||
|  |    */ | ||||||
|  |   public static android.media.Image extract(Image image) { | ||||||
|  |     ImageContainer container; | ||||||
|  |     if ((container = image.getContainer(Image.STORAGE_TYPE_MEDIA_IMAGE)) != null) { | ||||||
|  |       return ((MediaImageContainer) container).getImage(); | ||||||
|  |     } | ||||||
|  |     throw new IllegalArgumentException( | ||||||
|  |         "Extract Media Image from an Image created by objects other than Media Image" | ||||||
|  |             + " is not supported"); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -73,9 +73,14 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD( | ||||||
|   // TODO: get the graph's main context from the packet context?
 |   // TODO: get the graph's main context from the packet context?
 | ||||||
|   // Or clean up in some other way?
 |   // Or clean up in some other way?
 | ||||||
|   if (context_for_deletion) { |   if (context_for_deletion) { | ||||||
|     token = new mediapipe::GlSyncToken( |     auto sync = mediapipe::GlContext::CreateSyncTokenForCurrentExternalContext( | ||||||
|         mediapipe::GlContext::CreateSyncTokenForCurrentExternalContext( |         context_for_deletion); | ||||||
|             context_for_deletion)); |     // A Java handle to a token is a raw pointer to a std::shared_ptr on the
 | ||||||
|  |     // heap, cast to a long. If the shared_ptr itself is null, leave the token
 | ||||||
|  |     // null too.
 | ||||||
|  |     if (sync) { | ||||||
|  |       token = new mediapipe::GlSyncToken(std::move(sync)); | ||||||
|  |     } | ||||||
|   } |   } | ||||||
|   return reinterpret_cast<jlong>(token); |   return reinterpret_cast<jlong>(token); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -145,6 +145,7 @@ EOF | ||||||
|             "//mediapipe/java/com/google/mediapipe/components:android_components", |             "//mediapipe/java/com/google/mediapipe/components:android_components", | ||||||
|             "//mediapipe/java/com/google/mediapipe/components:android_camerax_helper", |             "//mediapipe/java/com/google/mediapipe/components:android_camerax_helper", | ||||||
|             "//mediapipe/java/com/google/mediapipe/framework:android_framework", |             "//mediapipe/java/com/google/mediapipe/framework:android_framework", | ||||||
|  |             "//mediapipe/java/com/google/mediapipe/framework/image", | ||||||
|             "//mediapipe/java/com/google/mediapipe/glutil", |             "//mediapipe/java/com/google/mediapipe/glutil", | ||||||
|             "//third_party:androidx_annotation", |             "//third_party:androidx_annotation", | ||||||
|             "//third_party:androidx_appcompat", |             "//third_party:androidx_appcompat", | ||||||
|  |  | ||||||
|  | @ -76,7 +76,8 @@ ConvertAudioClassifierOptionsToProto(AudioClassifierOptions* options) { | ||||||
|   options_proto->mutable_base_options()->Swap(base_options_proto.get()); |   options_proto->mutable_base_options()->Swap(base_options_proto.get()); | ||||||
|   options_proto->mutable_base_options()->set_use_stream_mode( |   options_proto->mutable_base_options()->set_use_stream_mode( | ||||||
|       options->running_mode == core::RunningMode::AUDIO_STREAM); |       options->running_mode == core::RunningMode::AUDIO_STREAM); | ||||||
|   auto classifier_options_proto = std::make_unique<tasks::ClassifierOptions>( |   auto classifier_options_proto = | ||||||
|  |       std::make_unique<tasks::components::proto::ClassifierOptions>( | ||||||
|           components::ConvertClassifierOptionsToProto( |           components::ConvertClassifierOptionsToProto( | ||||||
|               &(options->classifier_options))); |               &(options->classifier_options))); | ||||||
|   options_proto->mutable_classifier_options()->Swap( |   options_proto->mutable_classifier_options()->Swap( | ||||||
|  |  | ||||||
|  | @ -136,6 +136,11 @@ void ConfigureAudioToTensorCalculator( | ||||||
| //   options {
 | //   options {
 | ||||||
| //     [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext]
 | //     [mediapipe.tasks.audio.audio_classifier.proto.AudioClassifierOptions.ext]
 | ||||||
| //     {
 | //     {
 | ||||||
|  | //       base_options {
 | ||||||
|  | //         model_asset {
 | ||||||
|  | //           file_name: "/path/to/model.tflite"
 | ||||||
|  | //         }
 | ||||||
|  | //       }
 | ||||||
| //       max_results: 4
 | //       max_results: 4
 | ||||||
| //       score_threshold: 0.5
 | //       score_threshold: 0.5
 | ||||||
| //       category_allowlist: "foo"
 | //       category_allowlist: "foo"
 | ||||||
|  | @ -225,15 +230,17 @@ class AudioClassifierGraph : public core::ModelTaskGraph { | ||||||
| 
 | 
 | ||||||
|     // Adds inference subgraph and connects its input stream to the output
 |     // Adds inference subgraph and connects its input stream to the output
 | ||||||
|     // tensors produced by the AudioToTensorCalculator.
 |     // tensors produced by the AudioToTensorCalculator.
 | ||||||
|     auto& inference = AddInference(model_resources, graph); |     auto& inference = AddInference( | ||||||
|  |         model_resources, task_options.base_options().acceleration(), graph); | ||||||
|     audio_to_tensor.Out(kTensorsTag) >> inference.In(kTensorsTag); |     audio_to_tensor.Out(kTensorsTag) >> inference.In(kTensorsTag); | ||||||
| 
 | 
 | ||||||
|     // Adds postprocessing calculators and connects them to the graph output.
 |     // Adds postprocessing calculators and connects them to the graph output.
 | ||||||
|     auto& postprocessing = |     auto& postprocessing = graph.AddNode( | ||||||
|         graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph"); |         "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); | ||||||
|     MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( |     MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( | ||||||
|         model_resources, task_options.classifier_options(), |         model_resources, task_options.classifier_options(), | ||||||
|         &postprocessing.GetOptions<ClassificationPostprocessingOptions>())); |         &postprocessing.GetOptions< | ||||||
|  |             tasks::components::ClassificationPostprocessingOptions>())); | ||||||
|     inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); |     inference.Out(kTensorsTag) >> postprocessing.In(kTensorsTag); | ||||||
| 
 | 
 | ||||||
|     // Time aggregation is only needed for performing audio classification on
 |     // Time aggregation is only needed for performing audio classification on
 | ||||||
|  |  | ||||||
|  | @ -37,7 +37,6 @@ limitations under the License. | ||||||
| #include "mediapipe/tasks/cc/audio/core/running_mode.h" | #include "mediapipe/tasks/cc/audio/core/running_mode.h" | ||||||
| #include "mediapipe/tasks/cc/audio/utils/test_utils.h" | #include "mediapipe/tasks/cc/audio/utils/test_utils.h" | ||||||
| #include "mediapipe/tasks/cc/common.h" | #include "mediapipe/tasks/cc/common.h" | ||||||
| #include "mediapipe/tasks/cc/components/classifier_options.pb.h" |  | ||||||
| #include "mediapipe/tasks/cc/components/containers/category.pb.h" | #include "mediapipe/tasks/cc/components/containers/category.pb.h" | ||||||
| #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" | #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" | ||||||
| #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | #include "tensorflow/lite/core/shims/cc/shims_test_util.h" | ||||||
|  | @ -168,7 +167,7 @@ class CreateFromOptionsTest : public tflite_shims::testing::Test {}; | ||||||
| TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) { | TEST_F(CreateFromOptionsTest, SucceedsForModelWithMetadata) { | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->classifier_options.max_results = 3; |   options->classifier_options.max_results = 3; | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, | ||||||
|                           AudioClassifier::Create(std::move(options))); |                           AudioClassifier::Create(std::move(options))); | ||||||
|  | @ -192,7 +191,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { | ||||||
| TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { | TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->classifier_options.max_results = 0; |   options->classifier_options.max_results = 0; | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = |   StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = | ||||||
|       AudioClassifier::Create(std::move(options)); |       AudioClassifier::Create(std::move(options)); | ||||||
|  | @ -208,7 +207,7 @@ TEST_F(CreateFromOptionsTest, FailsWithInvalidMaxResults) { | ||||||
| 
 | 
 | ||||||
| TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) { | TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) { | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   options->classifier_options.category_allowlist.push_back("foo"); |   options->classifier_options.category_allowlist.push_back("foo"); | ||||||
|   options->classifier_options.category_denylist.push_back("bar"); |   options->classifier_options.category_denylist.push_back("bar"); | ||||||
|  | @ -226,7 +225,7 @@ TEST_F(CreateFromOptionsTest, FailsWithCombinedAllowlistAndDenylist) { | ||||||
| 
 | 
 | ||||||
| TEST_F(CreateFromOptionsTest, FailsWithMissingMetadata) { | TEST_F(CreateFromOptionsTest, FailsWithMissingMetadata) { | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); | ||||||
|   StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = |   StatusOr<std::unique_ptr<AudioClassifier>> audio_classifier_or = | ||||||
|       AudioClassifier::Create(std::move(options)); |       AudioClassifier::Create(std::move(options)); | ||||||
|  | @ -242,7 +241,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingMetadata) { | ||||||
| 
 | 
 | ||||||
| TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) { | TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) { | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); | ||||||
|   options->running_mode = core::RunningMode::AUDIO_STREAM; |   options->running_mode = core::RunningMode::AUDIO_STREAM; | ||||||
|   options->sample_rate = 16000; |   options->sample_rate = 16000; | ||||||
|  | @ -260,7 +259,7 @@ TEST_F(CreateFromOptionsTest, FailsWithMissingCallback) { | ||||||
| 
 | 
 | ||||||
| TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) { | TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) { | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); | ||||||
|   options->result_callback = |   options->result_callback = | ||||||
|       [](absl::StatusOr<ClassificationResult> status_or_result) {}; |       [](absl::StatusOr<ClassificationResult> status_or_result) {}; | ||||||
|  | @ -279,7 +278,7 @@ TEST_F(CreateFromOptionsTest, FailsWithUnnecessaryCallback) { | ||||||
| 
 | 
 | ||||||
| TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) { | TEST_F(CreateFromOptionsTest, FailsWithMissingDefaultInputAudioSampleRate) { | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithoutMetadata); | ||||||
|   options->running_mode = core::RunningMode::AUDIO_STREAM; |   options->running_mode = core::RunningMode::AUDIO_STREAM; | ||||||
|   options->result_callback = |   options->result_callback = | ||||||
|  | @ -301,7 +300,7 @@ class ClassifyTest : public tflite_shims::testing::Test {}; | ||||||
| TEST_F(ClassifyTest, Succeeds) { | TEST_F(ClassifyTest, Succeeds) { | ||||||
|   auto audio_buffer = GetAudioData(k16kTestWavFilename); |   auto audio_buffer = GetAudioData(k16kTestWavFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, | ||||||
|                           AudioClassifier::Create(std::move(options))); |                           AudioClassifier::Create(std::move(options))); | ||||||
|  | @ -315,7 +314,7 @@ TEST_F(ClassifyTest, Succeeds) { | ||||||
| TEST_F(ClassifyTest, SucceedsWithResampling) { | TEST_F(ClassifyTest, SucceedsWithResampling) { | ||||||
|   auto audio_buffer = GetAudioData(k48kTestWavFilename); |   auto audio_buffer = GetAudioData(k48kTestWavFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, | ||||||
|                           AudioClassifier::Create(std::move(options))); |                           AudioClassifier::Create(std::move(options))); | ||||||
|  | @ -330,7 +329,7 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) { | ||||||
|   auto audio_buffer_16k_hz = GetAudioData(k16kTestWavFilename); |   auto audio_buffer_16k_hz = GetAudioData(k16kTestWavFilename); | ||||||
|   auto audio_buffer_48k_hz = GetAudioData(k48kTestWavFilename); |   auto audio_buffer_48k_hz = GetAudioData(k48kTestWavFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, | ||||||
|                           AudioClassifier::Create(std::move(options))); |                           AudioClassifier::Create(std::move(options))); | ||||||
|  | @ -349,7 +348,7 @@ TEST_F(ClassifyTest, SucceedsWithInputsAtDifferentSampleRates) { | ||||||
| 
 | 
 | ||||||
| TEST_F(ClassifyTest, SucceedsWithInsufficientData) { | TEST_F(ClassifyTest, SucceedsWithInsufficientData) { | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, | ||||||
|                           AudioClassifier::Create(std::move(options))); |                           AudioClassifier::Create(std::move(options))); | ||||||
|  | @ -374,7 +373,7 @@ TEST_F(ClassifyTest, SucceedsWithInsufficientData) { | ||||||
| TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) { | TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) { | ||||||
|   auto audio_buffer = GetAudioData(k16kTestWavForTwoHeadsFilename); |   auto audio_buffer = GetAudioData(k16kTestWavForTwoHeadsFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, | ||||||
|                           AudioClassifier::Create(std::move(options))); |                           AudioClassifier::Create(std::move(options))); | ||||||
|  | @ -388,7 +387,7 @@ TEST_F(ClassifyTest, SucceedsWithMultiheadsModel) { | ||||||
| TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) { | TEST_F(ClassifyTest, SucceedsWithMultiheadsModelAndResampling) { | ||||||
|   auto audio_buffer = GetAudioData(k44kTestWavForTwoHeadsFilename); |   auto audio_buffer = GetAudioData(k44kTestWavForTwoHeadsFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, | ||||||
|                           AudioClassifier::Create(std::move(options))); |                           AudioClassifier::Create(std::move(options))); | ||||||
|  | @ -404,7 +403,7 @@ TEST_F(ClassifyTest, | ||||||
|   auto audio_buffer_44k_hz = GetAudioData(k44kTestWavForTwoHeadsFilename); |   auto audio_buffer_44k_hz = GetAudioData(k44kTestWavForTwoHeadsFilename); | ||||||
|   auto audio_buffer_16k_hz = GetAudioData(k16kTestWavForTwoHeadsFilename); |   auto audio_buffer_16k_hz = GetAudioData(k16kTestWavForTwoHeadsFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kTwoHeadsModelWithMetadata); | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, | ||||||
|                           AudioClassifier::Create(std::move(options))); |                           AudioClassifier::Create(std::move(options))); | ||||||
|  | @ -424,7 +423,7 @@ TEST_F(ClassifyTest, | ||||||
| TEST_F(ClassifyTest, SucceedsWithMaxResultOption) { | TEST_F(ClassifyTest, SucceedsWithMaxResultOption) { | ||||||
|   auto audio_buffer = GetAudioData(k48kTestWavFilename); |   auto audio_buffer = GetAudioData(k48kTestWavFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   options->classifier_options.max_results = 1; |   options->classifier_options.max_results = 1; | ||||||
|   options->classifier_options.score_threshold = 0.35f; |   options->classifier_options.score_threshold = 0.35f; | ||||||
|  | @ -440,7 +439,7 @@ TEST_F(ClassifyTest, SucceedsWithMaxResultOption) { | ||||||
| TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { | TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { | ||||||
|   auto audio_buffer = GetAudioData(k48kTestWavFilename); |   auto audio_buffer = GetAudioData(k48kTestWavFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   options->classifier_options.score_threshold = 0.35f; |   options->classifier_options.score_threshold = 0.35f; | ||||||
|   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, |   MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<AudioClassifier> audio_classifier, | ||||||
|  | @ -455,7 +454,7 @@ TEST_F(ClassifyTest, SucceedsWithScoreThresholdOption) { | ||||||
| TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) { | TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) { | ||||||
|   auto audio_buffer = GetAudioData(k48kTestWavFilename); |   auto audio_buffer = GetAudioData(k48kTestWavFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   options->classifier_options.score_threshold = 0.1f; |   options->classifier_options.score_threshold = 0.1f; | ||||||
|   options->classifier_options.category_allowlist.push_back("Speech"); |   options->classifier_options.category_allowlist.push_back("Speech"); | ||||||
|  | @ -471,7 +470,7 @@ TEST_F(ClassifyTest, SucceedsWithCategoryAllowlist) { | ||||||
| TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) { | TEST_F(ClassifyTest, SucceedsWithCategoryDenylist) { | ||||||
|   auto audio_buffer = GetAudioData(k48kTestWavFilename); |   auto audio_buffer = GetAudioData(k48kTestWavFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   options->classifier_options.score_threshold = 0.9f; |   options->classifier_options.score_threshold = 0.9f; | ||||||
|   options->classifier_options.category_denylist.push_back("Speech"); |   options->classifier_options.category_denylist.push_back("Speech"); | ||||||
|  | @ -499,7 +498,7 @@ TEST_F(ClassifyAsyncTest, Succeeds) { | ||||||
|   constexpr int kSampleRateHz = 48000; |   constexpr int kSampleRateHz = 48000; | ||||||
|   auto audio_buffer = GetAudioData(k48kTestWavFilename); |   auto audio_buffer = GetAudioData(k48kTestWavFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   options->classifier_options.max_results = 1; |   options->classifier_options.max_results = 1; | ||||||
|   options->classifier_options.score_threshold = 0.3f; |   options->classifier_options.score_threshold = 0.3f; | ||||||
|  | @ -529,7 +528,7 @@ TEST_F(ClassifyAsyncTest, SucceedsWithNonDeterministicNumAudioSamples) { | ||||||
|   constexpr int kSampleRateHz = 48000; |   constexpr int kSampleRateHz = 48000; | ||||||
|   auto audio_buffer = GetAudioData(k48kTestWavFilename); |   auto audio_buffer = GetAudioData(k48kTestWavFilename); | ||||||
|   auto options = std::make_unique<AudioClassifierOptions>(); |   auto options = std::make_unique<AudioClassifierOptions>(); | ||||||
|   options->base_options.model_file_name = |   options->base_options.model_asset_path = | ||||||
|       JoinPath("./", kTestDataDirectory, kModelWithMetadata); |       JoinPath("./", kTestDataDirectory, kModelWithMetadata); | ||||||
|   options->classifier_options.max_results = 1; |   options->classifier_options.max_results = 1; | ||||||
|   options->classifier_options.score_threshold = 0.3f; |   options->classifier_options.score_threshold = 0.3f; | ||||||
|  |  | ||||||
|  | @ -24,7 +24,7 @@ mediapipe_proto_library( | ||||||
|     deps = [ |     deps = [ | ||||||
|         "//mediapipe/framework:calculator_options_proto", |         "//mediapipe/framework:calculator_options_proto", | ||||||
|         "//mediapipe/framework:calculator_proto", |         "//mediapipe/framework:calculator_proto", | ||||||
|         "//mediapipe/tasks/cc/components:classifier_options_proto", |         "//mediapipe/tasks/cc/components/proto:classifier_options_proto", | ||||||
|         "//mediapipe/tasks/cc/core/proto:base_options_proto", |         "//mediapipe/tasks/cc/core/proto:base_options_proto", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | @ -18,7 +18,7 @@ syntax = "proto2"; | ||||||
| package mediapipe.tasks.audio.audio_classifier.proto; | package mediapipe.tasks.audio.audio_classifier.proto; | ||||||
| 
 | 
 | ||||||
| import "mediapipe/framework/calculator.proto"; | import "mediapipe/framework/calculator.proto"; | ||||||
| import "mediapipe/tasks/cc/components/classifier_options.proto"; | import "mediapipe/tasks/cc/components/proto/classifier_options.proto"; | ||||||
| import "mediapipe/tasks/cc/core/proto/base_options.proto"; | import "mediapipe/tasks/cc/core/proto/base_options.proto"; | ||||||
| 
 | 
 | ||||||
| message AudioClassifierOptions { | message AudioClassifierOptions { | ||||||
|  | @ -31,7 +31,7 @@ message AudioClassifierOptions { | ||||||
| 
 | 
 | ||||||
|   // Options for configuring the classifier behavior, such as score threshold, |   // Options for configuring the classifier behavior, such as score threshold, | ||||||
|   // number of results, etc. |   // number of results, etc. | ||||||
|   optional ClassifierOptions classifier_options = 2; |   optional components.proto.ClassifierOptions classifier_options = 2; | ||||||
| 
 | 
 | ||||||
|   // The default sample rate of the input audio. Must be set when the |   // The default sample rate of the input audio. Must be set when the | ||||||
|   // AudioClassifier is configured to process audio stream data. |   // AudioClassifier is configured to process audio stream data. | ||||||
|  |  | ||||||
|  | @ -35,6 +35,8 @@ cc_library( | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":image_preprocessing_options_cc_proto", |         ":image_preprocessing_options_cc_proto", | ||||||
|         "//mediapipe/calculators/core:pass_through_calculator", |         "//mediapipe/calculators/core:pass_through_calculator", | ||||||
|  |         "//mediapipe/calculators/image:image_clone_calculator", | ||||||
|  |         "//mediapipe/calculators/image:image_clone_calculator_cc_proto", | ||||||
|         "//mediapipe/calculators/image:image_properties_calculator", |         "//mediapipe/calculators/image:image_properties_calculator", | ||||||
|         "//mediapipe/calculators/tensor:image_to_tensor_calculator", |         "//mediapipe/calculators/tensor:image_to_tensor_calculator", | ||||||
|         "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", |         "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", | ||||||
|  | @ -56,21 +58,11 @@ cc_library( | ||||||
| 
 | 
 | ||||||
| # TODO: Enable this test | # TODO: Enable this test | ||||||
| 
 | 
 | ||||||
| mediapipe_proto_library( |  | ||||||
|     name = "segmenter_options_proto", |  | ||||||
|     srcs = ["segmenter_options.proto"], |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| cc_library( | cc_library( | ||||||
|     name = "classifier_options", |     name = "classifier_options", | ||||||
|     srcs = ["classifier_options.cc"], |     srcs = ["classifier_options.cc"], | ||||||
|     hdrs = ["classifier_options.h"], |     hdrs = ["classifier_options.h"], | ||||||
|     deps = [":classifier_options_cc_proto"], |     deps = ["//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto"], | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| mediapipe_proto_library( |  | ||||||
|     name = "classifier_options_proto", |  | ||||||
|     srcs = ["classifier_options.proto"], |  | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| mediapipe_proto_library( | mediapipe_proto_library( | ||||||
|  | @ -81,6 +73,7 @@ mediapipe_proto_library( | ||||||
|         "//mediapipe/framework:calculator_options_proto", |         "//mediapipe/framework:calculator_options_proto", | ||||||
|         "//mediapipe/framework:calculator_proto", |         "//mediapipe/framework:calculator_proto", | ||||||
|         "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", |         "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_proto", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -90,7 +83,6 @@ cc_library( | ||||||
|     hdrs = ["classification_postprocessing.h"], |     hdrs = ["classification_postprocessing.h"], | ||||||
|     deps = [ |     deps = [ | ||||||
|         ":classification_postprocessing_options_cc_proto", |         ":classification_postprocessing_options_cc_proto", | ||||||
|         ":classifier_options_cc_proto", |  | ||||||
|         "//mediapipe/calculators/core:split_vector_calculator", |         "//mediapipe/calculators/core:split_vector_calculator", | ||||||
|         "//mediapipe/calculators/core:split_vector_calculator_cc_proto", |         "//mediapipe/calculators/core:split_vector_calculator_cc_proto", | ||||||
|         "//mediapipe/calculators/tensor:tensors_dequantization_calculator", |         "//mediapipe/calculators/tensor:tensors_dequantization_calculator", | ||||||
|  | @ -104,7 +96,12 @@ cc_library( | ||||||
|         "//mediapipe/tasks/cc:common", |         "//mediapipe/tasks/cc:common", | ||||||
|         "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", |         "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator", | ||||||
|         "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", |         "//mediapipe/tasks/cc/components/calculators:classification_aggregation_calculator_cc_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator", | ||||||
|  |         "//mediapipe/tasks/cc/components/calculators:score_calibration_calculator_cc_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/calculators:score_calibration_utils", | ||||||
|         "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", |         "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/proto:classifier_options_cc_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/utils:source_or_node_output", | ||||||
|         "//mediapipe/tasks/cc/core:model_resources", |         "//mediapipe/tasks/cc/core:model_resources", | ||||||
|         "//mediapipe/tasks/cc/metadata:metadata_extractor", |         "//mediapipe/tasks/cc/metadata:metadata_extractor", | ||||||
|         "//mediapipe/tasks/metadata:metadata_schema_cc", |         "//mediapipe/tasks/metadata:metadata_schema_cc", | ||||||
|  | @ -119,3 +116,38 @@ cc_library( | ||||||
|     ], |     ], | ||||||
|     alwayslink = 1, |     alwayslink = 1, | ||||||
| ) | ) | ||||||
|  | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "embedder_options", | ||||||
|  |     srcs = ["embedder_options.cc"], | ||||||
|  |     hdrs = ["embedder_options.h"], | ||||||
|  |     deps = ["//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto"], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "embedding_postprocessing_graph", | ||||||
|  |     srcs = ["embedding_postprocessing_graph.cc"], | ||||||
|  |     hdrs = ["embedding_postprocessing_graph.h"], | ||||||
|  |     deps = [ | ||||||
|  |         "//mediapipe/calculators/tensor:tensors_dequantization_calculator", | ||||||
|  |         "//mediapipe/framework:calculator_framework", | ||||||
|  |         "//mediapipe/framework/api2:builder", | ||||||
|  |         "//mediapipe/framework/api2:port", | ||||||
|  |         "//mediapipe/framework/formats:tensor", | ||||||
|  |         "//mediapipe/framework/tool:options_map", | ||||||
|  |         "//mediapipe/tasks/cc:common", | ||||||
|  |         "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator", | ||||||
|  |         "//mediapipe/tasks/cc/components/calculators:tensors_to_embeddings_calculator_cc_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/proto:embedding_postprocessing_graph_options_cc_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/utils:source_or_node_output", | ||||||
|  |         "//mediapipe/tasks/cc/core:model_resources", | ||||||
|  |         "//mediapipe/tasks/cc/metadata:metadata_extractor", | ||||||
|  |         "@com_google_absl//absl/status", | ||||||
|  |         "@com_google_absl//absl/status:statusor", | ||||||
|  |         "@com_google_absl//absl/strings:str_format", | ||||||
|  |         "@org_tensorflow//tensorflow/lite/schema:schema_fbs", | ||||||
|  |     ], | ||||||
|  |     alwayslink = 1, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | @ -113,3 +113,66 @@ cc_test( | ||||||
|         "@com_google_absl//absl/strings", |         "@com_google_absl//absl/strings", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
|  | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "end_loop_calculator", | ||||||
|  |     srcs = ["end_loop_calculator.cc"], | ||||||
|  |     visibility = ["//visibility:public"], | ||||||
|  |     deps = [ | ||||||
|  |         "//mediapipe/calculators/core:end_loop_calculator", | ||||||
|  |         "//mediapipe/framework:calculator_context", | ||||||
|  |         "//mediapipe/framework:calculator_contract", | ||||||
|  |         "//mediapipe/framework:calculator_framework", | ||||||
|  |         "//mediapipe/framework:collection_item_id", | ||||||
|  |         "//mediapipe/framework:packet", | ||||||
|  |         "//mediapipe/framework/port:integral_types", | ||||||
|  |         "//mediapipe/framework/port:ret_check", | ||||||
|  |         "//mediapipe/framework/port:status", | ||||||
|  |         "//mediapipe/tasks/cc/components/containers:classifications_cc_proto", | ||||||
|  |     ], | ||||||
|  |     alwayslink = 1, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | mediapipe_proto_library( | ||||||
|  |     name = "tensors_to_embeddings_calculator_proto", | ||||||
|  |     srcs = ["tensors_to_embeddings_calculator.proto"], | ||||||
|  |     deps = [ | ||||||
|  |         "//mediapipe/framework:calculator_options_proto", | ||||||
|  |         "//mediapipe/framework:calculator_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/proto:embedder_options_proto", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | cc_library( | ||||||
|  |     name = "tensors_to_embeddings_calculator", | ||||||
|  |     srcs = ["tensors_to_embeddings_calculator.cc"], | ||||||
|  |     deps = [ | ||||||
|  |         ":tensors_to_embeddings_calculator_cc_proto", | ||||||
|  |         "//mediapipe/framework:calculator_framework", | ||||||
|  |         "//mediapipe/framework/api2:node", | ||||||
|  |         "//mediapipe/framework/api2:port", | ||||||
|  |         "//mediapipe/framework/formats:tensor", | ||||||
|  |         "//mediapipe/framework/port:ret_check", | ||||||
|  |         "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/proto:embedder_options_cc_proto", | ||||||
|  |         "@com_google_absl//absl/status", | ||||||
|  |         "@com_google_absl//absl/strings:str_format", | ||||||
|  |     ], | ||||||
|  |     alwayslink = 1, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | cc_test( | ||||||
|  |     name = "tensors_to_embeddings_calculator_test", | ||||||
|  |     srcs = ["tensors_to_embeddings_calculator_test.cc"], | ||||||
|  |     deps = [ | ||||||
|  |         ":tensors_to_embeddings_calculator", | ||||||
|  |         "//mediapipe/framework:calculator_framework", | ||||||
|  |         "//mediapipe/framework:calculator_runner", | ||||||
|  |         "//mediapipe/framework:packet", | ||||||
|  |         "//mediapipe/framework/formats:tensor", | ||||||
|  |         "//mediapipe/framework/port:gtest_main", | ||||||
|  |         "//mediapipe/framework/port:parse_text_proto", | ||||||
|  |         "//mediapipe/tasks/cc/components/containers/proto:embeddings_cc_proto", | ||||||
|  |         "@com_google_absl//absl/status", | ||||||
|  |     ], | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | @ -0,0 +1,29 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved.
 | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/calculators/core/end_loop_calculator.h" | ||||||
|  | 
 | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" | ||||||
|  | 
 | ||||||
|  | // Specialized EndLoopCalculator for Tasks specific types.
 | ||||||
|  | namespace mediapipe::tasks { | ||||||
|  | 
 | ||||||
|  | typedef EndLoopCalculator<std::vector<ClassificationResult>> | ||||||
|  |     EndLoopClassificationResultCalculator; | ||||||
|  | REGISTER_CALCULATOR(::mediapipe::tasks::EndLoopClassificationResultCalculator); | ||||||
|  | 
 | ||||||
|  | }  // namespace mediapipe::tasks
 | ||||||
|  | @ -25,7 +25,7 @@ mediapipe_proto_library( | ||||||
|         "//mediapipe/framework:calculator_options_proto", |         "//mediapipe/framework:calculator_options_proto", | ||||||
|         "//mediapipe/framework:calculator_proto", |         "//mediapipe/framework:calculator_proto", | ||||||
|         "//mediapipe/framework/formats:image_format_proto", |         "//mediapipe/framework/formats:image_format_proto", | ||||||
|         "//mediapipe/tasks/cc/components:segmenter_options_proto", |         "//mediapipe/tasks/cc/components/proto:segmenter_options_proto", | ||||||
|         "//mediapipe/util:label_map_proto", |         "//mediapipe/util:label_map_proto", | ||||||
|     ], |     ], | ||||||
| ) | ) | ||||||
|  | @ -45,7 +45,7 @@ cc_library( | ||||||
|         "//mediapipe/framework/port:opencv_core", |         "//mediapipe/framework/port:opencv_core", | ||||||
|         "//mediapipe/framework/port:opencv_imgproc", |         "//mediapipe/framework/port:opencv_imgproc", | ||||||
|         "//mediapipe/framework/port:status", |         "//mediapipe/framework/port:status", | ||||||
|         "//mediapipe/tasks/cc/components:segmenter_options_cc_proto", |         "//mediapipe/tasks/cc/components/proto:segmenter_options_cc_proto", | ||||||
|         "//mediapipe/tasks/cc/vision/utils:image_utils", |         "//mediapipe/tasks/cc/vision/utils:image_utils", | ||||||
|         "//mediapipe/util:label_map_cc_proto", |         "//mediapipe/util:label_map_cc_proto", | ||||||
|         "@com_google_absl//absl/status", |         "@com_google_absl//absl/status", | ||||||
|  |  | ||||||
|  | @ -36,19 +36,22 @@ limitations under the License. | ||||||
| #include "mediapipe/framework/port/opencv_imgproc_inc.h" | #include "mediapipe/framework/port/opencv_imgproc_inc.h" | ||||||
| #include "mediapipe/framework/port/status_macros.h" | #include "mediapipe/framework/port/status_macros.h" | ||||||
| #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" | #include "mediapipe/tasks/cc/components/calculators/tensor/tensors_to_segmentation_calculator.pb.h" | ||||||
| #include "mediapipe/tasks/cc/components/segmenter_options.pb.h" | #include "mediapipe/tasks/cc/components/proto/segmenter_options.pb.h" | ||||||
| #include "mediapipe/tasks/cc/vision/utils/image_utils.h" | #include "mediapipe/tasks/cc/vision/utils/image_utils.h" | ||||||
| #include "mediapipe/util/label_map.pb.h" | #include "mediapipe/util/label_map.pb.h" | ||||||
| 
 | 
 | ||||||
| namespace mediapipe { | namespace mediapipe { | ||||||
| namespace api2 { | namespace tasks { | ||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| using ::mediapipe::Image; | using ::mediapipe::Image; | ||||||
| using ::mediapipe::ImageFrameSharedPtr; | using ::mediapipe::ImageFrameSharedPtr; | ||||||
| using ::mediapipe::tasks::SegmenterOptions; | using ::mediapipe::api2::Input; | ||||||
|  | using ::mediapipe::api2::Node; | ||||||
|  | using ::mediapipe::api2::Output; | ||||||
| using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions; | using ::mediapipe::tasks::TensorsToSegmentationCalculatorOptions; | ||||||
|  | using ::mediapipe::tasks::components::proto::SegmenterOptions; | ||||||
| using ::mediapipe::tasks::vision::GetImageLikeTensorShape; | using ::mediapipe::tasks::vision::GetImageLikeTensorShape; | ||||||
| using ::mediapipe::tasks::vision::Shape; | using ::mediapipe::tasks::vision::Shape; | ||||||
| 
 | 
 | ||||||
|  | @ -254,7 +257,7 @@ std::vector<Image> TensorsToSegmentationCalculator::GetSegmentationResult( | ||||||
|   return segmented_masks; |   return segmented_masks; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| MEDIAPIPE_REGISTER_NODE(TensorsToSegmentationCalculator); | MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::TensorsToSegmentationCalculator); | ||||||
| 
 | 
 | ||||||
| }  // namespace api2
 | }  // namespace tasks
 | ||||||
| }  // namespace mediapipe
 | }  // namespace mediapipe
 | ||||||
|  |  | ||||||
|  | @ -18,7 +18,7 @@ syntax = "proto2"; | ||||||
| package mediapipe.tasks; | package mediapipe.tasks; | ||||||
| 
 | 
 | ||||||
| import "mediapipe/framework/calculator.proto"; | import "mediapipe/framework/calculator.proto"; | ||||||
| import "mediapipe/tasks/cc/components/segmenter_options.proto"; | import "mediapipe/tasks/cc/components/proto/segmenter_options.proto"; | ||||||
| import "mediapipe/util/label_map.proto"; | import "mediapipe/util/label_map.proto"; | ||||||
| 
 | 
 | ||||||
| message TensorsToSegmentationCalculatorOptions { | message TensorsToSegmentationCalculatorOptions { | ||||||
|  | @ -26,7 +26,7 @@ message TensorsToSegmentationCalculatorOptions { | ||||||
|     optional TensorsToSegmentationCalculatorOptions ext = 458105876; |     optional TensorsToSegmentationCalculatorOptions ext = 458105876; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   optional SegmenterOptions segmenter_options = 1; |   optional components.proto.SegmenterOptions segmenter_options = 1; | ||||||
| 
 | 
 | ||||||
|   // Identifying information for each classification label. |   // Identifying information for each classification label. | ||||||
|   map<int64, mediapipe.LabelMapItem> label_items = 2; |   map<int64, mediapipe.LabelMapItem> label_items = 2; | ||||||
|  |  | ||||||
|  | @ -117,7 +117,7 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionOne) { | ||||||
|   CalculatorRunner runner( |   CalculatorRunner runner( | ||||||
|       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( |       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( | ||||||
|           R"pb( |           R"pb( | ||||||
|             calculator: "TensorsToSegmentationCalculator" |             calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" | ||||||
|             input_stream: "TENSORS:tensors" |             input_stream: "TENSORS:tensors" | ||||||
|             output_stream: "SEGMENTATION:segmentation" |             output_stream: "SEGMENTATION:segmentation" | ||||||
|             options { |             options { | ||||||
|  | @ -144,7 +144,7 @@ TEST(TensorsToSegmentationCalculatorTest, FailsInvalidTensorDimensionFive) { | ||||||
|   CalculatorRunner runner( |   CalculatorRunner runner( | ||||||
|       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( |       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( | ||||||
|           R"pb( |           R"pb( | ||||||
|             calculator: "TensorsToSegmentationCalculator" |             calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" | ||||||
|             input_stream: "TENSORS:tensors" |             input_stream: "TENSORS:tensors" | ||||||
|             output_stream: "SEGMENTATION:segmentation" |             output_stream: "SEGMENTATION:segmentation" | ||||||
|             options { |             options { | ||||||
|  | @ -172,7 +172,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSoftmax) { | ||||||
|   CalculatorRunner runner( |   CalculatorRunner runner( | ||||||
|       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( |       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( | ||||||
|           R"pb( |           R"pb( | ||||||
|             calculator: "TensorsToSegmentationCalculator" |             calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" | ||||||
|             input_stream: "TENSORS:tensors" |             input_stream: "TENSORS:tensors" | ||||||
|             output_stream: "SEGMENTATION:0:segmented_mask_0" |             output_stream: "SEGMENTATION:0:segmented_mask_0" | ||||||
|             output_stream: "SEGMENTATION:1:segmented_mask_1" |             output_stream: "SEGMENTATION:1:segmented_mask_1" | ||||||
|  | @ -217,7 +217,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithNone) { | ||||||
|   CalculatorRunner runner( |   CalculatorRunner runner( | ||||||
|       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( |       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( | ||||||
|           R"pb( |           R"pb( | ||||||
|             calculator: "TensorsToSegmentationCalculator" |             calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" | ||||||
|             input_stream: "TENSORS:tensors" |             input_stream: "TENSORS:tensors" | ||||||
|             output_stream: "SEGMENTATION:0:segmented_mask_0" |             output_stream: "SEGMENTATION:0:segmented_mask_0" | ||||||
|             output_stream: "SEGMENTATION:1:segmented_mask_1" |             output_stream: "SEGMENTATION:1:segmented_mask_1" | ||||||
|  | @ -258,7 +258,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsConfidenceMaskWithSigmoid) { | ||||||
|   CalculatorRunner runner( |   CalculatorRunner runner( | ||||||
|       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( |       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( | ||||||
|           R"pb( |           R"pb( | ||||||
|             calculator: "TensorsToSegmentationCalculator" |             calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" | ||||||
|             input_stream: "TENSORS:tensors" |             input_stream: "TENSORS:tensors" | ||||||
|             output_stream: "SEGMENTATION:0:segmented_mask_0" |             output_stream: "SEGMENTATION:0:segmented_mask_0" | ||||||
|             output_stream: "SEGMENTATION:1:segmented_mask_1" |             output_stream: "SEGMENTATION:1:segmented_mask_1" | ||||||
|  | @ -300,7 +300,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMask) { | ||||||
|   CalculatorRunner runner( |   CalculatorRunner runner( | ||||||
|       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( |       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( | ||||||
|           R"pb( |           R"pb( | ||||||
|             calculator: "TensorsToSegmentationCalculator" |             calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" | ||||||
|             input_stream: "TENSORS:tensors" |             input_stream: "TENSORS:tensors" | ||||||
|             output_stream: "SEGMENTATION:segmentation" |             output_stream: "SEGMENTATION:segmentation" | ||||||
|             options { |             options { | ||||||
|  | @ -333,7 +333,7 @@ TEST(TensorsToSegmentationCalculatorTest, SucceedsCategoryMaskResize) { | ||||||
|   CalculatorRunner runner( |   CalculatorRunner runner( | ||||||
|       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( |       mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig::Node>( | ||||||
|           R"pb( |           R"pb( | ||||||
|             calculator: "TensorsToSegmentationCalculator" |             calculator: "mediapipe.tasks.TensorsToSegmentationCalculator" | ||||||
|             input_stream: "TENSORS:tensors" |             input_stream: "TENSORS:tensors" | ||||||
|             input_stream: "OUTPUT_SIZE:size" |             input_stream: "OUTPUT_SIZE:size" | ||||||
|             output_stream: "SEGMENTATION:segmentation" |             output_stream: "SEGMENTATION:segmentation" | ||||||
|  |  | ||||||
|  | @ -0,0 +1,158 @@ | ||||||
|  | // Copyright 2022 The MediaPipe Authors.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #include <math.h> | ||||||
|  | 
 | ||||||
|  | #include <algorithm> | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "absl/status/status.h" | ||||||
|  | #include "absl/strings/str_format.h" | ||||||
|  | #include "mediapipe/framework/api2/node.h" | ||||||
|  | #include "mediapipe/framework/api2/port.h" | ||||||
|  | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | #include "mediapipe/framework/formats/tensor.h" | ||||||
|  | #include "mediapipe/framework/port/ret_check.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/calculators/tensors_to_embeddings_calculator.pb.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/proto/embedder_options.pb.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace api2 { | ||||||
|  | 
 | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | using ::mediapipe::tasks::components::containers::proto::EmbeddingEntry; | ||||||
|  | using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; | ||||||
|  | 
 | ||||||
|  | // Computes the inverse L2 norm of the provided array of values. Returns 1.0 in
 | ||||||
|  | // case all values are 0.
 | ||||||
|  | float GetInverseL2Norm(const float* values, int size) { | ||||||
|  |   float squared_l2_norm = 0.0f; | ||||||
|  |   for (int i = 0; i < size; ++i) { | ||||||
|  |     squared_l2_norm += values[i] * values[i]; | ||||||
|  |   } | ||||||
|  |   float inv_l2_norm = 1.0f; | ||||||
|  |   if (squared_l2_norm > 0.0f) { | ||||||
|  |     inv_l2_norm = 1.0f / std::sqrt(squared_l2_norm); | ||||||
|  |   } | ||||||
|  |   return inv_l2_norm; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | 
 | ||||||
|  | // Converts tensors into an EmbeddingResult object, performing optional
 | ||||||
|  | // L2-normalization and scalar-quantization on-the-fly if required through the
 | ||||||
|  | // options.
 | ||||||
|  | //
 | ||||||
|  | // Input:
 | ||||||
|  | //   TENSORS - std::vector<Tensor>
 | ||||||
|  | //     A vector of one or more Tensors of type kFloat32.
 | ||||||
|  | // Output:
 | ||||||
|  | //   EMBEDDINGS - EmbeddingResult
 | ||||||
|  | //     The contents of the input tensors converted into an EmbeddingResult
 | ||||||
|  | //     proto.
 | ||||||
|  | class TensorsToEmbeddingsCalculator : public Node { | ||||||
|  |  public: | ||||||
|  |   static constexpr Input<std::vector<Tensor>> kTensorsIn{"TENSORS"}; | ||||||
|  |   static constexpr Output<EmbeddingResult> kEmbeddingsOut{"EMBEDDING_RESULT"}; | ||||||
|  |   MEDIAPIPE_NODE_CONTRACT(kTensorsIn, kEmbeddingsOut); | ||||||
|  | 
 | ||||||
|  |   absl::Status Open(CalculatorContext* cc) override; | ||||||
|  |   absl::Status Process(CalculatorContext* cc) override; | ||||||
|  | 
 | ||||||
|  |  private: | ||||||
|  |   bool l2_normalize_; | ||||||
|  |   bool quantize_; | ||||||
|  |   std::vector<std::string> head_names_; | ||||||
|  | 
 | ||||||
|  |   void FillFloatEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); | ||||||
|  |   void FillQuantizedEmbeddingEntry(const Tensor& tensor, EmbeddingEntry* entry); | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | absl::Status TensorsToEmbeddingsCalculator::Open(CalculatorContext* cc) { | ||||||
|  |   auto options = cc->Options<mediapipe::TensorsToEmbeddingsCalculatorOptions>(); | ||||||
|  |   l2_normalize_ = options.embedder_options().l2_normalize(); | ||||||
|  |   quantize_ = options.embedder_options().quantize(); | ||||||
|  |   if (!options.head_names().empty()) { | ||||||
|  |     head_names_.assign(options.head_names().begin(), | ||||||
|  |                        options.head_names().end()); | ||||||
|  |   } | ||||||
|  |   return absl::OkStatus(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | absl::Status TensorsToEmbeddingsCalculator::Process(CalculatorContext* cc) { | ||||||
|  |   EmbeddingResult result; | ||||||
|  |   const auto& tensors = *kTensorsIn(cc); | ||||||
|  |   if (!head_names_.empty() && tensors.size() != head_names_.size()) { | ||||||
|  |     return absl::InvalidArgumentError(absl::StrFormat( | ||||||
|  |         "Mismatch between number of provided head names (%d) and number " | ||||||
|  |         "of input tensors (%d).", | ||||||
|  |         head_names_.size(), tensors.size())); | ||||||
|  |   } | ||||||
|  |   for (int i = 0; i < tensors.size(); ++i) { | ||||||
|  |     const auto& tensor = tensors[i]; | ||||||
|  |     RET_CHECK(tensor.element_type() == Tensor::ElementType::kFloat32); | ||||||
|  |     auto* embeddings = result.add_embeddings(); | ||||||
|  |     embeddings->set_head_index(i); | ||||||
|  |     if (!head_names_.empty()) { | ||||||
|  |       embeddings->set_head_name(head_names_[i]); | ||||||
|  |     } | ||||||
|  |     if (quantize_) { | ||||||
|  |       FillQuantizedEmbeddingEntry(tensor, embeddings->add_entries()); | ||||||
|  |     } else { | ||||||
|  |       FillFloatEmbeddingEntry(tensor, embeddings->add_entries()); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   kEmbeddingsOut(cc).Send(result); | ||||||
|  |   return absl::OkStatus(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void TensorsToEmbeddingsCalculator::FillFloatEmbeddingEntry( | ||||||
|  |     const Tensor& tensor, EmbeddingEntry* entry) { | ||||||
|  |   int size = tensor.shape().num_elements(); | ||||||
|  |   auto tensor_view = tensor.GetCpuReadView(); | ||||||
|  |   const float* tensor_buffer = tensor_view.buffer<float>(); | ||||||
|  |   float inv_l2_norm = | ||||||
|  |       l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; | ||||||
|  |   auto* float_embedding = entry->mutable_float_embedding(); | ||||||
|  |   for (int i = 0; i < size; ++i) { | ||||||
|  |     float_embedding->add_values(tensor_buffer[i] * inv_l2_norm); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | void TensorsToEmbeddingsCalculator::FillQuantizedEmbeddingEntry( | ||||||
|  |     const Tensor& tensor, EmbeddingEntry* entry) { | ||||||
|  |   int size = tensor.shape().num_elements(); | ||||||
|  |   auto tensor_view = tensor.GetCpuReadView(); | ||||||
|  |   const float* tensor_buffer = tensor_view.buffer<float>(); | ||||||
|  |   float inv_l2_norm = | ||||||
|  |       l2_normalize_ ? GetInverseL2Norm(tensor_buffer, size) : 1.0f; | ||||||
|  |   auto* values = entry->mutable_quantized_embedding()->mutable_values(); | ||||||
|  |   values->resize(size); | ||||||
|  |   for (int i = 0; i < size; ++i) { | ||||||
|  |     // Normalize.
 | ||||||
|  |     float normalized = tensor_buffer[i] * inv_l2_norm; | ||||||
|  |     // Quantize.
 | ||||||
|  |     int unclamped_value = static_cast<int>(roundf(normalized * 128)); | ||||||
|  |     // Clamp and assign.
 | ||||||
|  |     (*values)[i] = | ||||||
|  |         static_cast<char>(std::max(-128, std::min(unclamped_value, 127))); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | MEDIAPIPE_REGISTER_NODE(TensorsToEmbeddingsCalculator); | ||||||
|  | 
 | ||||||
|  | }  // namespace api2
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
|  | @ -0,0 +1,35 @@ | ||||||
|  | /* Copyright 2022 The MediaPipe Authors. All Rights Reserved. | ||||||
|  | 
 | ||||||
|  | Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | you may not use this file except in compliance with the License. | ||||||
|  | You may obtain a copy of the License at | ||||||
|  | 
 | ||||||
|  |     http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | 
 | ||||||
|  | Unless required by applicable law or agreed to in writing, software | ||||||
|  | distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | See the License for the specific language governing permissions and | ||||||
|  | limitations under the License. | ||||||
|  | ==============================================================================*/ | ||||||
|  | 
 | ||||||
|  | syntax = "proto2"; | ||||||
|  | 
 | ||||||
|  | package mediapipe; | ||||||
|  | 
 | ||||||
|  | import "mediapipe/framework/calculator.proto"; | ||||||
|  | import "mediapipe/tasks/cc/components/proto/embedder_options.proto"; | ||||||
|  | 
 | ||||||
|  | message TensorsToEmbeddingsCalculatorOptions { | ||||||
|  |   extend mediapipe.CalculatorOptions { | ||||||
|  |     optional TensorsToEmbeddingsCalculatorOptions ext = 474762326; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // The embedder options defining whether to L2-normalize or scalar-quantize | ||||||
|  |   // the outputs. | ||||||
|  |   optional mediapipe.tasks.components.proto.EmbedderOptions embedder_options = | ||||||
|  |       1; | ||||||
|  | 
 | ||||||
|  |   // The embedder head names. | ||||||
|  |   repeated string head_names = 2; | ||||||
|  | } | ||||||
|  | @ -0,0 +1,249 @@ | ||||||
|  | // Copyright 2022 The MediaPipe Authors.
 | ||||||
|  | //
 | ||||||
|  | // Licensed under the Apache License, Version 2.0 (the "License");
 | ||||||
|  | // you may not use this file except in compliance with the License.
 | ||||||
|  | // You may obtain a copy of the License at
 | ||||||
|  | //
 | ||||||
|  | //      http://www.apache.org/licenses/LICENSE-2.0
 | ||||||
|  | //
 | ||||||
|  | // Unless required by applicable law or agreed to in writing, software
 | ||||||
|  | // distributed under the License is distributed on an "AS IS" BASIS,
 | ||||||
|  | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||||||
|  | // See the License for the specific language governing permissions and
 | ||||||
|  | // limitations under the License.
 | ||||||
|  | 
 | ||||||
|  | #include <memory> | ||||||
|  | #include <vector> | ||||||
|  | 
 | ||||||
|  | #include "absl/status/status.h" | ||||||
|  | #include "mediapipe/framework/calculator_framework.h" | ||||||
|  | #include "mediapipe/framework/calculator_runner.h" | ||||||
|  | #include "mediapipe/framework/formats/tensor.h" | ||||||
|  | #include "mediapipe/framework/packet.h" | ||||||
|  | #include "mediapipe/framework/port/gmock.h" | ||||||
|  | #include "mediapipe/framework/port/gtest.h" | ||||||
|  | #include "mediapipe/framework/port/parse_text_proto.h" | ||||||
|  | #include "mediapipe/framework/port/status_matchers.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/containers/proto/embeddings.pb.h" | ||||||
|  | 
 | ||||||
|  | namespace mediapipe { | ||||||
|  | namespace { | ||||||
|  | 
 | ||||||
|  | using ::mediapipe::tasks::components::containers::proto::EmbeddingResult; | ||||||
|  | using ::testing::HasSubstr; | ||||||
|  | using Node = ::mediapipe::CalculatorGraphConfig::Node; | ||||||
|  | 
 | ||||||
|  | // Builds the graph and feeds inputs.
 | ||||||
|  | void BuildGraph(CalculatorRunner* runner, | ||||||
|  |                 std::vector<std::vector<float>> tensors) { | ||||||
|  |   auto inputs = std::make_unique<std::vector<Tensor>>(); | ||||||
|  |   for (const auto& tensor : tensors) { | ||||||
|  |     inputs->emplace_back(Tensor::ElementType::kFloat32, | ||||||
|  |                          Tensor::Shape{1, static_cast<int>(tensor.size())}); | ||||||
|  |     auto view = inputs->back().GetCpuWriteView(); | ||||||
|  |     float* buffer = view.buffer<float>(); | ||||||
|  |     ASSERT_NE(buffer, nullptr); | ||||||
|  |     for (int i = 0; i < tensor.size(); ++i) { | ||||||
|  |       buffer[i] = tensor[i]; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   auto& input_packets = runner->MutableInputs()->Tag("TENSORS").packets; | ||||||
|  |   input_packets.push_back(Adopt(inputs.release()).At(Timestamp(0))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(TensorsToEmbeddingsCalculatorTest, FailsWithInvalidHeadNamesNumber) { | ||||||
|  |   CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( | ||||||
|  |     calculator: "TensorsToEmbeddingsCalculator" | ||||||
|  |     input_stream: "TENSORS:tensors" | ||||||
|  |     output_stream: "EMBEDDING_RESULT:embeddings" | ||||||
|  |     options { | ||||||
|  |       [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { head_names: "foo" } | ||||||
|  |     } | ||||||
|  |   )pb")); | ||||||
|  | 
 | ||||||
|  |   BuildGraph(&runner, {{0.1, 0.2}, {0.2, 0.3}}); | ||||||
|  |   auto status = runner.Run(); | ||||||
|  | 
 | ||||||
|  |   EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); | ||||||
|  |   EXPECT_THAT(status.message(), | ||||||
|  |               HasSubstr("Mismatch between number of provided head names")); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithoutHeadNames) { | ||||||
|  |   CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( | ||||||
|  |     calculator: "TensorsToEmbeddingsCalculator" | ||||||
|  |     input_stream: "TENSORS:tensors" | ||||||
|  |     output_stream: "EMBEDDING_RESULT:embeddings" | ||||||
|  |     options { | ||||||
|  |       [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { | ||||||
|  |         embedder_options { l2_normalize: false quantize: false } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   )pb")); | ||||||
|  | 
 | ||||||
|  |   BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); | ||||||
|  |   MP_ASSERT_OK(runner.Run()); | ||||||
|  | 
 | ||||||
|  |   const EmbeddingResult& result = runner.Outputs() | ||||||
|  |                                       .Get("EMBEDDING_RESULT", 0) | ||||||
|  |                                       .packets[0] | ||||||
|  |                                       .Get<EmbeddingResult>(); | ||||||
|  |   EXPECT_THAT( | ||||||
|  |       result, | ||||||
|  |       EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( | ||||||
|  |           R"pb(embeddings { | ||||||
|  |                  entries { float_embedding { values: 0.1 values: 0.2 } } | ||||||
|  |                  head_index: 0 | ||||||
|  |                } | ||||||
|  |                embeddings { | ||||||
|  |                  entries { float_embedding { values: -0.2 values: -0.3 } } | ||||||
|  |                  head_index: 1 | ||||||
|  |                })pb"))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithHeadNames) { | ||||||
|  |   CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( | ||||||
|  |     calculator: "TensorsToEmbeddingsCalculator" | ||||||
|  |     input_stream: "TENSORS:tensors" | ||||||
|  |     output_stream: "EMBEDDING_RESULT:embeddings" | ||||||
|  |     options { | ||||||
|  |       [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { | ||||||
|  |         embedder_options { l2_normalize: false quantize: false } | ||||||
|  |         head_names: "foo" | ||||||
|  |         head_names: "bar" | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   )pb")); | ||||||
|  | 
 | ||||||
|  |   BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); | ||||||
|  |   MP_ASSERT_OK(runner.Run()); | ||||||
|  | 
 | ||||||
|  |   const EmbeddingResult& result = runner.Outputs() | ||||||
|  |                                       .Get("EMBEDDING_RESULT", 0) | ||||||
|  |                                       .packets[0] | ||||||
|  |                                       .Get<EmbeddingResult>(); | ||||||
|  |   EXPECT_THAT( | ||||||
|  |       result, | ||||||
|  |       EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( | ||||||
|  |           R"pb(embeddings { | ||||||
|  |                  entries { float_embedding { values: 0.1 values: 0.2 } } | ||||||
|  |                  head_index: 0 | ||||||
|  |                  head_name: "foo" | ||||||
|  |                } | ||||||
|  |                embeddings { | ||||||
|  |                  entries { float_embedding { values: -0.2 values: -0.3 } } | ||||||
|  |                  head_index: 1 | ||||||
|  |                  head_name: "bar" | ||||||
|  |                })pb"))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithNormalization) { | ||||||
|  |   CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( | ||||||
|  |     calculator: "TensorsToEmbeddingsCalculator" | ||||||
|  |     input_stream: "TENSORS:tensors" | ||||||
|  |     output_stream: "EMBEDDING_RESULT:embeddings" | ||||||
|  |     options { | ||||||
|  |       [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { | ||||||
|  |         embedder_options { l2_normalize: true quantize: false } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   )pb")); | ||||||
|  | 
 | ||||||
|  |   BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); | ||||||
|  |   MP_ASSERT_OK(runner.Run()); | ||||||
|  | 
 | ||||||
|  |   const EmbeddingResult& result = runner.Outputs() | ||||||
|  |                                       .Get("EMBEDDING_RESULT", 0) | ||||||
|  |                                       .packets[0] | ||||||
|  |                                       .Get<EmbeddingResult>(); | ||||||
|  |   EXPECT_THAT( | ||||||
|  |       result, | ||||||
|  |       EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( | ||||||
|  |           R"pb(embeddings { | ||||||
|  |                  entries { | ||||||
|  |                    float_embedding { values: 0.44721356 values: 0.8944271 } | ||||||
|  |                  } | ||||||
|  |                  head_index: 0 | ||||||
|  |                } | ||||||
|  |                embeddings { | ||||||
|  |                  entries { | ||||||
|  |                    float_embedding { values: -0.5547002 values: -0.8320503 } | ||||||
|  |                  } | ||||||
|  |                  head_index: 1 | ||||||
|  |                })pb"))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(TensorsToEmbeddingsCalculatorTest, SucceedsWithQuantization) { | ||||||
|  |   CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( | ||||||
|  |     calculator: "TensorsToEmbeddingsCalculator" | ||||||
|  |     input_stream: "TENSORS:tensors" | ||||||
|  |     output_stream: "EMBEDDING_RESULT:embeddings" | ||||||
|  |     options { | ||||||
|  |       [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { | ||||||
|  |         embedder_options { l2_normalize: false quantize: true } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   )pb")); | ||||||
|  | 
 | ||||||
|  |   BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); | ||||||
|  |   MP_ASSERT_OK(runner.Run()); | ||||||
|  | 
 | ||||||
|  |   const EmbeddingResult& result = runner.Outputs() | ||||||
|  |                                       .Get("EMBEDDING_RESULT", 0) | ||||||
|  |                                       .packets[0] | ||||||
|  |                                       .Get<EmbeddingResult>(); | ||||||
|  |   EXPECT_THAT(result, | ||||||
|  |               EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( | ||||||
|  |                   R"pb(embeddings { | ||||||
|  |                          entries { | ||||||
|  |                            quantized_embedding { values: "\x0d\x1a" }  # 13,26 | ||||||
|  |                          } | ||||||
|  |                          head_index: 0 | ||||||
|  |                        } | ||||||
|  |                        embeddings { | ||||||
|  |                          entries { | ||||||
|  |                            quantized_embedding { values: "\xe6\xda" }  # -26,-38 | ||||||
|  |                          } | ||||||
|  |                          head_index: 1 | ||||||
|  |                        })pb"))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | TEST(TensorsToEmbeddingsCalculatorTest, | ||||||
|  |      SucceedsWithNormalizationAndQuantization) { | ||||||
|  |   CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb( | ||||||
|  |     calculator: "TensorsToEmbeddingsCalculator" | ||||||
|  |     input_stream: "TENSORS:tensors" | ||||||
|  |     output_stream: "EMBEDDING_RESULT:embeddings" | ||||||
|  |     options { | ||||||
|  |       [mediapipe.TensorsToEmbeddingsCalculatorOptions.ext] { | ||||||
|  |         embedder_options { l2_normalize: true quantize: true } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   )pb")); | ||||||
|  | 
 | ||||||
|  |   BuildGraph(&runner, {{0.1, 0.2}, {-0.2, -0.3}}); | ||||||
|  |   MP_ASSERT_OK(runner.Run()); | ||||||
|  | 
 | ||||||
|  |   const EmbeddingResult& result = runner.Outputs() | ||||||
|  |                                       .Get("EMBEDDING_RESULT", 0) | ||||||
|  |                                       .packets[0] | ||||||
|  |                                       .Get<EmbeddingResult>(); | ||||||
|  |   EXPECT_THAT( | ||||||
|  |       result, | ||||||
|  |       EqualsProto(ParseTextProtoOrDie<EmbeddingResult>( | ||||||
|  |           R"pb(embeddings { | ||||||
|  |                  entries { | ||||||
|  |                    quantized_embedding { values: "\x39\x72" }  # 57,114 | ||||||
|  |                  } | ||||||
|  |                  head_index: 0 | ||||||
|  |                } | ||||||
|  |                embeddings { | ||||||
|  |                  entries { | ||||||
|  |                    quantized_embedding { values: "\xb9\x95" }  # -71,-107 | ||||||
|  |                  } | ||||||
|  |                  head_index: 1 | ||||||
|  |                })pb"))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | }  // namespace
 | ||||||
|  | }  // namespace mediapipe
 | ||||||
|  | @ -35,9 +35,12 @@ limitations under the License. | ||||||
| #include "mediapipe/framework/formats/tensor.h" | #include "mediapipe/framework/formats/tensor.h" | ||||||
| #include "mediapipe/tasks/cc/common.h" | #include "mediapipe/tasks/cc/common.h" | ||||||
| #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" | #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/calculators/score_calibration_utils.h" | ||||||
| #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" | #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" | ||||||
| #include "mediapipe/tasks/cc/components/classifier_options.pb.h" |  | ||||||
| #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" | #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/utils/source_or_node_output.h" | ||||||
| #include "mediapipe/tasks/cc/core/model_resources.h" | #include "mediapipe/tasks/cc/core/model_resources.h" | ||||||
| #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" | #include "mediapipe/tasks/cc/metadata/metadata_extractor.h" | ||||||
| #include "mediapipe/tasks/metadata/metadata_schema_generated.h" | #include "mediapipe/tasks/metadata/metadata_schema_generated.h" | ||||||
|  | @ -47,6 +50,7 @@ limitations under the License. | ||||||
| 
 | 
 | ||||||
| namespace mediapipe { | namespace mediapipe { | ||||||
| namespace tasks { | namespace tasks { | ||||||
|  | namespace components { | ||||||
| 
 | 
 | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
|  | @ -57,18 +61,21 @@ using ::mediapipe::api2::Timestamp; | ||||||
| using ::mediapipe::api2::builder::GenericNode; | using ::mediapipe::api2::builder::GenericNode; | ||||||
| using ::mediapipe::api2::builder::Graph; | using ::mediapipe::api2::builder::Graph; | ||||||
| using ::mediapipe::api2::builder::Source; | using ::mediapipe::api2::builder::Source; | ||||||
|  | using ::mediapipe::tasks::components::proto::ClassifierOptions; | ||||||
| using ::mediapipe::tasks::core::ModelResources; | using ::mediapipe::tasks::core::ModelResources; | ||||||
| using ::mediapipe::tasks::metadata::ModelMetadataExtractor; | using ::mediapipe::tasks::metadata::ModelMetadataExtractor; | ||||||
| using ::tflite::ProcessUnit; | using ::tflite::ProcessUnit; | ||||||
| using ::tflite::ProcessUnitOptions_ScoreThresholdingOptions; |  | ||||||
| using ::tflite::TensorMetadata; | using ::tflite::TensorMetadata; | ||||||
| using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>; | using LabelItems = mediapipe::proto_ns::Map<int64, ::mediapipe::LabelMapItem>; | ||||||
|  | using TensorsSource = mediapipe::tasks::SourceOrNodeOutput<std::vector<Tensor>>; | ||||||
| 
 | 
 | ||||||
| constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest(); | constexpr float kDefaultScoreThreshold = std::numeric_limits<float>::lowest(); | ||||||
| 
 | 
 | ||||||
| constexpr char kTensorsTag[] = "TENSORS"; | constexpr char kCalibratedScoresTag[] = "CALIBRATED_SCORES"; | ||||||
| constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; | constexpr char kClassificationResultTag[] = "CLASSIFICATION_RESULT"; | ||||||
| constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; | constexpr char kClassificationsTag[] = "CLASSIFICATIONS"; | ||||||
|  | constexpr char kScoresTag[] = "SCORES"; | ||||||
|  | constexpr char kTensorsTag[] = "TENSORS"; | ||||||
| constexpr char kTimestampsTag[] = "TIMESTAMPS"; | constexpr char kTimestampsTag[] = "TIMESTAMPS"; | ||||||
| 
 | 
 | ||||||
| // Performs sanity checks on provided ClassifierOptions.
 | // Performs sanity checks on provided ClassifierOptions.
 | ||||||
|  | @ -183,10 +190,10 @@ absl::StatusOr<LabelItems> GetLabelItemsIfAny( | ||||||
| absl::StatusOr<float> GetScoreThreshold( | absl::StatusOr<float> GetScoreThreshold( | ||||||
|     const ModelMetadataExtractor& metadata_extractor, |     const ModelMetadataExtractor& metadata_extractor, | ||||||
|     const TensorMetadata& tensor_metadata) { |     const TensorMetadata& tensor_metadata) { | ||||||
|   ASSIGN_OR_RETURN( |   ASSIGN_OR_RETURN(const ProcessUnit* score_thresholding_process_unit, | ||||||
|       const ProcessUnit* score_thresholding_process_unit, |  | ||||||
|                    metadata_extractor.FindFirstProcessUnit( |                    metadata_extractor.FindFirstProcessUnit( | ||||||
|           tensor_metadata, ProcessUnitOptions_ScoreThresholdingOptions)); |                        tensor_metadata, | ||||||
|  |                        tflite::ProcessUnitOptions_ScoreThresholdingOptions)); | ||||||
|   if (score_thresholding_process_unit == nullptr) { |   if (score_thresholding_process_unit == nullptr) { | ||||||
|     return kDefaultScoreThreshold; |     return kDefaultScoreThreshold; | ||||||
|   } |   } | ||||||
|  | @ -230,8 +237,51 @@ absl::StatusOr<absl::flat_hash_set<int>> GetAllowOrDenyCategoryIndicesIfAny( | ||||||
|   return category_indices; |   return category_indices; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Fills in the TensorsToClassificationCalculatorOptions based on the classifier
 | absl::Status ConfigureScoreCalibrationIfAny( | ||||||
| // options and the (optional) output tensor metadata.
 |     const ModelMetadataExtractor& metadata_extractor, int tensor_index, | ||||||
|  |     ClassificationPostprocessingOptions* options) { | ||||||
|  |   const auto* tensor_metadata = | ||||||
|  |       metadata_extractor.GetOutputTensorMetadata(tensor_index); | ||||||
|  |   if (tensor_metadata == nullptr) { | ||||||
|  |     return absl::OkStatus(); | ||||||
|  |   } | ||||||
|  |   // Get ScoreCalibrationOptions, if any.
 | ||||||
|  |   ASSIGN_OR_RETURN(const ProcessUnit* score_calibration_process_unit, | ||||||
|  |                    metadata_extractor.FindFirstProcessUnit( | ||||||
|  |                        *tensor_metadata, | ||||||
|  |                        tflite::ProcessUnitOptions_ScoreCalibrationOptions)); | ||||||
|  |   if (score_calibration_process_unit == nullptr) { | ||||||
|  |     return absl::OkStatus(); | ||||||
|  |   } | ||||||
|  |   auto* score_calibration_options = | ||||||
|  |       score_calibration_process_unit->options_as_ScoreCalibrationOptions(); | ||||||
|  |   // Get corresponding AssociatedFile.
 | ||||||
|  |   auto score_calibration_filename = | ||||||
|  |       metadata_extractor.FindFirstAssociatedFileName( | ||||||
|  |           *tensor_metadata, | ||||||
|  |           tflite::AssociatedFileType_TENSOR_AXIS_SCORE_CALIBRATION); | ||||||
|  |   if (score_calibration_filename.empty()) { | ||||||
|  |     return CreateStatusWithPayload( | ||||||
|  |         absl::StatusCode::kNotFound, | ||||||
|  |         "Found ScoreCalibrationOptions but missing required associated " | ||||||
|  |         "parameters file with type TENSOR_AXIS_SCORE_CALIBRATION.", | ||||||
|  |         MediaPipeTasksStatus::kMetadataAssociatedFileNotFoundError); | ||||||
|  |   } | ||||||
|  |   ASSIGN_OR_RETURN( | ||||||
|  |       absl::string_view score_calibration_file, | ||||||
|  |       metadata_extractor.GetAssociatedFile(score_calibration_filename)); | ||||||
|  |   ScoreCalibrationCalculatorOptions calculator_options; | ||||||
|  |   MP_RETURN_IF_ERROR(ConfigureScoreCalibration( | ||||||
|  |       score_calibration_options->score_transformation(), | ||||||
|  |       score_calibration_options->default_score(), score_calibration_file, | ||||||
|  |       &calculator_options)); | ||||||
|  |   (*options->mutable_score_calibration_options())[tensor_index] = | ||||||
|  |       calculator_options; | ||||||
|  |   return absl::OkStatus(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Fills in the TensorsToClassificationCalculatorOptions based on the
 | ||||||
|  | // classifier options and the (optional) output tensor metadata.
 | ||||||
| absl::Status ConfigureTensorsToClassificationCalculator( | absl::Status ConfigureTensorsToClassificationCalculator( | ||||||
|     const ClassifierOptions& options, |     const ClassifierOptions& options, | ||||||
|     const ModelMetadataExtractor& metadata_extractor, int tensor_index, |     const ModelMetadataExtractor& metadata_extractor, int tensor_index, | ||||||
|  | @ -303,6 +353,8 @@ absl::Status ConfigureClassificationPostprocessing( | ||||||
|   ASSIGN_OR_RETURN(const auto heads_properties, |   ASSIGN_OR_RETURN(const auto heads_properties, | ||||||
|                    GetClassificationHeadsProperties(model_resources)); |                    GetClassificationHeadsProperties(model_resources)); | ||||||
|   for (int i = 0; i < heads_properties.num_heads; ++i) { |   for (int i = 0; i < heads_properties.num_heads; ++i) { | ||||||
|  |     MP_RETURN_IF_ERROR(ConfigureScoreCalibrationIfAny( | ||||||
|  |         *model_resources.GetMetadataExtractor(), i, options)); | ||||||
|     MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( |     MP_RETURN_IF_ERROR(ConfigureTensorsToClassificationCalculator( | ||||||
|         classifier_options, *model_resources.GetMetadataExtractor(), i, |         classifier_options, *model_resources.GetMetadataExtractor(), i, | ||||||
|         options->add_tensors_to_classifications_options())); |         options->add_tensors_to_classifications_options())); | ||||||
|  | @ -314,8 +366,8 @@ absl::Status ConfigureClassificationPostprocessing( | ||||||
|   return absl::OkStatus(); |   return absl::OkStatus(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // A "mediapipe.tasks.ClassificationPostprocessingSubgraph" converts raw
 | // A "mediapipe.tasks.components.ClassificationPostprocessingSubgraph" converts
 | ||||||
| // tensors into ClassificationResult objects.
 | // raw tensors into ClassificationResult objects.
 | ||||||
| // - Accepts CPU input tensors.
 | // - Accepts CPU input tensors.
 | ||||||
| //
 | //
 | ||||||
| // Inputs:
 | // Inputs:
 | ||||||
|  | @ -376,18 +428,21 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // If output tensors are quantized, they must be dequantized first.
 |     // If output tensors are quantized, they must be dequantized first.
 | ||||||
|     GenericNode* tensors_dequantization_node; |     TensorsSource dequantized_tensors(&tensors_in); | ||||||
|     if (options.has_quantized_outputs()) { |     if (options.has_quantized_outputs()) { | ||||||
|       tensors_dequantization_node = |       GenericNode* tensors_dequantization_node = | ||||||
|           &graph.AddNode("TensorsDequantizationCalculator"); |           &graph.AddNode("TensorsDequantizationCalculator"); | ||||||
|       tensors_in >> tensors_dequantization_node->In(kTensorsTag); |       tensors_in >> tensors_dequantization_node->In(kTensorsTag); | ||||||
|  |       dequantized_tensors = {tensors_dequantization_node, kTensorsTag}; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // If there are multiple classification heads, the output tensors need to be
 |     // If there are multiple classification heads, the output tensors need to be
 | ||||||
|     // split.
 |     // split.
 | ||||||
|     GenericNode* split_tensor_vector_node; |     std::vector<TensorsSource> split_tensors; | ||||||
|  |     split_tensors.reserve(num_heads); | ||||||
|     if (num_heads > 1) { |     if (num_heads > 1) { | ||||||
|       split_tensor_vector_node = &graph.AddNode("SplitTensorVectorCalculator"); |       GenericNode* split_tensor_vector_node = | ||||||
|  |           &graph.AddNode("SplitTensorVectorCalculator"); | ||||||
|       auto& split_tensor_vector_options = |       auto& split_tensor_vector_options = | ||||||
|           split_tensor_vector_node |           split_tensor_vector_node | ||||||
|               ->GetOptions<mediapipe::SplitVectorCalculatorOptions>(); |               ->GetOptions<mediapipe::SplitVectorCalculatorOptions>(); | ||||||
|  | @ -395,12 +450,27 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { | ||||||
|         auto* range = split_tensor_vector_options.add_ranges(); |         auto* range = split_tensor_vector_options.add_ranges(); | ||||||
|         range->set_begin(i); |         range->set_begin(i); | ||||||
|         range->set_end(i + 1); |         range->set_end(i + 1); | ||||||
|  |         split_tensors.emplace_back(split_tensor_vector_node, i); | ||||||
|       } |       } | ||||||
|       if (options.has_quantized_outputs()) { |       dequantized_tensors >> split_tensor_vector_node->In(0); | ||||||
|         tensors_dequantization_node->Out(kTensorsTag) >> |  | ||||||
|             split_tensor_vector_node->In(0); |  | ||||||
|     } else { |     } else { | ||||||
|         tensors_in >> split_tensor_vector_node->In(0); |       split_tensors.emplace_back(dequantized_tensors); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Adds score calibration for heads that specify it, if any.
 | ||||||
|  |     std::vector<TensorsSource> calibrated_tensors; | ||||||
|  |     calibrated_tensors.reserve(num_heads); | ||||||
|  |     for (int i = 0; i < num_heads; ++i) { | ||||||
|  |       if (options.score_calibration_options().contains(i)) { | ||||||
|  |         GenericNode* score_calibration_node = | ||||||
|  |             &graph.AddNode("ScoreCalibrationCalculator"); | ||||||
|  |         score_calibration_node->GetOptions<ScoreCalibrationCalculatorOptions>() | ||||||
|  |             .CopyFrom(options.score_calibration_options().at(i)); | ||||||
|  |         split_tensors[i] >> score_calibration_node->In(kScoresTag); | ||||||
|  |         calibrated_tensors.emplace_back(score_calibration_node, | ||||||
|  |                                         kCalibratedScoresTag); | ||||||
|  |       } else { | ||||||
|  |         calibrated_tensors.emplace_back(split_tensors[i]); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -413,17 +483,8 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { | ||||||
|       tensors_to_classification_nodes.back() |       tensors_to_classification_nodes.back() | ||||||
|           ->GetOptions<TensorsToClassificationCalculatorOptions>() |           ->GetOptions<TensorsToClassificationCalculatorOptions>() | ||||||
|           .CopyFrom(options.tensors_to_classifications_options(i)); |           .CopyFrom(options.tensors_to_classifications_options(i)); | ||||||
|       if (num_heads == 1) { |       calibrated_tensors[i] >> | ||||||
|         if (options.has_quantized_outputs()) { |  | ||||||
|           tensors_dequantization_node->Out(kTensorsTag) >> |  | ||||||
|           tensors_to_classification_nodes.back()->In(kTensorsTag); |           tensors_to_classification_nodes.back()->In(kTensorsTag); | ||||||
|         } else { |  | ||||||
|           tensors_in >> tensors_to_classification_nodes.back()->In(kTensorsTag); |  | ||||||
|         } |  | ||||||
|       } else { |  | ||||||
|         split_tensor_vector_node->Out(i) >> |  | ||||||
|             tensors_to_classification_nodes.back()->In(kTensorsTag); |  | ||||||
|       } |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // Aggregates Classifications into a single ClassificationResult.
 |     // Aggregates Classifications into a single ClassificationResult.
 | ||||||
|  | @ -444,7 +505,8 @@ class ClassificationPostprocessingSubgraph : public mediapipe::Subgraph { | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| REGISTER_MEDIAPIPE_GRAPH( | REGISTER_MEDIAPIPE_GRAPH( | ||||||
|     ::mediapipe::tasks::ClassificationPostprocessingSubgraph); |     ::mediapipe::tasks::components::ClassificationPostprocessingSubgraph); | ||||||
| 
 | 
 | ||||||
|  | }  // namespace components
 | ||||||
| }  // namespace tasks
 | }  // namespace tasks
 | ||||||
| }  // namespace mediapipe
 | }  // namespace mediapipe
 | ||||||
|  |  | ||||||
|  | @ -18,11 +18,12 @@ limitations under the License. | ||||||
| 
 | 
 | ||||||
| #include "absl/status/status.h" | #include "absl/status/status.h" | ||||||
| #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" | #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" | ||||||
| #include "mediapipe/tasks/cc/components/classifier_options.pb.h" | #include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" | ||||||
| #include "mediapipe/tasks/cc/core/model_resources.h" | #include "mediapipe/tasks/cc/core/model_resources.h" | ||||||
| 
 | 
 | ||||||
| namespace mediapipe { | namespace mediapipe { | ||||||
| namespace tasks { | namespace tasks { | ||||||
|  | namespace components { | ||||||
| 
 | 
 | ||||||
| // Configures a ClassificationPostprocessing subgraph using the provided model
 | // Configures a ClassificationPostprocessing subgraph using the provided model
 | ||||||
| // resources and ClassifierOptions.
 | // resources and ClassifierOptions.
 | ||||||
|  | @ -31,7 +32,7 @@ namespace tasks { | ||||||
| // Example usage:
 | // Example usage:
 | ||||||
| //
 | //
 | ||||||
| //   auto& postprocessing =
 | //   auto& postprocessing =
 | ||||||
| //       graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph");
 | //       graph.AddNode("mediapipe.tasks.components.ClassificationPostprocessingSubgraph");
 | ||||||
| //   MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
 | //   MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing(
 | ||||||
| //       model_resources,
 | //       model_resources,
 | ||||||
| //       classifier_options,
 | //       classifier_options,
 | ||||||
|  | @ -49,10 +50,11 @@ namespace tasks { | ||||||
| //   CLASSIFICATION_RESULT - ClassificationResult
 | //   CLASSIFICATION_RESULT - ClassificationResult
 | ||||||
| //     The output aggregated classification results.
 | //     The output aggregated classification results.
 | ||||||
| absl::Status ConfigureClassificationPostprocessing( | absl::Status ConfigureClassificationPostprocessing( | ||||||
|     const core::ModelResources& model_resources, |     const tasks::core::ModelResources& model_resources, | ||||||
|     const ClassifierOptions& classifier_options, |     const tasks::components::proto::ClassifierOptions& classifier_options, | ||||||
|     ClassificationPostprocessingOptions* options); |     ClassificationPostprocessingOptions* options); | ||||||
| 
 | 
 | ||||||
|  | }  // namespace components
 | ||||||
| }  // namespace tasks
 | }  // namespace tasks
 | ||||||
| }  // namespace mediapipe
 | }  // namespace mediapipe
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -15,17 +15,22 @@ limitations under the License. | ||||||
| 
 | 
 | ||||||
| syntax = "proto2"; | syntax = "proto2"; | ||||||
| 
 | 
 | ||||||
| package mediapipe.tasks; | package mediapipe.tasks.components; | ||||||
| 
 | 
 | ||||||
| import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto"; | import "mediapipe/calculators/tensor/tensors_to_classification_calculator.proto"; | ||||||
| import "mediapipe/framework/calculator.proto"; | import "mediapipe/framework/calculator.proto"; | ||||||
| import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto"; | import "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.proto"; | ||||||
|  | import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.proto"; | ||||||
| 
 | 
 | ||||||
| message ClassificationPostprocessingOptions { | message ClassificationPostprocessingOptions { | ||||||
|   extend mediapipe.CalculatorOptions { |   extend mediapipe.CalculatorOptions { | ||||||
|     optional ClassificationPostprocessingOptions ext = 460416950; |     optional ClassificationPostprocessingOptions ext = 460416950; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   // Optional mapping between output tensor index and corresponding score | ||||||
|  |   // calibration options. | ||||||
|  |   map<int32, ScoreCalibrationCalculatorOptions> score_calibration_options = 4; | ||||||
|  | 
 | ||||||
|   // Options for the TensorsToClassification calculators (one per classification |   // Options for the TensorsToClassification calculators (one per classification | ||||||
|   // head) encapsulated by the ClassificationPostprocessing subgraph. |   // head) encapsulated by the ClassificationPostprocessing subgraph. | ||||||
|   repeated mediapipe.TensorsToClassificationCalculatorOptions |   repeated mediapipe.TensorsToClassificationCalculatorOptions | ||||||
|  |  | ||||||
|  | @ -41,9 +41,10 @@ limitations under the License. | ||||||
| #include "mediapipe/framework/port/status_matchers.h" | #include "mediapipe/framework/port/status_matchers.h" | ||||||
| #include "mediapipe/framework/timestamp.h" | #include "mediapipe/framework/timestamp.h" | ||||||
| #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" | #include "mediapipe/tasks/cc/components/calculators/classification_aggregation_calculator.pb.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.pb.h" | ||||||
| #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" | #include "mediapipe/tasks/cc/components/classification_postprocessing_options.pb.h" | ||||||
| #include "mediapipe/tasks/cc/components/classifier_options.pb.h" |  | ||||||
| #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" | #include "mediapipe/tasks/cc/components/containers/classifications.pb.h" | ||||||
|  | #include "mediapipe/tasks/cc/components/proto/classifier_options.pb.h" | ||||||
| #include "mediapipe/tasks/cc/core/model_resources.h" | #include "mediapipe/tasks/cc/core/model_resources.h" | ||||||
| #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" | #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" | ||||||
| #include "mediapipe/util/label_map.pb.h" | #include "mediapipe/util/label_map.pb.h" | ||||||
|  | @ -51,6 +52,7 @@ limitations under the License. | ||||||
| 
 | 
 | ||||||
| namespace mediapipe { | namespace mediapipe { | ||||||
| namespace tasks { | namespace tasks { | ||||||
|  | namespace components { | ||||||
| namespace { | namespace { | ||||||
| 
 | 
 | ||||||
| using ::mediapipe::api2::Input; | using ::mediapipe::api2::Input; | ||||||
|  | @ -58,6 +60,7 @@ using ::mediapipe::api2::Output; | ||||||
| using ::mediapipe::api2::builder::Graph; | using ::mediapipe::api2::builder::Graph; | ||||||
| using ::mediapipe::api2::builder::Source; | using ::mediapipe::api2::builder::Source; | ||||||
| using ::mediapipe::file::JoinPath; | using ::mediapipe::file::JoinPath; | ||||||
|  | using ::mediapipe::tasks::components::proto::ClassifierOptions; | ||||||
| using ::mediapipe::tasks::core::ModelResources; | using ::mediapipe::tasks::core::ModelResources; | ||||||
| using ::testing::HasSubstr; | using ::testing::HasSubstr; | ||||||
| using ::testing::proto::Approximately; | using ::testing::proto::Approximately; | ||||||
|  | @ -65,6 +68,8 @@ using ::testing::proto::Approximately; | ||||||
| constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; | constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/"; | ||||||
| constexpr char kQuantizedImageClassifierWithMetadata[] = | constexpr char kQuantizedImageClassifierWithMetadata[] = | ||||||
|     "vision/mobilenet_v1_0.25_224_quant.tflite"; |     "vision/mobilenet_v1_0.25_224_quant.tflite"; | ||||||
|  | constexpr char kQuantizedImageClassifierWithDummyScoreCalibration[] = | ||||||
|  |     "vision/mobilenet_v1_0.25_224_quant_with_dummy_score_calibration.tflite"; | ||||||
| constexpr char kQuantizedImageClassifierWithoutMetadata[] = | constexpr char kQuantizedImageClassifierWithoutMetadata[] = | ||||||
|     "vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; |     "vision/mobilenet_v1_0.25_192_quantized_1_default_1.tflite"; | ||||||
| constexpr char kFloatTwoHeadsAudioClassifierWithMetadata[] = | constexpr char kFloatTwoHeadsAudioClassifierWithMetadata[] = | ||||||
|  | @ -147,11 +152,12 @@ TEST_F(ConfigureTest, SucceedsWithoutMetadata) { | ||||||
|   ClassifierOptions options_in; |   ClassifierOptions options_in; | ||||||
| 
 | 
 | ||||||
|   ClassificationPostprocessingOptions options_out; |   ClassificationPostprocessingOptions options_out; | ||||||
|   MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, |   MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, | ||||||
|                                                      options_in, &options_out)); |                                                      options_in, &options_out)); | ||||||
| 
 | 
 | ||||||
|   EXPECT_THAT(options_out, Approximately(EqualsProto( |   EXPECT_THAT(options_out, Approximately(EqualsProto( | ||||||
|                                R"pb(tensors_to_classifications_options { |                                R"pb(score_calibration_options: [] | ||||||
|  |                                     tensors_to_classifications_options { | ||||||
|                                       min_score_threshold: -3.4028235e+38 |                                       min_score_threshold: -3.4028235e+38 | ||||||
|                                       top_k: -1 |                                       top_k: -1 | ||||||
|                                       sort_by_descending_score: true |                                       sort_by_descending_score: true | ||||||
|  | @ -169,11 +175,12 @@ TEST_F(ConfigureTest, SucceedsWithMaxResults) { | ||||||
|   options_in.set_max_results(3); |   options_in.set_max_results(3); | ||||||
| 
 | 
 | ||||||
|   ClassificationPostprocessingOptions options_out; |   ClassificationPostprocessingOptions options_out; | ||||||
|   MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, |   MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, | ||||||
|                                                      options_in, &options_out)); |                                                      options_in, &options_out)); | ||||||
| 
 | 
 | ||||||
|   EXPECT_THAT(options_out, Approximately(EqualsProto( |   EXPECT_THAT(options_out, Approximately(EqualsProto( | ||||||
|                                R"pb(tensors_to_classifications_options { |                                R"pb(score_calibration_options: [] | ||||||
|  |                                     tensors_to_classifications_options { | ||||||
|                                       min_score_threshold: -3.4028235e+38 |                                       min_score_threshold: -3.4028235e+38 | ||||||
|                                       top_k: 3 |                                       top_k: 3 | ||||||
|                                       sort_by_descending_score: true |                                       sort_by_descending_score: true | ||||||
|  | @ -191,11 +198,12 @@ TEST_F(ConfigureTest, SucceedsWithScoreThreshold) { | ||||||
|   options_in.set_score_threshold(0.5); |   options_in.set_score_threshold(0.5); | ||||||
| 
 | 
 | ||||||
|   ClassificationPostprocessingOptions options_out; |   ClassificationPostprocessingOptions options_out; | ||||||
|   MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, |   MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, | ||||||
|                                                      options_in, &options_out)); |                                                      options_in, &options_out)); | ||||||
| 
 | 
 | ||||||
|   EXPECT_THAT(options_out, Approximately(EqualsProto( |   EXPECT_THAT(options_out, Approximately(EqualsProto( | ||||||
|                                R"pb(tensors_to_classifications_options { |                                R"pb(score_calibration_options: [] | ||||||
|  |                                     tensors_to_classifications_options { | ||||||
|                                       min_score_threshold: 0.5 |                                       min_score_threshold: 0.5 | ||||||
|                                       top_k: -1 |                                       top_k: -1 | ||||||
|                                       sort_by_descending_score: true |                                       sort_by_descending_score: true | ||||||
|  | @ -212,7 +220,7 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) { | ||||||
|   ClassifierOptions options_in; |   ClassifierOptions options_in; | ||||||
| 
 | 
 | ||||||
|   ClassificationPostprocessingOptions options_out; |   ClassificationPostprocessingOptions options_out; | ||||||
|   MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, |   MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, | ||||||
|                                                      options_in, &options_out)); |                                                      options_in, &options_out)); | ||||||
| 
 | 
 | ||||||
|   // Check label map size and two first elements.
 |   // Check label map size and two first elements.
 | ||||||
|  | @ -229,7 +237,8 @@ TEST_F(ConfigureTest, SucceedsWithMetadata) { | ||||||
|   options_out.mutable_tensors_to_classifications_options(0) |   options_out.mutable_tensors_to_classifications_options(0) | ||||||
|       ->clear_label_items(); |       ->clear_label_items(); | ||||||
|   EXPECT_THAT(options_out, Approximately(EqualsProto( |   EXPECT_THAT(options_out, Approximately(EqualsProto( | ||||||
|                                R"pb(tensors_to_classifications_options { |                                R"pb(score_calibration_options: [] | ||||||
|  |                                     tensors_to_classifications_options { | ||||||
|                                       min_score_threshold: -3.4028235e+38 |                                       min_score_threshold: -3.4028235e+38 | ||||||
|                                       top_k: -1 |                                       top_k: -1 | ||||||
|                                       sort_by_descending_score: true |                                       sort_by_descending_score: true | ||||||
|  | @ -249,14 +258,15 @@ TEST_F(ConfigureTest, SucceedsWithAllowlist) { | ||||||
|   options_in.add_category_allowlist("tench"); |   options_in.add_category_allowlist("tench"); | ||||||
| 
 | 
 | ||||||
|   ClassificationPostprocessingOptions options_out; |   ClassificationPostprocessingOptions options_out; | ||||||
|   MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, |   MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, | ||||||
|                                                      options_in, &options_out)); |                                                      options_in, &options_out)); | ||||||
| 
 | 
 | ||||||
|   // Clear label map and compare the rest of the options.
 |   // Clear label map and compare the rest of the options.
 | ||||||
|   options_out.mutable_tensors_to_classifications_options(0) |   options_out.mutable_tensors_to_classifications_options(0) | ||||||
|       ->clear_label_items(); |       ->clear_label_items(); | ||||||
|   EXPECT_THAT(options_out, Approximately(EqualsProto( |   EXPECT_THAT(options_out, Approximately(EqualsProto( | ||||||
|                                R"pb(tensors_to_classifications_options { |                                R"pb(score_calibration_options: [] | ||||||
|  |                                     tensors_to_classifications_options { | ||||||
|                                       min_score_threshold: -3.4028235e+38 |                                       min_score_threshold: -3.4028235e+38 | ||||||
|                                       top_k: -1 |                                       top_k: -1 | ||||||
|                                       sort_by_descending_score: true |                                       sort_by_descending_score: true | ||||||
|  | @ -277,14 +287,15 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) { | ||||||
|   options_in.add_category_denylist("background"); |   options_in.add_category_denylist("background"); | ||||||
| 
 | 
 | ||||||
|   ClassificationPostprocessingOptions options_out; |   ClassificationPostprocessingOptions options_out; | ||||||
|   MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, |   MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, | ||||||
|                                                      options_in, &options_out)); |                                                      options_in, &options_out)); | ||||||
| 
 | 
 | ||||||
|   // Clear label map and compare the rest of the options.
 |   // Clear label map and compare the rest of the options.
 | ||||||
|   options_out.mutable_tensors_to_classifications_options(0) |   options_out.mutable_tensors_to_classifications_options(0) | ||||||
|       ->clear_label_items(); |       ->clear_label_items(); | ||||||
|   EXPECT_THAT(options_out, Approximately(EqualsProto( |   EXPECT_THAT(options_out, Approximately(EqualsProto( | ||||||
|                                R"pb(tensors_to_classifications_options { |                                R"pb(score_calibration_options: [] | ||||||
|  |                                     tensors_to_classifications_options { | ||||||
|                                       min_score_threshold: -3.4028235e+38 |                                       min_score_threshold: -3.4028235e+38 | ||||||
|                                       top_k: -1 |                                       top_k: -1 | ||||||
|                                       sort_by_descending_score: true |                                       sort_by_descending_score: true | ||||||
|  | @ -297,6 +308,56 @@ TEST_F(ConfigureTest, SucceedsWithDenylist) { | ||||||
|                                )pb"))); |                                )pb"))); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(ConfigureTest, SucceedsWithScoreCalibration) { | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       auto model_resources, | ||||||
|  |       CreateModelResourcesForModel( | ||||||
|  |           kQuantizedImageClassifierWithDummyScoreCalibration)); | ||||||
|  |   ClassifierOptions options_in; | ||||||
|  | 
 | ||||||
|  |   ClassificationPostprocessingOptions options_out; | ||||||
|  |   MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, | ||||||
|  |                                                      options_in, &options_out)); | ||||||
|  | 
 | ||||||
|  |   // Check label map size and two first elements.
 | ||||||
|  |   EXPECT_EQ( | ||||||
|  |       options_out.tensors_to_classifications_options(0).label_items_size(), | ||||||
|  |       kMobileNetNumClasses); | ||||||
|  |   EXPECT_THAT( | ||||||
|  |       options_out.tensors_to_classifications_options(0).label_items().at(0), | ||||||
|  |       EqualsProto(R"pb(name: "background")pb")); | ||||||
|  |   EXPECT_THAT( | ||||||
|  |       options_out.tensors_to_classifications_options(0).label_items().at(1), | ||||||
|  |       EqualsProto(R"pb(name: "tench")pb")); | ||||||
|  |   // Clear label map.
 | ||||||
|  |   options_out.mutable_tensors_to_classifications_options(0) | ||||||
|  |       ->clear_label_items(); | ||||||
|  |   // Check sigmoids size and first element.
 | ||||||
|  |   EXPECT_EQ(options_out.score_calibration_options_size(), 1); | ||||||
|  |   auto score_calibration_options = | ||||||
|  |       options_out.score_calibration_options().at(0); | ||||||
|  |   EXPECT_EQ(score_calibration_options.sigmoids_size(), kMobileNetNumClasses); | ||||||
|  |   EXPECT_THAT(score_calibration_options.sigmoids(0), | ||||||
|  |               EqualsProto(R"pb(scale: 1.0 slope: 1.0 offset: 0.0)pb")); | ||||||
|  |   options_out.mutable_score_calibration_options()->at(0).clear_sigmoids(); | ||||||
|  |   // Compare the rest of the options.
 | ||||||
|  |   EXPECT_THAT( | ||||||
|  |       options_out, | ||||||
|  |       Approximately(EqualsProto( | ||||||
|  |           R"pb(score_calibration_options { | ||||||
|  |                  key: 0 | ||||||
|  |                  value { score_transformation: IDENTITY default_score: 0.5 } | ||||||
|  |                } | ||||||
|  |                tensors_to_classifications_options { | ||||||
|  |                  min_score_threshold: -3.4028235e+38 | ||||||
|  |                  top_k: -1 | ||||||
|  |                  sort_by_descending_score: true | ||||||
|  |                } | ||||||
|  |                classification_aggregation_options { head_names: "probability" } | ||||||
|  |                has_quantized_outputs: true | ||||||
|  |           )pb"))); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { | TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { | ||||||
|   MP_ASSERT_OK_AND_ASSIGN( |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|       auto model_resources, |       auto model_resources, | ||||||
|  | @ -304,7 +365,7 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { | ||||||
|   ClassifierOptions options_in; |   ClassifierOptions options_in; | ||||||
| 
 | 
 | ||||||
|   ClassificationPostprocessingOptions options_out; |   ClassificationPostprocessingOptions options_out; | ||||||
|   MP_EXPECT_OK(ConfigureClassificationPostprocessing(*model_resources, |   MP_ASSERT_OK(ConfigureClassificationPostprocessing(*model_resources, | ||||||
|                                                      options_in, &options_out)); |                                                      options_in, &options_out)); | ||||||
|   // Check label maps sizes and first two elements.
 |   // Check label maps sizes and first two elements.
 | ||||||
|   EXPECT_EQ( |   EXPECT_EQ( | ||||||
|  | @ -331,7 +392,8 @@ TEST_F(ConfigureTest, SucceedsWithMultipleHeads) { | ||||||
|   options_out.mutable_tensors_to_classifications_options(1) |   options_out.mutable_tensors_to_classifications_options(1) | ||||||
|       ->clear_label_items(); |       ->clear_label_items(); | ||||||
|   EXPECT_THAT(options_out, Approximately(EqualsProto( |   EXPECT_THAT(options_out, Approximately(EqualsProto( | ||||||
|                                R"pb(tensors_to_classifications_options { |                                R"pb(score_calibration_options: [] | ||||||
|  |                                     tensors_to_classifications_options { | ||||||
|                                       min_score_threshold: -3.4028235e+38 |                                       min_score_threshold: -3.4028235e+38 | ||||||
|                                       top_k: -1 |                                       top_k: -1 | ||||||
|                                       sort_by_descending_score: true |                                       sort_by_descending_score: true | ||||||
|  | @ -358,8 +420,8 @@ class PostprocessingTest : public tflite_shims::testing::Test { | ||||||
|                      CreateModelResourcesForModel(model_name)); |                      CreateModelResourcesForModel(model_name)); | ||||||
| 
 | 
 | ||||||
|     Graph graph; |     Graph graph; | ||||||
|     auto& postprocessing = |     auto& postprocessing = graph.AddNode( | ||||||
|         graph.AddNode("mediapipe.tasks.ClassificationPostprocessingSubgraph"); |         "mediapipe.tasks.components.ClassificationPostprocessingSubgraph"); | ||||||
|     MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( |     MP_RETURN_IF_ERROR(ConfigureClassificationPostprocessing( | ||||||
|         *model_resources, options, |         *model_resources, options, | ||||||
|         &postprocessing.GetOptions<ClassificationPostprocessingOptions>())); |         &postprocessing.GetOptions<ClassificationPostprocessingOptions>())); | ||||||
|  | @ -503,6 +565,52 @@ TEST_F(PostprocessingTest, SucceedsWithMetadata) { | ||||||
|                })pb")); |                })pb")); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | TEST_F(PostprocessingTest, SucceedsWithScoreCalibration) { | ||||||
|  |   // Build graph.
 | ||||||
|  |   ClassifierOptions options; | ||||||
|  |   options.set_max_results(3); | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN( | ||||||
|  |       auto poller, | ||||||
|  |       BuildGraph(kQuantizedImageClassifierWithDummyScoreCalibration, options)); | ||||||
|  |   // Build input tensors.
 | ||||||
|  |   std::vector<uint8> tensor(kMobileNetNumClasses, 0); | ||||||
|  |   tensor[1] = 12; | ||||||
|  |   tensor[2] = 14; | ||||||
|  |   tensor[3] = 16; | ||||||
|  |   tensor[4] = 18; | ||||||
|  | 
 | ||||||
|  |   // Send tensors and get results.
 | ||||||
|  |   AddTensor(tensor, Tensor::ElementType::kUInt8, | ||||||
|  |             /*quantization_parameters=*/{0.1, 10}); | ||||||
|  |   MP_ASSERT_OK(Run()); | ||||||
|  |   MP_ASSERT_OK_AND_ASSIGN(auto results, GetClassificationResult(poller)); | ||||||
|  | 
 | ||||||
|  |   // Validate results.
 | ||||||
|  |   EXPECT_THAT(results, EqualsProto( | ||||||
|  |                            R"pb(classifications { | ||||||
|  |                                   entries { | ||||||
|  |                                     categories { | ||||||
|  |                                       index: 4 | ||||||
|  |                                       score: 0.6899744811 | ||||||
|  |                                       category_name: "tiger shark" | ||||||
|  |                                     } | ||||||
|  |                                     categories { | ||||||
|  |                                       index: 3 | ||||||
|  |                                       score: 0.6456563062 | ||||||
|  |                                       category_name: "great white shark" | ||||||
|  |                                     } | ||||||
|  |                                     categories { | ||||||
|  |                                       index: 2 | ||||||
|  |                                       score: 0.5986876601 | ||||||
|  |                                       category_name: "goldfish" | ||||||
|  |                                     } | ||||||
|  |                                     timestamp_ms: 0 | ||||||
|  |                                   } | ||||||
|  |                                   head_index: 0 | ||||||
|  |                                   head_name: "probability" | ||||||
|  |                                 })pb")); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { | TEST_F(PostprocessingTest, SucceedsWithMultipleHeads) { | ||||||
|   // Build graph.
 |   // Build graph.
 | ||||||
|   ClassifierOptions options; |   ClassifierOptions options; | ||||||
|  | @ -621,5 +729,6 @@ TEST_F(PostprocessingTest, SucceedsWithTimestamps) { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| }  // namespace
 | }  // namespace
 | ||||||
|  | }  // namespace components
 | ||||||
| }  // namespace tasks
 | }  // namespace tasks
 | ||||||
| }  // namespace mediapipe
 | }  // namespace mediapipe
 | ||||||
|  |  | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
		Reference in New Issue
	
	Block a user