diff --git a/WORKSPACE b/WORKSPACE index e77a0e79d..aacf856c2 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -157,11 +157,11 @@ http_archive( http_archive( name = "pybind11", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.4.3.tar.gz", - "https://github.com/pybind/pybind11/archive/v2.4.3.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11/archive/v2.7.1.tar.gz", + "https://github.com/pybind/pybind11/archive/v2.7.1.tar.gz", ], - sha256 = "1eed57bc6863190e35637290f97a20c81cfe4d9090ac0a24f3bbf08f265eb71d", - strip_prefix = "pybind11-2.4.3", + sha256 = "616d1c42e4cf14fa27b2a4ff759d7d7b33006fdc5ad8fd603bb2c22622f27020", + strip_prefix = "pybind11-2.7.1", build_file = "@pybind11_bazel//:pybind11.BUILD", ) diff --git a/docs/getting_started/hello_world_ios.md b/docs/getting_started/hello_world_ios.md index 9f9e7c979..4591b5f33 100644 --- a/docs/getting_started/hello_world_ios.md +++ b/docs/getting_started/hello_world_ios.md @@ -113,6 +113,10 @@ bazel to build the iOS application. The content of the 5. `Main.storyboard` and `Launch.storyboard` 6. `Assets.xcassets` directory. +Note: In newer versions of Xcode, you may see additional files `SceneDelegate.h` +and `SceneDelegate.m`. Make sure to copy them too and add them to the `BUILD` +file mentioned below. + Copy these files to a directory named `HelloWorld` to a location that can access the MediaPipe source code. For example, the source code of the application that we will build in this tutorial is located in @@ -247,6 +251,12 @@ We need to get frames from the `_cameraSource` into our application `MPPInputSourceDelegate`. So our application `ViewController` can be a delegate of `_cameraSource`. +Update the interface definition of `ViewController` accordingly: + +``` +@interface ViewController () +``` + To handle camera setup and process incoming frames, we should use a queue different from the main queue. Add the following to the implementation block of the `ViewController`: @@ -288,6 +298,12 @@ utility called `MPPLayerRenderer` to display images on the screen. This utility can be used to display `CVPixelBufferRef` objects, which is the type of the images provided by `MPPCameraInputSource` to its delegates. +In `ViewController.m`, add the following import line: + +``` +#import "mediapipe/objc/MPPLayerRenderer.h" +``` + To display images of the screen, we need to add a new `UIView` object called `_liveView` to the `ViewController`. @@ -411,6 +427,12 @@ Objective-C++. ### Use the graph in `ViewController` +In `ViewController.m`, add the following import line: + +``` +#import "mediapipe/objc/MPPGraph.h" +``` + Declare a static constant with the name of the graph, the input stream and the output stream: @@ -549,6 +571,12 @@ method to receive packets on this output stream and display them on the screen: } ``` +Update the interface definition of `ViewController` with `MPPGraphDelegate`: + +``` +@interface ViewController () +``` + And that is all! Build and run the app on your iOS device. You should see the results of running the edge detection graph on a live video feed. Congrats! @@ -560,5 +588,5 @@ appropriate `BUILD` file dependencies for the edge detection graph. [Bazel]:https://bazel.build/ [`edge_detection_mobile_gpu.pbtxt`]:https://github.com/google/mediapipe/tree/master/mediapipe/graphs/edge_detection/edge_detection_mobile_gpu.pbtxt -[common]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common) -[helloworld]:(https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld) +[common]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/common +[helloworld]:https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/helloworld diff --git a/docs/getting_started/install.md b/docs/getting_started/install.md index b90c0f3bd..bb2539d33 100644 --- a/docs/getting_started/install.md +++ b/docs/getting_started/install.md @@ -796,7 +796,7 @@ This will use a Docker image that will isolate mediapipe's installation from the ```bash $ docker run -it --name mediapipe mediapipe:latest - root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazel run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world + root@bca08b91ff63:/mediapipe# GLOG_logtostderr=1 bazelisk run --define MEDIAPIPE_DISABLE_GPU=1 mediapipe/examples/desktop/hello_world:hello_world # Should print: # Hello World! diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index 457890372..d7dc8f045 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -529,7 +529,7 @@ Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](http > ``` > and then run > -> ```build +> ```bash > bazel run -c opt mediapipe/graphs/object_detection_3d/obj_parser:ObjParser -- input_dir=[INTERMEDIATE_OUTPUT_DIR] output_dir=[OUTPUT_DIR] > ``` > INPUT_DIR should be the folder with initial asset .obj files to be processed, diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index 696f71943..f9bbac1a1 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -141,7 +141,7 @@ Optionally, MediaPipe Pose can predicts a full-body Please find more detail in the [BlazePose Google AI Blog](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html), this [paper](https://arxiv.org/abs/2006.10204), -[the model card](./models.md#pose) and the [Output](#Output) section below. +[the model card](./models.md#pose) and the [Output](#output) section below. ## Solution APIs @@ -281,8 +281,8 @@ with mp_pose.Pose( continue print( f'Nose coordinates: (' - f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' - f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_height})' + f'{results.pose_landmarks.landmark[mp_pose.PoseLandmark.NOSE].x * image_width}, ' + f'{results.pose_landmarks.landmark[mp_pose.PoseLandmark.NOSE].y * image_height})' ) annotated_image = image.copy() @@ -369,6 +369,7 @@ Supported configuration options:
+
diff --git a/docs/solutions/selfie_segmentation.md b/docs/solutions/selfie_segmentation.md index fc063d1e7..2cb155fb3 100644 --- a/docs/solutions/selfie_segmentation.md +++ b/docs/solutions/selfie_segmentation.md @@ -262,7 +262,7 @@ to visualize its associated subgraphs, please see [(or download prebuilt ARM64 APK)](https://drive.google.com/file/d/1DoeyGzMmWUsjfVgZfGGecrn7GKzYcEAo/view?usp=sharing) [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu:selfiesegmentationgpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/selfiesegmentationgpu/BUILD) * iOS target: - [`mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp`](http:/mediapipe/examples/ios/selfiesegmentationgpu/BUILD) + [`mediapipe/examples/ios/selfiesegmentationgpu:SelfieSegmentationGpuApp`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/selfiesegmentationgpu/BUILD) ### Desktop diff --git a/docs/solutions/solutions.md b/docs/solutions/solutions.md index 5fffd5e4b..e9e4cdc38 100644 --- a/docs/solutions/solutions.md +++ b/docs/solutions/solutions.md @@ -13,6 +13,9 @@ has_toc: false {:toc} --- +MediaPipe offers open source cross-platform, customizable ML solutions for live +and streaming media. + diff --git a/mediapipe/calculators/core/begin_loop_calculator.cc b/mediapipe/calculators/core/begin_loop_calculator.cc index 1d0f7824d..7a5a9d5e9 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.cc +++ b/mediapipe/calculators/core/begin_loop_calculator.cc @@ -42,4 +42,9 @@ REGISTER_CALCULATOR(BeginLoopDetectionCalculator); typedef BeginLoopCalculator> BeginLoopMatrixCalculator; REGISTER_CALCULATOR(BeginLoopMatrixCalculator); +// A calculator to process std::vector>. +typedef BeginLoopCalculator>> + BeginLoopMatrixVectorCalculator; +REGISTER_CALCULATOR(BeginLoopMatrixVectorCalculator); + } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_cpu.cc b/mediapipe/calculators/tensor/inference_calculator_cpu.cc index 4995d9d4e..1a1912017 100644 --- a/mediapipe/calculators/tensor/inference_calculator_cpu.cc +++ b/mediapipe/calculators/tensor/inference_calculator_cpu.cc @@ -73,6 +73,7 @@ class InferenceCalculatorCpuImpl private: absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. Packet model_packet_; @@ -91,8 +92,7 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract( absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) { MP_RETURN_IF_ERROR(LoadModel(cc)); - MP_RETURN_IF_ERROR(LoadDelegate(cc)); - return absl::OkStatus(); + return LoadDelegateAndAllocateTensors(cc); } absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { @@ -156,11 +156,19 @@ absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) { cc->Options().cpu_num_thread()); #endif // __EMSCRIPTEN__ + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors( + CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + + // AllocateTensors() can be called only after ModifyGraphWithDelegate. RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. - CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != - kTfLiteAffineQuantization); - + RET_CHECK_NE( + interpreter_->tensor(interpreter_->inputs()[0])->quantization.type, + kTfLiteAffineQuantization); return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc index 4072e1d87..5ca673c25 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -53,6 +53,7 @@ class InferenceCalculatorGlImpl absl::Status WriteKernelsToFile(); absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); absl::Status InitTFLiteGPURunner(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. @@ -119,10 +120,11 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { } MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, - &cc]() -> ::mediapipe::Status { - return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) : LoadDelegate(cc); - })); + MP_RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { + return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) + : LoadDelegateAndAllocateTensors(cc); + })); return absl::OkStatus(); } @@ -324,11 +326,19 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { cc->Options().cpu_num_thread()); #endif // __EMSCRIPTEN__ + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors( + CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + + // AllocateTensors() can be called only after ModifyGraphWithDelegate. RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. - CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != - kTfLiteAffineQuantization); - + RET_CHECK_NE( + interpreter_->tensor(interpreter_->inputs()[0])->quantization.type, + kTfLiteAffineQuantization); return absl::OkStatus(); } diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc index 4bf3525e4..49e042290 100644 --- a/mediapipe/calculators/tensor/inference_calculator_metal.cc +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -92,6 +92,7 @@ class InferenceCalculatorMetalImpl private: absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); // TfLite requires us to keep the model alive as long as the interpreter is. Packet model_packet_; @@ -130,8 +131,7 @@ absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) { gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(gpu_helper_); - MP_RETURN_IF_ERROR(LoadDelegate(cc)); - return absl::OkStatus(); + return LoadDelegateAndAllocateTensors(cc); } absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { @@ -212,11 +212,19 @@ absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) { interpreter_->SetNumThreads( cc->Options().cpu_num_thread()); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors( + CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + + // AllocateTensors() can be called only after ModifyGraphWithDelegate. RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); // TODO: Support quantized tensors. - CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != - kTfLiteAffineQuantization); - + RET_CHECK_NE( + interpreter_->tensor(interpreter_->inputs()[0])->quantization.type, + kTfLiteAffineQuantization); return absl::OkStatus(); } @@ -236,6 +244,7 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); + id device = gpu_helper_.mtlDevice; // Get input image sizes. diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index f161127f5..498036c12 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -670,7 +670,8 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections( detection_boxes[box_offset + 2], detection_boxes[box_offset + 3], detection_scores[i], detection_classes[i], options_.flip_vertically()); const auto& bbox = detection.location_data().relative_bounding_box(); - if (bbox.width() < 0 || bbox.height() < 0) { + if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) || + std::isnan(bbox.height())) { // Decoded detection boxes could have negative values for width/height due // to model prediction. Filter out those boxes since some downstream // calculators may assume non-negative values. (b/171391719) diff --git a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc index 45e242f3c..ffc96b2e4 100644 --- a/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_segmentation_calculator.cc @@ -138,7 +138,6 @@ using ::tflite::gpu::gl::GlShader; // } // } // -// Currently only OpenGLES 3.1 and CPU backends supported. // TODO Refactor and add support for other backends/platforms. // class TensorsToSegmentationCalculator : public CalculatorBase { diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc index d7b20ad56..d12f91741 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator_test.cc @@ -56,6 +56,8 @@ constexpr char kBboxTag[] = "BBOX"; constexpr char kForwardFlowEncodedTag[] = "FORWARD_FLOW_ENCODED"; constexpr char kImagePrefixTag[] = "IMAGE_PREFIX"; constexpr char kImageTag[] = "IMAGE"; +constexpr char kFloatContextFeatureOtherTag[] = "FLOAT_CONTEXT_FEATURE_OTHER"; +constexpr char kFloatContextFeatureTestTag[] = "FLOAT_CONTEXT_FEATURE_TEST"; constexpr char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; class UnpackMediaSequenceCalculatorTest : public ::testing::Test { diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.cc b/mediapipe/calculators/util/labels_to_render_data_calculator.cc index c4038f3ce..4aab3b676 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.cc @@ -175,7 +175,8 @@ absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { auto* text = label_annotation->mutable_text(); std::string display_text = labels[i]; - if (cc->Inputs().HasTag(kScoresTag)) { + if (cc->Inputs().HasTag(kScoresTag) || + options_.display_classification_score()) { absl::StrAppend(&display_text, ":", scores[i]); } text->set_display_text(display_text); diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.proto b/mediapipe/calculators/util/labels_to_render_data_calculator.proto index c5012ce85..cf0ada9c2 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.proto +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.proto @@ -62,4 +62,7 @@ message LabelsToRenderDataCalculatorOptions { // Uses Classification.display_name field instead of Classification.label. optional bool use_display_name = 9 [default = false]; + + // Displays Classification score if enabled. + optional bool display_classification_score = 10 [default = false]; } diff --git a/mediapipe/framework/api2/node.h b/mediapipe/framework/api2/node.h index b5f7586e7..7061afcae 100644 --- a/mediapipe/framework/api2/node.h +++ b/mediapipe/framework/api2/node.h @@ -223,24 +223,23 @@ class SubgraphImpl : public Subgraph, public Intf { // This macro is used to register a calculator that does not use automatic // registration. Deprecated. -#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \ - static mediapipe::NoDestructor \ - REGISTRY_STATIC_VAR(calculator_registration, __LINE__)( \ - mediapipe::CalculatorBaseRegistry::Register( \ - Impl::kCalculatorName, \ - absl::make_unique< \ - mediapipe::internal::CalculatorBaseFactoryFor>)) +#define MEDIAPIPE_NODE_IMPLEMENTATION(Impl) \ + static mediapipe::NoDestructor \ + REGISTRY_STATIC_VAR(calculator_registration, \ + __LINE__)(mediapipe::CalculatorBaseRegistry::Register( \ + Impl::kCalculatorName, \ + absl::make_unique>)) // This macro is used to register a non-split-contract calculator. Deprecated. #define MEDIAPIPE_REGISTER_NODE(name) REGISTER_CALCULATOR(name) // This macro is used to define a subgraph that does not use automatic // registration. Deprecated. -#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \ - static mediapipe::NoDestructor \ - REGISTRY_STATIC_VAR(subgraph_registration, \ - __LINE__)(mediapipe::SubgraphRegistry::Register( \ - Impl::kCalculatorName, absl::make_unique)) +#define MEDIAPIPE_SUBGRAPH_IMPLEMENTATION(Impl) \ + static mediapipe::NoDestructor \ + REGISTRY_STATIC_VAR(subgraph_registration, \ + __LINE__)(mediapipe::SubgraphRegistry::Register( \ + Impl::kCalculatorName, absl::make_unique)) } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index cab39abdb..dbd15cc68 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -454,12 +454,12 @@ class OutputShardAccessBase { if (output_) output_->SetNextTimestampBound(timestamp); } - bool IsClosed() { return output_ ? output_->IsClosed() : true; } + bool IsClosed() const { return output_ ? output_->IsClosed() : true; } void Close() { if (output_) output_->Close(); } - bool IsConnected() { return output_ != nullptr; } + bool IsConnected() const { return output_ != nullptr; } protected: const CalculatorContext& context_; @@ -559,7 +559,7 @@ class InputShardAccess : public Packet { PacketBase packet() const&& { return *this; } bool IsDone() const { return stream_->IsDone(); } - bool IsConnected() { return stream_ != nullptr; } + bool IsConnected() const { return stream_ != nullptr; } PacketBase Header() const { return FromOldPacket(stream_->Header()); } @@ -619,7 +619,7 @@ class InputSidePacketAccess : public Packet { const PacketBase& packet() const& { return *this; } PacketBase packet() const&& { return *this; } - bool IsConnected() { return connected_; } + bool IsConnected() const { return connected_; } private: InputSidePacketAccess(const mediapipe::Packet* packet) @@ -639,8 +639,8 @@ class InputShardOrSideAccess : public Packet { PacketBase packet() const&& { return *this; } bool IsDone() const { return stream_->IsDone(); } - bool IsConnected() { return connected_; } - bool IsStream() { return stream_ != nullptr; } + bool IsConnected() const { return connected_; } + bool IsStream() const { return stream_ != nullptr; } PacketBase Header() const { return FromOldPacket(stream_->Header()); } @@ -662,7 +662,7 @@ class InputShardOrSideAccess : public Packet { class PacketTypeAccess { public: - bool IsConnected() { return packet_type_ != nullptr; } + bool IsConnected() const { return packet_type_ != nullptr; } protected: PacketTypeAccess(PacketType* pt) : packet_type_(pt) {} @@ -675,7 +675,7 @@ class PacketTypeAccess { class PacketTypeAccessFallback : public PacketTypeAccess { public: - bool IsStream() { return is_stream_; } + bool IsStream() const { return is_stream_; } private: PacketTypeAccessFallback(PacketType* pt, bool is_stream) diff --git a/mediapipe/framework/calculator.proto b/mediapipe/framework/calculator.proto index 283f49d96..7c5e8b144 100644 --- a/mediapipe/framework/calculator.proto +++ b/mediapipe/framework/calculator.proto @@ -321,6 +321,8 @@ message CalculatorGraphConfig { // The maximum number of invocations that can be executed in parallel. // If not specified, the limit is one invocation. int32 max_in_flight = 16; + // Defines an option value for this Node from graph options or packets. + repeated string option_value = 17; // DEPRECATED: For backwards compatibility we allow users to // specify the old name for "input_side_packet" in proto configs. // These are automatically converted to input_side_packets during diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index d95d6d32b..6cff78dd3 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -465,7 +465,7 @@ absl::Status CalculatorGraph::ObserveOutputStream( } absl::StatusOr CalculatorGraph::AddOutputStreamPoller( - const std::string& stream_name) { + const std::string& stream_name, bool observe_timestamp_bounds) { RET_CHECK(initialized_).SetNoLogging() << "CalculatorGraph is not initialized."; int output_stream_index = validated_graph_->OutputStreamIndex(stream_name); @@ -479,7 +479,7 @@ absl::StatusOr CalculatorGraph::AddOutputStreamPoller( stream_name, &any_packet_type_, std::bind(&CalculatorGraph::UpdateThrottledNodes, this, std::placeholders::_1, std::placeholders::_2), - &output_stream_managers_[output_stream_index])); + &output_stream_managers_[output_stream_index], observe_timestamp_bounds)); OutputStreamPoller poller(internal_poller); graph_output_streams_.push_back(std::move(internal_poller)); return std::move(poller); diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index fb0bb6971..0e6d53b6a 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -164,7 +164,8 @@ class CalculatorGraph { // polling API for accessing a stream's output. Should only be called before // Run() or StartRun(). For asynchronous output, use ObserveOutputStream. See // also the helpers in tool/sink.h. - StatusOrPoller AddOutputStreamPoller(const std::string& stream_name); + StatusOrPoller AddOutputStreamPoller(const std::string& stream_name, + bool observe_timestamp_bounds = false); // Gets output side packet by name after the graph is done. However, base // packets (generated by PacketGenerators) can be retrieved before diff --git a/mediapipe/framework/calculator_graph_test.cc b/mediapipe/framework/calculator_graph_test.cc index a451b987d..26d2f484c 100644 --- a/mediapipe/framework/calculator_graph_test.cc +++ b/mediapipe/framework/calculator_graph_test.cc @@ -4348,5 +4348,349 @@ TEST(CalculatorGraph, GraphInputStreamWithTag) { ASSERT_EQ(5, packet_dump.size()); } +TEST(CalculatorGraph, GraphInputStreamBeforeStartRun) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "VIDEO_METADATA:video_metadata" + input_stream: "max_count" + node { + calculator: "PassThroughCalculator" + input_stream: "FIRST_INPUT:video_metadata" + input_stream: "max_count" + output_stream: "FIRST_INPUT:output_0" + output_stream: "output_1" + } + )pb"); + std::vector packet_dump; + tool::AddVectorSink("output_0", &config, &packet_dump); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + ASSERT_EQ(graph + .AddPacketToInputStream("video_metadata", + MakePacket(0).At(Timestamp(0))) + .code(), + absl::StatusCode::kFailedPrecondition); +} + +// Returns the first packet of the input stream. +class FirstPacketFilterCalculator : public CalculatorBase { + public: + FirstPacketFilterCalculator() {} + ~FirstPacketFilterCalculator() override {} + + static absl::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + if (!seen_first_packet_) { + cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); + cc->Outputs().Index(0).Close(); + seen_first_packet_ = true; + } + return absl::OkStatus(); + } + + private: + bool seen_first_packet_ = false; +}; +REGISTER_CALCULATOR(FirstPacketFilterCalculator); +constexpr int kDefaultMaxCount = 1000; + +TEST(CalculatorGraph, TestPollPacket) { + CalculatorGraphConfig config; + CalculatorGraphConfig::Node* node = config.add_node(); + node->set_calculator("CountingSourceCalculator"); + node->add_output_stream("output"); + node->add_input_side_packet("MAX_COUNT:max_count"); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + auto status_or_poller = graph.AddOutputStreamPoller("output"); + ASSERT_TRUE(status_or_poller.ok()); + OutputStreamPoller poller = std::move(status_or_poller.value()); + MP_ASSERT_OK( + graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); + Packet packet; + int num_packets = 0; + while (poller.Next(&packet)) { + EXPECT_EQ(num_packets, packet.Get()); + ++num_packets; + } + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_FALSE(poller.Next(&packet)); + EXPECT_EQ(kDefaultMaxCount, num_packets); +} + +TEST(CalculatorGraph, TestOutputStreamPollerDesiredQueueSize) { + CalculatorGraphConfig config; + CalculatorGraphConfig::Node* node = config.add_node(); + node->set_calculator("CountingSourceCalculator"); + node->add_output_stream("output"); + node->add_input_side_packet("MAX_COUNT:max_count"); + + for (int queue_size = 1; queue_size < 10; ++queue_size) { + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + auto status_or_poller = graph.AddOutputStreamPoller("output"); + ASSERT_TRUE(status_or_poller.ok()); + OutputStreamPoller poller = std::move(status_or_poller.value()); + poller.SetMaxQueueSize(queue_size); + MP_ASSERT_OK( + graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); + Packet packet; + int num_packets = 0; + while (poller.Next(&packet)) { + EXPECT_EQ(num_packets, packet.Get()); + ++num_packets; + } + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_FALSE(poller.Next(&packet)); + EXPECT_EQ(kDefaultMaxCount, num_packets); + } +} + +TEST(CalculatorGraph, TestPollPacketsFromMultipleStreams) { + CalculatorGraphConfig config; + CalculatorGraphConfig::Node* node1 = config.add_node(); + node1->set_calculator("CountingSourceCalculator"); + node1->add_output_stream("stream1"); + node1->add_input_side_packet("MAX_COUNT:max_count"); + CalculatorGraphConfig::Node* node2 = config.add_node(); + node2->set_calculator("PassThroughCalculator"); + node2->add_input_stream("stream1"); + node2->add_output_stream("stream2"); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + auto status_or_poller1 = graph.AddOutputStreamPoller("stream1"); + ASSERT_TRUE(status_or_poller1.ok()); + OutputStreamPoller poller1 = std::move(status_or_poller1.value()); + auto status_or_poller2 = graph.AddOutputStreamPoller("stream2"); + ASSERT_TRUE(status_or_poller2.ok()); + OutputStreamPoller poller2 = std::move(status_or_poller2.value()); + MP_ASSERT_OK( + graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); + Packet packet1; + Packet packet2; + int num_packets1 = 0; + int num_packets2 = 0; + int running_pollers = 2; + while (running_pollers > 0) { + if (poller1.Next(&packet1)) { + EXPECT_EQ(num_packets1++, packet1.Get()); + } else { + --running_pollers; + } + if (poller2.Next(&packet2)) { + EXPECT_EQ(num_packets2++, packet2.Get()); + } else { + --running_pollers; + } + } + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + EXPECT_FALSE(poller1.Next(&packet1)); + EXPECT_FALSE(poller2.Next(&packet2)); + EXPECT_EQ(kDefaultMaxCount, num_packets1); + EXPECT_EQ(kDefaultMaxCount, num_packets2); +} + +class TimestampBoundTestCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + cc->Outputs().Index(0).Set(); + return absl::OkStatus(); + } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } + absl::Status Process(CalculatorContext* cc) final { + if (count_ % 50 == 1) { + // Outputs packets at t10 and t60. + cc->Outputs().Index(0).AddPacket( + MakePacket(count_).At(Timestamp(count_))); + } else if (count_ % 15 == 7) { + cc->Outputs().Index(0).SetNextTimestampBound(Timestamp(count_)); + } + absl::SleepFor(absl::Milliseconds(3)); + ++count_; + if (count_ == 110) { + return tool::StatusStop(); + } + return absl::OkStatus(); + } + + private: + int count_ = 0; +}; +REGISTER_CALCULATOR(TimestampBoundTestCalculator); + +TEST(CalculatorGraph, TestPollPacketsWithTimestampNotification) { + std::string config_str = R"( + node { + calculator: "TimestampBoundTestCalculator" + output_stream: "foo" + } + )"; + CalculatorGraphConfig graph_config = + mediapipe::ParseTextProtoOrDie(config_str); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + auto status_or_poller = + graph.AddOutputStreamPoller("foo", /*observe_timestamp_bounds=*/true); + ASSERT_TRUE(status_or_poller.ok()); + OutputStreamPoller poller = std::move(status_or_poller.value()); + Packet packet; + std::vector timestamps; + std::vector values; + MP_ASSERT_OK(graph.StartRun({})); + while (poller.Next(&packet)) { + if (packet.IsEmpty()) { + timestamps.push_back(packet.Timestamp().Value()); + } else { + values.push_back(packet.Get()); + } + } + MP_ASSERT_OK(graph.WaitUntilDone()); + ASSERT_FALSE(poller.Next(&packet)); + ASSERT_FALSE(timestamps.empty()); + int prev_t = 0; + for (auto t : timestamps) { + EXPECT_TRUE(t > prev_t && t < 110); + prev_t = t; + } + ASSERT_EQ(3, values.size()); + EXPECT_EQ(1, values[0]); + EXPECT_EQ(51, values[1]); + EXPECT_EQ(101, values[2]); +} + +// Ensure that when a custom input stream handler is used to handle packets from +// input streams, an error message is outputted with the appropriate link to +// resolve the issue when the calculator doesn't handle inputs in monotonically +// increasing order of timestamps. +TEST(CalculatorGraph, SimpleMuxCalculatorWithCustomInputStreamHandler) { + CalculatorGraph graph; + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: 'input0' + input_stream: 'input1' + node { + calculator: 'SimpleMuxCalculator' + input_stream: 'input0' + input_stream: 'input1' + input_stream_handler { + input_stream_handler: "ImmediateInputStreamHandler" + } + output_stream: 'output' + } + )pb"); + std::vector packet_dump; + tool::AddVectorSink("output", &config, &packet_dump); + + MP_ASSERT_OK(graph.Initialize(config)); + MP_ASSERT_OK(graph.StartRun({})); + + // Send packets to input stream "input0" at timestamps 0 and 1 consecutively. + Timestamp input0_timestamp = Timestamp(0); + MP_EXPECT_OK(graph.AddPacketToInputStream( + "input0", MakePacket(1).At(input0_timestamp))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, packet_dump.size()); + EXPECT_EQ(1, packet_dump[0].Get()); + + ++input0_timestamp; + MP_EXPECT_OK(graph.AddPacketToInputStream( + "input0", MakePacket(3).At(input0_timestamp))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(2, packet_dump.size()); + EXPECT_EQ(3, packet_dump[1].Get()); + + // Send a packet to input stream "input1" at timestamp 0 after sending two + // packets at timestamps 0 and 1 to input stream "input0". This will result + // in a mismatch in timestamps as the SimpleMuxCalculator doesn't handle + // inputs from all streams in monotonically increasing order of timestamps. + Timestamp input1_timestamp = Timestamp(0); + MP_EXPECT_OK(graph.AddPacketToInputStream( + "input1", MakePacket(2).At(input1_timestamp))); + absl::Status run_status = graph.WaitUntilIdle(); + EXPECT_THAT( + run_status.ToString(), + testing::AllOf( + // The core problem. + testing::HasSubstr("timestamp mismatch on a calculator"), + testing::HasSubstr( + "timestamps that are not strictly monotonically increasing"), + // Link to the possible solution. + testing::HasSubstr("ImmediateInputStreamHandler class comment"))); +} + +void DoTestMultipleGraphRuns(absl::string_view input_stream_handler, + bool select_packet) { + std::string graph_proto = absl::StrFormat(R"( + input_stream: 'input' + input_stream: 'select' + node { + calculator: 'PassThroughCalculator' + input_stream: 'input' + input_stream: 'select' + input_stream_handler { + input_stream_handler: "%s" + } + output_stream: 'output' + output_stream: 'select_out' + } + )", + input_stream_handler.data()); + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(graph_proto); + std::vector packet_dump; + tool::AddVectorSink("output", &config, &packet_dump); + + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(config)); + + struct Run { + Timestamp timestamp; + int value; + }; + std::vector runs = {{.timestamp = Timestamp(2000), .value = 2}, + {.timestamp = Timestamp(1000), .value = 1}}; + for (const Run& run : runs) { + MP_ASSERT_OK(graph.StartRun({})); + + if (select_packet) { + MP_EXPECT_OK(graph.AddPacketToInputStream( + "select", MakePacket(0).At(run.timestamp))); + } + MP_EXPECT_OK(graph.AddPacketToInputStream( + "input", MakePacket(run.value).At(run.timestamp))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, packet_dump.size()); + EXPECT_EQ(run.value, packet_dump[0].Get()); + EXPECT_EQ(run.timestamp, packet_dump[0].Timestamp()); + + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); + + packet_dump.clear(); + } +} + +TEST(CalculatorGraph, MultipleRunsWithDifferentInputStreamHandlers) { + DoTestMultipleGraphRuns("BarrierInputStreamHandler", true); + DoTestMultipleGraphRuns("DefaultInputStreamHandler", true); + DoTestMultipleGraphRuns("EarlyCloseInputStreamHandler", true); + DoTestMultipleGraphRuns("FixedSizeInputStreamHandler", true); + DoTestMultipleGraphRuns("ImmediateInputStreamHandler", false); + DoTestMultipleGraphRuns("MuxInputStreamHandler", true); + DoTestMultipleGraphRuns("SyncSetInputStreamHandler", true); + DoTestMultipleGraphRuns("TimestampAlignInputStreamHandler", true); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/deps/file_path.cc b/mediapipe/framework/deps/file_path.cc index 19ebbe500..45cb9bab8 100644 --- a/mediapipe/framework/deps/file_path.cc +++ b/mediapipe/framework/deps/file_path.cc @@ -45,7 +45,7 @@ std::string JoinPathImpl(bool honor_abs, // This size calculation is worst-case: it assumes one extra "/" for every // path other than the first. size_t total_size = paths.size() - 1; - for (const absl::string_view path : paths) total_size += path.size(); + for (const absl::string_view& path : paths) total_size += path.size(); result.resize(total_size); auto begin = result.begin(); diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index f91fb5d6f..f79e6aa43 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -81,6 +81,12 @@ mediapipe_proto_library( deps = ["//mediapipe/framework/formats/annotation:rasterization_proto"], ) +mediapipe_proto_library( + name = "affine_transform_data_proto", + srcs = ["affine_transform_data.proto"], + visibility = ["//visibility:public"], +) + mediapipe_proto_library( name = "time_series_header_proto", srcs = ["time_series_header.proto"], @@ -119,6 +125,31 @@ cc_library( ], ) +cc_library( + name = "affine_transform", + srcs = ["affine_transform.cc"], + hdrs = ["affine_transform.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//mediapipe/framework:port", + "//mediapipe/framework:type_map", + "//mediapipe/framework/formats:affine_transform_data_cc_proto", + "//mediapipe/framework/port:integral_types", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:point", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:statusor", + "//mediapipe/framework/tool:status_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + cc_library( name = "image_frame", srcs = ["image_frame.cc"], diff --git a/mediapipe/framework/formats/affine_transform.cc b/mediapipe/framework/formats/affine_transform.cc new file mode 100644 index 000000000..a2c0d0e1a --- /dev/null +++ b/mediapipe/framework/formats/affine_transform.cc @@ -0,0 +1,228 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/formats/affine_transform.h" + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/substitute.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/point2.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/statusor.h" +#include "mediapipe/framework/tool/status_util.h" +#include "mediapipe/framework/type_map.h" + +namespace mediapipe { +using ::mediapipe::AffineTransformData; + +AffineTransform::AffineTransform() { SetScale(Point2_f(1, 1)); } + +AffineTransform::AffineTransform( + const AffineTransformData& affine_transform_data) + : affine_transform_data_(affine_transform_data), is_dirty_(true) { + // make sure scale is set to default (1, 1) when none provided + if (!affine_transform_data_.has_scale()) { + SetScale(Point2_f(1, 1)); + } +} + +AffineTransform AffineTransform::Create(const Point2_f& translation, + const Point2_f& scale, float rotation, + const Point2_f& shear) { + AffineTransformData affine_transform_data; + + auto* t = affine_transform_data.mutable_translation(); + t->set_x(translation.x()); + t->set_y(translation.y()); + + auto* s = affine_transform_data.mutable_scale(); + s->set_x(scale.x()); + s->set_y(scale.y()); + + s = affine_transform_data.mutable_shear(); + s->set_x(shear.x()); + s->set_y(shear.y()); + + affine_transform_data.set_rotation(rotation); + + return AffineTransform(affine_transform_data); +} + +// Accessor for the composition matrix +std::vector AffineTransform::GetCompositionMatrix() { + float r = affine_transform_data_.rotation(); + const auto t = affine_transform_data_.translation(); + const auto sc = affine_transform_data_.scale(); + const auto sh = affine_transform_data_.shear(); + + if (is_dirty_) { + // Composition matrix M = T*R*Sh*Sc + // Column based to match GL matrix store order + float cos_r = std::cos(r); + float sin_r = std::sin(r); + matrix_[0] = (cos_r + sin_r * -sh.y()) * sc.x(); + matrix_[1] = (-sin_r + cos_r * -sh.y()) * sc.x(); + matrix_[2] = 0; + matrix_[3] = (cos_r * -sh.x() + sin_r) * sc.y(); + matrix_[4] = (-sin_r * -sh.x() + cos_r) * sc.y(); + matrix_[5] = 0; + matrix_[6] = t.x(); + matrix_[7] = -t.y(); + matrix_[8] = 1; + is_dirty_ = false; + } + + return matrix_; +} + +Point2_f AffineTransform::GetScale() const { + return Point2_f(affine_transform_data_.scale().x(), + affine_transform_data_.scale().y()); +} + +Point2_f AffineTransform::GetTranslation() const { + return Point2_f(affine_transform_data_.translation().x(), + affine_transform_data_.translation().y()); +} + +Point2_f AffineTransform::GetShear() const { + return Point2_f(affine_transform_data_.shear().x(), + affine_transform_data_.shear().y()); +} + +float AffineTransform::GetRotation() const { + return affine_transform_data_.rotation(); +} + +void AffineTransform::SetScale(const Point2_f& scale) { + auto* s = affine_transform_data_.mutable_scale(); + s->set_x(scale.x()); + s->set_y(scale.y()); + is_dirty_ = true; +} + +void AffineTransform::SetTranslation(const Point2_f& translation) { + auto* t = affine_transform_data_.mutable_translation(); + t->set_x(translation.x()); + t->set_y(translation.y()); + is_dirty_ = true; +} + +void AffineTransform::SetShear(const Point2_f& shear) { + auto* s = affine_transform_data_.mutable_shear(); + s->set_x(shear.x()); + s->set_y(shear.y()); + is_dirty_ = true; +} + +void AffineTransform::SetRotation(float rotationInRadians) { + affine_transform_data_.set_rotation(rotationInRadians); + is_dirty_ = true; +} + +void AffineTransform::AddScale(const Point2_f& scale) { + auto* s = affine_transform_data_.mutable_scale(); + s->set_x(s->x() + scale.x()); + s->set_y(s->y() + scale.y()); + is_dirty_ = true; +} + +void AffineTransform::AddTranslation(const Point2_f& translation) { + auto* t = affine_transform_data_.mutable_translation(); + t->set_x(t->x() + translation.x()); + t->set_y(t->y() + translation.y()); + is_dirty_ = true; +} + +void AffineTransform::AddShear(const Point2_f& shear) { + auto* s = affine_transform_data_.mutable_shear(); + s->set_x(s->x() + shear.x()); + s->set_y(s->y() + shear.y()); + is_dirty_ = true; +} + +void AffineTransform::AddRotation(float rotationInRadians) { + affine_transform_data_.set_rotation(affine_transform_data_.rotation() + + rotationInRadians); + is_dirty_ = true; +} + +void AffineTransform::SetFromProto(const AffineTransformData& proto) { + affine_transform_data_ = proto; +} + +void AffineTransform::ConvertToProto(AffineTransformData* proto) const { + *proto = affine_transform_data_; +} + +AffineTransformData AffineTransform::ConvertToProto() const { + AffineTransformData affine_transform_data; + ConvertToProto(&affine_transform_data); + return affine_transform_data; +} + +bool compare(float lhs, float rhs, float epsilon = 0.001f) { + return std::fabs(lhs - rhs) < epsilon; +} + +bool AffineTransform::Equals(const AffineTransform& other, + float epsilon) const { + auto trans1 = GetTranslation(); + auto trans2 = other.GetTranslation(); + + if (!(compare(trans1.x(), trans2.x(), epsilon) && + compare(trans1.y(), trans2.y(), epsilon))) + return false; + + auto scale1 = GetScale(); + auto scale2 = other.GetScale(); + + if (!(compare(scale1.x(), scale2.x(), epsilon) && + compare(scale1.y(), scale2.y(), epsilon))) + return false; + + auto shear1 = GetShear(); + auto shear2 = other.GetShear(); + + if (!(compare(shear1.x(), shear2.x(), epsilon) && + compare(shear1.y(), shear2.y(), epsilon))) + return false; + + auto rot1 = GetRotation(); + auto rot2 = other.GetRotation(); + + if (!compare(rot1, rot2, epsilon)) { + return false; + } + + return true; +} + +bool AffineTransform::Equal(const AffineTransform& lhs, + const AffineTransform& rhs, float epsilon) { + return lhs.Equals(rhs, epsilon); +} + +MEDIAPIPE_REGISTER_TYPE(mediapipe::AffineTransform, + "::mediapipe::AffineTransform", nullptr, nullptr); + +} // namespace mediapipe diff --git a/mediapipe/framework/formats/affine_transform.h b/mediapipe/framework/formats/affine_transform.h new file mode 100644 index 000000000..5dd3f1ffe --- /dev/null +++ b/mediapipe/framework/formats/affine_transform.h @@ -0,0 +1,86 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// A container for affine transform data +// This wrapper provides two functionalities: +// 1. Factory methods for creation of Transform objects and thus +// AffineTransformData protocol buffers. These methods guarantee a valid +// affine transform data and are the preferred way of creating such. +// 2. Accessors which allow for access of the data and the convertion to proto +// format + +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_AFFINE_TRANSFORM_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_AFFINE_TRANSFORM_H_ + +#include +#include + +#include "mediapipe/framework/formats/affine_transform_data.pb.h" +#include "mediapipe/framework/port.h" +#include "mediapipe/framework/port/point2.h" + +namespace mediapipe { + +class AffineTransform { + public: + // CREATION METHODS. + AffineTransform(); + + // Constructs a affine transform wrapping the specified affine transform data. + // Checks the validity of the input and crashes upon failure. + explicit AffineTransform(const AffineTransformData& transform_data); + + static AffineTransform Create(const Point2_f& translation = Point2_f(0, 0), + const Point2_f& scale = Point2_f(1, 1), + float rotation = 0, + const Point2_f& shear = Point2_f(0, 0)); + + // ACCESSORS + // Accessor for the composition matrix + std::vector GetCompositionMatrix(); + + Point2_f GetScale() const; + Point2_f GetTranslation() const; + Point2_f GetShear() const; + float GetRotation() const; + + void SetScale(const Point2_f& scale); + void SetTranslation(const Point2_f& translation); + void SetShear(const Point2_f& shear); + void SetRotation(float rotation); + + void AddScale(const Point2_f& scale); + void AddTranslation(const Point2_f& translation); + void AddShear(const Point2_f& shear); + void AddRotation(float rotation); + + // Serializes and deserializes the affine transform object. + void ConvertToProto(AffineTransformData* proto) const; + AffineTransformData ConvertToProto() const; + void SetFromProto(const AffineTransformData& proto); + + bool Equals(const AffineTransform& other, float epsilon = 0.001f) const; + + static bool Equal(const AffineTransform& lhs, const AffineTransform& rhs, + float epsilon = 0.001f); + + private: + // The wrapped transform data. + AffineTransformData affine_transform_data_; + std::vector matrix_ = {1, 0, 0, 0, 1, 0, 0, 0, 1}; + bool is_dirty_ = false; +}; +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_AFFINE_TRANSFORM_H_ diff --git a/mediapipe/framework/formats/affine_transform_data.proto b/mediapipe/framework/formats/affine_transform_data.proto new file mode 100644 index 000000000..4745ce443 --- /dev/null +++ b/mediapipe/framework/formats/affine_transform_data.proto @@ -0,0 +1,33 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +option objc_class_prefix = "MediaPipe"; + +// Proto for serializing Vector2 data +message Vector2Data { + optional float x = 1; + optional float y = 2; +} + +// Proto for serializing Affine Transform data. +message AffineTransformData { + optional Vector2Data translation = 1; + optional Vector2Data scale = 2; + optional Vector2Data shear = 3; + optional float rotation = 4; // in radians +} diff --git a/mediapipe/framework/formats/affine_transform_test.cc b/mediapipe/framework/formats/affine_transform_test.cc new file mode 100644 index 000000000..6ffd1b947 --- /dev/null +++ b/mediapipe/framework/formats/affine_transform_test.cc @@ -0,0 +1,94 @@ +#include "mediapipe/framework/formats/affine_transform.h" + +#include + +#include "base/logging.h" +#include "mediapipe/framework/formats/affine_transform_data.pb.h" +#include "mediapipe/framework/port/point2.h" +#include "testing/base/public/gmock.h" +#include "testing/base/public/gunit.h" + +namespace mediapipe { + +TEST(AffineTransformTest, TraslationTest) { + AffineTransform transform; + transform.SetTranslation(Point2_f(10, -3)); + + auto trans = transform.GetTranslation(); + EXPECT_FLOAT_EQ(10, trans.x()); + EXPECT_FLOAT_EQ(-3, trans.y()); + + transform.AddTranslation(Point2_f(-10, 3)); + + trans = transform.GetTranslation(); + EXPECT_FLOAT_EQ(0, trans.x()); + EXPECT_FLOAT_EQ(0, trans.y()); +} + +TEST(AffineTransformTest, ScaleTest) { + AffineTransform transform; + transform.SetScale(Point2_f(10, -3)); + + auto scale = transform.GetScale(); + EXPECT_FLOAT_EQ(10, scale.x()); + EXPECT_FLOAT_EQ(-3, scale.y()); + + transform.AddScale(Point2_f(-10, 3)); + + scale = transform.GetScale(); + EXPECT_FLOAT_EQ(0, scale.x()); + EXPECT_FLOAT_EQ(0, scale.y()); +} + +TEST(AffineTransformTest, RotationTest) { + AffineTransform transform; + transform.SetRotation(0.7); + + float rot = transform.GetRotation(); + EXPECT_FLOAT_EQ(0.7, rot); + + transform.AddRotation(-0.7); + rot = transform.GetRotation(); + EXPECT_FLOAT_EQ(0, rot); +} + +TEST(AffineTransformTest, ShearTest) { + AffineTransform transform; + transform.SetShear(Point2_f(10, -3)); + + auto shear = transform.GetShear(); + EXPECT_FLOAT_EQ(10, shear.x()); + EXPECT_FLOAT_EQ(-3, shear.y()); + + transform.AddShear(Point2_f(-10, 3)); + + shear = transform.GetShear(); + EXPECT_FLOAT_EQ(0, shear.x()); + EXPECT_FLOAT_EQ(0, shear.y()); +} + +TEST(AffineTransformTest, TransformTest) { + AffineTransform transform1; + transform1 = AffineTransform::Create(Point2_f(0.1, -0.2), Point2_f(0.3, -0.4), + 0.5, Point2_f(0.6, -0.7)); + + AffineTransform transform2; + transform2 = AffineTransform::Create(Point2_f(0.1, -0.2), Point2_f(0.3, -0.4), + 0.5, Point2_f(0.6, -0.7)); + + EXPECT_THAT(true, transform1.Equals(transform2)); + EXPECT_THAT(true, AffineTransform::Equal(transform1, transform2)); + + transform1 = AffineTransform::Create(Point2_f(0.00001, -0.00002), + Point2_f(0.00003, -0.00004), 0.00005, + Point2_f(0.00006, -0.00007)); + + transform2 = AffineTransform::Create(Point2_f(0.00001, -0.00002), + Point2_f(0.00003, -0.00004), 0.00005, + Point2_f(0.00006, -0.00007)); + + EXPECT_THAT(true, transform1.Equals(transform2, 0.000001)); + EXPECT_THAT(true, AffineTransform::Equal(transform1, transform2, 0.000001)); +} + +} // namespace mediapipe diff --git a/mediapipe/framework/graph_output_stream.cc b/mediapipe/framework/graph_output_stream.cc index 6639bb8bf..de024dfe5 100644 --- a/mediapipe/framework/graph_output_stream.cc +++ b/mediapipe/framework/graph_output_stream.cc @@ -125,9 +125,10 @@ absl::Status OutputStreamObserver::Notify() { absl::Status OutputStreamPollerImpl::Initialize( const std::string& stream_name, const PacketType* packet_type, std::function queue_size_callback, - OutputStreamManager* output_stream_manager) { + OutputStreamManager* output_stream_manager, bool observe_timestamp_bounds) { MP_RETURN_IF_ERROR(GraphOutputStream::Initialize(stream_name, packet_type, - output_stream_manager)); + output_stream_manager, + observe_timestamp_bounds)); input_stream_handler_->SetQueueSizeCallbacks(queue_size_callback, queue_size_callback); return absl::OkStatus(); @@ -176,11 +177,17 @@ void OutputStreamPollerImpl::NotifyError() { bool OutputStreamPollerImpl::Next(Packet* packet) { CHECK(packet); bool empty_queue = true; + bool timestamp_bound_changed = false; Timestamp min_timestamp = Timestamp::Unset(); mutex_.Lock(); while (true) { min_timestamp = input_stream_->MinTimestampOrBound(&empty_queue); - if (graph_has_error_ || !empty_queue || + if (empty_queue) { + timestamp_bound_changed = + input_stream_handler_->ProcessTimestampBounds() && + output_timestamp_ < min_timestamp.PreviousAllowedInStream(); + } + if (graph_has_error_ || !empty_queue || timestamp_bound_changed || min_timestamp == Timestamp::Done()) { break; } else { @@ -191,17 +198,26 @@ bool OutputStreamPollerImpl::Next(Packet* packet) { mutex_.Unlock(); return false; } + if (empty_queue) { + output_timestamp_ = min_timestamp.PreviousAllowedInStream(); + } else { + output_timestamp_ = min_timestamp; + } mutex_.Unlock(); if (min_timestamp == Timestamp::Done()) { return false; } - int num_packets_dropped = 0; - bool stream_is_done = false; - *packet = input_stream_->PopPacketAtTimestamp( - min_timestamp, &num_packets_dropped, &stream_is_done); - CHECK_EQ(num_packets_dropped, 0) - << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", - num_packets_dropped, input_stream_->Name()); + if (!empty_queue) { + int num_packets_dropped = 0; + bool stream_is_done = false; + *packet = input_stream_->PopPacketAtTimestamp( + min_timestamp, &num_packets_dropped, &stream_is_done); + CHECK_EQ(num_packets_dropped, 0) + << absl::Substitute("Dropped $0 packet(s) on input stream \"$1\".", + num_packets_dropped, input_stream_->Name()); + } else if (timestamp_bound_changed) { + *packet = Packet().At(min_timestamp.PreviousAllowedInStream()); + } return true; } diff --git a/mediapipe/framework/graph_output_stream.h b/mediapipe/framework/graph_output_stream.h index 393407aa3..b541aec12 100644 --- a/mediapipe/framework/graph_output_stream.h +++ b/mediapipe/framework/graph_output_stream.h @@ -143,7 +143,8 @@ class OutputStreamPollerImpl : public GraphOutputStream { absl::Status Initialize( const std::string& stream_name, const PacketType* packet_type, std::function queue_size_callback, - OutputStreamManager* output_stream_manager); + OutputStreamManager* output_stream_manager, + bool observe_timestamp_bounds = false); void PrepareForRun(std::function notification_callback, std::function error_callback) override; @@ -170,6 +171,7 @@ class OutputStreamPollerImpl : public GraphOutputStream { absl::Mutex mutex_; absl::CondVar handler_condvar_ ABSL_GUARDED_BY(mutex_); bool graph_has_error_ ABSL_GUARDED_BY(mutex_); + Timestamp output_timestamp_ ABSL_GUARDED_BY(mutex_) = Timestamp::Min(); }; } // namespace internal diff --git a/mediapipe/framework/testdata/night_light_calculator.proto b/mediapipe/framework/testdata/night_light_calculator.proto index 36180439a..2f1fb7db6 100644 --- a/mediapipe/framework/testdata/night_light_calculator.proto +++ b/mediapipe/framework/testdata/night_light_calculator.proto @@ -51,4 +51,17 @@ message NightLightCalculatorOptions { // Format string used by string::Substitute to construct the output. optional string format_string = 9; + + message LightBundle { + optional string room_id = 1; + repeated NightLightCalculatorOptions room_lights = 2; + } + + repeated LightBundle bundle = 10; + + // The number of night-lights. + repeated int32 num_lights = 11; + + // Options for nested night-lights. + optional NightLightCalculatorOptions sub_options = 12; } diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 00d3648e5..db92cfd38 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -180,15 +180,66 @@ cc_library( ], ) +mediapipe_proto_library( + name = "field_data_proto", + srcs = ["field_data.proto"], + visibility = ["//visibility:public"], + deps = ["@com_google_protobuf//:any_proto"], +) + +cc_library( + name = "options_field_util", + srcs = ["options_field_util.cc"], + hdrs = ["options_field_util.h"], + visibility = ["//mediapipe/framework:mediapipe_internal"], + deps = [ + ":field_data_cc_proto", + ":name_util", + ":options_registry", + ":proto_util_lite", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet", + "//mediapipe/framework:packet_type", + "//mediapipe/framework/port:advanced_proto", + "//mediapipe/framework/port:any_proto", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "options_syntax_util", + srcs = ["options_syntax_util.cc"], + hdrs = ["options_syntax_util.h"], + visibility = ["//mediapipe/framework:mediapipe_internal"], + deps = [ + ":name_util", + ":options_field_util", + ":options_registry", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:packet", + "//mediapipe/framework:packet_type", + "//mediapipe/framework/port:advanced_proto", + "//mediapipe/framework/port:any_proto", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "options_util", srcs = ["options_util.cc"], hdrs = ["options_util.h"], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ + ":options_field_util", ":options_map", + ":options_registry", + ":options_syntax_util", ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_context", "//mediapipe/framework:collection", "//mediapipe/framework:input_stream_shard", "//mediapipe/framework:output_side_packet", @@ -199,7 +250,7 @@ cc_library( "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:any_proto", "//mediapipe/framework/port:status", - "//mediapipe/framework/tool:type_util", + "//mediapipe/framework/tool:name_util", "@com_google_absl//absl/strings", ], ) @@ -227,6 +278,8 @@ cc_library( "//mediapipe/framework/deps:registration", "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:logging", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/synchronization", ], ) @@ -246,11 +299,13 @@ mediapipe_cc_test( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:validated_graph_config", + "//mediapipe/framework/deps:message_matchers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:status", "//mediapipe/framework/testdata:night_light_calculator_options_lib", "//mediapipe/framework/tool:node_chain_subgraph_options_lib", + "//mediapipe/framework/tool:options_syntax_util", "//mediapipe/util:header_util", ], ) diff --git a/mediapipe/framework/tool/field_data.proto b/mediapipe/framework/tool/field_data.proto new file mode 100644 index 000000000..c8713c2e6 --- /dev/null +++ b/mediapipe/framework/tool/field_data.proto @@ -0,0 +1,47 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from mediapipe/framework/tool/source.proto. +// The forked proto must remain identical to the original proto and should be +// ONLY used by mediapipe open source project. + +syntax = "proto2"; + +package mediapipe; + +// `MessageData`, like protobuf.Any, contains an arbitrary serialized protbuf +// along with a URL that describes the type of the serialized message. +message MessageData { + // A URL/resource name that identifies the type of serialized protbuf. + optional string type_url = 1; + + // Must be a valid serialized protocol buffer of the above specified type. + optional bytes value = 2; +} + +// Data for one Protobuf field or one MediaPipe packet. +message FieldData { + oneof value { + sint32 int32_value = 1; + sint64 int64_value = 2; + uint32 uint32_value = 3; + uint64 uint64_value = 4; + double double_value = 5; + float float_value = 6; + bool bool_value = 7; + sint32 enum_value = 8; + string string_value = 9; + MessageData message_value = 10; + } +} diff --git a/mediapipe/framework/tool/options_field_util.cc b/mediapipe/framework/tool/options_field_util.cc new file mode 100644 index 000000000..043015b2a --- /dev/null +++ b/mediapipe/framework/tool/options_field_util.cc @@ -0,0 +1,495 @@ + +#include "mediapipe/framework/tool/options_field_util.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port/advanced_proto_inc.h" +#include "mediapipe/framework/port/any_proto.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/name_util.h" +#include "mediapipe/framework/tool/proto_util_lite.h" + +namespace mediapipe { +namespace tool { +namespace options_field_util { + +using ::mediapipe::proto_ns::internal::WireFormatLite; +using FieldType = WireFormatLite::FieldType; +using ::mediapipe::proto_ns::io::ArrayInputStream; +using ::mediapipe::proto_ns::io::CodedInputStream; +using ::mediapipe::proto_ns::io::CodedOutputStream; +using ::mediapipe::proto_ns::io::StringOutputStream; + +// Utility functions for OptionsFieldUtil. +namespace { + +// Converts a FieldDescriptor::Type to the corresponding FieldType. +FieldType AsFieldType(proto_ns::FieldDescriptorProto::Type type) { + return static_cast(type); +} + +absl::Status WriteValue(const FieldData& value, FieldType field_type, + std::string* field_bytes) { + StringOutputStream sos(field_bytes); + CodedOutputStream out(&sos); + switch (field_type) { + case WireFormatLite::TYPE_INT32: + WireFormatLite::WriteInt32NoTag(value.int32_value(), &out); + break; + case WireFormatLite::TYPE_SINT32: + WireFormatLite::WriteSInt32NoTag(value.int32_value(), &out); + break; + case WireFormatLite::TYPE_INT64: + WireFormatLite::WriteInt64NoTag(value.int64_value(), &out); + break; + case WireFormatLite::TYPE_SINT64: + WireFormatLite::WriteSInt64NoTag(value.int64_value(), &out); + break; + case WireFormatLite::TYPE_UINT32: + WireFormatLite::WriteUInt32NoTag(value.uint32_value(), &out); + break; + case WireFormatLite::TYPE_UINT64: + WireFormatLite::WriteUInt64NoTag(value.uint64_value(), &out); + break; + case WireFormatLite::TYPE_DOUBLE: + WireFormatLite::WriteDoubleNoTag(value.uint64_value(), &out); + break; + case WireFormatLite::TYPE_FLOAT: + WireFormatLite::WriteFloatNoTag(value.float_value(), &out); + break; + case WireFormatLite::TYPE_BOOL: + WireFormatLite::WriteBoolNoTag(value.bool_value(), &out); + break; + case WireFormatLite::TYPE_ENUM: + WireFormatLite::WriteEnumNoTag(value.enum_value(), &out); + break; + case WireFormatLite::TYPE_STRING: + out.WriteString(value.string_value()); + break; + case WireFormatLite::TYPE_MESSAGE: + out.WriteString(value.message_value().value()); + break; + default: + return absl::UnimplementedError( + absl::StrCat("Cannot write type: ", field_type)); + } + return mediapipe::OkStatus(); +} + +// Serializes a packet value. +absl::Status WriteField(const FieldData& packet, const FieldDescriptor* field, + std::string* result) { + FieldType field_type = AsFieldType(field->type()); + return WriteValue(packet, field_type, result); +} + +template +static ValueT ReadValue(absl::string_view field_bytes, absl::Status* status) { + ArrayInputStream ais(field_bytes.data(), field_bytes.size()); + CodedInputStream input(&ais); + ValueT result; + if (!WireFormatLite::ReadPrimitive(&input, &result)) { + status->Update(mediapipe::InvalidArgumentError(absl::StrCat( + "Bad serialized value: ", MediaPipeTypeStringOrDemangled(), + "."))); + } + return result; +} + +absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type, + absl::string_view message_type, FieldData* result) { + absl::Status status; + result->Clear(); + switch (field_type) { + case WireFormatLite::TYPE_INT32: + result->set_int32_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_SINT32: + result->set_int32_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_INT64: + result->set_int64_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_SINT64: + result->set_int64_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_UINT32: + result->set_uint32_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_UINT64: + result->set_uint64_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_DOUBLE: + result->set_double_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_FLOAT: + result->set_float_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_BOOL: + result->set_bool_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_ENUM: + result->set_enum_value( + ReadValue(field_bytes, &status)); + break; + case WireFormatLite::TYPE_STRING: + result->set_string_value(std::string(field_bytes)); + break; + case WireFormatLite::TYPE_MESSAGE: + result->mutable_message_value()->set_value(std::string(field_bytes)); + result->mutable_message_value()->set_type_url(TypeUrl(message_type)); + break; + default: + status = absl::UnimplementedError( + absl::StrCat("Cannot read type: ", field_type)); + break; + } + return status; +} + +// Deserializes a packet from a protobuf field. +absl::Status ReadField(absl::string_view bytes, const FieldDescriptor* field, + FieldData* result) { + FieldType field_type = AsFieldType(field->type()); + std::string message_type = (field_type == WireFormatLite::TYPE_MESSAGE) + ? field->message_type()->full_name() + : ""; + return ReadValue(bytes, field_type, message_type, result); +} + +// Converts a chain of fields and indexes into field-numbers and indexes. +ProtoUtilLite::ProtoPath AsProtoPath(const FieldPath& field_path) { + ProtoUtilLite::ProtoPath result; + for (auto field : field_path) { + result.push_back({field.first->number(), field.second}); + } + return result; +} + +// Returns the options protobuf for a subgraph. +// TODO: Ensure that this works with multiple options protobufs. +absl::Status GetOptionsMessage( + const proto_ns::RepeatedPtrField& options_any, + const proto_ns::MessageLite& options_ext, FieldData* result) { + // Read the "graph_options" or "node_options" field. + for (const auto& options : options_any) { + if (options.type_url().empty()) { + continue; + } + result->mutable_message_value()->set_type_url(options.type_url()); + result->mutable_message_value()->set_value(std::string(options.value())); + return mediapipe::OkStatus(); + } + + // Read the "options" field. + FieldData message_data; + *message_data.mutable_message_value()->mutable_value() = + options_ext.SerializeAsString(); + message_data.mutable_message_value()->set_type_url(options_ext.GetTypeName()); + std::vector ext_fields; + OptionsRegistry::FindAllExtensions(options_ext.GetTypeName(), &ext_fields); + for (auto ext_field : ext_fields) { + absl::Status status = GetField({{ext_field, 0}}, message_data, result); + if (!status.ok()) { + return status; + } + if (result->has_message_value()) { + return status; + } + } + return mediapipe::OkStatus(); +} + +// Sets a protobuf in a repeated protobuf::Any field. +void SetOptionsMessage( + const FieldData& node_options, + proto_ns::RepeatedPtrField* result) { + protobuf::Any* options_any = nullptr; + for (auto& any : *result) { + if (any.type_url() == node_options.message_value().type_url()) { + options_any = &any; + } + } + if (!options_any) { + options_any = result->Add(); + options_any->set_type_url(node_options.message_value().type_url()); + } + *options_any->mutable_value() = node_options.message_value().value(); +} + +} // anonymous namespace + +// Deserializes a packet containing a MessageLite value. +absl::Status ReadMessage(const std::string& value, const std::string& type_name, + Packet* result) { + auto packet = packet_internal::PacketFromDynamicProto(type_name, value); + if (packet.ok()) { + *result = *packet; + } + return packet.status(); +} + +// Merge two options FieldData values. +absl::Status MergeOptionsMessages(const FieldData& base, const FieldData& over, + FieldData* result) { + absl::Status status; + if (over.value_case() == FieldData::VALUE_NOT_SET) { + *result = base; + return status; + } + if (base.value_case() == FieldData::VALUE_NOT_SET) { + *result = over; + return status; + } + if (over.value_case() != base.value_case()) { + return absl::InvalidArgumentError(absl::StrCat( + "Cannot merge field data with data types: ", base.value_case(), ", ", + over.value_case())); + } + if (over.message_value().type_url() != base.message_value().type_url()) { + return absl::InvalidArgumentError( + absl::StrCat("Cannot merge field data with message types: ", + base.message_value().type_url(), ", ", + over.message_value().type_url())); + } + absl::Cord merged_value; + merged_value.Append(base.message_value().value()); + merged_value.Append(over.message_value().value()); + result->mutable_message_value()->set_type_url( + base.message_value().type_url()); + result->mutable_message_value()->set_value(std::string(merged_value)); + return status; +} + +// Writes a FieldData value into protobuf field. +absl::Status SetField(const FieldPath& field_path, const FieldData& value, + FieldData* message_data) { + if (field_path.empty()) { + *message_data->mutable_message_value() = value.message_value(); + return mediapipe::OkStatus(); + } + ProtoUtilLite proto_util; + const FieldDescriptor* field = field_path.back().first; + FieldType field_type = AsFieldType(field->type()); + std::string field_value; + MP_RETURN_IF_ERROR(WriteField(value, field, &field_value)); + ProtoUtilLite::ProtoPath proto_path = AsProtoPath(field_path); + std::string* message_bytes = + message_data->mutable_message_value()->mutable_value(); + int field_count; + MP_RETURN_IF_ERROR(proto_util.GetFieldCount(*message_bytes, proto_path, + field_type, &field_count)); + MP_RETURN_IF_ERROR( + proto_util.ReplaceFieldRange(message_bytes, AsProtoPath(field_path), + field_count, field_type, {field_value})); + return mediapipe::OkStatus(); +} + +// Merges a packet value into nested protobuf Message. +absl::Status MergeField(const FieldPath& field_path, const FieldData& value, + FieldData* message_data) { + absl::Status status; + FieldType field_type = field_path.empty() + ? FieldType::TYPE_MESSAGE + : AsFieldType(field_path.back().first->type()); + std::string message_type = + (value.has_message_value()) + ? ParseTypeUrl(std::string(value.message_value().type_url())) + : ""; + FieldData v = value; + if (field_type == FieldType::TYPE_MESSAGE) { + FieldData b; + status.Update(GetField(field_path, *message_data, &b)); + status.Update(MergeOptionsMessages(b, v, &v)); + } + status.Update(SetField(field_path, v, message_data)); + return status; +} + +// Reads a packet value from a protobuf field. +absl::Status GetField(const FieldPath& field_path, + const FieldData& message_data, FieldData* result) { + if (field_path.empty()) { + *result->mutable_message_value() = message_data.message_value(); + return mediapipe::OkStatus(); + } + ProtoUtilLite proto_util; + const FieldDescriptor* field = field_path.back().first; + FieldType field_type = AsFieldType(field->type()); + std::vector field_values; + ProtoUtilLite::ProtoPath proto_path = AsProtoPath(field_path); + const std::string& message_bytes = message_data.message_value().value(); + int field_count; + MP_RETURN_IF_ERROR(proto_util.GetFieldCount(message_bytes, proto_path, + field_type, &field_count)); + if (field_count == 0) { + return mediapipe::OkStatus(); + } + MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, 1, + field_type, &field_values)); + MP_RETURN_IF_ERROR(ReadField(field_values.front(), field, result)); + return mediapipe::OkStatus(); +} + +// Returns the options protobuf for a graph. +absl::Status GetOptionsMessage(const CalculatorGraphConfig& config, + FieldData* result) { + return GetOptionsMessage(config.graph_options(), config.options(), result); +} + +// Returns the options protobuf for a node. +absl::Status GetOptionsMessage(const CalculatorGraphConfig::Node& node, + FieldData* result) { + return GetOptionsMessage(node.node_options(), node.options(), result); +} + +// Sets the node_options field in a Node, and clears the options field. +void SetOptionsMessage(const FieldData& node_options, + CalculatorGraphConfig::Node* node) { + SetOptionsMessage(node_options, node->mutable_node_options()); + node->clear_options(); +} + +// Represents a protobuf enum value stored in a Packet. +struct ProtoEnum { + ProtoEnum(int32 v) : value(v) {} + int32 value; +}; + +absl::Status AsPacket(const FieldData& data, Packet* result) { + switch (data.value_case()) { + case FieldData::ValueCase::kInt32Value: + *result = MakePacket(data.int32_value()); + break; + case FieldData::ValueCase::kInt64Value: + *result = MakePacket(data.int64_value()); + break; + case FieldData::ValueCase::kUint32Value: + *result = MakePacket(data.uint32_value()); + break; + case FieldData::ValueCase::kUint64Value: + *result = MakePacket(data.uint64_value()); + break; + case FieldData::ValueCase::kDoubleValue: + *result = MakePacket(data.double_value()); + break; + case FieldData::ValueCase::kFloatValue: + *result = MakePacket(data.float_value()); + break; + case FieldData::ValueCase::kBoolValue: + *result = MakePacket(data.bool_value()); + break; + case FieldData::ValueCase::kEnumValue: + *result = MakePacket(data.enum_value()); + break; + case FieldData::ValueCase::kStringValue: + *result = MakePacket(data.string_value()); + break; + case FieldData::ValueCase::kMessageValue: { + auto r = packet_internal::PacketFromDynamicProto( + ParseTypeUrl(std::string(data.message_value().type_url())), + std::string(data.message_value().value())); + if (!r.ok()) { + return r.status(); + } + *result = r.value(); + break; + } + case FieldData::VALUE_NOT_SET: + *result = Packet(); + } + return mediapipe::OkStatus(); +} + +absl::Status AsFieldData(Packet packet, FieldData* result) { + static const auto* kTypeIds = new std::map{ + {tool::GetTypeHash(), WireFormatLite::CPPTYPE_INT32}, + {tool::GetTypeHash(), WireFormatLite::CPPTYPE_INT64}, + {tool::GetTypeHash(), WireFormatLite::CPPTYPE_UINT32}, + {tool::GetTypeHash(), WireFormatLite::CPPTYPE_UINT64}, + {tool::GetTypeHash(), WireFormatLite::CPPTYPE_DOUBLE}, + {tool::GetTypeHash(), WireFormatLite::CPPTYPE_FLOAT}, + {tool::GetTypeHash(), WireFormatLite::CPPTYPE_BOOL}, + {tool::GetTypeHash(), WireFormatLite::CPPTYPE_ENUM}, + {tool::GetTypeHash(), WireFormatLite::CPPTYPE_STRING}, + }; + + if (packet.ValidateAsProtoMessageLite().ok()) { + result->mutable_message_value()->set_value( + packet.GetProtoMessageLite().SerializeAsString()); + result->mutable_message_value()->set_type_url( + TypeUrl(packet.GetProtoMessageLite().GetTypeName())); + return mediapipe::OkStatus(); + } + + if (kTypeIds->count(packet.GetTypeId()) == 0) { + return absl::UnimplementedError(absl::StrCat( + "Cannot construct FieldData for: ", packet.DebugTypeName())); + } + + switch (kTypeIds->at(packet.GetTypeId())) { + case WireFormatLite::CPPTYPE_INT32: + result->set_int32_value(packet.Get()); + break; + case WireFormatLite::CPPTYPE_INT64: + result->set_int64_value(packet.Get()); + break; + case WireFormatLite::CPPTYPE_UINT32: + result->set_uint32_value(packet.Get()); + break; + case WireFormatLite::CPPTYPE_UINT64: + result->set_uint64_value(packet.Get()); + break; + case WireFormatLite::CPPTYPE_DOUBLE: + result->set_double_value(packet.Get()); + break; + case WireFormatLite::CPPTYPE_FLOAT: + result->set_float_value(packet.Get()); + break; + case WireFormatLite::CPPTYPE_BOOL: + result->set_bool_value(packet.Get()); + break; + case WireFormatLite::CPPTYPE_ENUM: + result->set_enum_value(packet.Get().value); + break; + case WireFormatLite::CPPTYPE_STRING: + result->set_string_value(packet.Get()); + break; + } + return mediapipe::OkStatus(); +} + +std::string TypeUrl(absl::string_view type_name) { + constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/"; + return absl::StrCat(std::string(kTypeUrlPrefix), std::string(type_name)); +} + +std::string ParseTypeUrl(absl::string_view type_url) { + constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/"; + if (std::string(type_url).rfind(kTypeUrlPrefix, 0) == 0) { + return std::string( + type_url.substr(kTypeUrlPrefix.length(), std::string::npos)); + } + return std::string(type_url); +} + +} // namespace options_field_util +} // namespace tool +} // namespace mediapipe diff --git a/mediapipe/framework/tool/options_field_util.h b/mediapipe/framework/tool/options_field_util.h new file mode 100644 index 000000000..2dda09ca3 --- /dev/null +++ b/mediapipe/framework/tool/options_field_util.h @@ -0,0 +1,73 @@ +#ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_FIELD_UTIL_H_ +#define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_FIELD_UTIL_H_ + +#include +#include + +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port/advanced_proto_inc.h" +#include "mediapipe/framework/port/proto_ns.h" +#include "mediapipe/framework/tool/field_data.pb.h" +#include "mediapipe/framework/tool/options_registry.h" + +namespace mediapipe { + +namespace tool { + +// Utility to read and write Packet data from protobuf fields. +namespace options_field_util { + +// A chain of nested fields and indexes. +using FieldPath = std::vector>; + +// Writes a field value into protobuf field. +absl::Status SetField(const FieldPath& field_path, const FieldData& value, + FieldData* message_data); + +// Reads a field value from a protobuf field. +absl::Status GetField(const FieldPath& field_path, + const FieldData& message_data, FieldData* result); + +// Merges a field value into nested protobuf Message. +absl::Status MergeField(const FieldPath& field_path, const FieldData& value, + FieldData* message_data); + +// Deserializes a packet containing a MessageLite value. +absl::Status ReadMessage(const std::string& value, const std::string& type_name, + Packet* result); + +// Merge two options protobuf field values. +absl::Status MergeOptionsMessages(const FieldData& base, const FieldData& over, + FieldData* result); + +// Returns the options protobuf for a graph. +absl::Status GetOptionsMessage(const CalculatorGraphConfig& config, + FieldData* result); + +// Returns the options protobuf for a node. +absl::Status GetOptionsMessage(const CalculatorGraphConfig::Node& node, + FieldData* result); + +// Sets the node_options field in a Node, and clears the options field. +void SetOptionsMessage(const FieldData& node_options, + CalculatorGraphConfig::Node* node); + +// Constructs a Packet for a FieldData proto. +absl::Status AsPacket(const FieldData& data, Packet* result); + +// Constructs a FieldData proto for a Packet. +absl::Status AsFieldData(Packet packet, FieldData* result); + +// Returns the protobuf type-url for a protobuf type-name. +std::string TypeUrl(absl::string_view type_name); + +// Returns the protobuf type-name for a protobuf type-url. +std::string ParseTypeUrl(absl::string_view type_url); + +} // namespace options_field_util +} // namespace tool +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_FIELD_UTIL_H_ diff --git a/mediapipe/framework/tool/options_registry.cc b/mediapipe/framework/tool/options_registry.cc index be26c27e1..b65cc9fed 100644 --- a/mediapipe/framework/tool/options_registry.cc +++ b/mediapipe/framework/tool/options_registry.cc @@ -1,47 +1,112 @@ #include "mediapipe/framework/tool/options_registry.h" +#include "absl/synchronization/mutex.h" + namespace mediapipe { namespace tool { -proto_ns::DescriptorPool* OptionsRegistry::options_descriptor_pool() { - static proto_ns::DescriptorPool* result = new proto_ns::DescriptorPool(); - return result; +namespace { + +// Returns a canonical message type name, with any leading "." removed. +std::string CanonicalTypeName(const std::string& type_name) { + return (type_name.rfind('.', 0) == 0) ? type_name.substr(1) : type_name; } +} // namespace + RegistrationToken OptionsRegistry::Register( const proto_ns::FileDescriptorSet& files) { + absl::MutexLock lock(&mutex()); for (auto& file : files.file()) { - options_descriptor_pool()->BuildFile(file); + for (auto& message_type : file.message_type()) { + Register(message_type, file.package()); + } } return RegistrationToken([]() {}); } -const proto_ns::Descriptor* OptionsRegistry::GetProtobufDescriptor( - const std::string& type_name) { - const proto_ns::Descriptor* result = - proto_ns::DescriptorPool::generated_pool()->FindMessageTypeByName( - type_name); - if (!result) { - result = options_descriptor_pool()->FindMessageTypeByName(type_name); +void OptionsRegistry::Register(const proto_ns::DescriptorProto& message_type, + const std::string& parent_name) { + auto full_name = absl::StrCat(parent_name, ".", message_type.name()); + descriptors()[full_name] = Descriptor(message_type, full_name); + for (auto& nested : message_type.nested_type()) { + Register(nested, full_name); } - return result; + for (auto& extension : message_type.extension()) { + extensions()[CanonicalTypeName(extension.extendee())].push_back( + FieldDescriptor(extension)); + } +} + +const Descriptor* OptionsRegistry::GetProtobufDescriptor( + const std::string& type_name) { + absl::ReaderMutexLock lock(&mutex()); + auto it = descriptors().find(CanonicalTypeName(type_name)); + return (it == descriptors().end()) ? nullptr : &it->second; } void OptionsRegistry::FindAllExtensions( - const proto_ns::Descriptor& extendee, - std::vector* result) { - using proto_ns::DescriptorPool; - std::vector extensions; - DescriptorPool::generated_pool()->FindAllExtensions(&extendee, &extensions); - options_descriptor_pool()->FindAllExtensions(&extendee, &extensions); - absl::flat_hash_set numbers; - for (const proto_ns::FieldDescriptor* extension : extensions) { - bool inserted = numbers.insert(extension->number()).second; - if (inserted) { - result->push_back(extension); + absl::string_view extendee, std::vector* result) { + absl::ReaderMutexLock lock(&mutex()); + result->clear(); + if (extensions().count(extendee) > 0) { + for (const FieldDescriptor& field : extensions().at(extendee)) { + result->push_back(&field); } } } +absl::flat_hash_map& OptionsRegistry::descriptors() { + static auto* descriptors = new absl::flat_hash_map(); + return *descriptors; +} + +absl::flat_hash_map>& +OptionsRegistry::extensions() { + static auto* extensions = + new absl::flat_hash_map>(); + return *extensions; +} + +absl::Mutex& OptionsRegistry::mutex() { + static auto* mutex = new absl::Mutex(); + return *mutex; +} + +Descriptor::Descriptor(const proto_ns::DescriptorProto& proto, + const std::string& full_name) + : full_name_(full_name) { + for (auto& field : proto.field()) { + fields_[field.name()] = FieldDescriptor(field); + } +} + +const std::string& Descriptor::full_name() const { return full_name_; } + +const FieldDescriptor* Descriptor::FindFieldByName( + const std::string& name) const { + auto it = fields_.find(name); + return (it != fields_.end()) ? &it->second : nullptr; +} + +FieldDescriptor::FieldDescriptor(const proto_ns::FieldDescriptorProto& proto) { + name_ = proto.name(); + message_type_ = CanonicalTypeName(proto.type_name()); + type_ = proto.type(); + number_ = proto.number(); +} + +const std::string& FieldDescriptor::name() const { return name_; } + +int FieldDescriptor::number() const { return number_; } + +proto_ns::FieldDescriptorProto::Type FieldDescriptor::type() const { + return type_; +} + +const Descriptor* FieldDescriptor::message_type() const { + return OptionsRegistry::GetProtobufDescriptor(message_type_); +} + } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/options_registry.h b/mediapipe/framework/tool/options_registry.h index d798fe29d..34f04ede6 100644 --- a/mediapipe/framework/tool/options_registry.h +++ b/mediapipe/framework/tool/options_registry.h @@ -1,12 +1,16 @@ #ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_ #define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_ +#include "absl/container/flat_hash_map.h" #include "mediapipe/framework/deps/registration.h" #include "mediapipe/framework/port/advanced_proto_inc.h" namespace mediapipe { namespace tool { +class Descriptor; +class FieldDescriptor; + // A static registry that stores descriptors for protobufs used in MediaPipe // calculator options. Lite-proto builds do not normally include descriptors. // These registered descriptors allow individual protobuf fields to be @@ -17,23 +21,60 @@ class OptionsRegistry { static RegistrationToken Register(const proto_ns::FileDescriptorSet& files); // Finds the descriptor for a protobuf. - static const proto_ns::Descriptor* GetProtobufDescriptor( - const std::string& type_name); + static const Descriptor* GetProtobufDescriptor(const std::string& type_name); // Returns all known proto2 extensions to a type. - static void FindAllExtensions( - const proto_ns::Descriptor& extendee, - std::vector* result); + static void FindAllExtensions(absl::string_view extendee, + std::vector* result); private: - // Stores the descriptors for each options protobuf type. - static proto_ns::DescriptorPool* options_descriptor_pool(); + // Registers protobuf descriptors a MessageLite and nested types. + static void Register(const proto_ns::DescriptorProto& message_type, + const std::string& parent_name); + + static absl::flat_hash_map& descriptors(); + static absl::flat_hash_map>& + extensions(); + static absl::Mutex& mutex(); // Registers the descriptors for each options protobuf type. template static const RegistrationToken registration_token; }; +// A custom implementation proto_ns::Descriptor. This implementation +// avoids a code size problem introduced by proto_ns::FieldDescriptor. +class Descriptor { + public: + Descriptor() {} + Descriptor(const proto_ns::DescriptorProto& proto, + const std::string& full_name); + const std::string& full_name() const; + const FieldDescriptor* FindFieldByName(const std::string& name) const; + + private: + std::string full_name_; + absl::flat_hash_map fields_; +}; + +// A custom implementation proto_ns::FieldDescriptor. This implementation +// avoids a code size problem introduced by proto_ns::FieldDescriptor. +class FieldDescriptor { + public: + FieldDescriptor() {} + FieldDescriptor(const proto_ns::FieldDescriptorProto& proto); + const std::string& name() const; + int number() const; + proto_ns::FieldDescriptorProto::Type type() const; + const Descriptor* message_type() const; + + private: + std::string name_; + std::string message_type_; + proto_ns::FieldDescriptorProto::Type type_; + int number_; +}; + } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/options_syntax_util.cc b/mediapipe/framework/tool/options_syntax_util.cc new file mode 100644 index 000000000..0112189fb --- /dev/null +++ b/mediapipe/framework/tool/options_syntax_util.cc @@ -0,0 +1,143 @@ +#include "mediapipe/framework/tool/options_syntax_util.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port/advanced_proto_inc.h" +#include "mediapipe/framework/port/any_proto.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/name_util.h" + +namespace mediapipe { +namespace tool { + +// Helper functions for parsing the graph options syntax. +class OptionsSyntaxUtil::OptionsSyntaxHelper { + public: + // The usual graph options syntax tokens. + OptionsSyntaxHelper() : syntax_{"OPTIONS", "options", "/"} {} + + // Returns the tag name for an option protobuf field. + std::string OptionFieldTag(const std::string& name) { return name; } + + // Returns the packet name for an option protobuf field. + absl::string_view OptionFieldPacket(absl::string_view name) { return name; } + + // Returns the option protobuf field name for a tag or packet name. + absl::string_view OptionFieldName(absl::string_view name) { return name; } + + // Returns the field-path for an option stream-tag. + FieldPath OptionFieldPath(const std::string& tag, + const Descriptor* descriptor) { + int prefix = syntax_.tag_name.length() + syntax_.separator.length(); + std::string suffix = tag.substr(prefix); + std::vector name_tags = + absl::StrSplit(suffix, syntax_.separator); + FieldPath result; + for (absl::string_view name_tag : name_tags) { + if (name_tag.empty()) { + continue; + } + absl::string_view option_name = OptionFieldName(name_tag); + int index; + if (absl::SimpleAtoi(option_name, &index)) { + result.back().second = index; + } else { + auto field = descriptor->FindFieldByName(std::string(option_name)); + descriptor = field ? field->message_type() : nullptr; + result.push_back({std::move(field), 0}); + } + } + return result; + } + + // Returns the option field name for a graph options packet name. + std::string GraphOptionFieldName(const std::string& graph_option_name) { + int prefix = syntax_.packet_name.length() + syntax_.separator.length(); + std::string result = graph_option_name; + result.erase(0, prefix); + return result; + } + + // Returns the graph options packet name for an option field name. + std::string GraphOptionName(const std::string& option_name) { + std::string packet_prefix = + syntax_.packet_name + absl::AsciiStrToLower(syntax_.separator); + return absl::StrCat(packet_prefix, option_name); + } + + // Returns the tag name for a graph option. + std::string OptionTagName(const std::string& option_name) { + return absl::StrCat(syntax_.tag_name, syntax_.separator, + OptionFieldTag(option_name)); + } + + // Converts slash-separated field names into a tag name. + std::string OptionFieldsTag(const std::string& option_names) { + std::string tag_prefix = syntax_.tag_name + syntax_.separator; + std::vector names = absl::StrSplit(option_names, '/'); + if (!names.empty() && names[0] == syntax_.tag_name) { + names.erase(names.begin()); + } + if (!names.empty() && names[0] == syntax_.packet_name) { + names.erase(names.begin()); + } + std::string result; + std::string sep = ""; + for (absl::string_view v : names) { + absl::StrAppend(&result, sep, OptionFieldTag(std::string(v))); + sep = syntax_.separator; + } + result = tag_prefix + result; + return result; + } + + // Token definitions for the graph options syntax. + struct OptionsSyntax { + // The tag name for an options protobuf. + std::string tag_name; + // The packet name for an options protobuf. + std::string packet_name; + // The separator between nested options fields. + std::string separator; + }; + + OptionsSyntax syntax_; +}; // class OptionsSyntaxHelper + +OptionsSyntaxUtil::OptionsSyntaxUtil() + : syntax_helper_(std::make_unique()) {} + +OptionsSyntaxUtil::OptionsSyntaxUtil(const std::string& tag_name) + : OptionsSyntaxUtil() { + syntax_helper_->syntax_.tag_name = tag_name; +} + +OptionsSyntaxUtil::OptionsSyntaxUtil(const std::string& tag_name, + const std::string& packet_name, + const std::string& separator) + : OptionsSyntaxUtil() { + syntax_helper_->syntax_.tag_name = tag_name; + syntax_helper_->syntax_.packet_name = packet_name; + syntax_helper_->syntax_.separator = separator; +} + +OptionsSyntaxUtil::~OptionsSyntaxUtil() {} + +std::string OptionsSyntaxUtil::OptionFieldsTag( + const std::string& option_names) { + return syntax_helper_->OptionFieldsTag(option_names); +} + +OptionsSyntaxUtil::FieldPath OptionsSyntaxUtil::OptionFieldPath( + const std::string& tag, const Descriptor* descriptor) { + return syntax_helper_->OptionFieldPath(tag, descriptor); +} + +} // namespace tool +} // namespace mediapipe diff --git a/mediapipe/framework/tool/options_syntax_util.h b/mediapipe/framework/tool/options_syntax_util.h new file mode 100644 index 000000000..a75341b97 --- /dev/null +++ b/mediapipe/framework/tool/options_syntax_util.h @@ -0,0 +1,45 @@ +#ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_SYNTAX_UTIL_H_ +#define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_SYNTAX_UTIL_H_ + +#include +#include +#include + +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port/advanced_proto_inc.h" +#include "mediapipe/framework/port/proto_ns.h" +#include "mediapipe/framework/tool/options_field_util.h" +#include "mediapipe/framework/tool/options_registry.h" + +namespace mediapipe { +namespace tool { + +// Utility to parse the graph options syntax used in "option_value", +// "side_packet", and "stream". +class OptionsSyntaxUtil { + public: + using FieldPath = options_field_util::FieldPath; + OptionsSyntaxUtil(); + OptionsSyntaxUtil(const std::string& tag_name); + OptionsSyntaxUtil(const std::string& tag_name, const std::string& packet_name, + const std::string& separator); + ~OptionsSyntaxUtil(); + + // Converts slash-separated field names into a tag name. + std::string OptionFieldsTag(const std::string& option_names); + + // Returns the field-path for an option stream-tag. + FieldPath OptionFieldPath(const std::string& tag, + const Descriptor* descriptor); + + private: + class OptionsSyntaxHelper; + std::unique_ptr syntax_helper_; +}; + +} // namespace tool +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_SYNTAX_UTIL_H_ diff --git a/mediapipe/framework/tool/options_util.cc b/mediapipe/framework/tool/options_util.cc index 20734b953..5d7c64b75 100644 --- a/mediapipe/framework/tool/options_util.cc +++ b/mediapipe/framework/tool/options_util.cc @@ -1,16 +1,82 @@ #include "mediapipe/framework/tool/options_util.h" -#include "mediapipe/framework/port/proto_ns.h" +#include +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/input_stream_shard.h" +#include "mediapipe/framework/output_side_packet.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/packet_set.h" +#include "mediapipe/framework/packet_type.h" +#include "mediapipe/framework/port/advanced_proto_inc.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/tool/name_util.h" +#include "mediapipe/framework/tool/options_field_util.h" +#include "mediapipe/framework/tool/options_registry.h" +#include "mediapipe/framework/tool/options_syntax_util.h" +#include "mediapipe/framework/tool/proto_util_lite.h" namespace mediapipe { namespace tool { -// TODO: Return registered protobuf Descriptors when available. -const proto_ns::Descriptor* GetProtobufDescriptor( - const std::string& type_name) { - return proto_ns::DescriptorPool::generated_pool()->FindMessageTypeByName( - type_name); +// Copy literal options from graph_options to node_options. +absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node, + CalculatorGraphConfig* config) { + Status status; + FieldData config_options, parent_node_options, graph_options; + status.Update( + options_field_util::GetOptionsMessage(*config, &config_options)); + status.Update( + options_field_util::GetOptionsMessage(parent_node, &parent_node_options)); + status.Update(options_field_util::MergeOptionsMessages( + config_options, parent_node_options, &graph_options)); + const Descriptor* options_descriptor = + OptionsRegistry::GetProtobufDescriptor(options_field_util::ParseTypeUrl( + std::string(graph_options.message_value().type_url()))); + if (!options_descriptor) { + return status; + } + + OptionsSyntaxUtil syntax_util; + for (auto& node : *config->mutable_node()) { + FieldData node_data; + status.Update(options_field_util::GetOptionsMessage(node, &node_data)); + if (!node_data.has_message_value() || node.option_value_size() == 0) { + continue; + } + const Descriptor* node_options_descriptor = + OptionsRegistry::GetProtobufDescriptor(options_field_util::ParseTypeUrl( + std::string(node_data.message_value().type_url()))); + if (!node_options_descriptor) { + continue; + } + for (const std::string& option_def : node.option_value()) { + std::vector tag_and_name = absl::StrSplit(option_def, ':'); + std::string graph_tag = syntax_util.OptionFieldsTag(tag_and_name[1]); + std::string node_tag = syntax_util.OptionFieldsTag(tag_and_name[0]); + FieldData packet_data; + status.Update(options_field_util::GetField( + syntax_util.OptionFieldPath(graph_tag, options_descriptor), + graph_options, &packet_data)); + status.Update(options_field_util::MergeField( + syntax_util.OptionFieldPath(node_tag, node_options_descriptor), + packet_data, &node_data)); + } + options_field_util::SetOptionsMessage(node_data, &node); + } + return status; +} + +// Makes all configuration modifications needed for graph options. +absl::Status DefineGraphOptions(const CalculatorGraphConfig::Node& parent_node, + CalculatorGraphConfig* config) { + MP_RETURN_IF_ERROR(CopyLiteralOptions(parent_node, config)); + return mediapipe::OkStatus(); } } // namespace tool diff --git a/mediapipe/framework/tool/options_util.h b/mediapipe/framework/tool/options_util.h index 520e92a22..27fc004d9 100644 --- a/mediapipe/framework/tool/options_util.h +++ b/mediapipe/framework/tool/options_util.h @@ -21,7 +21,6 @@ #include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/port/any_proto.h" #include "mediapipe/framework/tool/options_map.h" -#include "mediapipe/framework/tool/type_util.h" namespace mediapipe { @@ -75,8 +74,9 @@ inline T RetrieveOptions(const T& base, const InputStreamShardSet& stream_set, return base; } -// Finds the descriptor for a protobuf. -const proto_ns::Descriptor* GetProtobufDescriptor(const std::string& type_name); +// Copy literal options from enclosing graphs. +absl::Status DefineGraphOptions(const CalculatorGraphConfig::Node& parent_node, + CalculatorGraphConfig* config); } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/options_util_test.cc b/mediapipe/framework/tool/options_util_test.cc index 62be334da..55263d00e 100644 --- a/mediapipe/framework/tool/options_util_test.cc +++ b/mediapipe/framework/tool/options_util_test.cc @@ -16,15 +16,41 @@ #include #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/message_matchers.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/testdata/night_light_calculator.pb.h" +#include "mediapipe/framework/tool/node_chain_subgraph.pb.h" #include "mediapipe/framework/tool/options_registry.h" +#include "mediapipe/framework/tool/options_syntax_util.h" namespace mediapipe { namespace { +using ::mediapipe::proto_ns::FieldDescriptorProto; +using FieldType = ::mediapipe::proto_ns::FieldDescriptorProto::Type; + +// A test Calculator using DeclareOptions and DefineOptions. +class NightLightCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + return mediapipe::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) final { + return mediapipe::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) final { + return mediapipe::OkStatus(); + } + + private: + NightLightCalculatorOptions options_; +}; +REGISTER_CALCULATOR(NightLightCalculator); + // Tests for calculator and graph options. // class OptionsUtilTest : public ::testing::Test { @@ -35,21 +61,108 @@ class OptionsUtilTest : public ::testing::Test { // Retrieves the description of a protobuf. TEST_F(OptionsUtilTest, GetProtobufDescriptor) { - const proto_ns::Descriptor* descriptor = - tool::GetProtobufDescriptor("mediapipe.CalculatorGraphConfig"); -#ifndef MEDIAPIPE_MOBILE + const tool::Descriptor* descriptor = + tool::OptionsRegistry::GetProtobufDescriptor( + "mediapipe.CalculatorGraphConfig"); EXPECT_NE(nullptr, descriptor); -#else - EXPECT_EQ(nullptr, descriptor); -#endif } -// Retrieves the description of a protobuf from the OptionsRegistry. +// Shows a calculator node deriving options from graph options. +// The subgraph specifies "graph_options" as "NodeChainSubgraphOptions". +// The calculator specifies "node_options as "NightLightCalculatorOptions". +TEST_F(OptionsUtilTest, CopyLiteralOptions) { + CalculatorGraphConfig subgraph_config; + + auto node = subgraph_config.add_node(); + *node->mutable_calculator() = "NightLightCalculator"; + *node->add_option_value() = "num_lights:options/chain_length"; + + // The options framework requires at least an empty options protobuf + // as an indication the options protobuf type expected by the node. + NightLightCalculatorOptions node_options; + node->add_node_options()->PackFrom(node_options); + + NodeChainSubgraphOptions options; + options.set_chain_length(8); + subgraph_config.add_graph_options()->PackFrom(options); + subgraph_config.set_type("NightSubgraph"); + + CalculatorGraphConfig graph_config; + node = graph_config.add_node(); + *node->mutable_calculator() = "NightSubgraph"; + + CalculatorGraph graph; + graph_config.set_num_threads(4); + MP_EXPECT_OK(graph.Initialize({subgraph_config, graph_config}, {}, {})); + + CalculatorGraphConfig expanded_config = graph.Config(); + expanded_config.clear_executor(); + CalculatorGraphConfig::Node actual_node; + actual_node = expanded_config.node(0); + + CalculatorGraphConfig::Node expected_node; + expected_node.set_name("nightsubgraph__NightLightCalculator"); + expected_node.set_calculator("NightLightCalculator"); + NightLightCalculatorOptions expected_node_options; + expected_node_options.add_num_lights(8); + expected_node.add_node_options()->PackFrom(expected_node_options); + *expected_node.add_option_value() = "num_lights:options/chain_length"; + EXPECT_THAT(actual_node, EqualsProto(expected_node)); + + MP_EXPECT_OK(graph.StartRun({})); + MP_EXPECT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); + + // Ensure static protobuf packet registration. + MakePacket(); + MakePacket(); +} + +// Retrieves the description of a protobuf message and a nested protobuf message +// from the OptionsRegistry. TEST_F(OptionsUtilTest, GetProtobufDescriptorRegistered) { - const proto_ns::Descriptor* descriptor = + const tool::Descriptor* options_descriptor = tool::OptionsRegistry::GetProtobufDescriptor( "mediapipe.NightLightCalculatorOptions"); - EXPECT_NE(nullptr, descriptor); + EXPECT_NE(nullptr, options_descriptor); + const tool::Descriptor* bundle_descriptor = + tool::OptionsRegistry::GetProtobufDescriptor( + "mediapipe.NightLightCalculatorOptions.LightBundle"); + EXPECT_NE(nullptr, bundle_descriptor); + EXPECT_EQ(options_descriptor->full_name(), + "mediapipe.NightLightCalculatorOptions"); + const tool::FieldDescriptor* bundle_field = + options_descriptor->FindFieldByName("bundle"); + EXPECT_EQ(bundle_field->message_type(), bundle_descriptor); +} + +// Constructs the FieldPath for a nested node-option. +TEST_F(OptionsUtilTest, OptionsSyntaxUtil) { + const tool::Descriptor* descriptor = + tool::OptionsRegistry::GetProtobufDescriptor( + "mediapipe.NightLightCalculatorOptions"); + std::string tag; + tool::OptionsSyntaxUtil::FieldPath field_path; + { + // The default tag syntax. + tool::OptionsSyntaxUtil syntax_util; + tag = syntax_util.OptionFieldsTag("options/sub_options/num_lights"); + EXPECT_EQ(tag, "OPTIONS/sub_options/num_lights"); + field_path = syntax_util.OptionFieldPath(tag, descriptor); + EXPECT_EQ(field_path.size(), 2); + EXPECT_EQ(field_path[0].first->name(), "sub_options"); + EXPECT_EQ(field_path[1].first->name(), "num_lights"); + } + { + // A tag syntax with a text-coded separator. + tool::OptionsSyntaxUtil syntax_util("OPTIONS", "options", "_Z0Z_"); + tag = syntax_util.OptionFieldsTag("options/sub_options/num_lights"); + EXPECT_EQ(tag, "OPTIONS_Z0Z_sub_options_Z0Z_num_lights"); + field_path = syntax_util.OptionFieldPath(tag, descriptor); + EXPECT_EQ(field_path.size(), 2); + EXPECT_EQ(field_path[0].first->name(), "sub_options"); + EXPECT_EQ(field_path[1].first->name(), "num_lights"); + } } } // namespace diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index e4844d5cd..4ad623a2e 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -196,6 +196,27 @@ absl::Status ProtoUtilLite::GetFieldRange( return absl::OkStatus(); } +// Returns the number of field values in a repeated protobuf field. +absl::Status ProtoUtilLite::GetFieldCount(const FieldValue& message, + ProtoPath proto_path, + FieldType field_type, + int* field_count) { + int field_id, index; + std::tie(field_id, index) = proto_path.back(); + proto_path.pop_back(); + std::vector parent; + if (proto_path.empty()) { + parent.push_back(std::string(message)); + } else { + MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange( + message, proto_path, 1, WireFormatLite::TYPE_MESSAGE, &parent)); + } + FieldAccess access(field_id, field_type); + MP_RETURN_IF_ERROR(access.SetMessage(parent[0])); + *field_count = access.mutable_field_values()->size(); + return absl::OkStatus(); +} + // If ok, returns OkStatus, otherwise returns InvalidArgumentError. template absl::Status SyntaxStatus(bool ok, const std::string& text, T* result) { diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index fcbcd7469..71221291f 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -75,6 +75,11 @@ class ProtoUtilLite { FieldType field_type, std::vector* field_values); + // Returns the number of field values in a repeated protobuf field. + static absl::Status GetFieldCount(const FieldValue& message, + ProtoPath proto_path, FieldType field_type, + int* field_count); + // Serialize one or more protobuf field values from text. static absl::Status Serialize(const std::vector& text_values, FieldType field_type, diff --git a/mediapipe/framework/tool/subgraph_expansion.cc b/mediapipe/framework/tool/subgraph_expansion.cc index 13405f7ec..18f255c70 100644 --- a/mediapipe/framework/tool/subgraph_expansion.cc +++ b/mediapipe/framework/tool/subgraph_expansion.cc @@ -278,6 +278,8 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, graph_registry = graph_registry ? graph_registry : &GraphRegistry::global_graph_registry; RET_CHECK(config); + MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions( + CalculatorGraphConfig::Node(), config)); auto* nodes = config->mutable_node(); while (1) { auto subgraph_nodes_start = std::stable_partition( @@ -297,6 +299,7 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, ASSIGN_OR_RETURN(auto subgraph, graph_registry->CreateByName( config->package(), node.calculator(), &subgraph_context)); + MP_RETURN_IF_ERROR(mediapipe::tool::DefineGraphOptions(node, &subgraph)); MP_RETURN_IF_ERROR(PrefixNames(node_name, &subgraph)); MP_RETURN_IF_ERROR(ConnectSubgraphStreams(node, &subgraph)); subgraphs.push_back(subgraph); diff --git a/mediapipe/gpu/gl_simple_shaders.cc b/mediapipe/gpu/gl_simple_shaders.cc index 1e6eefb5a..5cb718dc3 100644 --- a/mediapipe/gpu/gl_simple_shaders.cc +++ b/mediapipe/gpu/gl_simple_shaders.cc @@ -99,6 +99,26 @@ const GLchar* const kScaledVertexShader = VERTEX_PREAMBLE _STRINGIFY( sample_coordinate = texture_coordinate.xy; }); +const GLchar* const kTransformedVertexShader = VERTEX_PREAMBLE _STRINGIFY( + in vec4 position; in mediump vec4 texture_coordinate; + out mediump vec2 sample_coordinate; uniform mat3 transform; + uniform vec2 viewport_size; + + void main() { + // switch from clip to viewport aspect ratio in order to properly + // apply transformation + vec2 half_viewport_size = viewport_size * 0.5; + vec3 pos = vec3(position.xy * half_viewport_size, 1); + + // apply transform + pos = transform * pos; + + // switch back to clip space + gl_Position = vec4(pos.xy / half_viewport_size, 0, 1); + + sample_coordinate = texture_coordinate.xy; + }); + const GLchar* const kBasicTexturedFragmentShader = FRAGMENT_PREAMBLE _STRINGIFY( DEFAULT_PRECISION(mediump, float) diff --git a/mediapipe/gpu/gl_simple_shaders.h b/mediapipe/gpu/gl_simple_shaders.h index 8bc612ddd..e77a75327 100644 --- a/mediapipe/gpu/gl_simple_shaders.h +++ b/mediapipe/gpu/gl_simple_shaders.h @@ -38,6 +38,17 @@ extern const GLchar* const kBasicVertexShader; // vec2 sample_coordinate - texture coordinate for shader extern const GLchar* const kScaledVertexShader; +// Applies an affine transformation to the vertex and leaves texture coordinates +// as is. Input attributes: +// vec4 position - vertex position +// vec4 texture_coordinate - texture coordinate +// Input uniform: +// mat3 homogeneous affine transform - transformation matrix for vertices +// vec2 viewport_size - size of the viewport +// Output varying: +// vec2 sample_coordinate - texture coordinate for shader +extern const GLchar* const kTransformedVertexShader; + // Outputs the texture as it is. // Input varying: // vec2 sample_coordinate - texture coordinate diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index 702c80294..278ec444e 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -77,16 +77,24 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, }}, {GpuBufferFormat::kOneComponent8, { - // This should be GL_RED, but it would change the output for existing - // shaders. It would not be a good representation of a grayscale texture, - // unless we use texture swizzling. We could add swizzle parameters (e.g. - // GL_TEXTURE_SWIZZLE_R) in GLES 3 and desktop GL, and use GL_LUMINANCE - // in GLES 2. Or we could just punt and make it a red texture. - // {GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1}, + // This format is like RGBA grayscale: GL_LUMINANCE replicates + // the single channel texel values to RGB channels, and set alpha + // to 1.0. If it is desired to see only the texel values in the R + // channel, use kOneComponent8Red instead. #if !TARGET_OS_OSX {GL_LUMINANCE, GL_LUMINANCE, GL_UNSIGNED_BYTE, 1}, +#else + {GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1}, #endif // TARGET_OS_OSX }}, + {GpuBufferFormat::kOneComponent8Red, + { + {GL_R8, GL_RED, GL_UNSIGNED_BYTE, 1}, + }}, + {GpuBufferFormat::kTwoComponent8, + { + {GL_RG8, GL_RG, GL_UNSIGNED_BYTE, 1}, + }}, #ifdef __APPLE__ // TODO: figure out GL_RED_EXT etc. on Android. {GpuBufferFormat::kBiPlanar420YpCbCr8VideoRange, @@ -195,6 +203,8 @@ ImageFormat::Format ImageFormatForGpuBufferFormat(GpuBufferFormat format) { case GpuBufferFormat::kTwoComponentFloat32: return ImageFormat::VEC32F2; case GpuBufferFormat::kGrayHalf16: + case GpuBufferFormat::kOneComponent8Red: + case GpuBufferFormat::kTwoComponent8: case GpuBufferFormat::kTwoComponentHalf16: case GpuBufferFormat::kRGBAHalf64: case GpuBufferFormat::kRGBAFloat128: diff --git a/mediapipe/gpu/gpu_buffer_format.h b/mediapipe/gpu/gpu_buffer_format.h index a92c5712c..66999f755 100644 --- a/mediapipe/gpu/gpu_buffer_format.h +++ b/mediapipe/gpu/gpu_buffer_format.h @@ -38,6 +38,8 @@ enum class GpuBufferFormat : uint32_t { kGrayFloat32 = MEDIAPIPE_FOURCC('L', '0', '0', 'f'), kGrayHalf16 = MEDIAPIPE_FOURCC('L', '0', '0', 'h'), kOneComponent8 = MEDIAPIPE_FOURCC('L', '0', '0', '8'), + kOneComponent8Red = MEDIAPIPE_FOURCC('R', '0', '0', '8'), + kTwoComponent8 = MEDIAPIPE_FOURCC('2', 'C', '0', '8'), kTwoComponentHalf16 = MEDIAPIPE_FOURCC('2', 'C', '0', 'h'), kTwoComponentFloat32 = MEDIAPIPE_FOURCC('2', 'C', '0', 'f'), kBiPlanar420YpCbCr8VideoRange = MEDIAPIPE_FOURCC('4', '2', '0', 'v'), @@ -82,6 +84,10 @@ inline OSType CVPixelFormatForGpuBufferFormat(GpuBufferFormat format) { return kCVPixelFormatType_OneComponent32Float; case GpuBufferFormat::kOneComponent8: return kCVPixelFormatType_OneComponent8; + case GpuBufferFormat::kOneComponent8Red: + return -1; + case GpuBufferFormat::kTwoComponent8: + return kCVPixelFormatType_TwoComponent8; case GpuBufferFormat::kTwoComponentHalf16: return kCVPixelFormatType_TwoComponent16Half; case GpuBufferFormat::kTwoComponentFloat32: @@ -114,6 +120,8 @@ inline GpuBufferFormat GpuBufferFormatForCVPixelFormat(OSType format) { return GpuBufferFormat::kGrayFloat32; case kCVPixelFormatType_OneComponent8: return GpuBufferFormat::kOneComponent8; + case kCVPixelFormatType_TwoComponent8: + return GpuBufferFormat::kTwoComponent8; case kCVPixelFormatType_TwoComponent16Half: return GpuBufferFormat::kTwoComponentHalf16; case kCVPixelFormatType_TwoComponent32Float: diff --git a/mediapipe/java/com/google/mediapipe/solutioncore/ImageSolutionResult.java b/mediapipe/java/com/google/mediapipe/solutioncore/ImageSolutionResult.java index 881391a8d..3d2493dca 100644 --- a/mediapipe/java/com/google/mediapipe/solutioncore/ImageSolutionResult.java +++ b/mediapipe/java/com/google/mediapipe/solutioncore/ImageSolutionResult.java @@ -19,6 +19,7 @@ import com.google.mediapipe.framework.AndroidPacketGetter; import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.TextureFrame; +import java.util.ArrayList; import java.util.List; /** @@ -32,6 +33,9 @@ public class ImageSolutionResult implements SolutionResult { private Bitmap cachedBitmap; // A list of the output image packets produced by the graph. protected List imageResultPackets; + // The cached texture frames. + protected List imageResultTextureFrames; + private TextureFrame cachedTextureFrame; // Result timestamp, which is set to the timestamp of the corresponding input image. May return // Long.MIN_VALUE if the input image is not associated with a timestamp. @@ -61,8 +65,38 @@ public class ImageSolutionResult implements SolutionResult { return PacketGetter.getTextureFrame(imagePacket); } + // Returns the cached input image as a {@link TextureFrame}. + public TextureFrame getCachedInputTextureFrame() { + return cachedTextureFrame; + } + + // Produces all texture frames from image packets and caches them for further use. The caller must + // release the cached {@link TextureFrame}s after using. + void produceAllTextureFrames() { + cachedTextureFrame = acquireInputTextureFrame(); + if (imageResultPackets == null) { + return; + } + imageResultTextureFrames = new ArrayList<>(); + for (Packet p : imageResultPackets) { + imageResultTextureFrames.add(PacketGetter.getTextureFrame(p)); + } + } + + // Releases all cached {@link TextureFrame}s. + void releaseCachedTextureFrames() { + if (cachedTextureFrame != null) { + cachedTextureFrame.release(); + } + if (imageResultTextureFrames != null) { + for (TextureFrame textureFrame : imageResultTextureFrames) { + textureFrame.release(); + } + } + } + // Releases image packet and the underlying data. - void releaseImagePacket() { + void releaseImagePackets() { imagePacket.release(); if (imageResultPackets != null) { for (Packet p : imageResultPackets) { diff --git a/mediapipe/java/com/google/mediapipe/solutioncore/OutputHandler.java b/mediapipe/java/com/google/mediapipe/solutioncore/OutputHandler.java index c9236069f..51763431b 100644 --- a/mediapipe/java/com/google/mediapipe/solutioncore/OutputHandler.java +++ b/mediapipe/java/com/google/mediapipe/solutioncore/OutputHandler.java @@ -79,7 +79,7 @@ public class OutputHandler { } if (solutionResult instanceof ImageSolutionResult) { ImageSolutionResult imageSolutionResult = (ImageSolutionResult) solutionResult; - imageSolutionResult.releaseImagePacket(); + imageSolutionResult.releaseImagePackets(); } } } diff --git a/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceView.java b/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceView.java index 2c6ef0ace..8edbcb530 100644 --- a/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceView.java +++ b/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceView.java @@ -50,13 +50,25 @@ public class SolutionGlSurfaceView extends GLSurf } /** - * Sets the next textureframe and solution result to render. + * Sets the next input {@link TextureFrame} and solution result to render. * * @param solutionResult a solution result object that contains the solution outputs and a * textureframe. */ public void setRenderData(T solutionResult) { - renderer.setRenderData(solutionResult); + renderer.setRenderData(solutionResult, false); + } + + /** + * Sets the next input {@link TextureFrame} and solution result to render. + * + * @param solutionResult a solution result object that contains the solution outputs and a {@link + * TextureFrame}. + * @param produceTextureFrames whether to produce and cache all the {@link TextureFrame}s for + * further use. + */ + public void setRenderData(T solutionResult, boolean produceTextureFrames) { + renderer.setRenderData(solutionResult, produceTextureFrames); } /** Sets if the input image needs to be rendered. Default to true. */ diff --git a/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceViewRenderer.java b/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceViewRenderer.java index 8ba2715df..ccaa1e725 100644 --- a/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceViewRenderer.java +++ b/mediapipe/java/com/google/mediapipe/solutioncore/SolutionGlSurfaceViewRenderer.java @@ -27,9 +27,10 @@ import javax.microedition.khronos.opengles.GL10; * MediaPipe Solution's GlSurfaceViewRenderer. * *

Users can provide a custom {@link ResultGlRenderer} for rendering MediaPipe solution results. - * For setting the latest solution result, call {@link #setRenderData(ImageSolutionResult)}. By - * default, the renderer renders the input images. Call {@link #setRenderInputImage(boolean)} to - * explicitly set whether the input images should be rendered or not. + * For setting the latest solution result, call {@link #setRenderData(ImageSolutionResult, + * boolean)}. By default, the renderer renders the input images. Call {@link + * #setRenderInputImage(boolean)} to explicitly set whether the input images should be rendered or + * not. */ public class SolutionGlSurfaceViewRenderer extends GlSurfaceViewRenderer { @@ -49,16 +50,24 @@ public class SolutionGlSurfaceViewRenderer } /** - * Sets the next textureframe and solution result to render. + * Sets the next input {@link TextureFrame} and solution result to render. * * @param solutionResult a solution result object that contains the solution outputs and a * textureframe. + * @param produceTextureFrames whether to produce and cache all the {@link TextureFrame}s for + * further use. */ - public void setRenderData(T solutionResult) { + public void setRenderData(T solutionResult, boolean produceTextureFrames) { TextureFrame frame = solutionResult.acquireInputTextureFrame(); setFrameSize(frame.getWidth(), frame.getHeight()); setNextFrame(frame); - nextSolutionResult.getAndSet(solutionResult); + if (produceTextureFrames) { + solutionResult.produceAllTextureFrames(); + } + T oldSolutionResult = nextSolutionResult.getAndSet(solutionResult); + if (oldSolutionResult != null) { + oldSolutionResult.releaseCachedTextureFrames(); + } } @Override @@ -78,8 +87,9 @@ public class SolutionGlSurfaceViewRenderer GLES20.glActiveTexture(GLES20.GL_TEXTURE0); ShaderUtil.checkGlError("glActiveTexture"); } + T solutionResult = null; if (nextSolutionResult != null) { - T solutionResult = nextSolutionResult.getAndSet(null); + solutionResult = nextSolutionResult.getAndSet(null); float[] textureBoundary = calculateTextureBoundary(); // Scales the values from [0, 1] to [-1, 1]. ResultGlBoundary resultGlBoundary = @@ -91,6 +101,9 @@ public class SolutionGlSurfaceViewRenderer resultGlRenderer.renderResult(solutionResult, resultGlBoundary); } flush(frame); + if (solutionResult != null) { + solutionResult.releaseCachedTextureFrames(); + } } @Override diff --git a/mediapipe/modules/face_detection/face_detection_full_range_cpu.pbtxt b/mediapipe/modules/face_detection/face_detection_full_range_cpu.pbtxt index 13d474021..235040190 100644 --- a/mediapipe/modules/face_detection/face_detection_full_range_cpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_full_range_cpu.pbtxt @@ -64,9 +64,10 @@ node { options: { [mediapipe.InferenceCalculatorOptions.ext] { model_path: "mediapipe/modules/face_detection/face_detection_full_range_sparse.tflite" - delegate { xnnpack {} } + delegate { + xnnpack {} + } } - # } } diff --git a/mediapipe/modules/hand_landmark/hand_landmark_cpu.pbtxt b/mediapipe/modules/hand_landmark/hand_landmark_cpu.pbtxt index c46b243ad..98f126676 100644 --- a/mediapipe/modules/hand_landmark/hand_landmark_cpu.pbtxt +++ b/mediapipe/modules/hand_landmark/hand_landmark_cpu.pbtxt @@ -50,9 +50,10 @@ node { options: { [mediapipe.InferenceCalculatorOptions.ext] { model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite" - delegate { xnnpack {} } + delegate { + xnnpack {} + } } - # } } diff --git a/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt b/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt index d46d6a5c5..79ee1ac94 100644 --- a/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt +++ b/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt @@ -72,9 +72,10 @@ node { options: { [mediapipe.InferenceCalculatorOptions.ext] { model_path: "mediapipe/modules/pose_detection/pose_detection.tflite" - delegate { xnnpack {} } + delegate { + xnnpack {} + } } - # } } diff --git a/mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt b/mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt index 3075e2604..ac86233ef 100644 --- a/mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt +++ b/mediapipe/modules/pose_landmark/tensors_to_pose_landmarks_and_segmentation.pbtxt @@ -245,7 +245,7 @@ node { output_stream: "enabled_segmentation_tensor" options: { [mediapipe.GateCalculatorOptions.ext] { - allow: true + allow: false } } } diff --git a/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu.pbtxt b/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu.pbtxt index 550cee906..591824851 100644 --- a/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu.pbtxt +++ b/mediapipe/modules/selfie_segmentation/selfie_segmentation_cpu.pbtxt @@ -96,9 +96,10 @@ node { input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" options: { [mediapipe.InferenceCalculatorOptions.ext] { - delegate { xnnpack {} } + delegate { + xnnpack {} + } } - # } } diff --git a/mediapipe/python/pybind/calculator_graph.cc b/mediapipe/python/pybind/calculator_graph.cc index 017e16b3f..afcf4a65c 100644 --- a/mediapipe/python/pybind/calculator_graph.cc +++ b/mediapipe/python/pybind/calculator_graph.cc @@ -288,7 +288,10 @@ void CalculatorGraphSubmodule(pybind11::module* module) { calculator_graph.def( "wait_until_done", - [](CalculatorGraph* self) { RaisePyErrorIfNotOk(self->WaitUntilDone()); }, + [](CalculatorGraph* self) { + py::gil_scoped_release gil_release; + RaisePyErrorIfNotOk(self->WaitUntilDone(), /**acquire_gil=*/true); + }, R"doc(Wait for the current run to finish. A blocking call to wait for the current run to finish (block the current @@ -313,7 +316,10 @@ void CalculatorGraphSubmodule(pybind11::module* module) { calculator_graph.def( "wait_until_idle", - [](CalculatorGraph* self) { RaisePyErrorIfNotOk(self->WaitUntilIdle()); }, + [](CalculatorGraph* self) { + py::gil_scoped_release gil_release; + RaisePyErrorIfNotOk(self->WaitUntilIdle(), /**acquire_gil=*/true); + }, R"doc(Wait until the running graph is in the idle mode. Wait until the running graph is in the idle mode, which is when nothing can @@ -399,12 +405,9 @@ void CalculatorGraphSubmodule(pybind11::module* module) { stream_name, [callback_fn, stream_name](const Packet& packet) { absl::MutexLock lock(&callback_mutex); - py::gil_scoped_release gil_release; - { - // Acquires GIL before calling Python callback. - py::gil_scoped_acquire gil_acquire; - callback_fn(stream_name, packet); - } + // Acquires GIL before calling Python callback. + py::gil_scoped_acquire gil_acquire; + callback_fn(stream_name, packet); return absl::OkStatus(); }, observe_timestamp_bounds)); @@ -439,7 +442,8 @@ void CalculatorGraphSubmodule(pybind11::module* module) { "close", [](CalculatorGraph* self) { RaisePyErrorIfNotOk(self->CloseAllPacketSources()); - RaisePyErrorIfNotOk(self->WaitUntilDone()); + py::gil_scoped_release gil_release; + RaisePyErrorIfNotOk(self->WaitUntilDone(), /**acquire_gil=*/true); }, R"doc(Close all the input sources and shutdown the graph.)doc"); diff --git a/mediapipe/python/pybind/util.h b/mediapipe/python/pybind/util.h index 099f75bd6..94b6081c6 100644 --- a/mediapipe/python/pybind/util.h +++ b/mediapipe/python/pybind/util.h @@ -19,6 +19,7 @@ #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/timestamp.h" +#include "pybind11/gil.h" #include "pybind11/pybind11.h" namespace mediapipe { @@ -45,10 +46,17 @@ inline PyObject* StatusCodeToPyError(const ::absl::StatusCode& code) { } } -inline void RaisePyErrorIfNotOk(const absl::Status& status) { +inline void RaisePyErrorIfNotOk(const absl::Status& status, + bool acquire_gil = false) { if (!status.ok()) { - throw RaisePyError(StatusCodeToPyError(status.code()), - status.message().data()); + if (acquire_gil) { + py::gil_scoped_acquire acquire; + throw RaisePyError(StatusCodeToPyError(status.code()), + status.message().data()); + } else { + throw RaisePyError(StatusCodeToPyError(status.code()), + status.message().data()); + } } } diff --git a/mediapipe/python/solution_base.py b/mediapipe/python/solution_base.py index 1837f7f5e..b46a13209 100644 --- a/mediapipe/python/solution_base.py +++ b/mediapipe/python/solution_base.py @@ -441,7 +441,7 @@ class SolutionBase: else: field_label = calculator_options.DESCRIPTOR.fields_by_name[ field_name].label - if field_label is descriptor.FieldDescriptor.LABEL_REPEATED: + if field_label == descriptor.FieldDescriptor.LABEL_REPEATED: if not isinstance(field_value, Iterable): raise ValueError( f'{field_name} is a repeated proto field but the value ' diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index b468f75a4..5576e606f 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -111,7 +111,7 @@ def draw_detection( image_rows) rect_end_point = _normalized_to_pixel_coordinates( relative_bounding_box.xmin + relative_bounding_box.width, - relative_bounding_box.ymin + +relative_bounding_box.height, image_cols, + relative_bounding_box.ymin + relative_bounding_box.height, image_cols, image_rows) cv2.rectangle(image, rect_start_point, rect_end_point, bbox_drawing_spec.color, bbox_drawing_spec.thickness) diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index 11f08ade5..e66b27722 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -175,8 +175,9 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":resource_util_custom", - "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:str_format", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:singleton", "//mediapipe/framework/port:status", diff --git a/setup_opencv.sh b/setup_opencv.sh index 608bb1aa9..b055b3a87 100644 --- a/setup_opencv.sh +++ b/setup_opencv.sh @@ -38,10 +38,17 @@ workspace_file="$( cd "$(dirname "$0")" ; pwd -P )"/WORKSPACE if [ -z "$1" ] then echo "Installing OpenCV from source" - sudo apt update && sudo apt install build-essential git - sudo apt install cmake ffmpeg libavformat-dev libdc1394-22-dev libgtk2.0-dev \ - libjpeg-dev libpng-dev libswscale-dev libtbb2 libtbb-dev \ - libtiff-dev + if [[ -x "$(command -v apt)" ]]; then + sudo apt update && sudo apt install build-essential git + sudo apt install cmake ffmpeg libavformat-dev libdc1394-22-dev libgtk2.0-dev \ + libjpeg-dev libpng-dev libswscale-dev libtbb2 libtbb-dev \ + libtiff-dev + elif [[ -x "$(command -v dnf)" ]]; then + sudo dnf update && sudo dnf install cmake gcc gcc-c git + sudo dnf install ffmpeg-devel libdc1394-devel gtk2-devel \ + libjpeg-turbo-devel libpng-devel tbb-devel \ + libtiff-devel + fi rm -rf /tmp/build_opencv mkdir /tmp/build_opencv cd /tmp/build_opencv