diff --git a/.bazelrc b/.bazelrc index 0ec819c0d..37a0bc114 100644 --- a/.bazelrc +++ b/.bazelrc @@ -5,9 +5,12 @@ common --experimental_repo_remote_exec # Basic build settings build --jobs 128 -build --define='absl=1' +build --define='absl=1' # for gtest build --enable_platform_specific_config +# Enable stack traces +test --test_env="GTEST_INSTALL_FAILURE_SIGNAL_HANDLER=1" + # Linux build:linux --cxxopt=-std=c++17 build:linux --host_cxxopt=-std=c++17 diff --git a/MANIFEST.in b/MANIFEST.in index f277566d4..1994721f3 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,3 +8,5 @@ include README.md include requirements.txt recursive-include mediapipe/modules *.tflite *.txt *.binarypb +exclude mediapipe/modules/objectron/object_detection_3d_chair_1stage.tflite +exclude mediapipe/modules/objectron/object_detection_3d_sneakers_1stage.tflite diff --git a/README.md b/README.md index 444b2b1f6..06fa39b5e 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Hair Segmentation []() | [Android](https://google.github.io/mediapipe/getting_started/android) | [iOS](https://google.github.io/mediapipe/getting_started/ios) | [C++](https://google.github.io/mediapipe/getting_started/cpp) | [Python](https://google.github.io/mediapipe/getting_started/python) | [JS](https://google.github.io/mediapipe/getting_started/javascript) | [Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/README.md) :---------------------------------------------------------------------------------------- | :-------------------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------------: | :-----------------------------------------------------------: | :--------------------------------------------------------------------: -[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | | ✅ +[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | ✅ | ✅ | [Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | | [Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | ✅ | ✅ | @@ -44,7 +44,7 @@ Hair Segmentation [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | diff --git a/WORKSPACE b/WORKSPACE index d88d8fc95..32b466e6c 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -38,8 +38,8 @@ http_archive( http_archive( name = "rules_foreign_cc", - strip_prefix = "rules_foreign_cc-master", - url = "https://github.com/bazelbuild/rules_foreign_cc/archive/master.zip", + strip_prefix = "rules_foreign_cc-main", + url = "https://github.com/bazelbuild/rules_foreign_cc/archive/main.zip", ) load("@rules_foreign_cc//:workspace_definitions.bzl", "rules_foreign_cc_dependencies") diff --git a/docs/framework_concepts/calculators.md b/docs/framework_concepts/calculators.md index 0ee3473e6..3e1236aaa 100644 --- a/docs/framework_concepts/calculators.md +++ b/docs/framework_concepts/calculators.md @@ -67,26 +67,26 @@ class CalculatorBase { // The subclasses of CalculatorBase must implement GetContract. // ... - static ::MediaPipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // Open is called before any Process() calls, on a freshly constructed // calculator. Subclasses may override this method to perform necessary // setup, and possibly output Packets and/or set output streams' headers. // ... - virtual ::MediaPipe::Status Open(CalculatorContext* cc) { - return ::MediaPipe::OkStatus(); + virtual absl::Status Open(CalculatorContext* cc) { + return absl::OkStatus(); } // Processes the incoming inputs. May call the methods on cc to access // inputs and produce outputs. // ... - virtual ::MediaPipe::Status Process(CalculatorContext* cc) = 0; + virtual absl::Status Process(CalculatorContext* cc) = 0; // Is called if Open() was called and succeeded. Is called either // immediately after processing is complete or after a graph run has ended // (if an error occurred in the graph). ... - virtual ::MediaPipe::Status Close(CalculatorContext* cc) { - return ::MediaPipe::OkStatus(); + virtual absl::Status Close(CalculatorContext* cc) { + return absl::OkStatus(); } ... @@ -199,7 +199,7 @@ name and index number. In the function below input are output are identified: // c++ Code snippet describing the SomeAudioVideoCalculator GetContract() method class SomeAudioVideoCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); // SetAny() is used to specify that whatever the type of the // stream is, it's acceptable. This does not mean that any @@ -209,13 +209,13 @@ class SomeAudioVideoCalculator : public CalculatorBase { cc->Outputs().Tag("VIDEO").Set(); cc->Outputs().Get("AUDIO", 0).Set(); cc->Outputs().Get("AUDIO", 1).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } ``` ## Processing -`Process()` called on a non-source node must return `::mediapipe::OkStatus()` to +`Process()` called on a non-source node must return `absl::OkStatus()` to indicate that all went well, or any other status code to signal an error If a non-source calculator returns `tool::StatusStop()`, then this signals the @@ -224,12 +224,12 @@ input streams will be closed (and remaining Packets will propagate through the graph). A source node in a graph will continue to have `Process()` called on it as long -as it returns `::mediapipe::OkStatus(`). To indicate that there is no more data -to be generated return `tool::StatusStop()`. Any other status indicates an error -has occurred. +as it returns `absl::OkStatus(`). To indicate that there is no more data to be +generated return `tool::StatusStop()`. Any other status indicates an error has +occurred. -`Close()` returns `::mediapipe::OkStatus()` to indicate success. Any other -status indicates a failure. +`Close()` returns `absl::OkStatus()` to indicate success. Any other status +indicates a failure. Here is the basic `Process()` function. It uses the `Input()` method (which can be used only if the calculator has a single input) to request its input data. It @@ -238,13 +238,13 @@ and does the calculations. When done it releases the pointer when adding it to the output stream. ```c++ -::util::Status MyCalculator::Process() { +absl::Status MyCalculator::Process() { const Matrix& input = Input()->Get(); std::unique_ptr output(new Matrix(input.rows(), input.cols())); // do your magic here.... // output->row(n) = ... Output()->Add(output.release(), InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } ``` @@ -312,7 +312,7 @@ namespace mediapipe { // class PacketClonerCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const int tick_signal_index = cc->Inputs().NumEntries() - 1; // cc->Inputs().NumEntries() returns the number of input streams // for the PacketClonerCalculator @@ -322,10 +322,10 @@ class PacketClonerCalculator : public CalculatorBase { cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i)); } cc->Inputs().Index(tick_signal_index).SetAny(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { tick_signal_index_ = cc->Inputs().NumEntries() - 1; current_.resize(tick_signal_index_); // Pass along the header for each stream if present. @@ -336,10 +336,10 @@ class PacketClonerCalculator : public CalculatorBase { // the header for the input stream of index i } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { // Store input signals. for (int i = 0; i < tick_signal_index_; ++i) { if (!cc->Inputs().Index(i).Value().IsEmpty()) { @@ -364,7 +364,7 @@ class PacketClonerCalculator : public CalculatorBase { } } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/docs/framework_concepts/gpu.md b/docs/framework_concepts/gpu.md index 77d566e8d..8f9df6067 100644 --- a/docs/framework_concepts/gpu.md +++ b/docs/framework_concepts/gpu.md @@ -66,10 +66,10 @@ calculator derived from base class GlSimpleCalculator. The GPU calculator // See GlSimpleCalculator for inputs, outputs and input side packets. class LuminanceCalculator : public GlSimpleCalculator { public: - ::mediapipe::Status GlSetup() override; - ::mediapipe::Status GlRender(const GlTexture& src, - const GlTexture& dst) override; - ::mediapipe::Status GlTeardown() override; + absl::Status GlSetup() override; + absl::Status GlRender(const GlTexture& src, + const GlTexture& dst) override; + absl::Status GlTeardown() override; private: GLuint program_ = 0; @@ -77,8 +77,8 @@ class LuminanceCalculator : public GlSimpleCalculator { }; REGISTER_CALCULATOR(LuminanceCalculator); -::mediapipe::Status LuminanceCalculator::GlRender(const GlTexture& src, - const GlTexture& dst) { +absl::Status LuminanceCalculator::GlRender(const GlTexture& src, + const GlTexture& dst) { static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -128,7 +128,7 @@ REGISTER_CALCULATOR(LuminanceCalculator); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } ``` diff --git a/docs/framework_concepts/graphs.md b/docs/framework_concepts/graphs.md index 83f95e5bb..2bbdd6856 100644 --- a/docs/framework_concepts/graphs.md +++ b/docs/framework_concepts/graphs.md @@ -219,23 +219,23 @@ packet timestamps 0, 1, 2, 3, ... ```c++ class UnitDelayCalculator : public Calculator { public: -  static ::util::Status FillExpectations( +  static absl::Status FillExpectations(      const CalculatorOptions& extendable_options, PacketTypeSet* inputs,      PacketTypeSet* outputs, PacketTypeSet* input_side_packets) {    inputs->Index(0)->Set("An integer.");    outputs->Index(0)->Set("The input delayed by one time unit."); -    return ::mediapipe::OkStatus(); +    return absl::OkStatus();  } -  ::util::Status Open() final { +  absl::Status Open() final {    Output()->Add(new int(0), Timestamp(0)); -    return ::mediapipe::OkStatus(); +    return absl::OkStatus();  } -  ::util::Status Process() final { +  absl::Status Process() final {    const Packet& packet = Input()->Value();    Output()->AddPacket(packet.At(packet.Timestamp().NextAllowedInStream())); -    return ::mediapipe::OkStatus(); +    return absl::OkStatus();  } }; ``` diff --git a/docs/getting_started/android_archive_library.md b/docs/getting_started/android_archive_library.md index bd6a1e1c1..735bd7a39 100644 --- a/docs/getting_started/android_archive_library.md +++ b/docs/getting_started/android_archive_library.md @@ -45,7 +45,8 @@ each project. 2. Run the Bazel build command to generate the AAR. ```bash - bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain --fat_apk_cpu=arm64-v8a,armeabi-v7a \ + bazel build -c opt --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --fat_apk_cpu=arm64-v8a,armeabi-v7a --strip=ALWAYS \ //path/to/the/aar/build/file:aar_name ``` @@ -86,16 +87,14 @@ each project. Build the MediaPipe binary graph and copy the assets into app/src/main/assets, e.g., for the face detection graph, you need to build and copy - [the binary graph](https://github.com/google/mediapipe/blob/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD#L41), - [the tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/models/face_detection_front.tflite), + [the binary graph](https://github.com/google/mediapipe/blob/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectiongpu/BUILD#L41) and - [the label map](https://github.com/google/mediapipe/blob/master/mediapipe/models/face_detection_front_labelmap.txt). + [the face detection tflite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_detection/face_detection_front.tflite). ```bash bazel build -c opt mediapipe/mediapipe/graphs/face_detection:mobile_gpu_binary_graph cp bazel-bin/mediapipe/graphs/face_detection/mobile_gpu.binarypb /path/to/your/app/src/main/assets/ - cp mediapipe/models/face_detection_front.tflite /path/to/your/app/src/main/assets/ - cp mediapipe/models/face_detection_front_labelmap.txt /path/to/your/app/src/main/assets/ + cp mediapipe/modules/face_detection/face_detection_front.tflite /path/to/your/app/src/main/assets/ ``` ![Screenshot](../images/mobile/assets_location.png) diff --git a/docs/getting_started/hello_world_android.md b/docs/getting_started/hello_world_android.md index c828ed226..9f277f799 100644 --- a/docs/getting_started/hello_world_android.md +++ b/docs/getting_started/hello_world_android.md @@ -59,7 +59,7 @@ node: { output_stream: "luma_video" } -# Applies the Sobel filter to luminance images sotred in RGB format. +# Applies the Sobel filter to luminance images stored in RGB format. node: { calculator: "SobelEdgesCalculator" input_stream: "luma_video" diff --git a/docs/getting_started/hello_world_cpp.md b/docs/getting_started/hello_world_cpp.md index f46e88698..e3d34d9b4 100644 --- a/docs/getting_started/hello_world_cpp.md +++ b/docs/getting_started/hello_world_cpp.md @@ -44,7 +44,7 @@ nav_order: 1 `PrintHelloWorld()` function, defined in a [`CalculatorGraphConfig`] proto. ```C++ - ::mediapipe::Status PrintHelloWorld() { + absl::Status PrintHelloWorld() { // Configures a simple graph, which concatenates 2 PassThroughCalculators. CalculatorGraphConfig config = ParseTextProtoOrDie(R"( input_stream: "in" diff --git a/docs/getting_started/hello_world_ios.md b/docs/getting_started/hello_world_ios.md index 0441623e3..06d79c67d 100644 --- a/docs/getting_started/hello_world_ios.md +++ b/docs/getting_started/hello_world_ios.md @@ -492,6 +492,9 @@ in our app: if (![self.mediapipeGraph startWithError:&error]) { NSLog(@"Failed to start graph: %@", error); } + else if (![self.mediapipeGraph waitUntilIdleWithError:&error]) { + NSLog(@"Failed to complete graph initial run: %@", error); + } dispatch_async(_videoQueue, ^{ [_cameraSource start]; @@ -500,8 +503,9 @@ in our app: }]; ``` -Note: It is important to start the graph before starting the camera, so that the -graph is ready to process frames as soon as the camera starts sending them. +Note: It is important to start the graph before starting the camera and wait +until completion, so that the graph is ready to process frames as soon as the +camera starts sending them. Earlier, when we received frames from the camera in the `processVideoFrame` function, we displayed them in the `_liveView` using the `_renderer`. Now, we diff --git a/docs/getting_started/javascript.md b/docs/getting_started/javascript.md index 95a1e2610..c6df75bd8 100644 --- a/docs/getting_started/javascript.md +++ b/docs/getting_started/javascript.md @@ -19,9 +19,10 @@ MediaPipe currently offers the following solutions: Solution | NPM Package | Example ----------------- | ----------------------------- | ------- [Face Mesh][F-pg] | [@mediapipe/face_mesh][F-npm] | [mediapipe.dev/demo/face_mesh][F-demo] +[Face Detection][Fd-pg] | [@mediapipe/face_detection][Fd-npm] | [mediapipe.dev/demo/face_detection][Fd-demo] [Hands][H-pg] | [@mediapipe/hands][H-npm] | [mediapipe.dev/demo/hands][H-demo] -[Pose][P-pg] | [@mediapipe/pose][P-npm] | [mediapipe.dev/demo/pose][P-demo] [Holistic][Ho-pg] | [@mediapipe/holistic][Ho-npm] | [mediapipe.dev/demo/holistic][Ho-demo] +[Pose][P-pg] | [@mediapipe/pose][P-npm] | [mediapipe.dev/demo/pose][P-demo] Click on a solution link above for more information, including API and code snippets. @@ -63,10 +64,12 @@ affecting your work, restrict your request to a `` number. e.g., [Ho-pg]: ../solutions/holistic#javascript-solution-api [F-pg]: ../solutions/face_mesh#javascript-solution-api +[Fd-pg]: ../solutions/face_detection#javascript-solution-api [H-pg]: ../solutions/hands#javascript-solution-api [P-pg]: ../solutions/pose#javascript-solution-api [Ho-npm]: https://www.npmjs.com/package/@mediapipe/holistic [F-npm]: https://www.npmjs.com/package/@mediapipe/face_mesh +[Fd-npm]: https://www.npmjs.com/package/@mediapipe/face_detection [H-npm]: https://www.npmjs.com/package/@mediapipe/hands [P-npm]: https://www.npmjs.com/package/@mediapipe/pose [draw-npm]: https://www.npmjs.com/package/@mediapipe/pose @@ -74,14 +77,17 @@ affecting your work, restrict your request to a `` number. e.g., [ctrl-npm]: https://www.npmjs.com/package/@mediapipe/pose [Ho-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/holistic [F-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_mesh +[Fd-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/face_detection [H-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/hands [P-jsd]: https://www.jsdelivr.com/package/npm/@mediapipe/pose [Ho-pen]: https://code.mediapipe.dev/codepen/holistic [F-pen]: https://code.mediapipe.dev/codepen/face_mesh +[Fd-pen]: https://code.mediapipe.dev/codepen/face_detection [H-pen]: https://code.mediapipe.dev/codepen/hands [P-pen]: https://code.mediapipe.dev/codepen/pose [Ho-demo]: https://mediapipe.dev/demo/holistic [F-demo]: https://mediapipe.dev/demo/face_mesh +[Fd-demo]: https://mediapipe.dev/demo/face_detection [H-demo]: https://mediapipe.dev/demo/hands [P-demo]: https://mediapipe.dev/demo/pose [npm]: https://www.npmjs.com/package/@mediapipe diff --git a/docs/getting_started/python.md b/docs/getting_started/python.md index c97f2a839..5d4bc2fb9 100644 --- a/docs/getting_started/python.md +++ b/docs/getting_started/python.md @@ -45,17 +45,23 @@ Tip: Use command `deactivate` to later exit the Python virtual environment. To learn more about configuration options and usage examples, please find details in each solution via the links below: +* [MediaPipe Face Detection](../solutions/face_detection#python-solution-api) * [MediaPipe Face Mesh](../solutions/face_mesh#python-solution-api) * [MediaPipe Hands](../solutions/hands#python-solution-api) -* [MediaPipe Pose](../solutions/pose#python-solution-api) * [MediaPipe Holistic](../solutions/holistic#python-solution-api) +* [MediaPipe Objectron](../solutions/objectron#python-solution-api) +* [MediaPipe Pose](../solutions/pose#python-solution-api) ## MediaPipe on Google Colab +* [MediaPipe Face Detection Colab](https://mediapipe.page.link/face_detection_py_colab) * [MediaPipe Face Mesh Colab](https://mediapipe.page.link/face_mesh_py_colab) * [MediaPipe Hands Colab](https://mediapipe.page.link/hands_py_colab) -* [MediaPipe Pose Colab](https://mediapipe.page.link/pose_py_colab) * [MediaPipe Holistic Colab](https://mediapipe.page.link/holistic_py_colab) +* [MediaPipe Objectron Colab](https://mediapipe.page.link/objectron_py_colab) +* [MediaPipe Pose Colab](https://mediapipe.page.link/pose_py_colab) +* [MediaPipe Pose Classification Colab (Basic)](https://mediapipe.page.link/pose_classification_basic) +* [MediaPipe Pose Classification Colab (Extended)](https://mediapipe.page.link/pose_classification_extended) ## MediaPipe Python Framework diff --git a/docs/images/box_coordinate.svg b/docs/images/box_coordinate.svg new file mode 100644 index 000000000..f436de896 --- /dev/null +++ b/docs/images/box_coordinate.svg @@ -0,0 +1,3 @@ + + +
+Z
+Z
UP
UP
Front
Front
(0, 0, 0)
(0, 0, 0)
+Y
+Y
+X
+X
Viewer does not support full SVG 1.1
diff --git a/docs/images/camera_coordinate.svg b/docs/images/camera_coordinate.svg new file mode 100644 index 000000000..4cd3158ee --- /dev/null +++ b/docs/images/camera_coordinate.svg @@ -0,0 +1,3 @@ + + +
+Z
+Z
+Y
+Y
+X
+X
-Z
-Z
(l, t, -n)
(l, t,...
(l, b, -n)
(l, b, -...
(r, t, n)
(r, t,...
(r, b, -n)
(r, b, -...
Viewer does not support full SVG 1.1
diff --git a/docs/images/mobile/pose_classification_pairwise_distances.png b/docs/images/mobile/pose_classification_pairwise_distances.png new file mode 100644 index 000000000..1aa2206df Binary files /dev/null and b/docs/images/mobile/pose_classification_pairwise_distances.png differ diff --git a/docs/images/mobile/pose_classification_pushups_and_squats.gif b/docs/images/mobile/pose_classification_pushups_and_squats.gif new file mode 100644 index 000000000..fe75f3bca Binary files /dev/null and b/docs/images/mobile/pose_classification_pushups_and_squats.gif differ diff --git a/docs/images/mobile/pose_classification_pushups_un_and_down_samples.jpg b/docs/images/mobile/pose_classification_pushups_un_and_down_samples.jpg new file mode 100644 index 000000000..269e1b86b Binary files /dev/null and b/docs/images/mobile/pose_classification_pushups_un_and_down_samples.jpg differ diff --git a/docs/images/ndc_coordinate.svg b/docs/images/ndc_coordinate.svg new file mode 100644 index 000000000..038660fd4 --- /dev/null +++ b/docs/images/ndc_coordinate.svg @@ -0,0 +1,3 @@ + + +
+Z
+Z
(0, 0, 0)
(0, 0, 0)
+Y
+Y
+X
+X
(-1, 1, -1)
(-1, 1, -1)
(1, -1, -1)
(1, -1, -1)
(-1, -1, -1)
(-1, -1, -1)
(-1, -1, 1)
(-1, -1, 1)
(1, -1, 1)
(1, -1, 1)
(1, 1, 1)
(1, 1, 1)
Viewer does not support full SVG 1.1
diff --git a/docs/index.md b/docs/index.md index 806e31c7f..d3db8892d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -34,7 +34,7 @@ Hair Segmentation []() | [Android](https://google.github.io/mediapipe/getting_started/android) | [iOS](https://google.github.io/mediapipe/getting_started/ios) | [C++](https://google.github.io/mediapipe/getting_started/cpp) | [Python](https://google.github.io/mediapipe/getting_started/python) | [JS](https://google.github.io/mediapipe/getting_started/javascript) | [Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/README.md) :---------------------------------------------------------------------------------------- | :-------------------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------------: | :-----------------------------------------------------------: | :--------------------------------------------------------------------: -[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | | ✅ +[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | ✅ | ✅ | [Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | | [Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | ✅ | ✅ | @@ -44,7 +44,7 @@ Hair Segmentation [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | diff --git a/docs/solutions/face_detection.md b/docs/solutions/face_detection.md index 3a74bee0a..f04af27d7 100644 --- a/docs/solutions/face_detection.md +++ b/docs/solutions/face_detection.md @@ -39,6 +39,169 @@ section. ![face_detection_android_gpu.gif](../images/mobile/face_detection_android_gpu.gif) +## Solution APIs + +### Configuration Options + +Naming style and availability may differ slightly across platforms/languages. + +#### min_detection_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the face detection model for the +detection to be considered successful. Default to `0.5`. + +### Output + +Naming style may differ slightly across platforms/languages. + +#### detections + +Collection of detected faces, where each face is represented as a detection +proto message that contains a bounding box and 6 key points (right eye, left +eye, nose tip, mouth center, right ear tragion, and left ear tragion). The +bounding box is composed of `xmin` and `width` (both normalized to `[0.0, 1.0]` +by the image width) and `ymin` and `height` (both normalized to `[0.0, 1.0]` by +the image height). Each key point is composed of `x` and `y`, which are +normalized to `[0.0, 1.0]` by the image width and height respectively. + +### Python Solution API + +Please first follow general [instructions](../getting_started/python.md) to +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. + +Supported configuration options: + +* [min_detection_confidence](#min_detection_confidence) + +```python +import cv2 +import mediapipe as mp +mp_face_detction = mp.solutions.face_detection + +# For static images: +with mp_face_detection.FaceDetection( + min_detection_confidence=0.5) as face_detection: + for idx, file in enumerate(file_list): + image = cv2.imread(file) + # Convert the BGR image to RGB and process it with MediaPipe Face Detection. + results = face_detection.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Draw face detections of each face. + if not results.detections: + continue + annotated_image = image.copy() + for detection in results.detections: + print('Nose tip:') + print(mp_face_detection.get_key_point( + detection, mp_face_detection.FaceKeyPoint.NOSE_TIP)) + mp_drawing.draw_detection(annotated_image, detection) + cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) + +# For webcam input: +cap = cv2.VideoCapture(0) +with mp_face_detection.FaceDetection( + min_detection_confidence=0.5) as face_detection: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue + + # Flip the image horizontally for a later selfie-view display, and convert + # the BGR image to RGB. + image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = face_detection.process(image) + + # Draw the face detection annotations on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + if results.detections: + for detection in results.detections: + mp_drawing.draw_detection(image, detection) + cv2.imshow('MediaPipe Face Detection', image) + if cv2.waitKey(5) & 0xFF == 27: + break +cap.release() +``` + +### JavaScript Solution API + +Please first see general [introduction](../getting_started/javascript.md) on +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and the following usage example. + +Supported configuration options: + +* [minDetectionConfidence](#min_detection_confidence) + +```html + + + + + + + + + + + +
+ + +
+ + +``` + +```javascript + +``` + ## Example Apps Please first see general instructions for @@ -108,3 +271,5 @@ to cross-compile and run MediaPipe examples on the ([presentation](https://docs.google.com/presentation/d/1YCtASfnYyZtH-41QvnW5iZxELFnf0MF-pPWSLGj8yjQ/present?slide=id.g5bc8aeffdd_1_0)) ([poster](https://drive.google.com/file/d/1u6aB6wxDY7X2TmeUUKgFydulNtXkb3pu/view)) * [Models and model cards](./models.md#face_detection) +* [Web demo](https://code.mediapipe.dev/codepen/face_detection) +* [Python Colab](https://mediapipe.page.link/face_detection_py_colab) diff --git a/docs/solutions/face_mesh.md b/docs/solutions/face_mesh.md index bea135105..0c620120c 100644 --- a/docs/solutions/face_mesh.md +++ b/docs/solutions/face_mesh.md @@ -185,8 +185,8 @@ following steps are executed in the given order: The geometry pipeline is implemented as a MediaPipe [calculator](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/geometry_pipeline_calculator.cc). For your convenience, the face geometry pipeline calculator is bundled together -with the face landmark module into a unified MediaPipe -[subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry_front_gpu.pbtxt). +with corresponding metadata into a unified MediaPipe +[subgraph](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt). The face geometry format is defined as a Protocol Buffer [message](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/protos/face_geometry.proto). @@ -264,8 +264,8 @@ magnitude of `z` uses roughly the same scale as `x`. ### Python Solution API Please first follow general [instructions](../getting_started/python.md) to -install MediaPipe Python package, then learn more in the companion [Colab] and -the following usage example. +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. Supported configuration options: @@ -281,74 +281,73 @@ mp_drawing = mp.solutions.drawing_utils mp_face_mesh = mp.solutions.face_mesh # For static images: -face_mesh = mp_face_mesh.FaceMesh( +drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) +with mp_face_mesh.FaceMesh( static_image_mode=True, max_num_faces=1, - min_detection_confidence=0.5) -drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) -for idx, file in enumerate(file_list): - image = cv2.imread(file) - # Convert the BGR image to RGB before processing. - results = face_mesh.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + min_detection_confidence=0.5) as face_mesh: + for idx, file in enumerate(file_list): + image = cv2.imread(file) + # Convert the BGR image to RGB before processing. + results = face_mesh.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - # Print and draw face mesh landmarks on the image. - if not results.multi_face_landmarks: - continue - annotated_image = image.copy() - for face_landmarks in results.multi_face_landmarks: - print('face_landmarks:', face_landmarks) - mp_drawing.draw_landmarks( - image=annotated_image, - landmark_list=face_landmarks, - connections=mp_face_mesh.FACE_CONNECTIONS, - landmark_drawing_spec=drawing_spec, - connection_drawing_spec=drawing_spec) - cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) -face_mesh.close() - -# For webcam input: -face_mesh = mp_face_mesh.FaceMesh( - min_detection_confidence=0.5, min_tracking_confidence=0.5) -drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) -cap = cv2.VideoCapture(0) -while cap.isOpened(): - success, image = cap.read() - if not success: - print("Ignoring empty camera frame.") - # If loading a video, use 'break' instead of 'continue'. - continue - - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) - # To improve performance, optionally mark the image as not writeable to - # pass by reference. - image.flags.writeable = False - results = face_mesh.process(image) - - # Draw the face mesh annotations on the image. - image.flags.writeable = True - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - if results.multi_face_landmarks: + # Print and draw face mesh landmarks on the image. + if not results.multi_face_landmarks: + continue + annotated_image = image.copy() for face_landmarks in results.multi_face_landmarks: + print('face_landmarks:', face_landmarks) mp_drawing.draw_landmarks( - image=image, + image=annotated_image, landmark_list=face_landmarks, connections=mp_face_mesh.FACE_CONNECTIONS, landmark_drawing_spec=drawing_spec, connection_drawing_spec=drawing_spec) - cv2.imshow('MediaPipe FaceMesh', image) - if cv2.waitKey(5) & 0xFF == 27: - break -face_mesh.close() + cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) + +# For webcam input: +drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1) +cap = cv2.VideoCapture(0) +with mp_face_mesh.FaceMesh( + min_detection_confidence=0.5, + min_tracking_confidence=0.5) as face_mesh: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue + + # Flip the image horizontally for a later selfie-view display, and convert + # the BGR image to RGB. + image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = face_mesh.process(image) + + # Draw the face mesh annotations on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + if results.multi_face_landmarks: + for face_landmarks in results.multi_face_landmarks: + mp_drawing.draw_landmarks( + image=image, + landmark_list=face_landmarks, + connections=mp_face_mesh.FACE_CONNECTIONS, + landmark_drawing_spec=drawing_spec, + connection_drawing_spec=drawing_spec) + cv2.imshow('MediaPipe FaceMesh', image) + if cv2.waitKey(5) & 0xFF == 27: + break cap.release() ``` ### JavaScript Solution API Please first see general [introduction](../getting_started/javascript.md) on -MediaPipe in JavaScript, then learn more in the companion [web demo] and the -following usage example. +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and the following usage example. Supported configuration options: @@ -503,7 +502,5 @@ only works for a single face. For visual reference, please refer to *Fig. 4*. [OBJ](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/data/canonical_face_model.obj), [UV visualization](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/data/canonical_face_model_uv_visualization.png) * [Models and model cards](./models.md#face_mesh) - -[Colab]:https://mediapipe.page.link/face_mesh_py_colab - -[web demo]:https://code.mediapipe.dev/codepen/face_mesh +* [Web demo](https://code.mediapipe.dev/codepen/face_mesh) +* [Python Colab](https://mediapipe.page.link/face_mesh_py_colab) diff --git a/docs/solutions/hands.md b/docs/solutions/hands.md index 3d07411c2..ac10124f2 100644 --- a/docs/solutions/hands.md +++ b/docs/solutions/hands.md @@ -91,13 +91,14 @@ To detect initial hand locations, we designed a mobile real-time uses in a manner similar to the face detection model in [MediaPipe Face Mesh](./face_mesh.md). Detecting hands is a decidedly complex task: our -[model](https://github.com/google/mediapipe/tree/master/mediapipe/models/palm_detection.tflite) has -to work across a variety of hand sizes with a large scale span (~20x) relative -to the image frame and be able to detect occluded and self-occluded hands. -Whereas faces have high contrast patterns, e.g., in the eye and mouth region, -the lack of such features in hands makes it comparatively difficult to detect -them reliably from their visual features alone. Instead, providing additional -context, like arm, body, or person features, aids accurate hand localization. +[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/palm_detection/palm_detection.tflite) +has to work across a variety of hand sizes with a large scale span (~20x) +relative to the image frame and be able to detect occluded and self-occluded +hands. Whereas faces have high contrast patterns, e.g., in the eye and mouth +region, the lack of such features in hands makes it comparatively difficult to +detect them reliably from their visual features alone. Instead, providing +additional context, like arm, body, or person features, aids accurate hand +localization. Our method addresses the above challenges using different strategies. First, we train a palm detector instead of a hand detector, since estimating bounding @@ -119,7 +120,7 @@ just 86.22%. ### Hand Landmark Model After the palm detection over the whole image our subsequent hand landmark -[model](https://github.com/google/mediapipe/tree/master/mediapipe/models/hand_landmark.tflite) +[model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite) performs precise keypoint localization of 21 3D hand-knuckle coordinates inside the detected hand regions via regression, that is direct coordinate prediction. The model learns a consistent internal hand pose representation and is robust @@ -136,11 +137,9 @@ to the corresponding 3D coordinates. :--------------------------------------------------------: | *Fig 2. 21 hand landmarks.* | -| ![hand_crops.png](../images/mobile/hand_crops.png) | -| :-------------------------------------------------------------------------: | -| *Fig 3. Top: Aligned hand crops passed to the tracking network with ground | -: truth annotation. Bottom\: Rendered synthetic hand images with ground truth : -: annotation.* : +![hand_crops.png](../images/mobile/hand_crops.png) | +:-------------------------------------------------------------------------: | +*Fig 3. Top: Aligned hand crops passed to the tracking network with ground truth annotation. Bottom: Rendered synthetic hand images with ground truth annotation.* | ## Solution APIs @@ -206,8 +205,8 @@ is not the case, please swap the handedness output in the application. ### Python Solution API Please first follow general [instructions](../getting_started/python.md) to -install MediaPipe Python package, then learn more in the companion [Colab] and -the following usage example. +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. Supported configuration options: @@ -223,74 +222,73 @@ mp_drawing = mp.solutions.drawing_utils mp_hands = mp.solutions.hands # For static images: -hands = mp_hands.Hands( +with mp_hands.Hands( static_image_mode=True, max_num_hands=2, - min_detection_confidence=0.5) -for idx, file in enumerate(file_list): - # Read an image, flip it around y-axis for correct handedness output (see - # above). - image = cv2.flip(cv2.imread(file), 1) - # Convert the BGR image to RGB before processing. - results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + min_detection_confidence=0.5) as hands: + for idx, file in enumerate(file_list): + # Read an image, flip it around y-axis for correct handedness output (see + # above). + image = cv2.flip(cv2.imread(file), 1) + # Convert the BGR image to RGB before processing. + results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - # Print handedness and draw hand landmarks on the image. - print('Handedness:', results.multi_handedness) - if not results.multi_hand_landmarks: - continue - image_hight, image_width, _ = image.shape - annotated_image = image.copy() - for hand_landmarks in results.multi_hand_landmarks: - print('hand_landmarks:', hand_landmarks) - print( - f'Index finger tip coordinates: (', - f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].x * image_width}, ' - f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].y * image_hight})' - ) - mp_drawing.draw_landmarks( - annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS) - cv2.imwrite( - '/tmp/annotated_image' + str(idx) + '.png', cv2.flip(annotated_image, 1)) -hands.close() + # Print handedness and draw hand landmarks on the image. + print('Handedness:', results.multi_handedness) + if not results.multi_hand_landmarks: + continue + image_height, image_width, _ = image.shape + annotated_image = image.copy() + for hand_landmarks in results.multi_hand_landmarks: + print('hand_landmarks:', hand_landmarks) + print( + f'Index finger tip coordinates: (', + f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].x * image_width}, ' + f'{hand_landmarks.landmark[mp_hands.HandLandmark.INDEX_FINGER_TIP].y * image_height})' + ) + mp_drawing.draw_landmarks( + annotated_image, hand_landmarks, mp_hands.HAND_CONNECTIONS) + cv2.imwrite( + '/tmp/annotated_image' + str(idx) + '.png', cv2.flip(annotated_image, 1)) # For webcam input: -hands = mp_hands.Hands( - min_detection_confidence=0.5, min_tracking_confidence=0.5) cap = cv2.VideoCapture(0) -while cap.isOpened(): - success, image = cap.read() - if not success: - print("Ignoring empty camera frame.") - # If loading a video, use 'break' instead of 'continue'. - continue +with mp_hands.Hands( + min_detection_confidence=0.5, + min_tracking_confidence=0.5) as hands: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) - # To improve performance, optionally mark the image as not writeable to - # pass by reference. - image.flags.writeable = False - results = hands.process(image) + # Flip the image horizontally for a later selfie-view display, and convert + # the BGR image to RGB. + image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = hands.process(image) - # Draw the hand annotations on the image. - image.flags.writeable = True - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - if results.multi_hand_landmarks: - for hand_landmarks in results.multi_hand_landmarks: - mp_drawing.draw_landmarks( - image, hand_landmarks, mp_hands.HAND_CONNECTIONS) - cv2.imshow('MediaPipe Hands', image) - if cv2.waitKey(5) & 0xFF == 27: - break -hands.close() + # Draw the hand annotations on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + if results.multi_hand_landmarks: + for hand_landmarks in results.multi_hand_landmarks: + mp_drawing.draw_landmarks( + image, hand_landmarks, mp_hands.HAND_CONNECTIONS) + cv2.imshow('MediaPipe Hands', image) + if cv2.waitKey(5) & 0xFF == 27: + break cap.release() ``` ### JavaScript Solution API Please first see general [introduction](../getting_started/javascript.md) on -MediaPipe in JavaScript, then learn more in the companion [web demo] and a -[fun application], and the following usage example. +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and a [fun application], and the following usage example. Supported configuration options: @@ -425,8 +423,6 @@ it, in the graph file modify the option of `ConstantSidePacketCalculator`. [MediaPipe Hands: On-device Real-time Hand Tracking](https://arxiv.org/abs/2006.10214) ([presentation](https://www.youtube.com/watch?v=I-UOrvxxXEk)) * [Models and model cards](./models.md#hands) - -[Colab]:https://mediapipe.page.link/hands_py_colab - -[web demo]:https://code.mediapipe.dev/codepen/hands -[fun application]:https://code.mediapipe.dev/codepen/defrost +* [Web demo](https://code.mediapipe.dev/codepen/hands) +* [Fun application](https://code.mediapipe.dev/codepen/defrost) +* [Python Colab](https://mediapipe.page.link/hands_py_colab) diff --git a/docs/solutions/holistic.md b/docs/solutions/holistic.md index c2de9185a..8ee0f8ff6 100644 --- a/docs/solutions/holistic.md +++ b/docs/solutions/holistic.md @@ -201,8 +201,8 @@ A list of 21 hand landmarks on the right hand, in the same representation as ### Python Solution API Please first follow general [instructions](../getting_started/python.md) to -install MediaPipe Python package, then learn more in the companion [Colab] and -the following usage example. +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. Supported configuration options: @@ -219,74 +219,75 @@ mp_drawing = mp.solutions.drawing_utils mp_holistic = mp.solutions.holistic # For static images: -holistic = mp_holistic.Holistic(static_image_mode=True) -for idx, file in enumerate(file_list): - image = cv2.imread(file) - image_hight, image_width, _ = image.shape - # Convert the BGR image to RGB before processing. - results = holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) +with mp_holistic.Holistic(static_image_mode=True) as holistic: + for idx, file in enumerate(file_list): + image = cv2.imread(file) + image_height, image_width, _ = image.shape + # Convert the BGR image to RGB before processing. + results = holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - if results.pose_landmarks: - print( - f'Nose coordinates: (' - f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' - f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_hight})' - ) - # Draw pose, left and right hands, and face landmarks on the image. - annotated_image = image.copy() - mp_drawing.draw_landmarks( - annotated_image, results.face_landmarks, mp_holistic.FACE_CONNECTIONS) - mp_drawing.draw_landmarks( - annotated_image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - annotated_image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - annotated_image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) - cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) -holistic.close() + if results.pose_landmarks: + print( + f'Nose coordinates: (' + f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' + f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_height})' + ) + # Draw pose, left and right hands, and face landmarks on the image. + annotated_image = image.copy() + mp_drawing.draw_landmarks( + annotated_image, results.face_landmarks, mp_holistic.FACE_CONNECTIONS) + mp_drawing.draw_landmarks( + annotated_image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) + mp_drawing.draw_landmarks( + annotated_image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) + # Use mp_holistic.UPPER_BODY_POSE_CONNECTIONS for drawing below when + # upper_body_only is set to True. + mp_drawing.draw_landmarks( + annotated_image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) + cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) # For webcam input: -holistic = mp_holistic.Holistic( - min_detection_confidence=0.5, min_tracking_confidence=0.5) cap = cv2.VideoCapture(0) -while cap.isOpened(): - success, image = cap.read() - if not success: - print("Ignoring empty camera frame.") - # If loading a video, use 'break' instead of 'continue'. - continue +with mp_holistic.Holistic( + min_detection_confidence=0.5, + min_tracking_confidence=0.5) as holistic: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) - # To improve performance, optionally mark the image as not writeable to - # pass by reference. - image.flags.writeable = False - results = holistic.process(image) + # Flip the image horizontally for a later selfie-view display, and convert + # the BGR image to RGB. + image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = holistic.process(image) - # Draw landmark annotation on the image. - image.flags.writeable = True - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - mp_drawing.draw_landmarks( - image, results.face_landmarks, mp_holistic.FACE_CONNECTIONS) - mp_drawing.draw_landmarks( - image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) - mp_drawing.draw_landmarks( - image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) - cv2.imshow('MediaPipe Holistic', image) - if cv2.waitKey(5) & 0xFF == 27: - break -holistic.close() + # Draw landmark annotation on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + mp_drawing.draw_landmarks( + image, results.face_landmarks, mp_holistic.FACE_CONNECTIONS) + mp_drawing.draw_landmarks( + image, results.left_hand_landmarks, mp_holistic.HAND_CONNECTIONS) + mp_drawing.draw_landmarks( + image, results.right_hand_landmarks, mp_holistic.HAND_CONNECTIONS) + mp_drawing.draw_landmarks( + image, results.pose_landmarks, mp_holistic.POSE_CONNECTIONS) + cv2.imshow('MediaPipe Holistic', image) + if cv2.waitKey(5) & 0xFF == 27: + break cap.release() ``` ### JavaScript Solution API Please first see general [introduction](../getting_started/javascript.md) on -MediaPipe in JavaScript, then learn more in the companion [web demo] and the -following usage example. +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and the following usage example. Supported configuration options: @@ -407,7 +408,5 @@ on how to build MediaPipe examples. * Google AI Blog: [MediaPipe Holistic - Simultaneous Face, Hand and Pose Prediction, on Device](https://ai.googleblog.com/2020/12/mediapipe-holistic-simultaneous-face.html) * [Models and model cards](./models.md#holistic) - -[Colab]:https://mediapipe.page.link/holistic_py_colab - -[web demo]:https://code.mediapipe.dev/codepen/holistic +* [Web demo](https://code.mediapipe.dev/codepen/holistic) +* [Python Colab](https://mediapipe.page.link/holistic_py_colab) diff --git a/docs/solutions/instant_motion_tracking.md b/docs/solutions/instant_motion_tracking.md index 720fe80f6..36e5e83e0 100644 --- a/docs/solutions/instant_motion_tracking.md +++ b/docs/solutions/instant_motion_tracking.md @@ -117,6 +117,25 @@ Please first see general instructions for * Android target (or download prebuilt [ARM64 APK](https://drive.google.com/file/d/1KnaBBoKpCHR73nOBJ4fL_YdWVTAcwe6L/view?usp=sharing)): [`mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking:instantmotiontracking`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/instantmotiontracking/BUILD) +* Assets rendered by the [GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc) must be preprocessed into an OpenGL-ready custom .uuu format. This can be done +for user assets as follows: +> First run +> +> ```shell +> ./mediapipe/graphs/object_detection_3d/obj_parser/obj_cleanup.sh [INPUT_DIR] [INTERMEDIATE_OUTPUT_DIR] +> ``` +> and then run +> +> ```build +> bazel run -c opt mediapipe/graphs/object_detection_3d/obj_parser:ObjParser -- input_dir=[INTERMEDIATE_OUTPUT_DIR] output_dir=[OUTPUT_DIR] +> ``` +> INPUT_DIR should be the folder with initial asset .obj files to be processed, +> and OUTPUT_DIR is the folder where the processed asset .uuu file will be placed. +> +> Note: ObjParser combines all .obj files found in the given directory into a +> single .uuu animation file, using the order given by sorting the filenames alphanumerically. Also the ObjParser directory inputs must be given as +> absolute paths, not relative paths. See parser utility library at [`mediapipe/graphs/object_detection_3d/obj_parser/`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/obj_parser/) for more details. + ## Resources * Google Developers Blog: diff --git a/docs/solutions/models.md b/docs/solutions/models.md index 4bc3d849e..b0f1fad7a 100644 --- a/docs/solutions/models.md +++ b/docs/solutions/models.md @@ -41,8 +41,9 @@ nav_order: 30 [TF.js model](https://tfhub.dev/mediapipe/handdetector/1) * Hand landmark model: [TFLite model](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark.tflite), + [TFLite model (sparse)](https://github.com/google/mediapipe/tree/master/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite), [TF.js model](https://tfhub.dev/mediapipe/handskeleton/1) -* [Model card](https://mediapipe.page.link/handmc) +* [Model card](https://mediapipe.page.link/handmc), [Model card (sparse)](https://mediapipe.page.link/handmc-sparse) ### [Pose](https://google.github.io/mediapipe/solutions/pose) @@ -73,12 +74,12 @@ nav_order: 30 ### [Objectron](https://google.github.io/mediapipe/solutions/objectron) -* [TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_sneakers.tflite) -* [TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_chair.tflite) -* [TFLite model for cameras](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_camera.tflite) -* [TFLite model for cups](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_cup.tflite) -* [Single-stage TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_sneakers_1stage.tflite) -* [Single-stage TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_3d_chair_1stage.tflite) +* [TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_sneakers.tflite) +* [TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_chair.tflite) +* [TFLite model for cameras](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_camera.tflite) +* [TFLite model for cups](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_cup.tflite) +* [Single-stage TFLite model for shoes](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_sneakers_1stage.tflite) +* [Single-stage TFLite model for chairs](https://github.com/google/mediapipe/tree/master/mediapipe/modules/objectron/object_detection_3d_chair_1stage.tflite) * [Model card](https://mediapipe.page.link/objectron-mc) ### [KNIFT](https://google.github.io/mediapipe/solutions/knift) diff --git a/docs/solutions/objectron.md b/docs/solutions/objectron.md index c65b72bae..c689f9c40 100644 --- a/docs/solutions/objectron.md +++ b/docs/solutions/objectron.md @@ -186,6 +186,175 @@ trained our 3D object detection models. The technical details of the Objectron dataset, including usage and tutorials, are available on the [dataset website](https://github.com/google-research-datasets/Objectron/). +## Solution APIs + +### Cross-platform Configuration Options + +Naming style and availability may differ slightly across platforms/languages. + +#### static_image_mode + +If set to `false`, the solution treats the input images as a video stream. It +will try to detect objects in the very first images, and upon successful +detection further localizes the 3D bounding box landmarks. In subsequent images, +once all [max_num_objects](#max_num_objects) objects are detected and the +corresponding 3D bounding box landmarks are localized, it simply tracks those +landmarks without invoking another detection until it loses track of any of the +objects. This reduces latency and is ideal for processing video frames. If set +to `true`, object detection runs every input image, ideal for processing a batch +of static, possibly unrelated, images. Default to `false`. + +#### max_num_objects + +Maximum number of objects to detect. Default to `5`. + +#### min_detection_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the object-detection model for the +detection to be considered successful. Default to `0.5`. + +#### min_tracking_confidence + +Minimum confidence value (`[0.0, 1.0]`) from the landmark-tracking model for the +3D bounding box landmarks to be considered tracked successfully, or otherwise +object detection will be invoked automatically on the next input image. Setting +it to a higher value can increase robustness of the solution, at the expense of +a higher latency. Ignored if [static_image_mode](#static_image_mode) is `true`, +where object detection simply runs on every image. Default to `0.99`. + +#### model_name + +Name of the model to use for predicting 3D bounding box landmarks. Currently supports +`{'Shoe', 'Chair', 'Cup', 'Camera'}`. + +#### focal_length + +Camera focal length `(fx, fy)`, by default is defined in +[NDC space](#ndc-space). To use focal length `(fx_pixel, fy_pixel)` in +[pixel space](#pixel-space), users should provide `image_size` = `(image_width, +image_height)` to enable conversions inside the API. For further details about +NDC and pixel space, please see [Coordinate Systems](#coordinate-systems). + +#### principal_point + +Camera principal point `(px, py)`, by default is defined in +[NDC space](#ndc-space). To use principal point `(px_pixel, py_pixel)` in +[pixel space](#pixel-space), users should provide `image_size` = `(image_width, +image_height)` to enable conversions inside the API. For further details about +NDC and pixel space, please see [Coordinate Systems](#coordinate-systems). + +#### image_size + +(**Optional**) size `(image_width, image_height)` of the input image, **ONLY** +needed when use `focal_length` and `principal_point` in pixel space. + +### Output + + + +#### detected_objects + +A list of detected 3D bounding box. Each 3D bounding box consists of the +following: + +* `landmarks_2d` : 2D landmarks of the object's 3D bounding box. The landmark + coordinates are normalized to `[0.0, 1.0]` by the image width and height + respectively. + +* `landmarks_3d` : 3D landmarks of the object's 3D bounding box. The landmark + coordinates are represented in [camera coordinate](#camera-coordinate) + frame. + +* `rotation` : rotation matrix from object coordinate frame to camera + coordinate frame. + +* `translation` : translation vector from object coordinate frame to camera + coordinate frame. + +* `scale` : relative scale of the object along `x`, `y` and `z` directions. + +## Python Solution API + +Please first follow general [instructions](../getting_started/python.md) to +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. + +Supported configuration options: + +* [static_image_mode](#static_image_mode) +* [max_num_objects](#max_num_objects) +* [min_detection_confidence](#min_detection_confidence) +* [min_tracking_confidence](#min_tracking_confidence) +* [model_name](#model_name) +* [focal_length](#focal_length) +* [principal_point](#principal_point) +* [image_size](#image_size) + +```python +import cv2 +import mediapipe as mp +mp_drawing = mp.solutions.drawing_utils +mp_objectron = mp.solutions.objectron + +# For static images: +with mp_objectron.Objectron(static_image_mode=True, + max_num_objects=5, + min_detection_confidence=0.5, + model_name='Shoe') as objectron: + for idx, file in enumerate(file_list): + image = cv2.imread(file) + # Convert the BGR image to RGB and process it with MediaPipe Objectron. + results = objectron.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # Draw box landmarks. + if not results.detected_objects: + print(f'No box landmarks detected on {file}') + continue + print(f'Box landmarks of {file}:') + annotated_image = image.copy() + for detected_object in results.detected_objects: + mp_drawing.draw_landmarks( + annotated_image, detected_object.landmarks_2d, mp_objectron.BOX_CONNECTIONS) + mp_drawing.draw_axis(annotated_image, detected_object.rotation, + detected_object.translation) + cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) + +# For webcam input: +cap = cv2.VideoCapture(0) +with mp_objectron.Objectron(static_image_mode=False, + max_num_objects=5, + min_detection_confidence=0.5, + min_tracking_confidence=0.99, + model_name='Shoe') as objectron: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue + + # Convert the BGR image to RGB. + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = objectron.process(image) + + # Draw the box landmarks on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + if results.detected_objects: + for detected_object in results.detected_objects: + mp_drawing.draw_landmarks( + image, detected_object.landmarks_2d, mp_objectron.BOX_CONNECTIONS) + mp_drawing.draw_axis(image, detected_object.rotation, + detected_object.translation) + cv2.imshow('MediaPipe Objectron', image) + if cv2.waitKey(5) & 0xFF == 27: + break +cap.release() +``` + ## Example Apps Please first see general instructions for @@ -259,6 +428,104 @@ to visualize its associated subgraphs, please see * iOS target: Not available +### Assets + +Example app bounding boxes are rendered with [GlAnimationOverlayCalculator](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc) using a parsing of the sequenced .obj file + format into a custom .uuu format. This can be done for user assets as follows: +> First run +> +> ```shell +> ./mediapipe/graphs/object_detection_3d/obj_parser/obj_cleanup.sh [INPUT_DIR] [INTERMEDIATE_OUTPUT_DIR] +> ``` +> and then run +> +> ```build +> bazel run -c opt mediapipe/graphs/object_detection_3d/obj_parser:ObjParser -- input_dir=[INTERMEDIATE_OUTPUT_DIR] output_dir=[OUTPUT_DIR] +> ``` +> INPUT_DIR should be the folder with initial asset .obj files to be processed, +> and OUTPUT_DIR is the folder where the processed asset .uuu file will be placed. +> +> Note: ObjParser combines all .obj files found in the given directory into a +> single .uuu animation file, using the order given by sorting the filenames alphanumerically. Also the ObjParser directory inputs must be given as +> absolute paths, not relative paths. See parser utility library at [`mediapipe/graphs/object_detection_3d/obj_parser/`](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection_3d/obj_parser/) for more details. + +### Coordinate Systems + +#### Object Coordinate + +Each object has its object coordinate frame. We use the below object coordinate +definition, with `+x` pointing right, `+y` pointing up and `+z` pointing front, +origin is at the center of the 3D bounding box. + +![box_coordinate.svg](../images/box_coordinate.svg) + +#### Camera Coordinate + +A 3D object is parameterized by its `scale` and `rotation`, `translation` with +regard to the camera coordinate frame. In this API we use the below camera +coordinate definition, with `+x` pointing right, `+y` pointing up and `-z` +pointing to the scene. + +![camera_coordinate.svg](../images/camera_coordinate.svg) + +To work with box landmarks, one can first derive landmark coordinates in object +frame by scaling a origin centered unit box with `scale`, then transform to +camera frame by applying `rotation` and `translation`: + +``` +landmarks_3d = rotation * scale * unit_box + translation +``` + +#### NDC Space + +In this API we use +[NDC(normalized device coordinates)](http://www.songho.ca/opengl/gl_projectionmatrix.html) +as an intermediate space when projecting points from 3D to 2D. In NDC space, +`x`, `y` are confined to `[-1, 1]`. + +![ndc_coordinate.svg](../images/ndc_coordinate.svg) + +By default the camera parameters `(fx, fy)` and `(px, py)` are defined in NDC +space. Given `(X, Y, Z)` of 3D points in camera coordinate, one can project 3D +points to NDC space as follows: + +``` +x_ndc = -fx * X / Z + px +y_ndc = -fy * Y / Z + py +z_ndc = 1 / Z +``` + +#### Pixel Space + +In this API we set upper-left coner of an image as the origin of pixel +coordinate. One can convert from NDC to pixel space as follows: + +``` +x_pixel = (1 + x_ndc) / 2.0 * image_width +y_pixel = (1 - y_ndc) / 2.0 * image_height +``` + +Alternatively one can directly project from camera coordinate to pixel +coordinate with camera parameters `(fx_pixel, fy_pixel)` and `(px_pixel, +py_pixel)` defined in pixel space as follows: + +``` +x_pixel = -fx_pixel * X / Z + px_pixel +y_pixel = fy_pixel * Y / Z + py_pixel +``` + +Conversion of camera parameters from pixel space to NDC space: + +``` +fx = fx_pixel * 2.0 / image_width +fy = fy_pixel * 2.0 / image_height +``` + +``` +px = -px_pixel * 2.0 / image_width + 1.0 +py = -py_pixel * 2.0 / image_height + 1.0 +``` + ## Resources * Google AI Blog: @@ -271,3 +538,4 @@ to visualize its associated subgraphs, please see [Instant 3D Object Tracking with Applications in Augmented Reality](https://drive.google.com/open?id=1O_zHmlgXIzAdKljp20U_JUkEHOGG52R8) ([presentation](https://www.youtube.com/watch?v=9ndF1AIo7h0)) * [Models and model cards](./models.md#objectron) +* [Python Colab](https://mediapipe.page.link/objectron_py_colab) diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index 0130a5f46..9190484e7 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -21,13 +21,15 @@ nav_order: 5 ## Overview Human pose estimation from video plays a critical role in various applications -such as quantifying physical exercises, sign language recognition, and full-body -gesture control. For example, it can form the basis for yoga, dance, and fitness -applications. It can also enable the overlay of digital content and information -on top of the physical world in augmented reality. +such as +[quantifying physical exercises](#pose-classification-and-repetition-counting), +sign language recognition, and full-body gesture control. For example, it can +form the basis for yoga, dance, and fitness applications. It can also enable the +overlay of digital content and information on top of the physical world in +augmented reality. MediaPipe Pose is a ML solution for high-fidelity body pose tracking, inferring -33 2D landmarks on the whole body (or 25 upper-body landmarks) from RGB video +33 3D landmarks on the whole body (or 25 upper-body landmarks) from RGB video frames utilizing our [BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) research that also powers the @@ -35,7 +37,7 @@ research that also powers the Current state-of-the-art approaches rely primarily on powerful desktop environments for inference, whereas our method achieves real-time performance on most modern [mobile phones](#mobile), [desktops/laptops](#desktop), in -[python](#python) and even on the [web](#web). +[python](#python-solution-api) and even on the [web](#javascript-solution-api). ![pose_tracking_upper_body_example.gif](../images/mobile/pose_tracking_upper_body_example.gif) | :--------------------------------------------------------------------------------------------: | @@ -92,7 +94,7 @@ hip midpoints. :----------------------------------------------------------------------------------------------------: | *Fig 2. Vitruvian man aligned via two virtual keypoints predicted by BlazePose detector in addition to the face bounding box.* | -### Pose Landmark Model (BlazePose Tracker) +### Pose Landmark Model (BlazePose GHUM 3D) The landmark model in MediaPipe Pose comes in two versions: a full-body model that predicts the location of 33 pose landmarks (see figure below), and an @@ -163,16 +165,21 @@ A list of pose landmarks. Each lanmark consists of the following: * `x` and `y`: Landmark coordinates normalized to `[0.0, 1.0]` by the image width and height respectively. -* `z`: Should be discarded as currently the model is not fully trained to - predict depth, but this is something on the roadmap. +* `z`: Represents the landmark depth with the depth at the midpoint of hips + being the origin, and the smaller the value the closer the landmark is to + the camera. The magnitude of `z` uses roughly the same scale as `x`. + + Note: `z` is predicted only in full-body mode, and should be discarded when + [upper_body_only](#upper_body_only) is `true`. + * `visibility`: A value in `[0.0, 1.0]` indicating the likelihood of the landmark being visible (present and not occluded) in the image. ### Python Solution API Please first follow general [instructions](../getting_started/python.md) to -install MediaPipe Python package, then learn more in the companion [Colab] and -the following usage example. +install MediaPipe Python package, then learn more in the companion +[Python Colab](#resources) and the following usage example. Supported configuration options: @@ -189,64 +196,65 @@ mp_drawing = mp.solutions.drawing_utils mp_pose = mp.solutions.pose # For static images: -pose = mp_pose.Pose( - static_image_mode=True, min_detection_confidence=0.5) -for idx, file in enumerate(file_list): - image = cv2.imread(file) - image_hight, image_width, _ = image.shape - # Convert the BGR image to RGB before processing. - results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) +with mp_pose.Pose( + static_image_mode=True, min_detection_confidence=0.5) as pose: + for idx, file in enumerate(file_list): + image = cv2.imread(file) + image_height, image_width, _ = image.shape + # Convert the BGR image to RGB before processing. + results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - if not results.pose_landmarks: - continue - print( - f'Nose coordinates: (' - f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' - f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_hight})' - ) - # Draw pose landmarks on the image. - annotated_image = image.copy() - mp_drawing.draw_landmarks( - annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) - cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) -pose.close() + if not results.pose_landmarks: + continue + print( + f'Nose coordinates: (' + f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].x * image_width}, ' + f'{results.pose_landmarks.landmark[mp_holistic.PoseLandmark.NOSE].y * image_height})' + ) + # Draw pose landmarks on the image. + annotated_image = image.copy() + # Use mp_pose.UPPER_BODY_POSE_CONNECTIONS for drawing below when + # upper_body_only is set to True. + mp_drawing.draw_landmarks( + annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) + cv2.imwrite('/tmp/annotated_image' + str(idx) + '.png', annotated_image) # For webcam input: -pose = mp_pose.Pose( - min_detection_confidence=0.5, min_tracking_confidence=0.5) cap = cv2.VideoCapture(0) -while cap.isOpened(): - success, image = cap.read() - if not success: - print("Ignoring empty camera frame.") - # If loading a video, use 'break' instead of 'continue'. - continue +with mp_pose.Pose( + min_detection_confidence=0.5, + min_tracking_confidence=0.5) as pose: + while cap.isOpened(): + success, image = cap.read() + if not success: + print("Ignoring empty camera frame.") + # If loading a video, use 'break' instead of 'continue'. + continue - # Flip the image horizontally for a later selfie-view display, and convert - # the BGR image to RGB. - image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) - # To improve performance, optionally mark the image as not writeable to - # pass by reference. - image.flags.writeable = False - results = pose.process(image) + # Flip the image horizontally for a later selfie-view display, and convert + # the BGR image to RGB. + image = cv2.cvtColor(cv2.flip(image, 1), cv2.COLOR_BGR2RGB) + # To improve performance, optionally mark the image as not writeable to + # pass by reference. + image.flags.writeable = False + results = pose.process(image) - # Draw the pose annotation on the image. - image.flags.writeable = True - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - mp_drawing.draw_landmarks( - image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) - cv2.imshow('MediaPipe Pose', image) - if cv2.waitKey(5) & 0xFF == 27: - break -pose.close() + # Draw the pose annotation on the image. + image.flags.writeable = True + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + mp_drawing.draw_landmarks( + image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) + cv2.imshow('MediaPipe Pose', image) + if cv2.waitKey(5) & 0xFF == 27: + break cap.release() ``` ### JavaScript Solution API Please first see general [introduction](../getting_started/javascript.md) on -MediaPipe in JavaScript, then learn more in the companion [web demo] and the -following usage example. +MediaPipe in JavaScript, then learn more in the companion [web demo](#resources) +and the following usage example. Supported configuration options: @@ -379,6 +387,121 @@ on how to build MediaPipe examples. * Target: [`mediapipe/examples/desktop/upper_body_pose_tracking:upper_body_pose_tracking_gpu`](https://github.com/google/mediapipe/tree/master/mediapipe/examples/desktop/upper_body_pose_tracking/BUILD) +## Pose Classification and Repetition Counting + +One of the applications +[BlazePose](https://ai.googleblog.com/2020/08/on-device-real-time-body-pose-tracking.html) +can enable is fitness. More specifically - pose classification and repetition +counting. In this section we'll provide basic guidance on building a custom pose +classifier with the help of a +[Colab](https://drive.google.com/file/d/19txHpN8exWhstO6WVkfmYYVC6uug_oVR/view?usp=sharing) +and wrap it in a simple +[fitness app](https://mediapipe.page.link/mlkit-pose-classification-demo-app) +powered by [ML Kit](https://developers.google.com/ml-kit). Push-ups and squats +are used for demonstration purposes as the most common exercises. + +![pose_classification_pushups_and_squats.gif](../images/mobile/pose_classification_pushups_and_squats.gif) | +:--------------------------------------------------------------------------------------------------------: | +*Fig 4. Pose classification and repetition counting with MediaPipe Pose.* | + +We picked the +[k-nearest neighbors algorithm](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm) +(k-NN) as the classifier. It's simple and easy to start with. The algorithm +determines the object's class based on the closest samples in the training set. +To build it, one needs to: + +* Collect image samples of the target exercises and run pose prediction on + them, +* Convert obtained pose landmarks to a representation suitable for the k-NN + classifier and form a training set, +* Perform the classification itself followed by repetition counting. + +### Training Set + +To build a good classifier appropriate samples should be collected for the +training set: about a few hundred samples for each terminal state of each +exercise (e.g., "up" and "down" positions for push-ups). It's important that +collected samples cover different camera angles, environment conditions, body +shapes, and exercise variations. + +![pose_classification_pushups_un_and_down_samples.jpg](../images/mobile/pose_classification_pushups_un_and_down_samples.jpg) | +:--------------------------------------------------------------------------------------------------------------------------: | +*Fig 5. Two terminal states of push-ups.* | + +To transform samples into a k-NN classifier training set, either +[basic](https://drive.google.com/file/d/1z4IM8kG6ipHN6keadjD-F6vMiIIgViKK/view?usp=sharing) +or +[extended](https://drive.google.com/file/d/19txHpN8exWhstO6WVkfmYYVC6uug_oVR/view?usp=sharing) +Colab could be used. They both use the +[Python Solution API](#python-solution-api) to run the BlazePose models on given +images and dump predicted pose landmarks to a CSV file. Additionally, the +extended Colab provides useful tools to find outliers (e.g., wrongly predicted +poses) and underrepresented classes (e.g., not covering all camera angles) by +classifying each sample against the entire training set. After that, you'll be +able to test the classifier on an arbitrary video right in the Colab. + +### Classification + +Code of the classifier is available both in the +[extended](https://drive.google.com/file/d/19txHpN8exWhstO6WVkfmYYVC6uug_oVR/view?usp=sharing) +Colab and in the +[ML Kit demo app](https://mediapipe.page.link/mlkit-pose-classification-demo-app). +Please refer to them for details of the approach described below. + +The k-NN algorithm used for pose classification requires a feature vector +representation of each sample and a metric to compute the distance between two +such vectors to find the nearest pose samples to a target one. + +To convert pose landmarks to a feature vector, we use pairwise distances between +predefined lists of pose joints, such as distances between wrist and shoulder, +ankle and hip, and two wrists. Since the algorithm relies on distances, all +poses are normalized to have the same torso size and vertical torso orientation +before the conversion. + +![pose_classification_pairwise_distances.png](../images/mobile/pose_classification_pairwise_distances.png) | +:--------------------------------------------------------------------------------------------------------: | +*Fig 6. Main pairwise distances used for the pose feature vector.* | + +To get a better classification result, k-NN search is invoked twice with +different distance metrics: + +* First, to filter out samples that are almost the same as the target one but + have only a few different values in the feature vector (which means + differently bent joints and thus other pose class), minimum per-coordinate + distance is used as distance metric, +* Then average per-coordinate distance is used to find the nearest pose + cluster among those from the first search. + +Finally, we apply +[exponential moving average](https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average) +(EMA) smoothing to level any noise from pose prediction or classification. To do +that, we search not only for the nearest pose cluster, but we calculate a +probability for each of them and use it for smoothing over time. + +### Repetition Counter + +To count the repetitions, the algorithm monitors the probability of a target +pose class. Let's take push-ups with its "up" and "down" terminal states: + +* When the probability of the "down" pose class passes a certain threshold for + the first time, the algorithm marks that the "down" pose class is entered. +* Once the probability drops below the threshold, the algorithm marks that the + "down" pose class has been exited and increases the counter. + +To avoid cases when the probability fluctuates around the threshold (e.g., when +the user pauses between "up" and "down" states) causing phantom counts, the +threshold used to detect when the state is exited is actually slightly lower +than the one used to detect when the state is entered. It creates an interval +where the pose class and the counter can't be changed. + +### Future Work + +We are actively working on improving BlazePose GHUM 3D's Z prediction. It will +allow us to use joint angles in the feature vectors, which are more natural and +easier to configure (although distances can still be useful to detect touches +between body parts) and to perform rotation normalization of poses and reduce +the number of camera angles required for accurate k-NN classification. + ## Resources * Google AI Blog: @@ -387,7 +510,7 @@ on how to build MediaPipe examples. [BlazePose: On-device Real-time Body Pose Tracking](https://arxiv.org/abs/2006.10204) ([presentation](https://youtu.be/YPpUOTRn5tA)) * [Models and model cards](./models.md#pose) - -[Colab]:https://mediapipe.page.link/pose_py_colab - -[web demo]:https://code.mediapipe.dev/codepen/pose +* [Web demo](https://code.mediapipe.dev/codepen/pose) +* [Python Colab](https://mediapipe.page.link/pose_py_colab) +* [Pose Classification Colab (Basic)](https://mediapipe.page.link/pose_classification_basic) +* [Pose Classification Colab (Extended)](https://mediapipe.page.link/pose_classification_extended) diff --git a/docs/solutions/solutions.md b/docs/solutions/solutions.md index a0dce94a0..c78dffea0 100644 --- a/docs/solutions/solutions.md +++ b/docs/solutions/solutions.md @@ -18,7 +18,7 @@ has_toc: false []() | [Android](https://google.github.io/mediapipe/getting_started/android) | [iOS](https://google.github.io/mediapipe/getting_started/ios) | [C++](https://google.github.io/mediapipe/getting_started/cpp) | [Python](https://google.github.io/mediapipe/getting_started/python) | [JS](https://google.github.io/mediapipe/getting_started/javascript) | [Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/README.md) :---------------------------------------------------------------------------------------- | :-------------------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------: | :-----------------------------------------------------------: | :-----------------------------------------------------------: | :--------------------------------------------------------------------: -[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | | | ✅ +[Face Detection](https://google.github.io/mediapipe/solutions/face_detection) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ [Face Mesh](https://google.github.io/mediapipe/solutions/face_mesh) | ✅ | ✅ | ✅ | ✅ | ✅ | [Iris](https://google.github.io/mediapipe/solutions/iris) | ✅ | ✅ | ✅ | | | [Hands](https://google.github.io/mediapipe/solutions/hands) | ✅ | ✅ | ✅ | ✅ | ✅ | @@ -28,7 +28,7 @@ has_toc: false [Object Detection](https://google.github.io/mediapipe/solutions/object_detection) | ✅ | ✅ | ✅ | | | ✅ [Box Tracking](https://google.github.io/mediapipe/solutions/box_tracking) | ✅ | ✅ | ✅ | | | [Instant Motion Tracking](https://google.github.io/mediapipe/solutions/instant_motion_tracking) | ✅ | | | | | -[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | | | +[Objectron](https://google.github.io/mediapipe/solutions/objectron) | ✅ | | | ✅ | | [KNIFT](https://google.github.io/mediapipe/solutions/knift) | ✅ | | | | | [AutoFlip](https://google.github.io/mediapipe/solutions/autoflip) | | | ✅ | | | [MediaSequence](https://google.github.io/mediapipe/solutions/media_sequence) | | | ✅ | | | diff --git a/docs/tools/visualizer.md b/docs/tools/visualizer.md index ecd4487a8..9324576a2 100644 --- a/docs/tools/visualizer.md +++ b/docs/tools/visualizer.md @@ -37,7 +37,7 @@ The graph can be modified by adding and editing code in the Editor view. ![New Button](../images/upload_button.png) * Pressing the "Upload" button will prompt the user to select a local PBTXT - file, which will everwrite the current code within the editor. + file, which will overwrite the current code within the editor. * Alternatively, code can be pasted directly into the editor window. diff --git a/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen b/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen index 862282f72..d3cd4971a 100644 --- a/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen +++ b/mediapipe/MediaPipe.tulsiproj/Configs/MediaPipe.tulsigen @@ -14,6 +14,7 @@ "mediapipe/examples/ios/iristrackinggpu/BUILD", "mediapipe/examples/ios/objectdetectioncpu/BUILD", "mediapipe/examples/ios/objectdetectiongpu/BUILD", + "mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD", "mediapipe/examples/ios/posetrackinggpu/BUILD", "mediapipe/examples/ios/upperbodyposetrackinggpu/BUILD", "mediapipe/framework/BUILD", @@ -33,6 +34,7 @@ "//mediapipe/examples/ios/iristrackinggpu:IrisTrackingGpuApp", "//mediapipe/examples/ios/objectdetectioncpu:ObjectDetectionCpuApp", "//mediapipe/examples/ios/objectdetectiongpu:ObjectDetectionGpuApp", + "//mediapipe/examples/ios/objectdetectiontrackinggpu:ObjectDetectionTrackingGpuApp", "//mediapipe/examples/ios/posetrackinggpu:PoseTrackingGpuApp", "//mediapipe/examples/ios/upperbodyposetrackinggpu:UpperBodyPoseTrackingGpuApp", "//mediapipe/objc:mediapipe_framework_ios" diff --git a/mediapipe/MediaPipe.tulsiproj/project.tulsiconf b/mediapipe/MediaPipe.tulsiproj/project.tulsiconf index 0b0f6569c..7303828ad 100644 --- a/mediapipe/MediaPipe.tulsiproj/project.tulsiconf +++ b/mediapipe/MediaPipe.tulsiproj/project.tulsiconf @@ -20,6 +20,7 @@ "mediapipe/examples/ios/iristrackinggpu", "mediapipe/examples/ios/objectdetectioncpu", "mediapipe/examples/ios/objectdetectiongpu", + "mediapipe/examples/ios/objectdetectiontrackinggpu", "mediapipe/examples/ios/posetrackinggpu", "mediapipe/examples/ios/upperbodyposetrackinggpu", "mediapipe/objc" diff --git a/mediapipe/calculators/audio/BUILD b/mediapipe/calculators/audio/BUILD index b32529b79..9667e11d5 100644 --- a/mediapipe/calculators/audio/BUILD +++ b/mediapipe/calculators/audio/BUILD @@ -1,4 +1,4 @@ -# Copyright 2019 The MediaPipe Authors. +# Copyright 2019, 2021 The MediaPipe Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -167,7 +167,7 @@ cc_library( "//mediapipe/util:time_series_util", "@com_google_absl//absl/strings", "@com_google_audio_tools//audio/dsp:resampler", - "@com_google_audio_tools//audio/dsp:resampler_rational_factor", + "@com_google_audio_tools//audio/dsp:resampler_q", "@eigen_archive//:eigen", ], alwayslink = 1, @@ -242,6 +242,7 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:time_series_header_cc_proto", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", ], diff --git a/mediapipe/calculators/audio/audio_decoder_calculator.cc b/mediapipe/calculators/audio/audio_decoder_calculator.cc index 1ff70eb23..49c201b37 100644 --- a/mediapipe/calculators/audio/audio_decoder_calculator.cc +++ b/mediapipe/calculators/audio/audio_decoder_calculator.cc @@ -48,17 +48,17 @@ namespace mediapipe { // TODO: support decoding multiple streams. class AudioDecoderCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: std::unique_ptr decoder_; }; -mediapipe::Status AudioDecoderCalculator::GetContract(CalculatorContract* cc) { +absl::Status AudioDecoderCalculator::GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set(); if (cc->InputSidePackets().HasTag("OPTIONS")) { cc->InputSidePackets().Tag("OPTIONS").Set(); @@ -67,10 +67,10 @@ mediapipe::Status AudioDecoderCalculator::GetContract(CalculatorContract* cc) { if (cc->Outputs().HasTag("AUDIO_HEADER")) { cc->Outputs().Tag("AUDIO_HEADER").SetNone(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AudioDecoderCalculator::Open(CalculatorContext* cc) { +absl::Status AudioDecoderCalculator::Open(CalculatorContext* cc) { const std::string& input_file_path = cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get(); const auto& decoder_options = @@ -87,10 +87,10 @@ mediapipe::Status AudioDecoderCalculator::Open(CalculatorContext* cc) { cc->Outputs().Tag("AUDIO_HEADER").SetHeader(Adopt(header.release())); } cc->Outputs().Tag("AUDIO_HEADER").Close(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AudioDecoderCalculator::Process(CalculatorContext* cc) { +absl::Status AudioDecoderCalculator::Process(CalculatorContext* cc) { Packet data; int options_index = -1; auto status = decoder_->GetData(&options_index, &data); @@ -100,7 +100,7 @@ mediapipe::Status AudioDecoderCalculator::Process(CalculatorContext* cc) { return status; } -mediapipe::Status AudioDecoderCalculator::Close(CalculatorContext* cc) { +absl::Status AudioDecoderCalculator::Close(CalculatorContext* cc) { return decoder_->Close(); } diff --git a/mediapipe/calculators/audio/audio_decoder_calculator_test.cc b/mediapipe/calculators/audio/audio_decoder_calculator_test.cc index be0cd1836..33ab9e04f 100644 --- a/mediapipe/calculators/audio/audio_decoder_calculator_test.cc +++ b/mediapipe/calculators/audio/audio_decoder_calculator_test.cc @@ -15,6 +15,7 @@ #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/time_series_header.pb.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" diff --git a/mediapipe/calculators/audio/basic_time_series_calculators.cc b/mediapipe/calculators/audio/basic_time_series_calculators.cc index f372e5a7c..f7b24f6f6 100644 --- a/mediapipe/calculators/audio/basic_time_series_calculators.cc +++ b/mediapipe/calculators/audio/basic_time_series_calculators.cc @@ -38,7 +38,7 @@ static bool SafeMultiply(int x, int y, int* result) { } } // namespace -mediapipe::Status BasicTimeSeriesCalculatorBase::GetContract( +absl::Status BasicTimeSeriesCalculatorBase::GetContract( CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Input stream with TimeSeriesHeader. @@ -46,10 +46,10 @@ mediapipe::Status BasicTimeSeriesCalculatorBase::GetContract( cc->Outputs().Index(0).Set( // Output stream with TimeSeriesHeader. ); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BasicTimeSeriesCalculatorBase::Open(CalculatorContext* cc) { +absl::Status BasicTimeSeriesCalculatorBase::Open(CalculatorContext* cc) { TimeSeriesHeader input_header; MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid( cc->Inputs().Index(0).Header(), &input_header)); @@ -57,11 +57,13 @@ mediapipe::Status BasicTimeSeriesCalculatorBase::Open(CalculatorContext* cc) { auto output_header = new TimeSeriesHeader(input_header); MP_RETURN_IF_ERROR(MutateHeader(output_header)); cc->Outputs().Index(0).SetHeader(Adopt(output_header)); - return mediapipe::OkStatus(); + + cc->SetOffset(0); + + return absl::OkStatus(); } -mediapipe::Status BasicTimeSeriesCalculatorBase::Process( - CalculatorContext* cc) { +absl::Status BasicTimeSeriesCalculatorBase::Process(CalculatorContext* cc) { const Matrix& input = cc->Inputs().Index(0).Get(); MP_RETURN_IF_ERROR(time_series_util::IsMatrixShapeConsistentWithHeader( input, cc->Inputs().Index(0).Header().Get())); @@ -71,12 +73,12 @@ mediapipe::Status BasicTimeSeriesCalculatorBase::Process( *output, cc->Outputs().Index(0).Header().Get())); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BasicTimeSeriesCalculatorBase::MutateHeader( +absl::Status BasicTimeSeriesCalculatorBase::MutateHeader( TimeSeriesHeader* output_header) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Calculator to sum an input time series across channels. This is @@ -86,9 +88,9 @@ mediapipe::Status BasicTimeSeriesCalculatorBase::MutateHeader( class SumTimeSeriesAcrossChannelsCalculator : public BasicTimeSeriesCalculatorBase { protected: - mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_channels(1); - return mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -104,9 +106,9 @@ REGISTER_CALCULATOR(SumTimeSeriesAcrossChannelsCalculator); class AverageTimeSeriesAcrossChannelsCalculator : public BasicTimeSeriesCalculatorBase { protected: - mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_channels(1); - return mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -122,7 +124,7 @@ REGISTER_CALCULATOR(AverageTimeSeriesAcrossChannelsCalculator); // Options proto: None. class SummarySaiToPitchogramCalculator : public BasicTimeSeriesCalculatorBase { protected: - mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { if (output_header->num_channels() != 1) { return tool::StatusInvalid( absl::StrCat("Expected single-channel input, got ", @@ -131,7 +133,7 @@ class SummarySaiToPitchogramCalculator : public BasicTimeSeriesCalculatorBase { output_header->set_num_channels(output_header->num_samples()); output_header->set_num_samples(1); output_header->set_sample_rate(output_header->packet_rate()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -160,7 +162,7 @@ REGISTER_CALCULATOR(ReverseChannelOrderCalculator); // Options proto: None. class FlattenPacketCalculator : public BasicTimeSeriesCalculatorBase { protected: - mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { const int num_input_channels = output_header->num_channels(); const int num_input_samples = output_header->num_samples(); RET_CHECK(num_input_channels >= 0) @@ -174,7 +176,7 @@ class FlattenPacketCalculator : public BasicTimeSeriesCalculatorBase { output_header->set_num_channels(output_num_channels); output_header->set_num_samples(1); output_header->set_sample_rate(output_header->packet_rate()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -253,10 +255,10 @@ REGISTER_CALCULATOR(DivideByMeanAcrossChannelsCalculator); // Options proto: None. class MeanCalculator : public BasicTimeSeriesCalculatorBase { protected: - mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_samples(1); output_header->set_sample_rate(output_header->packet_rate()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -272,10 +274,10 @@ REGISTER_CALCULATOR(MeanCalculator); // Options proto: None. class StandardDeviationCalculator : public BasicTimeSeriesCalculatorBase { protected: - mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_samples(1); output_header->set_sample_rate(output_header->packet_rate()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -293,9 +295,9 @@ REGISTER_CALCULATOR(StandardDeviationCalculator); // Options proto: None. class CovarianceCalculator : public BasicTimeSeriesCalculatorBase { protected: - mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_samples(output_header->num_channels()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -313,9 +315,9 @@ REGISTER_CALCULATOR(CovarianceCalculator); // Options proto: None. class L2NormCalculator : public BasicTimeSeriesCalculatorBase { protected: - mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { output_header->set_num_channels(1); - return mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { @@ -385,12 +387,12 @@ REGISTER_CALCULATOR(ElementwiseSquareCalculator); // Options proto: None. class FirstHalfSlicerCalculator : public BasicTimeSeriesCalculatorBase { protected: - mediapipe::Status MutateHeader(TimeSeriesHeader* output_header) final { + absl::Status MutateHeader(TimeSeriesHeader* output_header) final { const int num_input_samples = output_header->num_samples(); RET_CHECK(num_input_samples >= 0) << "FirstHalfSlicerCalculator: num_input_samples < 0"; output_header->set_num_samples(num_input_samples / 2); - return mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix ProcessMatrix(const Matrix& input_matrix) final { diff --git a/mediapipe/calculators/audio/basic_time_series_calculators.h b/mediapipe/calculators/audio/basic_time_series_calculators.h index f08939440..ef31f3448 100644 --- a/mediapipe/calculators/audio/basic_time_series_calculators.h +++ b/mediapipe/calculators/audio/basic_time_series_calculators.h @@ -28,16 +28,16 @@ namespace mediapipe { class BasicTimeSeriesCalculatorBase : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) final; + absl::Status Process(CalculatorContext* cc) final; protected: // Open() calls this method to mutate the output stream header. The input // to this function will contain a copy of the input stream header, so // subclasses that do not need to mutate the header do not need to override // it. - virtual mediapipe::Status MutateHeader(TimeSeriesHeader* output_header); + virtual absl::Status MutateHeader(TimeSeriesHeader* output_header); // Process() calls this method on each packet to compute the output matrix. virtual Matrix ProcessMatrix(const Matrix& input_matrix) = 0; diff --git a/mediapipe/calculators/audio/mfcc_mel_calculators.cc b/mediapipe/calculators/audio/mfcc_mel_calculators.cc index d6d5cf56c..a63b9d6ea 100644 --- a/mediapipe/calculators/audio/mfcc_mel_calculators.cc +++ b/mediapipe/calculators/audio/mfcc_mel_calculators.cc @@ -66,7 +66,7 @@ std::string PortableDebugString(const TimeSeriesHeader& header) { // rows corresponding to the new feature space). class FramewiseTransformCalculatorBase : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Sequence of Matrices, each column describing a particular time frame, // each row a feature dimension, with TimeSeriesHeader. @@ -75,11 +75,11 @@ class FramewiseTransformCalculatorBase : public CalculatorBase { // Sequence of Matrices, each column describing a particular time frame, // each row a feature dimension, with TimeSeriesHeader. ); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) final; + absl::Status Process(CalculatorContext* cc) final; int num_output_channels(void) { return num_output_channels_; } @@ -90,8 +90,8 @@ class FramewiseTransformCalculatorBase : public CalculatorBase { private: // Takes header and options, and sets up state including calling // set_num_output_channels() on the base object. - virtual mediapipe::Status ConfigureTransform(const TimeSeriesHeader& header, - CalculatorContext* cc) = 0; + virtual absl::Status ConfigureTransform(const TimeSeriesHeader& header, + CalculatorContext* cc) = 0; // Takes a vector corresponding to an input frame, and // perform the specific transformation to produce an output frame. @@ -102,23 +102,23 @@ class FramewiseTransformCalculatorBase : public CalculatorBase { int num_output_channels_; }; -mediapipe::Status FramewiseTransformCalculatorBase::Open( - CalculatorContext* cc) { +absl::Status FramewiseTransformCalculatorBase::Open(CalculatorContext* cc) { TimeSeriesHeader input_header; MP_RETURN_IF_ERROR(time_series_util::FillTimeSeriesHeaderIfValid( cc->Inputs().Index(0).Header(), &input_header)); - mediapipe::Status status = ConfigureTransform(input_header, cc); + absl::Status status = ConfigureTransform(input_header, cc); auto output_header = new TimeSeriesHeader(input_header); output_header->set_num_channels(num_output_channels_); cc->Outputs().Index(0).SetHeader(Adopt(output_header)); + cc->SetOffset(0); + return status; } -mediapipe::Status FramewiseTransformCalculatorBase::Process( - CalculatorContext* cc) { +absl::Status FramewiseTransformCalculatorBase::Process(CalculatorContext* cc) { const Matrix& input = cc->Inputs().Index(0).Get(); const int num_frames = input.cols(); std::unique_ptr output(new Matrix(num_output_channels_, num_frames)); @@ -145,7 +145,7 @@ mediapipe::Status FramewiseTransformCalculatorBase::Process( } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Calculator wrapper around the dsp/mfcc/mfcc.cc routine. @@ -170,13 +170,13 @@ mediapipe::Status FramewiseTransformCalculatorBase::Process( // } class MfccCalculator : public FramewiseTransformCalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { return FramewiseTransformCalculatorBase::GetContract(cc); } private: - mediapipe::Status ConfigureTransform(const TimeSeriesHeader& header, - CalculatorContext* cc) override { + absl::Status ConfigureTransform(const TimeSeriesHeader& header, + CalculatorContext* cc) override { MfccCalculatorOptions mfcc_options = cc->Options(); mfcc_.reset(new audio_dsp::Mfcc()); int input_length = header.num_channels(); @@ -194,7 +194,7 @@ class MfccCalculator : public FramewiseTransformCalculatorBase { // audio_dsp::MelFilterBank needs to know this to // correctly interpret the spectrogram bins. if (!header.has_audio_sample_rate()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ", PortableDebugString(header))); } @@ -203,10 +203,10 @@ class MfccCalculator : public FramewiseTransformCalculatorBase { mfcc_->Initialize(input_length, header.audio_sample_rate()); if (initialized) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { - return mediapipe::Status(mediapipe::StatusCode::kInternal, - "Mfcc::Initialize returned uninitialized"); + return absl::Status(absl::StatusCode::kInternal, + "Mfcc::Initialize returned uninitialized"); } } @@ -228,13 +228,13 @@ REGISTER_CALCULATOR(MfccCalculator); // if you ask for too many channels. class MelSpectrumCalculator : public FramewiseTransformCalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { return FramewiseTransformCalculatorBase::GetContract(cc); } private: - mediapipe::Status ConfigureTransform(const TimeSeriesHeader& header, - CalculatorContext* cc) override { + absl::Status ConfigureTransform(const TimeSeriesHeader& header, + CalculatorContext* cc) override { MelSpectrumCalculatorOptions mel_spectrum_options = cc->Options(); mel_filterbank_.reset(new audio_dsp::MelFilterbank()); @@ -245,7 +245,7 @@ class MelSpectrumCalculator : public FramewiseTransformCalculatorBase { // audio_dsp::MelFilterBank needs to know this to // correctly interpret the spectrogram bins. if (!header.has_audio_sample_rate()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("No audio_sample_rate in input TimeSeriesHeader ", PortableDebugString(header))); } @@ -255,10 +255,10 @@ class MelSpectrumCalculator : public FramewiseTransformCalculatorBase { mel_spectrum_options.max_frequency_hertz()); if (initialized) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { - return mediapipe::Status(mediapipe::StatusCode::kInternal, - "mfcc::Initialize returned uninitialized"); + return absl::Status(absl::StatusCode::kInternal, + "mfcc::Initialize returned uninitialized"); } } diff --git a/mediapipe/calculators/audio/mfcc_mel_calculators_test.cc b/mediapipe/calculators/audio/mfcc_mel_calculators_test.cc index e7e7ac3a0..e7e312db9 100644 --- a/mediapipe/calculators/audio/mfcc_mel_calculators_test.cc +++ b/mediapipe/calculators/audio/mfcc_mel_calculators_test.cc @@ -84,7 +84,7 @@ class FramewiseTransformCalculatorTest num_samples_per_packet_ = GenerateRandomNonnegInputStream(kNumPackets); } - mediapipe::Status Run() { return this->RunGraph(); } + absl::Status Run() { return this->RunGraph(); } void CheckResults(int expected_num_channels) { const auto& output_header = diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator.cc b/mediapipe/calculators/audio/rational_factor_resample_calculator.cc index b335fbe40..1a4210c30 100644 --- a/mediapipe/calculators/audio/rational_factor_resample_calculator.cc +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator.cc @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2019, 2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -16,22 +16,18 @@ #include "mediapipe/calculators/audio/rational_factor_resample_calculator.h" -#include "audio/dsp/resampler_rational_factor.h" +#include "audio/dsp/resampler_q.h" -using audio_dsp::DefaultResamplingKernel; -using audio_dsp::RationalFactorResampler; using audio_dsp::Resampler; namespace mediapipe { -mediapipe::Status RationalFactorResampleCalculator::Process( - CalculatorContext* cc) { +absl::Status RationalFactorResampleCalculator::Process(CalculatorContext* cc) { return ProcessInternal(cc->Inputs().Index(0).Get(), false, cc); } -mediapipe::Status RationalFactorResampleCalculator::Close( - CalculatorContext* cc) { +absl::Status RationalFactorResampleCalculator::Close(CalculatorContext* cc) { if (initial_timestamp_ == Timestamp::Unstarted()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } Matrix empty_input_frame(num_channels_, 0); return ProcessInternal(empty_input_frame, true, cc); @@ -40,11 +36,8 @@ mediapipe::Status RationalFactorResampleCalculator::Close( namespace { void CopyChannelToVector(const Matrix& matrix, int channel, std::vector* vec) { - vec->clear(); - vec->reserve(matrix.cols()); - for (int sample = 0; sample < matrix.cols(); ++sample) { - vec->push_back(matrix(channel, sample)); - } + vec->resize(matrix.cols()); + Eigen::Map(vec->data(), vec->size()) = matrix.row(channel); } void CopyVectorToChannel(const std::vector& vec, Matrix* matrix, @@ -53,17 +46,14 @@ void CopyVectorToChannel(const std::vector& vec, Matrix* matrix, matrix->resize(matrix->rows(), vec.size()); } else { CHECK_EQ(vec.size(), matrix->cols()); - CHECK_LT(channel, matrix->rows()); - } - for (int sample = 0; sample < matrix->cols(); ++sample) { - (*matrix)(channel, sample) = vec[sample]; } + CHECK_LT(channel, matrix->rows()); + matrix->row(channel) = + Eigen::Map(vec.data(), vec.size()); } - } // namespace -mediapipe::Status RationalFactorResampleCalculator::Open( - CalculatorContext* cc) { +absl::Status RationalFactorResampleCalculator::Open(CalculatorContext* cc) { RationalFactorResampleCalculatorOptions resample_options = cc->Options(); @@ -88,7 +78,7 @@ mediapipe::Status RationalFactorResampleCalculator::Open( resample_options); if (!r) { LOG(ERROR) << "Failed to initialize resampler."; - return mediapipe::UnknownError("Failed to initialize resampler."); + return absl::UnknownError("Failed to initialize resampler."); } } } @@ -106,10 +96,10 @@ mediapipe::Status RationalFactorResampleCalculator::Open( initial_timestamp_ = Timestamp::Unstarted(); check_inconsistent_timestamps_ = resample_options.check_inconsistent_timestamps(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RationalFactorResampleCalculator::ProcessInternal( +absl::Status RationalFactorResampleCalculator::ProcessInternal( const Matrix& input_frame, bool should_flush, CalculatorContext* cc) { if (initial_timestamp_ == Timestamp::Unstarted()) { initial_timestamp_ = cc->InputTimestamp(); @@ -131,7 +121,7 @@ mediapipe::Status RationalFactorResampleCalculator::ProcessInternal( *output_frame = input_frame; } else { if (!Resample(input_frame, output_frame.get(), should_flush)) { - return mediapipe::UnknownError("Resample() failed."); + return absl::UnknownError("Resample() failed."); } } cumulative_output_samples_ += output_frame->cols(); @@ -139,7 +129,7 @@ mediapipe::Status RationalFactorResampleCalculator::ProcessInternal( if (output_frame->cols() > 0) { cc->Outputs().Index(0).Add(output_frame.release(), output_timestamp); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } bool RationalFactorResampleCalculator::Resample(const Matrix& input_frame, @@ -167,25 +157,28 @@ RationalFactorResampleCalculator::ResamplerFromOptions( std::unique_ptr> resampler; const auto& rational_factor_options = options.resampler_rational_factor_options(); - std::unique_ptr kernel; + audio_dsp::QResamplerParams params; if (rational_factor_options.has_radius() && rational_factor_options.has_cutoff() && rational_factor_options.has_kaiser_beta()) { - kernel = absl::make_unique( - source_sample_rate, target_sample_rate, - rational_factor_options.radius(), rational_factor_options.cutoff(), - rational_factor_options.kaiser_beta()); - } else { - kernel = absl::make_unique(source_sample_rate, - target_sample_rate); + // Convert RationalFactorResampler kernel parameters to QResampler + // settings. + params.filter_radius_factor = + rational_factor_options.radius() * + std::min(1.0, target_sample_rate / source_sample_rate); + params.cutoff_proportion = 2 * rational_factor_options.cutoff() / + std::min(source_sample_rate, target_sample_rate); + params.kaiser_beta = rational_factor_options.kaiser_beta(); } - // Set large enough so that the resampling factor between common sample // rates (e.g. 8kHz, 16kHz, 22.05kHz, 32kHz, 44.1kHz, 48kHz) is exact, and // that any factor is represented with error less than 0.025%. - const int kMaxDenominator = 2000; - resampler = absl::make_unique>( - *kernel, kMaxDenominator); + params.max_denominator = 2000; + + // NOTE: QResampler supports multichannel resampling, so the code might be + // simplified using a single instance rather than one per channel. + resampler = absl::make_unique>( + source_sample_rate, target_sample_rate, /*num_channels=*/1, params); if (resampler != nullptr && !resampler->Valid()) { resampler = std::unique_ptr>(); } diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator.h b/mediapipe/calculators/audio/rational_factor_resample_calculator.h index dc0719b39..325886dc7 100644 --- a/mediapipe/calculators/audio/rational_factor_resample_calculator.h +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator.h @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2019, 2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,28 +36,31 @@ namespace mediapipe { // stream's sampling rate is specified by target_sample_rate in the // RationalFactorResampleCalculatorOptions. The output time series may have // a varying number of samples per frame. +// +// NOTE: This calculator uses QResampler, despite the name, which supersedes +// RationalFactorResampler. class RationalFactorResampleCalculator : public CalculatorBase { public: struct TestAccess; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Single input stream with TimeSeriesHeader. ); cc->Outputs().Index(0).Set( // Resampled stream with TimeSeriesHeader. ); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Returns FAIL if the input stream header is invalid or if the // resampler cannot be initialized. - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; // Resamples a packet of TimeSeries data. Returns FAIL if the // resampler state becomes inconsistent. - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; // Flushes any remaining state. Returns FAIL if the resampler state // becomes inconsistent. - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; protected: typedef audio_dsp::Resampler ResamplerType; @@ -72,8 +75,8 @@ class RationalFactorResampleCalculator : public CalculatorBase { // Does Timestamp bookkeeping and resampling common to Process() and // Close(). Returns FAIL if the resampler state becomes // inconsistent. - mediapipe::Status ProcessInternal(const Matrix& input_frame, - bool should_flush, CalculatorContext* cc); + absl::Status ProcessInternal(const Matrix& input_frame, bool should_flush, + CalculatorContext* cc); // Uses the internal resampler_ objects to actually resample each // row of the input TimeSeries. Returns false if the resampler diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator.proto b/mediapipe/calculators/audio/rational_factor_resample_calculator.proto index 6eb36e672..97d7f202c 100644 --- a/mediapipe/calculators/audio/rational_factor_resample_calculator.proto +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator.proto @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2019, 2021 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ package mediapipe; import "mediapipe/framework/calculator.proto"; +// NOTE: This calculator uses QResampler, despite the name, which supersedes +// RationalFactorResampler. message RationalFactorResampleCalculatorOptions { extend CalculatorOptions { optional RationalFactorResampleCalculatorOptions ext = 259760074; @@ -27,8 +29,7 @@ message RationalFactorResampleCalculatorOptions { // stream. Required. Must be greater than 0. optional double target_sample_rate = 1; - // Parameters for initializing the RationalFactorResampler. See - // RationalFactorResampler for more details. + // Parameters for initializing QResampler. See QResampler for more details. message ResamplerRationalFactorOptions { // Kernel radius in units of input samples. optional double radius = 1; diff --git a/mediapipe/calculators/audio/rational_factor_resample_calculator_test.cc b/mediapipe/calculators/audio/rational_factor_resample_calculator_test.cc index aefa21205..6ae360303 100644 --- a/mediapipe/calculators/audio/rational_factor_resample_calculator_test.cc +++ b/mediapipe/calculators/audio/rational_factor_resample_calculator_test.cc @@ -80,7 +80,7 @@ class RationalFactorResampleCalculatorTest } // Initializes and runs the test graph. - mediapipe::Status Run(double output_sample_rate) { + absl::Status Run(double output_sample_rate) { options_.set_target_sample_rate(output_sample_rate); InitializeGraph(); @@ -120,7 +120,6 @@ class RationalFactorResampleCalculatorTest // The exact number of expected samples may vary based on the implementation // of the resampler since the exact value is not an integer. - // TODO: Reduce this offset to + 1 once cl/185829520 is submitted. const double expected_num_output_samples = num_input_samples_ * factor; EXPECT_LE(ceil(expected_num_output_samples), num_output_samples); EXPECT_GE(ceil(expected_num_output_samples) + 11, num_output_samples); diff --git a/mediapipe/calculators/audio/spectrogram_calculator.cc b/mediapipe/calculators/audio/spectrogram_calculator.cc index dd2dae886..bd2234f86 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator.cc @@ -66,7 +66,7 @@ namespace mediapipe { // analysis frame will advance from its predecessor by the same time step. class SpectrogramCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Input stream with TimeSeriesHeader. ); @@ -96,26 +96,34 @@ class SpectrogramCalculator : public CalculatorBase { ); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Returns FAIL if the input stream header is invalid. - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; // Outputs at most one packet consisting of a single Matrix with one or // more columns containing the spectral values from as many input frames // as are completed by the input samples. Always returns OK. - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; // Performs zero-padding and processing of any remaining samples // if pad_final_packet is set. // Returns OK. - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: Timestamp CurrentOutputTimestamp(CalculatorContext* cc) { if (use_local_timestamp_) { - return cc->InputTimestamp(); + const Timestamp now = cc->InputTimestamp(); + if (now == Timestamp::Done()) { + // During Close the timestamp is not available, send an estimate. + return last_local_output_timestamp_ + + round(last_completed_frames_ * frame_step_samples() * + Timestamp::kTimestampUnitsPerSecond / input_sample_rate_); + } + last_local_output_timestamp_ = now; + return now; } return CumulativeOutputTimestamp(); } @@ -138,17 +146,20 @@ class SpectrogramCalculator : public CalculatorBase { // Convert the output of the spectrogram object into a Matrix (or an // Eigen::MatrixXcf if complex-valued output is requested) and pass to // MediaPipe output. - mediapipe::Status ProcessVector(const Matrix& input_stream, - CalculatorContext* cc); + absl::Status ProcessVector(const Matrix& input_stream, CalculatorContext* cc); // Templated function to process either real- or complex-output spectrogram. template - mediapipe::Status ProcessVectorToOutput( + absl::Status ProcessVectorToOutput( const Matrix& input_stream, const OutputMatrixType postprocess_output_fn(const OutputMatrixType&), CalculatorContext* cc); + // Use the MediaPipe timestamp instead of the estimated one. Useful when the + // data is intermittent. bool use_local_timestamp_; + Timestamp last_local_output_timestamp_; + double input_sample_rate_; bool pad_final_packet_; int frame_duration_samples_; @@ -157,6 +168,9 @@ class SpectrogramCalculator : public CalculatorBase { int64 cumulative_input_samples_; // How many frames we've emitted, used for calculating output time stamps. int64 cumulative_completed_frames_; + // How many frames were emitted last, used for estimating the timestamp on + // Close when use_local_timestamp_ is true; + int64 last_completed_frames_; Timestamp initial_input_timestamp_; int num_input_channels_; // How many frequency bins we emit (=N_FFT/2 + 1). @@ -177,7 +191,7 @@ REGISTER_CALCULATOR(SpectrogramCalculator); // Factor to convert ln(magnitude_squared) to deciBels = 10.0/ln(10.0). const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518; -mediapipe::Status SpectrogramCalculator::Open(CalculatorContext* cc) { +absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { SpectrogramCalculatorOptions spectrogram_options = cc->Options(); @@ -271,11 +285,20 @@ mediapipe::Status SpectrogramCalculator::Open(CalculatorContext* cc) { Adopt(multichannel_output_header.release())); } cumulative_completed_frames_ = 0; + last_completed_frames_ = 0; initial_input_timestamp_ = Timestamp::Unstarted(); - return mediapipe::OkStatus(); + if (use_local_timestamp_) { + // Inform the framework that the calculator will output packets at the same + // timestamps as input packets to enable packet queueing optimizations. The + // final packet (emitted from Close()) does not follow this rule but it's + // sufficient that its timestamp is strictly greater than the timestamp of + // the previous packet. + cc->SetOffset(0); + } + return absl::OkStatus(); } -mediapipe::Status SpectrogramCalculator::Process(CalculatorContext* cc) { +absl::Status SpectrogramCalculator::Process(CalculatorContext* cc) { if (initial_input_timestamp_ == Timestamp::Unstarted()) { initial_input_timestamp_ = cc->InputTimestamp(); } @@ -291,7 +314,7 @@ mediapipe::Status SpectrogramCalculator::Process(CalculatorContext* cc) { } template -mediapipe::Status SpectrogramCalculator::ProcessVectorToOutput( +absl::Status SpectrogramCalculator::ProcessVectorToOutput( const Matrix& input_stream, const OutputMatrixType postprocess_output_fn(const OutputMatrixType&), CalculatorContext* cc) { @@ -311,8 +334,8 @@ mediapipe::Status SpectrogramCalculator::ProcessVectorToOutput( if (!spectrogram_generators_[channel]->ComputeSpectrogram( input_vector, &output_vectors)) { - return mediapipe::Status(mediapipe::StatusCode::kInternal, - "Spectrogram returned failure"); + return absl::Status(absl::StatusCode::kInternal, + "Spectrogram returned failure"); } if (channel == 0) { // Record the number of time frames we expect from each channel. @@ -354,12 +377,19 @@ mediapipe::Status SpectrogramCalculator::ProcessVectorToOutput( CurrentOutputTimestamp(cc)); } cumulative_completed_frames_ += output_vectors.size(); + last_completed_frames_ = output_vectors.size(); + if (!use_local_timestamp_) { + // In non-local timestamp mode the timestamp of the next packet will be + // equal to CumulativeOutputTimestamp(). Inform the framework about this + // fact to enable packet queueing optimizations. + cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp()); + } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SpectrogramCalculator::ProcessVector( - const Matrix& input_stream, CalculatorContext* cc) { +absl::Status SpectrogramCalculator::ProcessVector(const Matrix& input_stream, + CalculatorContext* cc) { switch (output_type_) { // These blocks deliberately ignore clang-format to preserve the // "silhouette" of the different cases. @@ -394,13 +424,13 @@ mediapipe::Status SpectrogramCalculator::ProcessVector( } // clang-format on default: { - return mediapipe::Status(mediapipe::StatusCode::kInvalidArgument, - "Unrecognized spectrogram output type."); + return absl::Status(absl::StatusCode::kInvalidArgument, + "Unrecognized spectrogram output type."); } } } -mediapipe::Status SpectrogramCalculator::Close(CalculatorContext* cc) { +absl::Status SpectrogramCalculator::Close(CalculatorContext* cc) { if (cumulative_input_samples_ > 0 && pad_final_packet_) { // We can flush any remaining samples by sending frame_step_samples - 1 // zeros to the Process method, and letting it do its thing, @@ -416,7 +446,7 @@ mediapipe::Status SpectrogramCalculator::Close(CalculatorContext* cc) { Matrix::Zero(num_input_channels_, required_padding_samples), cc); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/audio/spectrogram_calculator_test.cc b/mediapipe/calculators/audio/spectrogram_calculator_test.cc index c28ffb4d2..3c2b8435d 100644 --- a/mediapipe/calculators/audio/spectrogram_calculator_test.cc +++ b/mediapipe/calculators/audio/spectrogram_calculator_test.cc @@ -50,7 +50,7 @@ class SpectrogramCalculatorTest } // Initializes and runs the test graph. - mediapipe::Status Run() { + absl::Status Run() { // Now that options are set, we can set up some internal constants. frame_duration_samples_ = round(options_.frame_duration_seconds() * input_sample_rate_); diff --git a/mediapipe/calculators/audio/stabilized_log_calculator.cc b/mediapipe/calculators/audio/stabilized_log_calculator.cc index 20d062bfb..0c697a196 100644 --- a/mediapipe/calculators/audio/stabilized_log_calculator.cc +++ b/mediapipe/calculators/audio/stabilized_log_calculator.cc @@ -41,17 +41,17 @@ namespace mediapipe { // } class StabilizedLogCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Input stream with TimeSeriesHeader. ); cc->Outputs().Index(0).Set( // Output stabilized log stream with TimeSeriesHeader. ); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { StabilizedLogCalculatorOptions stabilized_log_calculator_options = cc->Options(); @@ -70,23 +70,23 @@ class StabilizedLogCalculator : public CalculatorBase { cc->Outputs().Index(0).SetHeader( Adopt(new TimeSeriesHeader(input_header))); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { auto input_matrix = cc->Inputs().Index(0).Get(); if (input_matrix.array().isNaN().any()) { - return mediapipe::InvalidArgumentError("NaN input to log operation."); + return absl::InvalidArgumentError("NaN input to log operation."); } if (check_nonnegativity_) { if (input_matrix.minCoeff() < 0.0) { - return mediapipe::OutOfRangeError("Negative input to log operation."); + return absl::OutOfRangeError("Negative input to log operation."); } } std::unique_ptr output_frame(new Matrix( output_scale_ * (input_matrix.array() + stabilizer_).log().matrix())); cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/audio/time_series_framer_calculator.cc b/mediapipe/calculators/audio/time_series_framer_calculator.cc index bffda6723..fbbf34226 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator.cc @@ -66,26 +66,26 @@ namespace mediapipe { // cumulative_completed_samples / sample_rate_. class TimeSeriesFramerCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set( // Input stream with TimeSeriesHeader. ); cc->Outputs().Index(0).Set( // Fixed length time series Packets with TimeSeriesHeader. ); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Returns FAIL if the input stream header is invalid. - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; // Outputs as many framed packets as possible given the accumulated // input. Always returns OK. - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; // Flushes any remaining samples in a zero-padded packet. Always // returns OK. - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: // Adds input data to the internal buffer. @@ -134,7 +134,6 @@ class TimeSeriesFramerCalculator : public CalculatorBase { // emulate_fractional_frame_overlap is true. double average_frame_step_samples_; int samples_still_to_drop_; - int64 cumulative_input_samples_; int64 cumulative_output_frames_; // "Completed" samples are samples that are no longer needed because // the framer has completely stepped past them (taking into account @@ -163,8 +162,6 @@ void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) { sample_buffer_.emplace_back(std::make_pair( input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i))); } - - cumulative_input_samples_ += input_frame.cols(); } void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { @@ -203,9 +200,15 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) { ++cumulative_output_frames_; cumulative_completed_samples_ += frame_step_samples; } + if (!use_local_timestamp_) { + // In non-local timestamp mode the timestamp of the next packet will be + // equal to CumulativeOutputTimestamp(). Inform the framework about this + // fact to enable packet queueing optimizations. + cc->Outputs().Index(0).SetNextTimestampBound(CumulativeOutputTimestamp()); + } } -mediapipe::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) { +absl::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) { if (initial_input_timestamp_ == Timestamp::Unstarted()) { initial_input_timestamp_ = cc->InputTimestamp(); current_timestamp_ = initial_input_timestamp_; @@ -214,10 +217,10 @@ mediapipe::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) { EnqueueInput(cc); FrameOutput(cc); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) { +absl::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) { while (samples_still_to_drop_ > 0 && !sample_buffer_.empty()) { sample_buffer_.pop_front(); --samples_still_to_drop_; @@ -234,10 +237,10 @@ mediapipe::Status TimeSeriesFramerCalculator::Close(CalculatorContext* cc) { CurrentOutputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) { +absl::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) { TimeSeriesFramerCalculatorOptions framer_options = cc->Options(); @@ -286,7 +289,6 @@ mediapipe::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) { } cc->Outputs().Index(0).SetHeader(Adopt(output_header)); cumulative_completed_samples_ = 0; - cumulative_input_samples_ = 0; cumulative_output_frames_ = 0; samples_still_to_drop_ = 0; initial_input_timestamp_ = Timestamp::Unstarted(); @@ -317,7 +319,7 @@ mediapipe::Status TimeSeriesFramerCalculator::Open(CalculatorContext* cc) { } use_local_timestamp_ = framer_options.use_local_timestamp(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc index c4978a00e..ca88cebb5 100644 --- a/mediapipe/calculators/audio/time_series_framer_calculator_test.cc +++ b/mediapipe/calculators/audio/time_series_framer_calculator_test.cc @@ -69,7 +69,7 @@ class TimeSeriesFramerCalculatorTest } // Initializes and runs the test graph. - mediapipe::Status Run() { + absl::Status Run() { InitializeGraph(); FillInputHeader(); @@ -441,7 +441,7 @@ class TimeSeriesFramerCalculatorTimestampingTest } } - mediapipe::Status RunTimestampTest() { + absl::Status RunTimestampTest() { InitializeGraph(); InitializeInputForTimeStampingTest(); FillInputHeader(); diff --git a/mediapipe/calculators/core/BUILD b/mediapipe/calculators/core/BUILD index af98fef3a..61d402f74 100644 --- a/mediapipe/calculators/core/BUILD +++ b/mediapipe/calculators/core/BUILD @@ -249,6 +249,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":concatenate_vector_calculator_cc_proto", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/api2:port", "//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:tensor", @@ -554,6 +556,7 @@ cc_library( ], deps = [ "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", "//mediapipe/framework/port:ret_check", ], alwayslink = 1, diff --git a/mediapipe/calculators/core/add_header_calculator.cc b/mediapipe/calculators/core/add_header_calculator.cc index dc0fa8aed..1c636afd0 100644 --- a/mediapipe/calculators/core/add_header_calculator.cc +++ b/mediapipe/calculators/core/add_header_calculator.cc @@ -53,27 +53,28 @@ class AddHeaderCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kHeader, kHeaderSide, kData, kOut); - static mediapipe::Status UpdateContract(CalculatorContract* cc) { + static absl::Status UpdateContract(CalculatorContract* cc) { if (kHeader(cc).IsConnected() == kHeaderSide(cc).IsConnected()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Header must be provided via exactly one of side input and input " "stream"); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const PacketBase& header = kHeader(cc).IsConnected() ? kHeader(cc).Header() : kHeaderSide(cc); if (!header.IsEmpty()) { kOut(cc).SetHeader(header); } - return mediapipe::OkStatus(); + cc->SetOffset(0); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { kOut(cc).Send(kData(cc).packet()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/core/add_header_calculator_test.cc b/mediapipe/calculators/core/add_header_calculator_test.cc index 8aa5d3424..4e197918d 100644 --- a/mediapipe/calculators/core/add_header_calculator_test.cc +++ b/mediapipe/calculators/core/add_header_calculator_test.cc @@ -153,7 +153,7 @@ TEST_F(AddHeaderCalculatorTest, UsingBothSideInputAndStream) { } // Run should fail because header can only be provided one way. - EXPECT_EQ(runner.Run().code(), mediapipe::InvalidArgumentError("").code()); + EXPECT_EQ(runner.Run().code(), absl::InvalidArgumentError("").code()); } } // namespace mediapipe diff --git a/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc b/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc index 85834491a..b627e5b23 100644 --- a/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc +++ b/mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc @@ -42,22 +42,22 @@ REGISTER_CALCULATOR(BeginLoopIntegerCalculator); class IncrementCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const int& input_int = cc->Inputs().Index(0).Get(); auto output_int = absl::make_unique(input_int + 1); cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; @@ -166,19 +166,19 @@ TEST_F(BeginEndLoopCalculatorGraphTest, MultipleVectors) { // bound update. class PassThroughOrEmptyVectorCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->SetProcessTimestampBounds(true); cc->Inputs().Index(0).Set>(); cc->Outputs().Index(0).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (!cc->Inputs().Index(0).IsEmpty()) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); } else { @@ -186,7 +186,7 @@ class PassThroughOrEmptyVectorCalculator : public CalculatorBase { MakePacket>(std::vector()) .At(cc->InputTimestamp())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; @@ -311,24 +311,24 @@ TEST_F(BeginEndLoopCalculatorGraphProcessingEmptyPacketsTest, MultipleVectors) { class MultiplierCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Inputs().Index(1).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const int& input_int = cc->Inputs().Index(0).Get(); const int& multiplier_int = cc->Inputs().Index(1).Get(); auto output_int = absl::make_unique(input_int * multiplier_int); cc->Outputs().Index(0).Add(output_int.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/core/begin_loop_calculator.h b/mediapipe/calculators/core/begin_loop_calculator.h index a655d1871..a9d29e687 100644 --- a/mediapipe/calculators/core/begin_loop_calculator.h +++ b/mediapipe/calculators/core/begin_loop_calculator.h @@ -61,7 +61,7 @@ class BeginLoopCalculator : public CalculatorBase { using ItemT = typename IterableT::value_type; public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { // The below enables processing of timestamp bound updates, and that enables // correct timestamp propagation by the companion EndLoopCalculator. // @@ -106,10 +106,10 @@ class BeginLoopCalculator : public CalculatorBase { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { Timestamp last_timestamp = loop_internal_timestamp_; if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) { const IterableT& collection = @@ -139,7 +139,7 @@ class BeginLoopCalculator : public CalculatorBase { .AddPacket(MakePacket(cc->InputTimestamp()) .At(Timestamp(loop_internal_timestamp_ - 1))); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/clip_vector_size_calculator.h b/mediapipe/calculators/core/clip_vector_size_calculator.h index c16fd6dcc..00de9be7f 100644 --- a/mediapipe/calculators/core/clip_vector_size_calculator.h +++ b/mediapipe/calculators/core/clip_vector_size_calculator.h @@ -43,13 +43,13 @@ namespace mediapipe { template class ClipVectorSizeCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().NumEntries() == 1); RET_CHECK(cc->Outputs().NumEntries() == 1); if (cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>() .max_vec_size() < 1) { - return mediapipe::InternalError( + return absl::InternalError( "max_vec_size should be greater than or equal to 1."); } @@ -60,10 +60,10 @@ class ClipVectorSizeCalculator : public CalculatorBase { cc->InputSidePackets().Index(0).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); max_vec_size_ = cc->Options<::mediapipe::ClipVectorSizeCalculatorOptions>() .max_vec_size(); @@ -72,23 +72,23 @@ class ClipVectorSizeCalculator : public CalculatorBase { !cc->InputSidePackets().Index(0).IsEmpty()) { max_vec_size_ = cc->InputSidePackets().Index(0).Get(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (max_vec_size_ < 1) { - return mediapipe::InternalError( + return absl::InternalError( "max_vec_size should be greater than or equal to 1."); } if (cc->Inputs().Index(0).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } return ClipVectorSize(std::is_copy_constructible(), cc); } template - mediapipe::Status ClipVectorSize(std::true_type, CalculatorContext* cc) { + absl::Status ClipVectorSize(std::true_type, CalculatorContext* cc) { auto output = absl::make_unique>(); const std::vector& input_vector = cc->Inputs().Index(0).Get>(); @@ -100,24 +100,23 @@ class ClipVectorSizeCalculator : public CalculatorBase { } } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } template - mediapipe::Status ClipVectorSize(std::false_type, CalculatorContext* cc) { + absl::Status ClipVectorSize(std::false_type, CalculatorContext* cc) { return ConsumeAndClipVectorSize(std::is_move_constructible(), cc); } template - mediapipe::Status ConsumeAndClipVectorSize(std::true_type, - CalculatorContext* cc) { + absl::Status ConsumeAndClipVectorSize(std::true_type, CalculatorContext* cc) { auto output = absl::make_unique>(); - mediapipe::StatusOr>> input_status = + absl::StatusOr>> input_status = cc->Inputs().Index(0).Value().Consume>(); if (input_status.ok()) { std::unique_ptr> input_vector = - std::move(input_status).ValueOrDie(); + std::move(input_status).value(); auto begin_it = input_vector->begin(); auto end_it = input_vector->end(); if (max_vec_size_ < input_vector->size()) { @@ -129,13 +128,13 @@ class ClipVectorSizeCalculator : public CalculatorBase { return input_status.status(); } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } template - mediapipe::Status ConsumeAndClipVectorSize(std::false_type, - CalculatorContext* cc) { - return mediapipe::InternalError( + absl::Status ConsumeAndClipVectorSize(std::false_type, + CalculatorContext* cc) { + return absl::InternalError( "Cannot copy or move input vectors and clip their size."); } diff --git a/mediapipe/calculators/core/concatenate_detection_vector_calculator.cc b/mediapipe/calculators/core/concatenate_detection_vector_calculator.cc index 161a323cf..fd7d324b2 100644 --- a/mediapipe/calculators/core/concatenate_detection_vector_calculator.cc +++ b/mediapipe/calculators/core/concatenate_detection_vector_calculator.cc @@ -1,4 +1,4 @@ -// Copyright 2019 The MediaPipe Authors. +// Copyright 2019-2020 The MediaPipe Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,14 +20,16 @@ namespace mediapipe { // Example config: +// // node { // calculator: "ConcatenateDetectionVectorCalculator" // input_stream: "detection_vector_1" // input_stream: "detection_vector_2" // output_stream: "concatenated_detection_vector" // } +// typedef ConcatenateVectorCalculator<::mediapipe::Detection> ConcatenateDetectionVectorCalculator; -REGISTER_CALCULATOR(ConcatenateDetectionVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateDetectionVectorCalculator); } // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc b/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc index fb405533b..f0a4043a7 100644 --- a/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc +++ b/mediapipe/calculators/core/concatenate_normalized_landmark_list_calculator.cc @@ -36,35 +36,35 @@ class ConcatenateNormalizedLandmarkListCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kOut); - static mediapipe::Status UpdateContract(CalculatorContract* cc) { + static absl::Status UpdateContract(CalculatorContract* cc) { RET_CHECK_GE(kIn(cc).Count(), 1); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { only_emit_if_all_present_ = cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>() .only_emit_if_all_present(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (only_emit_if_all_present_) { - for (int i = 0; i < kIn(cc).Count(); ++i) { - if (kIn(cc)[i].IsEmpty()) return mediapipe::OkStatus(); + for (const auto& input : kIn(cc)) { + if (input.IsEmpty()) return absl::OkStatus(); } } NormalizedLandmarkList output; - for (int i = 0; i < kIn(cc).Count(); ++i) { - if (kIn(cc)[i].IsEmpty()) continue; - const NormalizedLandmarkList& input = *kIn(cc)[i]; - for (int j = 0; j < input.landmark_size(); ++j) { - *output.add_landmark() = input.landmark(j); + for (const auto& input : kIn(cc)) { + if (input.IsEmpty()) continue; + const NormalizedLandmarkList& list = *input; + for (int j = 0; j < list.landmark_size(); ++j) { + *output.add_landmark() = list.landmark(j); } } kOut(cc).Send(std::move(output)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.cc b/mediapipe/calculators/core/concatenate_vector_calculator.cc index 39be14f46..20d6a3286 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator.cc +++ b/mediapipe/calculators/core/concatenate_vector_calculator.cc @@ -25,7 +25,7 @@ #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) namespace mediapipe { @@ -37,7 +37,7 @@ namespace mediapipe { // output_stream: "concatenated_float_vector" // } typedef ConcatenateVectorCalculator ConcatenateFloatVectorCalculator; -REGISTER_CALCULATOR(ConcatenateFloatVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateFloatVectorCalculator); // Example config: // node { @@ -47,13 +47,13 @@ REGISTER_CALCULATOR(ConcatenateFloatVectorCalculator); // output_stream: "concatenated_int32_vector" // } typedef ConcatenateVectorCalculator ConcatenateInt32VectorCalculator; -REGISTER_CALCULATOR(ConcatenateInt32VectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateInt32VectorCalculator); typedef ConcatenateVectorCalculator ConcatenateUInt64VectorCalculator; -REGISTER_CALCULATOR(ConcatenateUInt64VectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateUInt64VectorCalculator); typedef ConcatenateVectorCalculator ConcatenateBoolVectorCalculator; -REGISTER_CALCULATOR(ConcatenateBoolVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateBoolVectorCalculator); // Example config: // node { @@ -64,31 +64,31 @@ REGISTER_CALCULATOR(ConcatenateBoolVectorCalculator); // } typedef ConcatenateVectorCalculator ConcatenateTfLiteTensorVectorCalculator; -REGISTER_CALCULATOR(ConcatenateTfLiteTensorVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateTfLiteTensorVectorCalculator); typedef ConcatenateVectorCalculator ConcatenateTensorVectorCalculator; -REGISTER_CALCULATOR(ConcatenateTensorVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateTensorVectorCalculator); typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmark> ConcatenateLandmarkVectorCalculator; -REGISTER_CALCULATOR(ConcatenateLandmarkVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkVectorCalculator); typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmarkList> ConcatenateLandmarListVectorCalculator; -REGISTER_CALCULATOR(ConcatenateLandmarListVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarListVectorCalculator); typedef ConcatenateVectorCalculator ConcatenateClassificationListVectorCalculator; -REGISTER_CALCULATOR(ConcatenateClassificationListVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateClassificationListVectorCalculator); #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) typedef ConcatenateVectorCalculator<::tflite::gpu::gl::GlBuffer> ConcatenateGlBufferVectorCalculator; -REGISTER_CALCULATOR(ConcatenateGlBufferVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateGlBufferVectorCalculator); #endif typedef ConcatenateVectorCalculator ConcatenateRenderDataVectorCalculator; -REGISTER_CALCULATOR(ConcatenateRenderDataVectorCalculator); +MEDIAPIPE_REGISTER_NODE(ConcatenateRenderDataVectorCalculator); } // namespace mediapipe diff --git a/mediapipe/calculators/core/concatenate_vector_calculator.h b/mediapipe/calculators/core/concatenate_vector_calculator.h index 01b729ed9..c6687814c 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator.h +++ b/mediapipe/calculators/core/concatenate_vector_calculator.h @@ -20,120 +20,96 @@ #include #include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { +// Note: since this is a calculator template that can be included by other +// source files, we do not place this in namespace api2 directly, but qualify +// the api2 names below, to avoid changing the visible name of the class. +// We cannot simply write "using mediapipe::api2" since it's a header file. +// This distinction will go away once api2 is finalized. // Concatenates several objects of type T or std::vector following stream // index order. This class assumes that every input stream contains either T or // vector type. To use this class for a particular type T, regisiter a // calculator using ConcatenateVectorCalculator. template -class ConcatenateVectorCalculator : public CalculatorBase { +class ConcatenateVectorCalculator : public api2::Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { - RET_CHECK(cc->Inputs().NumEntries() != 0); - RET_CHECK(cc->Outputs().NumEntries() == 1); + static constexpr + typename api2::Input>>::Multiple kIn{""}; + static constexpr api2::Output> kOut{""}; - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - // Actual type T or vector will be validated in Process(). - cc->Inputs().Index(i).SetAny(); - } + MEDIAPIPE_NODE_CONTRACT(kIn, kOut); - cc->Outputs().Index(0).Set>(); - - return mediapipe::OkStatus(); + static absl::Status UpdateContract(CalculatorContract* cc) { + RET_CHECK_GE(kIn(cc).Count(), 1); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { - cc->SetOffset(TimestampDiff(0)); + absl::Status Open(CalculatorContext* cc) override { only_emit_if_all_present_ = cc->Options<::mediapipe::ConcatenateVectorCalculatorOptions>() .only_emit_if_all_present(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (only_emit_if_all_present_) { - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - if (cc->Inputs().Index(i).IsEmpty()) return mediapipe::OkStatus(); + for (const auto& input : kIn(cc)) { + if (input.IsEmpty()) return ::absl::OkStatus(); } } - return ConcatenateVectors(std::is_copy_constructible(), cc); } template - mediapipe::Status ConcatenateVectors(std::true_type, CalculatorContext* cc) { - auto output = absl::make_unique>(); - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - auto& input = cc->Inputs().Index(i); - + absl::Status ConcatenateVectors(std::true_type, CalculatorContext* cc) { + auto output = std::vector(); + for (const auto& input : kIn(cc)) { if (input.IsEmpty()) continue; - - if (input.Value().ValidateAsType().ok()) { - const U& value = input.Get(); - output->push_back(value); - } else if (input.Value().ValidateAsType>().ok()) { - const std::vector& value = input.Get>(); - output->insert(output->end(), value.begin(), value.end()); - } else { - return mediapipe::InvalidArgumentError("Invalid input stream type."); - } + input.Visit([&output](const U& value) { output.push_back(value); }, + [&output](const std::vector& value) { + output.insert(output.end(), value.begin(), value.end()); + }); } - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + kOut(cc).Send(std::move(output)); + return absl::OkStatus(); } template - mediapipe::Status ConcatenateVectors(std::false_type, CalculatorContext* cc) { + absl::Status ConcatenateVectors(std::false_type, CalculatorContext* cc) { return ConsumeAndConcatenateVectors(std::is_move_constructible(), cc); } template - mediapipe::Status ConsumeAndConcatenateVectors(std::true_type, - CalculatorContext* cc) { - auto output = absl::make_unique>(); - for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { - auto& input = cc->Inputs().Index(i); - + absl::Status ConsumeAndConcatenateVectors(std::true_type, + CalculatorContext* cc) { + auto output = std::vector(); + for (auto input : kIn(cc)) { if (input.IsEmpty()) continue; - - if (input.Value().ValidateAsType().ok()) { - mediapipe::StatusOr> value_status = - input.Value().Consume(); - if (value_status.ok()) { - std::unique_ptr value = std::move(value_status).ValueOrDie(); - output->push_back(std::move(*value)); - } else { - return value_status.status(); - } - } else if (input.Value().ValidateAsType>().ok()) { - mediapipe::StatusOr>> value_status = - input.Value().Consume>(); - if (value_status.ok()) { - std::unique_ptr> value = - std::move(value_status).ValueOrDie(); - output->insert(output->end(), std::make_move_iterator(value->begin()), - std::make_move_iterator(value->end())); - } else { - return value_status.status(); - } - } else { - return mediapipe::InvalidArgumentError("Invalid input stream type."); - } + MP_RETURN_IF_ERROR(input.ConsumeAndVisit( + [&output](std::unique_ptr value) { + output.push_back(std::move(*value)); + }, + [&output](std::unique_ptr> value) { + output.insert(output.end(), std::make_move_iterator(value->begin()), + std::make_move_iterator(value->end())); + })); } - cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + kOut(cc).Send(std::move(output)); + return absl::OkStatus(); } template - mediapipe::Status ConsumeAndConcatenateVectors(std::false_type, - CalculatorContext* cc) { - return mediapipe::InternalError( + absl::Status ConsumeAndConcatenateVectors(std::false_type, + CalculatorContext* cc) { + return absl::InternalError( "Cannot copy or move inputs to concatenate them"); } diff --git a/mediapipe/calculators/core/concatenate_vector_calculator_test.cc b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc index eaf23700c..83f058086 100644 --- a/mediapipe/calculators/core/concatenate_vector_calculator_test.cc +++ b/mediapipe/calculators/core/concatenate_vector_calculator_test.cc @@ -28,7 +28,7 @@ namespace mediapipe { typedef ConcatenateVectorCalculator TestConcatenateIntVectorCalculator; -REGISTER_CALCULATOR(TestConcatenateIntVectorCalculator); +MEDIAPIPE_REGISTER_NODE(TestConcatenateIntVectorCalculator); void AddInputVector(int index, const std::vector& input, int64 timestamp, CalculatorRunner* runner) { @@ -384,7 +384,7 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) { typedef ConcatenateVectorCalculator> TestConcatenateUniqueIntPtrCalculator; -REGISTER_CALCULATOR(TestConcatenateUniqueIntPtrCalculator); +MEDIAPIPE_REGISTER_NODE(TestConcatenateUniqueIntPtrCalculator); TEST(TestConcatenateUniqueIntVectorCalculatorTest, ConsumeOneTimestamp) { /* Note: We don't use CalculatorRunner for this test because it keeps copies diff --git a/mediapipe/calculators/core/constant_side_packet_calculator.cc b/mediapipe/calculators/core/constant_side_packet_calculator.cc index 4b6952deb..ff328377e 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator.cc +++ b/mediapipe/calculators/core/constant_side_packet_calculator.cc @@ -54,7 +54,7 @@ namespace {} // namespace // } class ConstantSidePacketCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options<::mediapipe::ConstantSidePacketCalculatorOptions>(); RET_CHECK_EQ(cc->OutputSidePackets().NumEntries(kPacketTag), @@ -80,14 +80,14 @@ class ConstantSidePacketCalculator : public CalculatorBase { } else if (packet_options.has_classification_list_value()) { packet.Set(); } else { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "None of supported values were specified in options."); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const auto& options = cc->Options<::mediapipe::ConstantSidePacketCalculatorOptions>(); int index = 0; @@ -109,15 +109,15 @@ class ConstantSidePacketCalculator : public CalculatorBase { packet.Set(MakePacket( packet_options.classification_list_value())); } else { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "None of supported values were specified in options."); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/constant_side_packet_calculator_test.cc b/mediapipe/calculators/core/constant_side_packet_calculator_test.cc index 497dc5e55..1357a99a5 100644 --- a/mediapipe/calculators/core/constant_side_packet_calculator_test.cc +++ b/mediapipe/calculators/core/constant_side_packet_calculator_test.cc @@ -49,7 +49,7 @@ void DoTestSingleSidePacket(absl::string_view packet_spec, MP_ASSERT_OK(graph.GetOutputSidePacket("packet")); auto actual_value = - graph.GetOutputSidePacket("packet").ValueOrDie().template Get(); + graph.GetOutputSidePacket("packet").value().template Get(); EXPECT_EQ(actual_value, expected_value); } @@ -89,28 +89,24 @@ TEST(ConstantSidePacketCalculatorTest, MultiplePackets) { MP_ASSERT_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.GetOutputSidePacket("int_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("int_packet").ValueOrDie().Get(), - 256); + EXPECT_EQ(graph.GetOutputSidePacket("int_packet").value().Get(), 256); MP_ASSERT_OK(graph.GetOutputSidePacket("float_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("float_packet").ValueOrDie().Get(), + EXPECT_EQ(graph.GetOutputSidePacket("float_packet").value().Get(), 0.5f); MP_ASSERT_OK(graph.GetOutputSidePacket("bool_packet")); - EXPECT_FALSE( - graph.GetOutputSidePacket("bool_packet").ValueOrDie().Get()); + EXPECT_FALSE(graph.GetOutputSidePacket("bool_packet").value().Get()); MP_ASSERT_OK(graph.GetOutputSidePacket("string_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("string_packet") - .ValueOrDie() - .Get(), - "string"); + EXPECT_EQ( + graph.GetOutputSidePacket("string_packet").value().Get(), + "string"); MP_ASSERT_OK(graph.GetOutputSidePacket("another_string_packet")); EXPECT_EQ(graph.GetOutputSidePacket("another_string_packet") - .ValueOrDie() + .value() .Get(), "another string"); MP_ASSERT_OK(graph.GetOutputSidePacket("another_int_packet")); - EXPECT_EQ( - graph.GetOutputSidePacket("another_int_packet").ValueOrDie().Get(), - 128); + EXPECT_EQ(graph.GetOutputSidePacket("another_int_packet").value().Get(), + 128); } TEST(ConstantSidePacketCalculatorTest, ProcessingPacketsWithCorrectTagOnly) { @@ -142,19 +138,16 @@ TEST(ConstantSidePacketCalculatorTest, ProcessingPacketsWithCorrectTagOnly) { MP_ASSERT_OK(graph.WaitUntilIdle()); MP_ASSERT_OK(graph.GetOutputSidePacket("int_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("int_packet").ValueOrDie().Get(), - 256); + EXPECT_EQ(graph.GetOutputSidePacket("int_packet").value().Get(), 256); MP_ASSERT_OK(graph.GetOutputSidePacket("float_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("float_packet").ValueOrDie().Get(), + EXPECT_EQ(graph.GetOutputSidePacket("float_packet").value().Get(), 0.5f); MP_ASSERT_OK(graph.GetOutputSidePacket("bool_packet")); - EXPECT_FALSE( - graph.GetOutputSidePacket("bool_packet").ValueOrDie().Get()); + EXPECT_FALSE(graph.GetOutputSidePacket("bool_packet").value().Get()); MP_ASSERT_OK(graph.GetOutputSidePacket("string_packet")); - EXPECT_EQ(graph.GetOutputSidePacket("string_packet") - .ValueOrDie() - .Get(), - "string"); + EXPECT_EQ( + graph.GetOutputSidePacket("string_packet").value().Get(), + "string"); } TEST(ConstantSidePacketCalculatorTest, IncorrectConfig_MoreOptionsThanPackets) { diff --git a/mediapipe/calculators/core/counting_source_calculator.cc b/mediapipe/calculators/core/counting_source_calculator.cc index efd8148e9..0b731d9ce 100644 --- a/mediapipe/calculators/core/counting_source_calculator.cc +++ b/mediapipe/calculators/core/counting_source_calculator.cc @@ -30,7 +30,7 @@ namespace mediapipe { // provided, then batches are of size 1. class CountingSourceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN")) { @@ -55,13 +55,13 @@ class CountingSourceCalculator : public CalculatorBase { if (cc->InputSidePackets().HasTag("INCREMENT")) { cc->InputSidePackets().Tag("INCREMENT").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { if (cc->InputSidePackets().HasTag("ERROR_ON_OPEN") && cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get()) { - return mediapipe::NotFoundError("expected error"); + return absl::NotFoundError("expected error"); } if (cc->InputSidePackets().HasTag("ERROR_COUNT")) { error_count_ = cc->InputSidePackets().Tag("ERROR_COUNT").Get(); @@ -83,12 +83,12 @@ class CountingSourceCalculator : public CalculatorBase { RET_CHECK_LT(0, increment_); } RET_CHECK(error_count_ >= 0 || max_count_ >= 0); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (error_count_ >= 0 && batch_counter_ >= error_count_) { - return mediapipe::InternalError("expected error"); + return absl::InternalError("expected error"); } if (max_count_ >= 0 && batch_counter_ >= max_count_) { return tool::StatusStop(); @@ -98,7 +98,7 @@ class CountingSourceCalculator : public CalculatorBase { counter_ += increment_; } ++batch_counter_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/dequantize_byte_array_calculator.cc b/mediapipe/calculators/core/dequantize_byte_array_calculator.cc index 373c15461..04a7e55a0 100644 --- a/mediapipe/calculators/core/dequantize_byte_array_calculator.cc +++ b/mediapipe/calculators/core/dequantize_byte_array_calculator.cc @@ -37,34 +37,34 @@ namespace mediapipe { class DequantizeByteArrayCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("ENCODED").Set(); cc->Outputs().Tag("FLOAT_VECTOR").Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { const auto options = cc->Options<::mediapipe::DequantizeByteArrayCalculatorOptions>(); if (!options.has_max_quantized_value() || !options.has_min_quantized_value()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Both max_quantized_value and min_quantized_value must be provided " "in DequantizeByteArrayCalculatorOptions."); } float max_quantized_value = options.max_quantized_value(); float min_quantized_value = options.min_quantized_value(); if (max_quantized_value < min_quantized_value + FLT_EPSILON) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "max_quantized_value must be greater than min_quantized_value."); } float range = max_quantized_value - min_quantized_value; scalar_ = range / 255.0; bias_ = (range / 512.0) + min_quantized_value; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { const std::string& encoded = cc->Inputs().Tag("ENCODED").Value().Get(); std::vector float_vector; @@ -77,7 +77,7 @@ class DequantizeByteArrayCalculator : public CalculatorBase { .Tag("FLOAT_VECTOR") .AddPacket(MakePacket>(float_vector) .At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/end_loop_calculator.h b/mediapipe/calculators/core/end_loop_calculator.h index 969ed003a..e40301e81 100644 --- a/mediapipe/calculators/core/end_loop_calculator.h +++ b/mediapipe/calculators/core/end_loop_calculator.h @@ -57,7 +57,7 @@ class EndLoopCalculator : public CalculatorBase { using ItemT = typename IterableT::value_type; public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("BATCH_END")) << "Missing BATCH_END tagged input_stream."; cc->Inputs().Tag("BATCH_END").Set(); @@ -67,10 +67,10 @@ class EndLoopCalculator : public CalculatorBase { RET_CHECK(cc->Outputs().HasTag("ITERABLE")); cc->Outputs().Tag("ITERABLE").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (!cc->Inputs().Tag("ITEM").IsEmpty()) { if (!input_stream_collection_) { input_stream_collection_.reset(new IterableT); @@ -94,7 +94,7 @@ class EndLoopCalculator : public CalculatorBase { .SetNextTimestampBound(Timestamp(loop_control_ts.Value() + 1)); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/flow_limiter_calculator.cc b/mediapipe/calculators/core/flow_limiter_calculator.cc index ffee4b8ed..4fbfced96 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator.cc @@ -67,7 +67,7 @@ namespace mediapipe { // class FlowLimiterCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { auto& side_inputs = cc->InputSidePackets(); side_inputs.Tag("OPTIONS").Set().Optional(); cc->Inputs().Tag("OPTIONS").Set().Optional(); @@ -81,10 +81,10 @@ class FlowLimiterCalculator : public CalculatorBase { cc->Outputs().Tag("ALLOW").Set().Optional(); cc->SetInputStreamHandler("ImmediateInputStreamHandler"); cc->SetProcessTimestampBounds(true); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { options_ = cc->Options(); options_ = tool::RetrieveOptions(options_, cc->InputSidePackets()); if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { @@ -93,7 +93,7 @@ class FlowLimiterCalculator : public CalculatorBase { } input_queues_.resize(cc->Inputs().NumEntries("")); RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Returns true if an additional frame can be released for processing. @@ -151,7 +151,7 @@ class FlowLimiterCalculator : public CalculatorBase { } // Releases input packets allowed by the max_in_flight constraint. - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { options_ = tool::RetrieveOptions(options_, cc->Inputs()); // Process the FINISHED input stream. @@ -216,7 +216,7 @@ class FlowLimiterCalculator : public CalculatorBase { } ProcessAuxiliaryInputs(cc); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/flow_limiter_calculator_test.cc b/mediapipe/calculators/core/flow_limiter_calculator_test.cc index 12cacfc72..303c1a053 100644 --- a/mediapipe/calculators/core/flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/flow_limiter_calculator_test.cc @@ -71,19 +71,19 @@ std::vector PacketValues(const std::vector& packets) { } // A Calculator::Process callback function. -typedef std::function +typedef std::function ProcessFunction; // A testing callback function that passes through all packets. -mediapipe::Status PassthroughFunction(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { +absl::Status PassthroughFunction(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { for (int i = 0; i < inputs.NumEntries(); ++i) { if (!inputs.Index(i).Value().IsEmpty()) { outputs->Index(i).AddPacket(inputs.Index(i).Value()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Tests demonstrating an FlowLimiterCalculator operating in a cyclic graph. @@ -111,8 +111,8 @@ class FlowLimiterCalculatorSemaphoreTest : public testing::Test { {"callback_1", Adopt(new auto(semaphore_1_func))}, })); - allow_poller_.reset(new OutputStreamPoller( - graph_.AddOutputStreamPoller("allow").ValueOrDie())); + allow_poller_.reset( + new OutputStreamPoller(graph_.AddOutputStreamPoller("allow").value())); } // Adds a packet to a graph input stream. @@ -203,22 +203,22 @@ TEST_F(FlowLimiterCalculatorSemaphoreTest, FramesDropped) { // A calculator that sleeps during Process. class SleepCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("PACKET").SetAny(); cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET")); cc->InputSidePackets().Tag("SLEEP_TIME").Set(); cc->InputSidePackets().Tag("WARMUP_TIME").Set(); cc->InputSidePackets().Tag("CLOCK").Set(); cc->SetTimestampOffset(0); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { clock_ = cc->InputSidePackets().Tag("CLOCK").Get(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { ++packet_count; absl::Duration sleep_time = absl::Microseconds( packet_count == 1 @@ -226,7 +226,7 @@ class SleepCalculator : public CalculatorBase { : cc->InputSidePackets().Tag("SLEEP_TIME").Get()); clock_->Sleep(sleep_time); cc->Outputs().Tag("PACKET").AddPacket(cc->Inputs().Tag("PACKET").Value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -239,15 +239,15 @@ REGISTER_CALCULATOR(SleepCalculator); // Drops the 3rd packet, and optionally the corresponding timestamp bound. class DropCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("PACKET").SetAny(); cc->Outputs().Tag("PACKET").SetSameAs(&cc->Inputs().Tag("PACKET")); cc->InputSidePackets().Tag("DROP_TIMESTAMPS").Set(); cc->SetProcessTimestampBounds(true); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { if (!cc->Inputs().Tag("PACKET").Value().IsEmpty()) { ++packet_count; } @@ -259,7 +259,7 @@ class DropCalculator : public CalculatorBase { cc->Outputs().Tag("PACKET").SetNextTimestampBound( cc->InputTimestamp().NextAllowedInStream()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -365,11 +365,11 @@ TEST_F(FlowLimiterCalculatorTest, FinishedTimestamps) { MP_ASSERT_OK(graph_.Initialize(graph_config)); MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { out_1_packets_.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { allow_packets_.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); simulation_clock_->ThreadStart(); MP_ASSERT_OK(graph_.StartRun(side_packets)); @@ -437,11 +437,11 @@ TEST_F(FlowLimiterCalculatorTest, FinishedLost) { MP_ASSERT_OK(graph_.Initialize(graph_config)); MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { out_1_packets_.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { allow_packets_.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); simulation_clock_->ThreadStart(); MP_ASSERT_OK(graph_.StartRun(side_packets)); @@ -501,11 +501,11 @@ TEST_F(FlowLimiterCalculatorTest, FinishedDelayed) { MP_ASSERT_OK(graph_.Initialize(graph_config)); MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { out_1_packets_.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { allow_packets_.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); simulation_clock_->ThreadStart(); MP_ASSERT_OK(graph_.StartRun(side_packets)); @@ -596,16 +596,16 @@ TEST_F(FlowLimiterCalculatorTest, TwoInputStreams) { MP_ASSERT_OK(graph_.Initialize(graph_config)); MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { out_1_packets_.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); std::vector out_2_packets; MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) { out_2_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { allow_packets_.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); simulation_clock_->ThreadStart(); MP_ASSERT_OK(graph_.StartRun(side_packets)); @@ -705,16 +705,16 @@ TEST_F(FlowLimiterCalculatorTest, ZeroQueue) { MP_ASSERT_OK(graph_.Initialize(graph_config)); MP_EXPECT_OK(graph_.ObserveOutputStream("out_1", [this](Packet p) { out_1_packets_.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); std::vector out_2_packets; MP_EXPECT_OK(graph_.ObserveOutputStream("in_2_sampled", [&](Packet p) { out_2_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_EXPECT_OK(graph_.ObserveOutputStream("allow", [this](Packet p) { allow_packets_.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); simulation_clock_->ThreadStart(); MP_ASSERT_OK(graph_.StartRun(side_packets)); diff --git a/mediapipe/calculators/core/gate_calculator.cc b/mediapipe/calculators/core/gate_calculator.cc index 95ae9b03f..189671860 100644 --- a/mediapipe/calculators/core/gate_calculator.cc +++ b/mediapipe/calculators/core/gate_calculator.cc @@ -82,8 +82,7 @@ class GateCalculator : public CalculatorBase { public: GateCalculator() {} - static mediapipe::Status CheckAndInitAllowDisallowInputs( - CalculatorContract* cc) { + static absl::Status CheckAndInitAllowDisallowInputs(CalculatorContract* cc) { bool input_via_side_packet = cc->InputSidePackets().HasTag("ALLOW") || cc->InputSidePackets().HasTag("DISALLOW"); bool input_via_stream = @@ -110,10 +109,10 @@ class GateCalculator : public CalculatorBase { cc->Inputs().Tag("DISALLOW").Set(); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK_OK(CheckAndInitAllowDisallowInputs(cc)); const int num_data_streams = cc->Inputs().NumEntries(""); @@ -130,10 +129,10 @@ class GateCalculator : public CalculatorBase { cc->Outputs().Tag("STATE_CHANGE").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { use_side_packet_for_allow_disallow_ = false; if (cc->InputSidePackets().HasTag("ALLOW")) { use_side_packet_for_allow_disallow_ = true; @@ -153,10 +152,10 @@ class GateCalculator : public CalculatorBase { const auto& options = cc->Options<::mediapipe::GateCalculatorOptions>(); empty_packets_as_allow_ = options.empty_packets_as_allow(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { bool allow = empty_packets_as_allow_; if (use_side_packet_for_allow_disallow_) { allow = allow_by_side_packet_decision_; @@ -195,7 +194,7 @@ class GateCalculator : public CalculatorBase { cc->Outputs().Get("", i).Close(); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Process data streams. @@ -205,7 +204,7 @@ class GateCalculator : public CalculatorBase { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/gate_calculator_test.cc b/mediapipe/calculators/core/gate_calculator_test.cc index d1dcae09d..0b78b9b75 100644 --- a/mediapipe/calculators/core/gate_calculator_test.cc +++ b/mediapipe/calculators/core/gate_calculator_test.cc @@ -25,7 +25,7 @@ namespace { class GateCalculatorTest : public ::testing::Test { protected: // Helper to run a graph and return status. - static mediapipe::Status RunGraph(const std::string& proto) { + static absl::Status RunGraph(const std::string& proto) { auto runner = absl::make_unique( ParseTextProtoOrDie(proto)); return runner->Run(); diff --git a/mediapipe/calculators/core/immediate_mux_calculator.cc b/mediapipe/calculators/core/immediate_mux_calculator.cc index e0e129f4b..0e51cda5e 100644 --- a/mediapipe/calculators/core/immediate_mux_calculator.cc +++ b/mediapipe/calculators/core/immediate_mux_calculator.cc @@ -43,16 +43,16 @@ class ImmediateMuxCalculator : public CalculatorBase { public: // This calculator combines any set of input streams into a single // output stream. All input stream types must match the output stream type. - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // Passes any input packet to the output stream immediately, unless the // packet timestamp is lower than a previously passed packet. - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(ImmediateMuxCalculator); -mediapipe::Status ImmediateMuxCalculator::GetContract(CalculatorContract* cc) { +absl::Status ImmediateMuxCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Outputs().NumEntries() >= 1 && cc->Outputs().NumEntries() <= 2) << "This calculator produces only one or two output streams."; cc->Outputs().Index(0).SetAny(); @@ -62,15 +62,15 @@ mediapipe::Status ImmediateMuxCalculator::GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetSameAs(&cc->Outputs().Index(0)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImmediateMuxCalculator::Open(CalculatorContext* cc) { +absl::Status ImmediateMuxCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImmediateMuxCalculator::Process(CalculatorContext* cc) { +absl::Status ImmediateMuxCalculator::Process(CalculatorContext* cc) { // Pass along the first packet, unless it has been superseded. for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { const Packet& packet = cc->Inputs().Index(i).Value(); @@ -88,7 +88,7 @@ mediapipe::Status ImmediateMuxCalculator::Process(CalculatorContext* cc) { } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/core/immediate_mux_calculator_test.cc b/mediapipe/calculators/core/immediate_mux_calculator_test.cc index d691e0f73..6913fd000 100644 --- a/mediapipe/calculators/core/immediate_mux_calculator_test.cc +++ b/mediapipe/calculators/core/immediate_mux_calculator_test.cc @@ -289,19 +289,19 @@ TEST_F(ImmediateMuxCalculatorTest, SimultaneousTimestamps) { } // A Calculator::Process callback function. -typedef std::function +typedef std::function ProcessFunction; // A testing callback function that passes through all packets. -mediapipe::Status PassThrough(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { +absl::Status PassThrough(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { for (int i = 0; i < inputs.NumEntries(); ++i) { if (!inputs.Index(i).Value().IsEmpty()) { outputs->Index(i).AddPacket(inputs.Index(i).Value()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } TEST_F(ImmediateMuxCalculatorTest, Demux) { @@ -325,7 +325,7 @@ TEST_F(ImmediateMuxCalculatorTest, Demux) { auto out_cb = [&](const Packet& p) { absl::MutexLock lock(&out_mutex); out_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); }; auto wait_for = [&](std::function cond) { absl::MutexLock lock(&out_mutex); diff --git a/mediapipe/calculators/core/make_pair_calculator.cc b/mediapipe/calculators/core/make_pair_calculator.cc index 5d3cf1daf..561656861 100644 --- a/mediapipe/calculators/core/make_pair_calculator.cc +++ b/mediapipe/calculators/core/make_pair_calculator.cc @@ -41,14 +41,14 @@ class MakePairCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kPair); - static mediapipe::Status UpdateContract(CalculatorContract* cc) { + static absl::Status UpdateContract(CalculatorContract* cc) { RET_CHECK_EQ(kIn(cc).Count(), 2); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { kPair(cc).Send({kIn(cc)[0].packet(), kIn(cc)[1].packet()}); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/core/matrix_multiply_calculator.cc b/mediapipe/calculators/core/matrix_multiply_calculator.cc index fbc18297b..d52e0c2fa 100644 --- a/mediapipe/calculators/core/matrix_multiply_calculator.cc +++ b/mediapipe/calculators/core/matrix_multiply_calculator.cc @@ -38,13 +38,13 @@ class MatrixMultiplyCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kOut, kSide); - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; MEDIAPIPE_REGISTER_NODE(MatrixMultiplyCalculator); -mediapipe::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) { +absl::Status MatrixMultiplyCalculator::Process(CalculatorContext* cc) { kOut(cc).Send(*kSide(cc) * *kIn(cc)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace api2 diff --git a/mediapipe/calculators/core/matrix_subtract_calculator.cc b/mediapipe/calculators/core/matrix_subtract_calculator.cc index f526a0ceb..09471a5ee 100644 --- a/mediapipe/calculators/core/matrix_subtract_calculator.cc +++ b/mediapipe/calculators/core/matrix_subtract_calculator.cc @@ -50,32 +50,31 @@ class MatrixSubtractCalculator : public Node { static constexpr Output kOut{""}; MEDIAPIPE_NODE_CONTRACT(kMinuend, kSubtrahend, kOut); - static mediapipe::Status UpdateContract(CalculatorContract* cc); + static absl::Status UpdateContract(CalculatorContract* cc); - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; MEDIAPIPE_REGISTER_NODE(MatrixSubtractCalculator); // static -mediapipe::Status MatrixSubtractCalculator::UpdateContract( - CalculatorContract* cc) { +absl::Status MatrixSubtractCalculator::UpdateContract(CalculatorContract* cc) { // TODO: the next restriction could be relaxed. RET_CHECK(kMinuend(cc).IsStream() ^ kSubtrahend(cc).IsStream()) << "MatrixSubtractCalculator only accepts exactly one input stream and " "one input side packet"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) { +absl::Status MatrixSubtractCalculator::Process(CalculatorContext* cc) { const Matrix& minuend = *kMinuend(cc); const Matrix& subtrahend = *kSubtrahend(cc); if (minuend.rows() != subtrahend.rows() || minuend.cols() != subtrahend.cols()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Minuend and subtrahend must have the same dimensions."); } kOut(cc).Send(minuend - subtrahend); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace api2 diff --git a/mediapipe/calculators/core/matrix_to_vector_calculator.cc b/mediapipe/calculators/core/matrix_to_vector_calculator.cc index cd10d2668..90a36053b 100644 --- a/mediapipe/calculators/core/matrix_to_vector_calculator.cc +++ b/mediapipe/calculators/core/matrix_to_vector_calculator.cc @@ -49,12 +49,19 @@ class MatrixToVectorCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kOut); + absl::Status Open(CalculatorContext* cc) override; + // Outputs a packet containing a vector for each input packet. - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; MEDIAPIPE_REGISTER_NODE(MatrixToVectorCalculator); -mediapipe::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) { +absl::Status MatrixToVectorCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(0); + return mediapipe::OkStatus(); +} + +absl::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) { const Matrix& input = *kIn(cc); auto output = absl::make_unique>(); @@ -66,7 +73,7 @@ mediapipe::Status MatrixToVectorCalculator::Process(CalculatorContext* cc) { output_as_matrix = input; kOut(cc).Send(std::move(output)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace api2 diff --git a/mediapipe/calculators/core/merge_calculator.cc b/mediapipe/calculators/core/merge_calculator.cc index 96056b4e3..a283842ae 100644 --- a/mediapipe/calculators/core/merge_calculator.cc +++ b/mediapipe/calculators/core/merge_calculator.cc @@ -50,7 +50,7 @@ class MergeCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kOut); - static mediapipe::Status UpdateContract(CalculatorContract* cc) { + static absl::Status UpdateContract(CalculatorContract* cc) { RET_CHECK_GT(kIn(cc).Count(), 0) << "Needs at least one input stream"; if (kIn(cc).Count() == 1) { LOG(WARNING) @@ -59,23 +59,23 @@ class MergeCalculator : public Node { "correctly or consider removing this calculator to reduce " "unnecessary overhead."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { // Output the packet from the first input stream with a packet ready at this // timestamp. - for (int i = 0; i < kIn(cc).Count(); ++i) { - if (!kIn(cc)[i].IsEmpty()) { - kOut(cc).Send(kIn(cc)[i].packet()); - return mediapipe::OkStatus(); + for (const auto& input : kIn(cc)) { + if (!input.IsEmpty()) { + kOut(cc).Send(input.packet()); + return absl::OkStatus(); } } LOG(WARNING) << "Empty input packets at timestamp " << cc->InputTimestamp().Value(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/core/mux_calculator.cc b/mediapipe/calculators/core/mux_calculator.cc index f488a5c98..a0ce2ae34 100644 --- a/mediapipe/calculators/core/mux_calculator.cc +++ b/mediapipe/calculators/core/mux_calculator.cc @@ -40,13 +40,13 @@ class MuxCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kSelect, kIn, kOut, StreamHandler("MuxInputStreamHandler")); - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int select = *kSelect(cc); RET_CHECK(0 <= select && select < kIn(cc).Count()); if (!kIn(cc)[select].IsEmpty()) { kOut(cc).Send(kIn(cc)[select].packet()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/core/mux_calculator_test.cc b/mediapipe/calculators/core/mux_calculator_test.cc index c99919a38..46cce74d0 100644 --- a/mediapipe/calculators/core/mux_calculator_test.cc +++ b/mediapipe/calculators/core/mux_calculator_test.cc @@ -134,7 +134,7 @@ void RunGraph(const std::string& graph_config_proto, const std::string& input_stream_name, int num_input_packets, std::function input_fn, const std::string& output_stream_name, - std::function output_fn) { + std::function output_fn) { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(graph_config_proto); CalculatorGraph graph; @@ -165,9 +165,9 @@ TEST(MuxCalculatorTest, InputStreamSelector_DefaultInputStreamHandler) { // Output and handling. std::vector output; // This function collects the output from the packet. - auto output_fn = [&output](const Packet& p) -> mediapipe::Status { + auto output_fn = [&output](const Packet& p) -> absl::Status { output.push_back(p.Get()); - return mediapipe::OkStatus(); + return absl::OkStatus(); }; RunGraph(kTestGraphConfig1, {}, kInputName, input_packets.size(), input_fn, @@ -191,9 +191,9 @@ TEST(MuxCalculatorTest, InputSidePacketSelector_DefaultInputStreamHandler) { // Output and handling. std::vector output; // This function collects the output from the packet. - auto output_fn = [&output](const Packet& p) -> mediapipe::Status { + auto output_fn = [&output](const Packet& p) -> absl::Status { output.push_back(p.Get()); - return mediapipe::OkStatus(); + return absl::OkStatus(); }; RunGraph(kTestGraphConfig2, {{kInputSelector, MakePacket(0)}}, @@ -225,9 +225,9 @@ TEST(MuxCalculatorTest, InputStreamSelector_MuxInputStreamHandler) { // Output and handling. std::vector output; // This function collects the output from the packet. - auto output_fn = [&output](const Packet& p) -> mediapipe::Status { + auto output_fn = [&output](const Packet& p) -> absl::Status { output.push_back(p.Get()); - return mediapipe::OkStatus(); + return absl::OkStatus(); }; RunGraph(kTestGraphConfig3, {}, kInputName, input_packets.size(), input_fn, @@ -260,7 +260,7 @@ TEST(MuxCalculatorTest, DiscardSkippedInputs_MuxInputStreamHandler) { MP_ASSERT_OK( graph.ObserveOutputStream("test_output", [&output](const Packet& p) { output = p.Get>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); diff --git a/mediapipe/calculators/core/packet_cloner_calculator.cc b/mediapipe/calculators/core/packet_cloner_calculator.cc index c2a2979c7..41bddbfa7 100644 --- a/mediapipe/calculators/core/packet_cloner_calculator.cc +++ b/mediapipe/calculators/core/packet_cloner_calculator.cc @@ -45,17 +45,17 @@ namespace mediapipe { // packet_inner_join_calculator.cc: Don't output unless all inputs are new. class PacketClonerCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const int tick_signal_index = cc->Inputs().NumEntries() - 1; for (int i = 0; i < tick_signal_index; ++i) { cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i)); } cc->Inputs().Index(tick_signal_index).SetAny(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { // Load options. const auto calculator_options = cc->Options(); @@ -71,10 +71,10 @@ class PacketClonerCalculator : public CalculatorBase { cc->Outputs().Index(i).SetHeader(cc->Inputs().Index(i).Header()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { // Store input signals. for (int i = 0; i < tick_signal_index_; ++i) { if (!cc->Inputs().Index(i).Value().IsEmpty()) { @@ -88,7 +88,7 @@ class PacketClonerCalculator : public CalculatorBase { // Return if one of the input is null. for (int i = 0; i < tick_signal_index_; ++i) { if (current_[i].IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } } @@ -103,7 +103,7 @@ class PacketClonerCalculator : public CalculatorBase { } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/packet_inner_join_calculator.cc b/mediapipe/calculators/core/packet_inner_join_calculator.cc index 1f77d4149..6ffffb58b 100644 --- a/mediapipe/calculators/core/packet_inner_join_calculator.cc +++ b/mediapipe/calculators/core/packet_inner_join_calculator.cc @@ -34,10 +34,10 @@ namespace mediapipe { // packet_cloner_calculator.cc: Repeats last-seen packets from empty inputs. class PacketInnerJoinCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: int num_streams_; @@ -45,8 +45,7 @@ class PacketInnerJoinCalculator : public CalculatorBase { REGISTER_CALCULATOR(PacketInnerJoinCalculator); -mediapipe::Status PacketInnerJoinCalculator::GetContract( - CalculatorContract* cc) { +absl::Status PacketInnerJoinCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().NumEntries() == cc->Outputs().NumEntries()) << "The number of input and output streams must match."; const int num_streams = cc->Inputs().NumEntries(); @@ -54,25 +53,25 @@ mediapipe::Status PacketInnerJoinCalculator::GetContract( cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketInnerJoinCalculator::Open(CalculatorContext* cc) { +absl::Status PacketInnerJoinCalculator::Open(CalculatorContext* cc) { num_streams_ = cc->Inputs().NumEntries(); cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketInnerJoinCalculator::Process(CalculatorContext* cc) { +absl::Status PacketInnerJoinCalculator::Process(CalculatorContext* cc) { for (int i = 0; i < num_streams_; ++i) { if (cc->Inputs().Index(i).Value().IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } for (int i = 0; i < num_streams_; ++i) { cc->Outputs().Index(i).AddPacket(cc->Inputs().Index(i).Value()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/core/packet_presence_calculator.cc b/mediapipe/calculators/core/packet_presence_calculator.cc index 7f823f27e..cb119a76d 100644 --- a/mediapipe/calculators/core/packet_presence_calculator.cc +++ b/mediapipe/calculators/core/packet_presence_calculator.cc @@ -57,26 +57,26 @@ namespace mediapipe { // } class PacketPresenceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("PACKET").SetAny(); cc->Outputs().Tag("PRESENCE").Set(); // Process() function is invoked in response to input stream timestamp // bound updates. cc->SetProcessTimestampBounds(true); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->Outputs() .Tag("PRESENCE") .AddPacket(MakePacket(!cc->Inputs().Tag("PACKET").IsEmpty()) .At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(PacketPresenceCalculator); diff --git a/mediapipe/calculators/core/packet_resampler_calculator.cc b/mediapipe/calculators/core/packet_resampler_calculator.cc index 4a08c1f1c..32b1c850a 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.cc +++ b/mediapipe/calculators/core/packet_resampler_calculator.cc @@ -47,8 +47,7 @@ TimestampDiff TimestampDiffFromSeconds(double seconds) { } } // namespace -mediapipe::Status PacketResamplerCalculator::GetContract( - CalculatorContract* cc) { +absl::Status PacketResamplerCalculator::GetContract(CalculatorContract* cc) { const auto& resampler_options = cc->Options(); if (cc->InputSidePackets().HasTag("OPTIONS")) { @@ -78,10 +77,10 @@ mediapipe::Status PacketResamplerCalculator::GetContract( RET_CHECK(cc->InputSidePackets().HasTag("SEED")); cc->InputSidePackets().Tag("SEED").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { +absl::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { const auto resampler_options = tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), "OPTIONS"); @@ -156,8 +155,8 @@ mediapipe::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { const auto& seed = cc->InputSidePackets().Tag("SEED").Get(); random_ = CreateSecureRandom(seed); if (random_ == nullptr) { - return mediapipe::Status( - mediapipe::StatusCode::kInvalidArgument, + return absl::Status( + absl::StatusCode::kInvalidArgument, "SecureRandom is not available. With \"jitter\" specified, " "PacketResamplerCalculator processing cannot proceed."); } @@ -165,17 +164,17 @@ mediapipe::Status PacketResamplerCalculator::Open(CalculatorContext* cc) { } packet_reservoir_ = std::make_unique(packet_reservoir_random_.get()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { +absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { if (cc->InputTimestamp() == Timestamp::PreStream() && cc->Inputs().UsesTags() && cc->Inputs().HasTag("VIDEO_HEADER") && !cc->Inputs().Tag("VIDEO_HEADER").IsEmpty()) { video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get(); video_header_.frame_rate = frame_rate_; if (cc->Inputs().Get(input_data_id_).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } if (jitter_ != 0.0 && random_ != nullptr) { @@ -192,7 +191,7 @@ mediapipe::Status PacketResamplerCalculator::Process(CalculatorContext* cc) { MP_RETURN_IF_ERROR(ProcessWithoutJitter(cc)); } last_packet_ = cc->Inputs().Get(input_data_id_).Value(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void PacketResamplerCalculator::InitializeNextOutputTimestampWithJitter() { @@ -229,7 +228,7 @@ void PacketResamplerCalculator::UpdateNextOutputTimestampWithJitter() { ((1.0 - jitter_) + 2.0 * jitter_ * random_->RandFloat()); } -mediapipe::Status PacketResamplerCalculator::ProcessWithJitter( +absl::Status PacketResamplerCalculator::ProcessWithJitter( CalculatorContext* cc) { RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); RET_CHECK_NE(jitter_, 0.0); @@ -243,7 +242,7 @@ mediapipe::Status PacketResamplerCalculator::ProcessWithJitter( cc->Inputs().Get(input_data_id_).Value().At(next_output_timestamp_)); UpdateNextOutputTimestampWithJitter(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (frame_time_usec_ < @@ -266,11 +265,21 @@ mediapipe::Status PacketResamplerCalculator::ProcessWithJitter( : cc->Inputs().Get(input_data_id_).Value()) .At(next_output_timestamp_)); UpdateNextOutputTimestampWithJitter(); + // From now on every time a packet is emitted the timestamp of the next + // packet becomes known; that timestamp is stored in next_output_timestamp_. + // The only exception to this rule is the packet emitted from Close() which + // can only happen when jitter_with_reflection is enabled but in this case + // next_output_timestamp_min_ is a non-decreasing lower bound of any + // subsequent packet. + const Timestamp timestamp_bound = jitter_with_reflection_ + ? next_output_timestamp_min_ + : next_output_timestamp_; + cc->Outputs().Get(output_data_id_).SetNextTimestampBound(timestamp_bound); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketResamplerCalculator::ProcessWithoutJitter( +absl::Status PacketResamplerCalculator::ProcessWithoutJitter( CalculatorContext* cc) { RET_CHECK_GT(cc->InputTimestamp(), Timestamp::PreStream()); RET_CHECK_EQ(jitter_, 0.0); @@ -333,12 +342,12 @@ mediapipe::Status PacketResamplerCalculator::ProcessWithoutJitter( .Get(output_data_id_) .SetNextTimestampBound(PeriodIndexToTimestamp(period_count_)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketResamplerCalculator::Close(CalculatorContext* cc) { +absl::Status PacketResamplerCalculator::Close(CalculatorContext* cc) { if (!cc->GraphStatus().ok()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Emit the last packet received if we have at least one packet, but // haven't sent anything for its period. @@ -350,7 +359,7 @@ mediapipe::Status PacketResamplerCalculator::Close(CalculatorContext* cc) { if (!packet_reservoir_->IsEmpty()) { OutputWithinLimits(cc, packet_reservoir_->GetSample()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } Timestamp PacketResamplerCalculator::PeriodIndexToTimestamp(int64 index) const { diff --git a/mediapipe/calculators/core/packet_resampler_calculator.h b/mediapipe/calculators/core/packet_resampler_calculator.h index c07eb1c24..4a1a3ffaa 100644 --- a/mediapipe/calculators/core/packet_resampler_calculator.h +++ b/mediapipe/calculators/core/packet_resampler_calculator.h @@ -99,11 +99,11 @@ class PacketReservoir { // packet_downsampler_calculator.cc: skips packets regardless of timestamps. class PacketResamplerCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Calculates the first sampled timestamp that incorporates a jittering @@ -113,10 +113,10 @@ class PacketResamplerCalculator : public CalculatorBase { void UpdateNextOutputTimestampWithJitter(); // Logic for Process() when jitter_ != 0.0. - mediapipe::Status ProcessWithJitter(CalculatorContext* cc); + absl::Status ProcessWithJitter(CalculatorContext* cc); // Logic for Process() when jitter_ == 0.0. - mediapipe::Status ProcessWithoutJitter(CalculatorContext* cc); + absl::Status ProcessWithoutJitter(CalculatorContext* cc); // Given the current count of periods that have passed, this returns // the next valid timestamp of the middle point of the next period: diff --git a/mediapipe/calculators/core/packet_thinner_calculator.cc b/mediapipe/calculators/core/packet_thinner_calculator.cc index 4795ad5e4..d3d391b61 100644 --- a/mediapipe/calculators/core/packet_thinner_calculator.cc +++ b/mediapipe/calculators/core/packet_thinner_calculator.cc @@ -90,7 +90,7 @@ class PacketThinnerCalculator : public CalculatorBase { PacketThinnerCalculator() {} ~PacketThinnerCalculator() override {} - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { if (cc->InputSidePackets().HasTag(kOptionsTag)) { cc->InputSidePackets().Tag(kOptionsTag).Set(); } @@ -99,21 +99,21 @@ class PacketThinnerCalculator : public CalculatorBase { if (cc->InputSidePackets().HasTag(kPeriodTag)) { cc->InputSidePackets().Tag(kPeriodTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override { if (cc->InputTimestamp() < start_time_) { - return mediapipe::OkStatus(); // Drop packets before start_time_. + return absl::OkStatus(); // Drop packets before start_time_. } else if (cc->InputTimestamp() >= end_time_) { if (!cc->Outputs().Index(0).IsClosed()) { cc->Outputs() .Index(0) .Close(); // No more Packets will be output after end_time_. } - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { return thinner_type_ == PacketThinnerCalculatorOptions::ASYNC ? AsyncThinnerProcess(cc) @@ -123,8 +123,8 @@ class PacketThinnerCalculator : public CalculatorBase { private: // Implementation of ASYNC and SYNC versions of thinner algorithm. - mediapipe::Status AsyncThinnerProcess(CalculatorContext* cc); - mediapipe::Status SyncThinnerProcess(CalculatorContext* cc); + absl::Status AsyncThinnerProcess(CalculatorContext* cc); + absl::Status SyncThinnerProcess(CalculatorContext* cc); // Cached option. PacketThinnerCalculatorOptions::ThinnerType thinner_type_; @@ -153,7 +153,7 @@ namespace { TimestampDiff abs(TimestampDiff t) { return t < 0 ? -t : t; } } // namespace -mediapipe::Status PacketThinnerCalculator::Open(CalculatorContext* cc) { +absl::Status PacketThinnerCalculator::Open(CalculatorContext* cc) { PacketThinnerCalculatorOptions options = mediapipe::tool::RetrieveOptions( cc->Options(), cc->InputSidePackets(), kOptionsTag); @@ -224,10 +224,10 @@ mediapipe::Status PacketThinnerCalculator::Open(CalculatorContext* cc) { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketThinnerCalculator::Close(CalculatorContext* cc) { +absl::Status PacketThinnerCalculator::Close(CalculatorContext* cc) { // Emit any saved packets before quitting. if (!saved_packet_.IsEmpty()) { // Only sync thinner should have saved packets. @@ -239,10 +239,10 @@ mediapipe::Status PacketThinnerCalculator::Close(CalculatorContext* cc) { cc->Outputs().Index(0).AddPacket(saved_packet_); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketThinnerCalculator::AsyncThinnerProcess( +absl::Status PacketThinnerCalculator::AsyncThinnerProcess( CalculatorContext* cc) { if (cc->InputTimestamp() >= next_valid_timestamp_) { cc->Outputs().Index(0).AddPacket( @@ -251,10 +251,10 @@ mediapipe::Status PacketThinnerCalculator::AsyncThinnerProcess( // Guaranteed not to emit packets seen during refractory period. cc->Outputs().Index(0).SetNextTimestampBound(next_valid_timestamp_); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketThinnerCalculator::SyncThinnerProcess( +absl::Status PacketThinnerCalculator::SyncThinnerProcess( CalculatorContext* cc) { if (saved_packet_.IsEmpty()) { // If no packet has been saved, store the current packet. @@ -290,7 +290,7 @@ mediapipe::Status PacketThinnerCalculator::SyncThinnerProcess( saved_packet_ = cc->Inputs().Index(0).Value(); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } Timestamp PacketThinnerCalculator::NearestSyncTimestamp(Timestamp now) const { diff --git a/mediapipe/calculators/core/pass_through_calculator.cc b/mediapipe/calculators/core/pass_through_calculator.cc index d07104733..197e1331a 100644 --- a/mediapipe/calculators/core/pass_through_calculator.cc +++ b/mediapipe/calculators/core/pass_through_calculator.cc @@ -28,9 +28,9 @@ namespace mediapipe { // ignored. class PassThroughCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { if (!cc->Inputs().TagMap()->SameAs(*cc->Outputs().TagMap())) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Input and output streams to PassThroughCalculator must use " "matching tags and indexes."); } @@ -46,7 +46,7 @@ class PassThroughCalculator : public CalculatorBase { if (cc->OutputSidePackets().NumEntries() != 0) { if (!cc->InputSidePackets().TagMap()->SameAs( *cc->OutputSidePackets().TagMap())) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Input and output side packets to PassThroughCalculator must use " "matching tags and indexes."); } @@ -56,10 +56,10 @@ class PassThroughCalculator : public CalculatorBase { &cc->InputSidePackets().Get(id)); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); ++id) { if (!cc->Inputs().Get(id).Header().IsEmpty()) { @@ -73,10 +73,10 @@ class PassThroughCalculator : public CalculatorBase { } } cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->GetCounter("PassThrough")->Increment(); if (cc->Inputs().NumEntries() == 0) { return tool::StatusStop(); @@ -90,7 +90,7 @@ class PassThroughCalculator : public CalculatorBase { cc->Outputs().Get(id).AddPacket(cc->Inputs().Get(id).Value()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(PassThroughCalculator); diff --git a/mediapipe/calculators/core/previous_loopback_calculator.cc b/mediapipe/calculators/core/previous_loopback_calculator.cc index e42261c28..d67e6c061 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator.cc @@ -65,19 +65,19 @@ class PreviousLoopbackCalculator : public Node { StreamHandler("ImmediateInputStreamHandler"), TimestampChange::Arbitrary()); - static mediapipe::Status UpdateContract(CalculatorContract* cc) { + static absl::Status UpdateContract(CalculatorContract* cc) { // Process() function is invoked in response to MAIN/LOOP stream timestamp // bound updates. cc->SetProcessTimestampBounds(true); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { kPrevLoop(cc).SetHeader(kLoop(cc).Header()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { // Non-empty packets and empty packets indicating timestamp bound updates // are guaranteed to have timestamps greater than timestamps of previous // packets within the same stream. Calculator tracks and operates on such @@ -106,7 +106,7 @@ class PreviousLoopbackCalculator : public Node { while (!main_packet_specs_.empty() && !loop_packets_.empty()) { // The earliest MAIN packet. - const MainPacketSpec& main_spec = main_packet_specs_.front(); + MainPacketSpec main_spec = main_packet_specs_.front(); // The earliest LOOP packet. const PacketBase& loop_candidate = loop_packets_.front(); // Match LOOP and MAIN packets. @@ -139,7 +139,7 @@ class PreviousLoopbackCalculator : public Node { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/previous_loopback_calculator_test.cc b/mediapipe/calculators/core/previous_loopback_calculator_test.cc index 4c4d9b6e8..54959edae 100644 --- a/mediapipe/calculators/core/previous_loopback_calculator_test.cc +++ b/mediapipe/calculators/core/previous_loopback_calculator_test.cc @@ -136,27 +136,27 @@ TEST(PreviousLoopbackCalculator, CorrectTimestamps) { // A Calculator that outputs a summary packet in CalculatorBase::Close(). class PacketOnCloseCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { sum_ += cc->Inputs().Index(0).Value().Get(); cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { cc->Outputs().Index(0).AddPacket( MakePacket(sum_).At(Timestamp::Max())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -700,19 +700,19 @@ TEST_F(PreviousLoopbackCalculatorProcessingTimestampsTest, // Similar to GateCalculator, but it doesn't propagate timestamp bound updates. class DroppingGateCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Inputs().Tag("DISALLOW").Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { if (!cc->Inputs().Index(0).IsEmpty() && !cc->Inputs().Tag("DISALLOW").Get()) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(DroppingGateCalculator); diff --git a/mediapipe/calculators/core/quantize_float_vector_calculator.cc b/mediapipe/calculators/core/quantize_float_vector_calculator.cc index 514159145..e95509298 100644 --- a/mediapipe/calculators/core/quantize_float_vector_calculator.cc +++ b/mediapipe/calculators/core/quantize_float_vector_calculator.cc @@ -43,32 +43,32 @@ namespace mediapipe { class QuantizeFloatVectorCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("FLOAT_VECTOR").Set>(); cc->Outputs().Tag("ENCODED").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { const auto options = cc->Options<::mediapipe::QuantizeFloatVectorCalculatorOptions>(); if (!options.has_max_quantized_value() || !options.has_min_quantized_value()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Both max_quantized_value and min_quantized_value must be provided " "in QuantizeFloatVectorCalculatorOptions."); } max_quantized_value_ = options.max_quantized_value(); min_quantized_value_ = options.min_quantized_value(); if (max_quantized_value_ < min_quantized_value_ + FLT_EPSILON) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "max_quantized_value must be greater than min_quantized_value."); } range_ = max_quantized_value_ - min_quantized_value_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { const std::vector& float_vector = cc->Inputs().Tag("FLOAT_VECTOR").Value().Get>(); int feature_size = float_vector.size(); @@ -88,7 +88,7 @@ class QuantizeFloatVectorCalculator : public CalculatorBase { } cc->Outputs().Tag("ENCODED").AddPacket( MakePacket(encoded_features).At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc index 0f7cde49a..277f83fe2 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator.cc @@ -75,7 +75,7 @@ namespace mediapipe { // } class RealTimeFlowLimiterCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { int num_data_streams = cc->Inputs().NumEntries(""); RET_CHECK_GE(num_data_streams, 1); RET_CHECK_EQ(cc->Outputs().NumEntries(""), num_data_streams) @@ -95,10 +95,10 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase { cc->SetInputStreamHandler("ImmediateInputStreamHandler"); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { finished_id_ = cc->Inputs().GetId("FINISHED", 0); max_in_flight_ = 1; if (cc->InputSidePackets().HasTag("MAX_IN_FLIGHT")) { @@ -113,12 +113,12 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase { num_data_streams_ = cc->Inputs().NumEntries(""); data_stream_bound_ts_.resize(num_data_streams_); RET_CHECK_OK(CopyInputHeadersToOutputs(cc->Inputs(), &(cc->Outputs()))); - return mediapipe::OkStatus(); + return absl::OkStatus(); } bool Allow() { return num_in_flight_ < max_in_flight_; } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { bool old_allow = Allow(); Timestamp lowest_incomplete_ts = Timestamp::Done(); @@ -180,7 +180,7 @@ class RealTimeFlowLimiterCalculator : public CalculatorBase { .Get(allowed_id_) .AddPacket(MakePacket(Allow()).At(++allow_ctr_ts_)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc b/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc index 7f4ce1db1..73c50e56d 100644 --- a/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc +++ b/mediapipe/calculators/core/real_time_flow_limiter_calculator_test.cc @@ -127,25 +127,25 @@ TEST(RealTimeFlowLimiterCalculator, BasicTest) { } // A Calculator::Process callback function. -typedef std::function +typedef std::function ProcessFunction; // A testing callback function that passes through all packets. -mediapipe::Status PassthroughFunction(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { +absl::Status PassthroughFunction(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { for (int i = 0; i < inputs.NumEntries(); ++i) { if (!inputs.Index(i).Value().IsEmpty()) { outputs->Index(i).AddPacket(inputs.Index(i).Value()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // A Calculator that runs a testing callback function in Close. class CloseCallbackCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); ++id) { cc->Inputs().Get(id).SetAny(); @@ -154,18 +154,17 @@ class CloseCallbackCalculator : public CalculatorBase { id < cc->Outputs().EndId(); ++id) { cc->Outputs().Get(id).SetAny(); } - cc->InputSidePackets().Index(0).Set>(); - return mediapipe::OkStatus(); + cc->InputSidePackets().Index(0).Set>(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { return PassthroughFunction(cc->Inputs(), &(cc->Outputs())); } - mediapipe::Status Close(CalculatorContext* cc) override { - const auto& callback = cc->InputSidePackets() - .Index(0) - .Get>(); + absl::Status Close(CalculatorContext* cc) override { + const auto& callback = + cc->InputSidePackets().Index(0).Get>(); return callback(); } }; @@ -196,9 +195,9 @@ class RealTimeFlowLimiterCalculatorTest : public testing::Test { exit_semaphore_.Acquire(1); return PassthroughFunction(inputs, outputs); }; - std::function close_func = [this]() { + std::function close_func = [this]() { close_count_++; - return mediapipe::OkStatus(); + return absl::OkStatus(); }; MP_ASSERT_OK(graph_.Initialize( graph_config_, { diff --git a/mediapipe/calculators/core/round_robin_demux_calculator.cc b/mediapipe/calculators/core/round_robin_demux_calculator.cc index 8fe2c2b9c..8c93bba71 100644 --- a/mediapipe/calculators/core/round_robin_demux_calculator.cc +++ b/mediapipe/calculators/core/round_robin_demux_calculator.cc @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/port/ret_check.h" namespace mediapipe { +namespace api2 { // Forwards the input packet to one of the n output streams "OUTPUT:0", // "OUTPUT:1", ..., in round robin fashion. The index of the selected output @@ -71,50 +73,34 @@ namespace mediapipe { // output with MakePairCalculator, MakeVectorCalculator, or a similar variant to // use it with MuxCalculator and later unpack, or can create new variants of // MuxCalculator/MuxInputStreamHandler. -class RoundRobinDemuxCalculator : public CalculatorBase { +class RoundRobinDemuxCalculator : public Node { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { - RET_CHECK_EQ(cc->Inputs().NumEntries(), 1); - cc->Inputs().Index(0).SetAny(); - if (cc->Outputs().HasTag("SELECT")) { - cc->Outputs().Tag("SELECT").Set(); - } - for (CollectionItemId id = cc->Outputs().BeginId("OUTPUT"); - id < cc->Outputs().EndId("OUTPUT"); ++id) { - cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Index(0)); - } - return mediapipe::OkStatus(); - } + static constexpr Input kIn{""}; + static constexpr Output::Optional kSelect{"SELECT"}; + static constexpr Output>::Multiple kOut{"OUTPUT"}; - mediapipe::Status Open(CalculatorContext* cc) override { - select_output_ = cc->Outputs().GetId("SELECT", 0); + MEDIAPIPE_NODE_CONTRACT(kIn, kSelect, kOut); + + absl::Status Open(CalculatorContext* cc) override { output_data_stream_index_ = 0; - output_data_stream_base_ = cc->Outputs().GetId("OUTPUT", 0); - num_output_data_streams_ = cc->Outputs().NumEntries("OUTPUT"); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - cc->Outputs() - .Get(output_data_stream_base_ + output_data_stream_index_) - .AddPacket(cc->Inputs().Index(0).Value()); - if (select_output_.IsValid()) { - cc->Outputs() - .Get(select_output_) - .Add(new int(output_data_stream_index_), cc->InputTimestamp()); + absl::Status Process(CalculatorContext* cc) override { + kOut(cc)[output_data_stream_index_].Send(kIn(cc).packet()); + if (kSelect(cc).IsConnected()) { + kSelect(cc).Send(output_data_stream_index_); } output_data_stream_index_ = - (output_data_stream_index_ + 1) % num_output_data_streams_; - return mediapipe::OkStatus(); + (output_data_stream_index_ + 1) % kOut(cc).Count(); + return absl::OkStatus(); } private: - CollectionItemId select_output_; - CollectionItemId output_data_stream_base_; - int num_output_data_streams_; int output_data_stream_index_; }; -REGISTER_CALCULATOR(RoundRobinDemuxCalculator); +MEDIAPIPE_REGISTER_NODE(RoundRobinDemuxCalculator); +} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/core/sequence_shift_calculator.cc b/mediapipe/calculators/core/sequence_shift_calculator.cc index c2c9b8e5e..66dbdef2e 100644 --- a/mediapipe/calculators/core/sequence_shift_calculator.cc +++ b/mediapipe/calculators/core/sequence_shift_calculator.cc @@ -39,8 +39,8 @@ class SequenceShiftCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kOffset, kOut, TimestampChange::Arbitrary()); // Reads from options to set cache_size_ and packet_offset_. - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // A positive offset means we want a packet to be output with the timestamp of @@ -69,7 +69,7 @@ class SequenceShiftCalculator : public Node { }; MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator); -mediapipe::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { +absl::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { packet_offset_ = kOffset(cc).GetOr( cc->Options().packet_offset()); cache_size_ = abs(packet_offset_); @@ -77,10 +77,10 @@ mediapipe::Status SequenceShiftCalculator::Open(CalculatorContext* cc) { if (packet_offset_ == 0) { cc->Outputs().Index(0).SetOffset(0); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) { +absl::Status SequenceShiftCalculator::Process(CalculatorContext* cc) { if (packet_offset_ > 0) { ProcessPositiveOffset(cc); } else if (packet_offset_ < 0) { @@ -88,7 +88,7 @@ mediapipe::Status SequenceShiftCalculator::Process(CalculatorContext* cc) { } else { kOut(cc).Send(kIn(cc).packet()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) { diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc index 4ad359bbe..ed89889df 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator.cc @@ -89,10 +89,10 @@ class SidePacketToStreamCalculator : public CalculatorBase { SidePacketToStreamCalculator() = default; ~SidePacketToStreamCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: bool is_tick_processing_ = false; @@ -100,8 +100,7 @@ class SidePacketToStreamCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(SidePacketToStreamCalculator); -mediapipe::Status SidePacketToStreamCalculator::GetContract( - CalculatorContract* cc) { +absl::Status SidePacketToStreamCalculator::GetContract(CalculatorContract* cc) { const auto& tags = cc->Outputs().GetTags(); RET_CHECK(tags.size() == 1 && kTimestampMap->count(*tags.begin()) == 1) << "Only one of AT_PRESTREAM, AT_POSTSTREAM, AT_ZERO, AT_TICK and " @@ -138,10 +137,10 @@ mediapipe::Status SidePacketToStreamCalculator::GetContract( cc->Inputs().Tag(kTagTick).SetAny(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) { +absl::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) { output_tag_ = GetOutputTag(*cc); if (cc->Inputs().HasTag(kTagTick)) { is_tick_processing_ = true; @@ -149,10 +148,10 @@ mediapipe::Status SidePacketToStreamCalculator::Open(CalculatorContext* cc) { // timestamp bound update. cc->SetOffset(TimestampDiff(0)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { +absl::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { if (is_tick_processing_) { // TICK input is guaranteed to be non-empty, as it's the only input stream // for this calculator. @@ -163,13 +162,13 @@ mediapipe::Status SidePacketToStreamCalculator::Process(CalculatorContext* cc) { .AddPacket(cc->InputSidePackets().Index(i).At(timestamp)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } return mediapipe::tool::StatusStop(); } -mediapipe::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { +absl::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { if (!cc->Outputs().HasTag(kTagAtTick) && !cc->Outputs().HasTag(kTagAtTimestamp)) { const auto& timestamp = kTimestampMap->at(output_tag_); @@ -187,7 +186,7 @@ mediapipe::Status SidePacketToStreamCalculator::Close(CalculatorContext* cc) { .AddPacket(cc->InputSidePackets().Index(i).At(Timestamp(timestamp))); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc b/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc index 706825f19..b6b3d4e5c 100644 --- a/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc +++ b/mediapipe/calculators/core/side_packet_to_stream_calculator_test.cc @@ -189,7 +189,7 @@ void DoTestNonAtTickOutputTag(absl::string_view tag, MP_ASSERT_OK(graph.ObserveOutputStream( "packet", [&output_packets](const Packet& packet) { output_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK( graph.StartRun({{"side_packet", MakePacket(expected_value)}})); diff --git a/mediapipe/calculators/core/split_normalized_landmark_list_calculator.cc b/mediapipe/calculators/core/split_normalized_landmark_list_calculator.cc index c51a0a42f..d57cebe9c 100644 --- a/mediapipe/calculators/core/split_normalized_landmark_list_calculator.cc +++ b/mediapipe/calculators/core/split_normalized_landmark_list_calculator.cc @@ -35,7 +35,7 @@ namespace mediapipe { // NormalizedLandmarkList. class SplitNormalizedLandmarkListCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().NumEntries() == 1); RET_CHECK(cc->Outputs().NumEntries() != 0); @@ -55,7 +55,7 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { range_0.begin() < range_1.end()) || (range_1.begin() >= range_0.begin() && range_1.begin() < range_0.end())) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Ranges must be non-overlapping when using combine_outputs " "option."); } @@ -63,7 +63,7 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { } } else { if (cc->Outputs().NumEntries() != options.ranges_size()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "The number of output streams should match the number of ranges " "specified in the CalculatorOptions."); } @@ -72,13 +72,13 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { for (int i = 0; i < cc->Outputs().NumEntries(); ++i) { if (options.ranges(i).begin() < 0 || options.ranges(i).end() < 0 || options.ranges(i).begin() >= options.ranges(i).end()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Indices should be non-negative and begin index should be less " "than the end index."); } if (options.element_only()) { if (options.ranges(i).end() - options.ranges(i).begin() != 1) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Since element_only is true, all ranges should be of size 1."); } cc->Outputs().Index(i).Set(); @@ -88,10 +88,10 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -106,10 +106,10 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { total_elements_ += range.end() - range.begin(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const NormalizedLandmarkList& input = cc->Inputs().Index(0).Get(); RET_CHECK_GE(input.landmark_size(), max_range_end_) @@ -148,7 +148,7 @@ class SplitNormalizedLandmarkListCalculator : public CalculatorBase { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/core/split_vector_calculator.cc b/mediapipe/calculators/core/split_vector_calculator.cc index 730a8e34e..c8f1177d5 100644 --- a/mediapipe/calculators/core/split_vector_calculator.cc +++ b/mediapipe/calculators/core/split_vector_calculator.cc @@ -26,7 +26,7 @@ #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) namespace mediapipe { diff --git a/mediapipe/calculators/core/split_vector_calculator.h b/mediapipe/calculators/core/split_vector_calculator.h index 6fb863377..c77c6a40d 100644 --- a/mediapipe/calculators/core/split_vector_calculator.h +++ b/mediapipe/calculators/core/split_vector_calculator.h @@ -58,7 +58,7 @@ using IsNotMovable = template class SplitVectorCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().NumEntries() == 1); RET_CHECK(cc->Outputs().NumEntries() != 0); @@ -79,7 +79,7 @@ class SplitVectorCalculator : public CalculatorBase { RET_CHECK_OK(checkRangesDontOverlap(options)); } else { if (cc->Outputs().NumEntries() != options.ranges_size()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "The number of output streams should match the number of ranges " "specified in the CalculatorOptions."); } @@ -88,13 +88,13 @@ class SplitVectorCalculator : public CalculatorBase { for (int i = 0; i < cc->Outputs().NumEntries(); ++i) { if (options.ranges(i).begin() < 0 || options.ranges(i).end() < 0 || options.ranges(i).begin() >= options.ranges(i).end()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Indices should be non-negative and begin index should be less " "than the end index."); } if (options.element_only()) { if (options.ranges(i).end() - options.ranges(i).begin() != 1) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Since element_only is true, all ranges should be of size 1."); } cc->Outputs().Index(i).Set(); @@ -104,10 +104,10 @@ class SplitVectorCalculator : public CalculatorBase { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -122,11 +122,11 @@ class SplitVectorCalculator : public CalculatorBase { total_elements_ += range.end() - range.begin(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - if (cc->Inputs().Index(0).IsEmpty()) return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + if (cc->Inputs().Index(0).IsEmpty()) return absl::OkStatus(); if (move_elements) { return ProcessMovableElements(cc); @@ -136,7 +136,7 @@ class SplitVectorCalculator : public CalculatorBase { } template = true> - mediapipe::Status ProcessCopyableElements(CalculatorContext* cc) { + absl::Status ProcessCopyableElements(CalculatorContext* cc) { // static_assert(std::is_copy_constructible::value, // "Cannot copy non-copyable elements"); const auto& input = cc->Inputs().Index(0).Get>(); @@ -167,21 +167,21 @@ class SplitVectorCalculator : public CalculatorBase { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } template = true> - mediapipe::Status ProcessCopyableElements(CalculatorContext* cc) { - return mediapipe::InternalError("Cannot copy non-copyable elements."); + absl::Status ProcessCopyableElements(CalculatorContext* cc) { + return absl::InternalError("Cannot copy non-copyable elements."); } template = true> - mediapipe::Status ProcessMovableElements(CalculatorContext* cc) { - mediapipe::StatusOr>> input_status = + absl::Status ProcessMovableElements(CalculatorContext* cc) { + absl::StatusOr>> input_status = cc->Inputs().Index(0).Value().Consume>(); if (!input_status.ok()) return input_status.status(); std::unique_ptr> input_vector = - std::move(input_status).ValueOrDie(); + std::move(input_status).value(); RET_CHECK_GE(input_vector->size(), max_range_end_); if (combine_outputs_) { @@ -214,16 +214,16 @@ class SplitVectorCalculator : public CalculatorBase { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } template = true> - mediapipe::Status ProcessMovableElements(CalculatorContext* cc) { - return mediapipe::InternalError("Cannot move non-movable elements."); + absl::Status ProcessMovableElements(CalculatorContext* cc) { + return absl::InternalError("Cannot move non-movable elements."); } private: - static mediapipe::Status checkRangesDontOverlap( + static absl::Status checkRangesDontOverlap( const ::mediapipe::SplitVectorCalculatorOptions& options) { for (int i = 0; i < options.ranges_size() - 1; ++i) { for (int j = i + 1; j < options.ranges_size(); ++j) { @@ -233,13 +233,13 @@ class SplitVectorCalculator : public CalculatorBase { range_0.begin() < range_1.end()) || (range_1.begin() >= range_0.begin() && range_1.begin() < range_0.end())) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Ranges must be non-overlapping when using combine_outputs " "option."); } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } std::vector> ranges_; diff --git a/mediapipe/calculators/core/stream_to_side_packet_calculator.cc b/mediapipe/calculators/core/stream_to_side_packet_calculator.cc index 07bb8c852..9dc25142a 100644 --- a/mediapipe/calculators/core/stream_to_side_packet_calculator.cc +++ b/mediapipe/calculators/core/stream_to_side_packet_calculator.cc @@ -30,17 +30,17 @@ namespace mediapipe { // } class StreamToSidePacketCalculator : public mediapipe::CalculatorBase { public: - static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) { + static absl::Status GetContract(mediapipe::CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->OutputSidePackets().Index(0).SetAny(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(mediapipe::CalculatorContext* cc) override { + absl::Status Process(mediapipe::CalculatorContext* cc) override { mediapipe::Packet& packet = cc->Inputs().Index(0).Value(); cc->OutputSidePackets().Index(0).Set( packet.At(mediapipe::Timestamp::Unset())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(StreamToSidePacketCalculator); diff --git a/mediapipe/calculators/core/stream_to_side_packet_calculator_test.cc b/mediapipe/calculators/core/stream_to_side_packet_calculator_test.cc index 12f417c58..606f0e352 100644 --- a/mediapipe/calculators/core/stream_to_side_packet_calculator_test.cc +++ b/mediapipe/calculators/core/stream_to_side_packet_calculator_test.cc @@ -44,7 +44,7 @@ class StreamToSidePacketCalculatorTest : public Test { TEST_F(StreamToSidePacketCalculatorTest, StreamToSidePacketCalculatorWithEmptyStreamFails) { - EXPECT_EQ(runner_->Run().code(), mediapipe::StatusCode::kUnavailable); + EXPECT_EQ(runner_->Run().code(), absl::StatusCode::kUnavailable); } TEST_F(StreamToSidePacketCalculatorTest, @@ -61,7 +61,7 @@ TEST_F(StreamToSidePacketCalculatorTest, Adopt(new std::string("test1")).At(Timestamp(1))); runner_->MutableInputs()->Index(0).packets.push_back( Adopt(new std::string("test2")).At(Timestamp(2))); - EXPECT_EQ(runner_->Run().code(), mediapipe::StatusCode::kAlreadyExists); + EXPECT_EQ(runner_->Run().code(), absl::StatusCode::kAlreadyExists); } } // namespace mediapipe diff --git a/mediapipe/calculators/core/string_to_int_calculator.cc b/mediapipe/calculators/core/string_to_int_calculator.cc index 5f8a6e325..13a9a29e0 100644 --- a/mediapipe/calculators/core/string_to_int_calculator.cc +++ b/mediapipe/calculators/core/string_to_int_calculator.cc @@ -36,25 +36,25 @@ namespace mediapipe { template class StringToIntCalculatorTemplate : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Index(0).Set(); cc->OutputSidePackets().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { IntType number; if (!absl::SimpleAtoi(cc->InputSidePackets().Index(0).Get(), &number)) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "The std::string could not be parsed as an integer."); } cc->OutputSidePackets().Index(0).Set(MakePacket(number)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 7328ec3d9..e94fb7ec7 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -21,9 +21,7 @@ package(default_visibility = ["//visibility:private"]) mediapipe_proto_library( name = "opencv_image_encoder_calculator_proto", srcs = ["opencv_image_encoder_calculator.proto"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -64,9 +62,7 @@ mediapipe_proto_library( mediapipe_proto_library( name = "bilateral_filter_calculator_proto", srcs = ["bilateral_filter_calculator.proto"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -87,9 +83,7 @@ mediapipe_proto_library( cc_library( name = "color_convert_calculator", srcs = ["color_convert_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:timestamp", @@ -123,9 +117,7 @@ cc_library( cc_library( name = "opencv_image_encoder_calculator", srcs = ["opencv_image_encoder_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ ":opencv_image_encoder_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -181,9 +173,7 @@ cc_library( cc_library( name = "bilateral_filter_calculator", srcs = ["bilateral_filter_calculator.cc"], - visibility = [ - "//visibility:public", - ], + visibility = ["//visibility:public"], deps = [ ":bilateral_filter_calculator_cc_proto", "//mediapipe/framework:calculator_options_cc_proto", @@ -448,7 +438,6 @@ cc_test( "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/port:status", ], ) @@ -467,7 +456,6 @@ cc_test( "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/port:status", ], ) @@ -503,6 +491,7 @@ mediapipe_proto_library( mediapipe_proto_library( name = "feature_detector_calculator_proto", srcs = ["feature_detector_calculator.proto"], + visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_proto", @@ -528,7 +517,7 @@ cc_library( cc_library( name = "feature_detector_calculator", srcs = ["feature_detector_calculator.cc"], - visibility = ["//mediapipe:__subpackages__"], + visibility = ["//visibility:public"], deps = [ ":feature_detector_calculator_cc_proto", "//mediapipe/framework:calculator_framework", @@ -579,6 +568,5 @@ cc_test( "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/port:status", ], ) diff --git a/mediapipe/calculators/image/bilateral_filter_calculator.cc b/mediapipe/calculators/image/bilateral_filter_calculator.cc index ae89c4f4a..3d878bffc 100644 --- a/mediapipe/calculators/image/bilateral_filter_calculator.cc +++ b/mediapipe/calculators/image/bilateral_filter_calculator.cc @@ -28,11 +28,11 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -82,18 +82,18 @@ class BilateralFilterCalculator : public CalculatorBase { BilateralFilterCalculator() = default; ~BilateralFilterCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // From Calculator. - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status RenderGpu(CalculatorContext* cc); - mediapipe::Status RenderCpu(CalculatorContext* cc); + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); - mediapipe::Status GlSetup(CalculatorContext* cc); + absl::Status GlSetup(CalculatorContext* cc); void GlRender(CalculatorContext* cc); mediapipe::BilateralFilterCalculatorOptions options_; @@ -102,7 +102,7 @@ class BilateralFilterCalculator : public CalculatorBase { bool use_gpu_ = false; bool gpu_initialized_ = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; GLuint vao_; @@ -111,71 +111,70 @@ class BilateralFilterCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(BilateralFilterCalculator); -mediapipe::Status BilateralFilterCalculator::GetContract( - CalculatorContract* cc) { +absl::Status BilateralFilterCalculator::GetContract(CalculatorContract* cc) { CHECK_GE(cc->Inputs().NumEntries(), 1); if (cc->Inputs().HasTag(kInputFrameTag) && cc->Inputs().HasTag(kInputFrameTagGpu)) { - return mediapipe::InternalError("Cannot have multiple input images."); + return absl::InternalError("Cannot have multiple input images."); } if (cc->Inputs().HasTag(kInputFrameTagGpu) != cc->Outputs().HasTag(kOutputFrameTagGpu)) { - return mediapipe::InternalError("GPU output must have GPU input."); + return absl::InternalError("GPU output must have GPU input."); } bool use_gpu = false; // Input image to filter. -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().Tag(kInputFrameTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputFrameTag)) { cc->Inputs().Tag(kInputFrameTag).Set(); } // Input guide image mask (optional) -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputGuideTagGpu)) { cc->Inputs().Tag(kInputGuideTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputGuideTag)) { cc->Inputs().Tag(kInputGuideTag).Set(); } // Output image. -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().Tag(kOutputFrameTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kOutputFrameTag)) { cc->Outputs().Tag(kOutputFrameTag).Set(); } if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BilateralFilterCalculator::Open(CalculatorContext* cc) { +absl::Status BilateralFilterCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); if (cc->Inputs().HasTag(kInputFrameTagGpu) && cc->Outputs().HasTag(kOutputFrameTagGpu)) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU use_gpu_ = true; #else RET_CHECK_FAIL() << "GPU processing not enabled."; @@ -189,36 +188,35 @@ mediapipe::Status BilateralFilterCalculator::Open(CalculatorContext* cc) { if (!use_gpu_) sigma_color_ *= 255.0; if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BilateralFilterCalculator::Process(CalculatorContext* cc) { +absl::Status BilateralFilterCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> mediapipe::Status { - if (!gpu_initialized_) { - MP_RETURN_IF_ERROR(GlSetup(cc)); - gpu_initialized_ = true; - } - MP_RETURN_IF_ERROR(RenderGpu(cc)); - return mediapipe::OkStatus(); - })); -#endif // !MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + if (!gpu_initialized_) { + MP_RETURN_IF_ERROR(GlSetup(cc)); + gpu_initialized_ = true; + } + MP_RETURN_IF_ERROR(RenderGpu(cc)); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BilateralFilterCalculator::Close(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status BilateralFilterCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); if (vao_) glDeleteVertexArrays(1, &vao_); @@ -228,14 +226,14 @@ mediapipe::Status BilateralFilterCalculator::Close(CalculatorContext* cc) { vbo_[0] = 0; vbo_[1] = 0; }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BilateralFilterCalculator::RenderCpu(CalculatorContext* cc) { +absl::Status BilateralFilterCalculator::RenderCpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kInputFrameTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get(); @@ -243,7 +241,7 @@ mediapipe::Status BilateralFilterCalculator::RenderCpu(CalculatorContext* cc) { // Only 1 or 3 channel images supported by OpenCV. if ((input_mat.channels() == 1 || input_mat.channels() == 3)) { - return mediapipe::InternalError( + return absl::InternalError( "CPU filtering supports only 1 or 3 channel input images."); } @@ -254,7 +252,7 @@ mediapipe::Status BilateralFilterCalculator::RenderCpu(CalculatorContext* cc) { if (has_guide_image) { // cv::jointBilateralFilter() is in contrib module 'ximgproc'. - return mediapipe::UnimplementedError( + return absl::UnimplementedError( "CPU joint filtering support is not implemented yet."); } else { auto output_mat = mediapipe::formats::MatView(output_frame.get()); @@ -266,14 +264,14 @@ mediapipe::Status BilateralFilterCalculator::RenderCpu(CalculatorContext* cc) { cc->Outputs() .Tag(kOutputFrameTag) .Add(output_frame.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BilateralFilterCalculator::RenderGpu(CalculatorContext* cc) { +absl::Status BilateralFilterCalculator::RenderGpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); @@ -283,8 +281,7 @@ mediapipe::Status BilateralFilterCalculator::RenderGpu(CalculatorContext* cc) { // Setup textures and Update image in GPU shader. if (has_guide_image) { - if (cc->Inputs().Tag(kInputGuideTagGpu).IsEmpty()) - return mediapipe::OkStatus(); + if (cc->Inputs().Tag(kInputGuideTagGpu).IsEmpty()) return absl::OkStatus(); // joint bilateral filter glUseProgram(program_); const auto& guide_image = @@ -330,13 +327,13 @@ mediapipe::Status BilateralFilterCalculator::RenderGpu(CalculatorContext* cc) { // Cleanup input_texture.Release(); output_texture.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } void BilateralFilterCalculator::GlRender(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU // bring back vao and vbo glBindVertexArray(vao_); @@ -345,11 +342,11 @@ void BilateralFilterCalculator::GlRender(CalculatorContext* cc) { // cleanup glBindVertexArray(0); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } -mediapipe::Status BilateralFilterCalculator::GlSetup(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status BilateralFilterCalculator::GlSetup(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -513,9 +510,9 @@ mediapipe::Status BilateralFilterCalculator::GlSetup(CalculatorContext* cc) { glBindBuffer(GL_ARRAY_BUFFER, 0); glBindVertexArray(0); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/color_convert_calculator.cc b/mediapipe/calculators/image/color_convert_calculator.cc index b0f72d29b..bdac932bb 100644 --- a/mediapipe/calculators/image/color_convert_calculator.cc +++ b/mediapipe/calculators/image/color_convert_calculator.cc @@ -78,12 +78,12 @@ constexpr char kGrayOutTag[] = "GRAY_OUT"; class ColorConvertCalculator : public CalculatorBase { public: ~ColorConvertCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Process(CalculatorContext* cc) override; - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -91,16 +91,16 @@ class ColorConvertCalculator : public CalculatorBase { // conversion. The ImageFrame on input_tag is converted using the // open_cv_convert_code provided and then output on the output_tag stream. // Note that the output_format must match the destination conversion code. - mediapipe::Status ConvertAndOutput(const std::string& input_tag, - const std::string& output_tag, - ImageFormat::Format output_format, - int open_cv_convert_code, - CalculatorContext* cc); + absl::Status ConvertAndOutput(const std::string& input_tag, + const std::string& output_tag, + ImageFormat::Format output_format, + int open_cv_convert_code, + CalculatorContext* cc); }; REGISTER_CALCULATOR(ColorConvertCalculator); -mediapipe::Status ColorConvertCalculator::GetContract(CalculatorContract* cc) { +absl::Status ColorConvertCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Only one input stream is allowed."; RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) @@ -138,10 +138,10 @@ mediapipe::Status ColorConvertCalculator::GetContract(CalculatorContract* cc) { cc->Outputs().Tag(kBgraOutTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ColorConvertCalculator::ConvertAndOutput( +absl::Status ColorConvertCalculator::ConvertAndOutput( const std::string& input_tag, const std::string& output_tag, ImageFormat::Format output_format, int open_cv_convert_code, CalculatorContext* cc) { @@ -160,10 +160,10 @@ mediapipe::Status ColorConvertCalculator::ConvertAndOutput( cc->Outputs() .Tag(output_tag) .Add(output_frame.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ColorConvertCalculator::Process(CalculatorContext* cc) { +absl::Status ColorConvertCalculator::Process(CalculatorContext* cc) { // RGBA -> RGB if (cc->Inputs().HasTag(kRgbaInTag) && cc->Outputs().HasTag(kRgbOutTag)) { return ConvertAndOutput(kRgbaInTag, kRgbOutTag, ImageFormat::SRGB, diff --git a/mediapipe/calculators/image/feature_detector_calculator.cc b/mediapipe/calculators/image/feature_detector_calculator.cc index d3b774bca..389a33696 100644 --- a/mediapipe/calculators/image/feature_detector_calculator.cc +++ b/mediapipe/calculators/image/feature_detector_calculator.cc @@ -50,10 +50,10 @@ class FeatureDetectorCalculator : public CalculatorBase { public: ~FeatureDetectorCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: FeatureDetectorCalculatorOptions options_; @@ -71,8 +71,7 @@ class FeatureDetectorCalculator : public CalculatorBase { REGISTER_CALCULATOR(FeatureDetectorCalculator); -mediapipe::Status FeatureDetectorCalculator::GetContract( - CalculatorContract* cc) { +absl::Status FeatureDetectorCalculator::GetContract(CalculatorContract* cc) { if (cc->Inputs().HasTag("IMAGE")) { cc->Inputs().Tag("IMAGE").Set(); } @@ -85,10 +84,10 @@ mediapipe::Status FeatureDetectorCalculator::GetContract( if (cc->Outputs().HasTag("PATCHES")) { cc->Outputs().Tag("PATCHES").Set>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FeatureDetectorCalculator::Open(CalculatorContext* cc) { +absl::Status FeatureDetectorCalculator::Open(CalculatorContext* cc) { options_ = tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), kOptionsTag) .GetExtension(FeatureDetectorCalculatorOptions::ext); @@ -97,14 +96,14 @@ mediapipe::Status FeatureDetectorCalculator::Open(CalculatorContext* cc) { options_.pyramid_level(), kPatchSize - 1, 0, 2, cv::ORB::FAST_SCORE); pool_ = absl::make_unique("ThreadPool", kNumThreads); pool_->StartWorkers(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FeatureDetectorCalculator::Process(CalculatorContext* cc) { +absl::Status FeatureDetectorCalculator::Process(CalculatorContext* cc) { const Timestamp& timestamp = cc->InputTimestamp(); if (timestamp == Timestamp::PreStream()) { // Indicator packet. - return mediapipe::OkStatus(); + return absl::OkStatus(); } InputStream* input_frame = &(cc->Inputs().Tag("IMAGE")); cv::Mat input_view = formats::MatView(&input_frame->Get()); @@ -176,7 +175,7 @@ mediapipe::Status FeatureDetectorCalculator::Process(CalculatorContext* cc) { cc->Outputs().Tag("PATCHES").Add(patches.release(), timestamp); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void FeatureDetectorCalculator::ComputeImagePyramid( diff --git a/mediapipe/calculators/image/image_cropping_calculator.cc b/mediapipe/calculators/image/image_cropping_calculator.cc index 591fcb47b..e4b0b7218 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.cc +++ b/mediapipe/calculators/image/image_cropping_calculator.cc @@ -24,11 +24,11 @@ #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace { enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; @@ -38,9 +38,9 @@ namespace mediapipe { namespace { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU constexpr char kRectTag[] = "RECT"; constexpr char kNormRectTag[] = "NORM_RECT"; @@ -53,7 +53,7 @@ constexpr char kWidthTag[] = "WIDTH"; REGISTER_CALCULATOR(ImageCroppingCalculator); -mediapipe::Status ImageCroppingCalculator::GetContract(CalculatorContract* cc) { +absl::Status ImageCroppingCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kImageTag) ^ cc->Inputs().HasTag(kImageGpuTag)); RET_CHECK(cc->Outputs().HasTag(kImageTag) ^ cc->Outputs().HasTag(kImageGpuTag)); @@ -65,14 +65,14 @@ mediapipe::Status ImageCroppingCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kImageTag).Set(); cc->Outputs().Tag(kImageTag).Set(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kImageGpuTag)) { RET_CHECK(cc->Outputs().HasTag(kImageGpuTag)); cc->Inputs().Tag(kImageGpuTag).Set(); cc->Outputs().Tag(kImageGpuTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU int flags = 0; if (cc->Inputs().HasTag(kRectTag)) { @@ -110,15 +110,15 @@ mediapipe::Status ImageCroppingCalculator::GetContract(CalculatorContract* cc) { } if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageCroppingCalculator::Open(CalculatorContext* cc) { +absl::Status ImageCroppingCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); if (cc->Inputs().HasTag(kImageGpuTag)) { @@ -132,11 +132,11 @@ mediapipe::Status ImageCroppingCalculator::Open(CalculatorContext* cc) { options_.has_output_max_height() ? options_.output_max_height() : FLT_MAX; if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); #else RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } // Validate border mode. @@ -146,56 +146,55 @@ mediapipe::Status ImageCroppingCalculator::Open(CalculatorContext* cc) { MP_RETURN_IF_ERROR(ValidateBorderModeForCPU(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageCroppingCalculator::Process(CalculatorContext* cc) { +absl::Status ImageCroppingCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kRectTag) && cc->Inputs().Tag(kRectTag).IsEmpty()) { VLOG(1) << "RECT is empty for timestamp: " << cc->InputTimestamp(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (cc->Inputs().HasTag(kNormRectTag) && cc->Inputs().Tag(kNormRectTag).IsEmpty()) { VLOG(1) << "NORM_RECT is empty for timestamp: " << cc->InputTimestamp(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> mediapipe::Status { - if (!gpu_initialized_) { - MP_RETURN_IF_ERROR(InitGpu(cc)); - gpu_initialized_ = true; - } - MP_RETURN_IF_ERROR(RenderGpu(cc)); - return mediapipe::OkStatus(); - })); -#endif // !MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + if (!gpu_initialized_) { + MP_RETURN_IF_ERROR(InitGpu(cc)); + gpu_initialized_ = true; + } + MP_RETURN_IF_ERROR(RenderGpu(cc)); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageCroppingCalculator::Close(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status ImageCroppingCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; }); gpu_initialized_ = false; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageCroppingCalculator::ValidateBorderModeForCPU( +absl::Status ImageCroppingCalculator::ValidateBorderModeForCPU( CalculatorContext* cc) { int border_mode; return GetBorderModeForOpenCV(cc, &border_mode); } -mediapipe::Status ImageCroppingCalculator::ValidateBorderModeForGPU( +absl::Status ImageCroppingCalculator::ValidateBorderModeForGPU( CalculatorContext* cc) { mediapipe::ImageCroppingCalculatorOptions options = cc->Options(); @@ -212,12 +211,12 @@ mediapipe::Status ImageCroppingCalculator::ValidateBorderModeForGPU( << options.border_mode(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageCroppingCalculator::RenderCpu(CalculatorContext* cc) { +absl::Status ImageCroppingCalculator::RenderCpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kImageTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_img = cc->Inputs().Tag(kImageTag).Get(); cv::Mat input_mat = formats::MatView(&input_img); @@ -267,14 +266,14 @@ mediapipe::Status ImageCroppingCalculator::RenderCpu(CalculatorContext* cc) { cropped_image.copyTo(output_mat); cc->Outputs().Tag(kImageTag).Add(output_frame.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) { +absl::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kImageGpuTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU const Packet& input_packet = cc->Inputs().Tag(kImageGpuTag).Value(); const auto& input_buffer = input_packet.Get(); auto src_tex = gpu_helper_.CreateSourceTexture(input_buffer); @@ -305,13 +304,13 @@ mediapipe::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) { // Cleanup src_tex.Release(); dst_tex.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } void ImageCroppingCalculator::GlRender() { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -355,11 +354,11 @@ void ImageCroppingCalculator::GlRender() { glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } -mediapipe::Status ImageCroppingCalculator::InitGpu(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status ImageCroppingCalculator::InitGpu(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -405,9 +404,9 @@ mediapipe::Status ImageCroppingCalculator::InitGpu(CalculatorContext* cc) { // Parameters glUseProgram(program_); glUniform1i(glGetUniformLocation(program_, "input_frame"), 1); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } // For GPU only. @@ -533,7 +532,7 @@ RectSpec ImageCroppingCalculator::GetCropSpecs(const CalculatorContext* cc, return {crop_width, crop_height, x_center, y_center, rotation}; } -mediapipe::Status ImageCroppingCalculator::GetBorderModeForOpenCV( +absl::Status ImageCroppingCalculator::GetBorderModeForOpenCV( CalculatorContext* cc, int* border_mode) { mediapipe::ImageCroppingCalculatorOptions options = cc->Options(); @@ -550,7 +549,7 @@ mediapipe::Status ImageCroppingCalculator::GetBorderModeForOpenCV( << options.border_mode(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/image_cropping_calculator.h b/mediapipe/calculators/image/image_cropping_calculator.h index 2f0324879..39d99cc55 100644 --- a/mediapipe/calculators/image/image_cropping_calculator.h +++ b/mediapipe/calculators/image/image_cropping_calculator.h @@ -6,9 +6,9 @@ #include "mediapipe/calculators/image/image_cropping_calculator.pb.h" #include "mediapipe/framework/calculator_framework.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU // Crops the input texture to the given rectangle region. The rectangle can // be at arbitrary location on the image with rotation. If there's rotation, the @@ -58,24 +58,23 @@ class ImageCroppingCalculator : public CalculatorBase { ImageCroppingCalculator() = default; ~ImageCroppingCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; static RectSpec GetCropSpecs(const CalculatorContext* cc, int src_width, int src_height); private: - mediapipe::Status ValidateBorderModeForCPU(CalculatorContext* cc); - mediapipe::Status ValidateBorderModeForGPU(CalculatorContext* cc); - mediapipe::Status RenderCpu(CalculatorContext* cc); - mediapipe::Status RenderGpu(CalculatorContext* cc); - mediapipe::Status InitGpu(CalculatorContext* cc); + absl::Status ValidateBorderModeForCPU(CalculatorContext* cc); + absl::Status ValidateBorderModeForGPU(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status InitGpu(CalculatorContext* cc); void GlRender(); void GetOutputDimensions(CalculatorContext* cc, int src_width, int src_height, int* dst_width, int* dst_height); - mediapipe::Status GetBorderModeForOpenCV(CalculatorContext* cc, - int* border_mode); + absl::Status GetBorderModeForOpenCV(CalculatorContext* cc, int* border_mode); mediapipe::ImageCroppingCalculatorOptions options_; @@ -84,11 +83,11 @@ class ImageCroppingCalculator : public CalculatorBase { float transformed_points_[8]; float output_max_width_ = FLT_MAX; float output_max_height_ = FLT_MAX; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU bool gpu_initialized_ = false; mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; } // namespace mediapipe diff --git a/mediapipe/calculators/image/image_cropping_calculator_test.cc b/mediapipe/calculators/image/image_cropping_calculator_test.cc index c511014aa..bb75826a1 100644 --- a/mediapipe/calculators/image/image_cropping_calculator_test.cc +++ b/mediapipe/calculators/image/image_cropping_calculator_test.cc @@ -59,8 +59,8 @@ TEST(ImageCroppingCalculatorTest, GetCroppingDimensionsNormal) { auto calculator_state = absl::make_unique( "Node", 0, "Calculator", calculator_node, nullptr); auto cc = absl::make_unique( - calculator_state.get(), tool::CreateTagMap({}).ValueOrDie(), - tool::CreateTagMap({}).ValueOrDie()); + calculator_state.get(), tool::CreateTagMap({}).value(), + tool::CreateTagMap({}).value()); RectSpec expectRect = { .width = 60, @@ -99,8 +99,8 @@ TEST(ImageCroppingCalculatorTest, RedundantSpecInOptions) { auto calculator_state = absl::make_unique( "Node", 0, "Calculator", calculator_node, nullptr); auto cc = absl::make_unique( - calculator_state.get(), tool::CreateTagMap({}).ValueOrDie(), - tool::CreateTagMap({}).ValueOrDie()); + calculator_state.get(), tool::CreateTagMap({}).value(), + tool::CreateTagMap({}).value()); RectSpec expectRect = { .width = 50, .height = 50, @@ -144,9 +144,9 @@ TEST(ImageCroppingCalculatorTest, RedundantSpectWithInputStream) { "HEIGHT:0:crop_height", "WIDTH:0:crop_width", }) - .ValueOrDie(); + .value(); auto cc = absl::make_unique( - calculator_state.get(), inputTags, tool::CreateTagMap({}).ValueOrDie()); + calculator_state.get(), inputTags, tool::CreateTagMap({}).value()); auto& inputs = cc->Inputs(); inputs.Tag(kHeightTag).Value() = MakePacket(1); inputs.Tag(kWidthTag).Value() = MakePacket(1); @@ -191,9 +191,9 @@ TEST(ImageCroppingCalculatorTest, RedundantSpecWithInputStream) { auto inputTags = tool::CreateTagMap({ "RECT:0:rect", }) - .ValueOrDie(); + .value(); auto cc = absl::make_unique( - calculator_state.get(), inputTags, tool::CreateTagMap({}).ValueOrDie()); + calculator_state.get(), inputTags, tool::CreateTagMap({}).value()); auto& inputs = cc->Inputs(); mediapipe::Rect rect = ParseTextProtoOrDie( R"( diff --git a/mediapipe/calculators/image/image_file_properties_calculator.cc b/mediapipe/calculators/image/image_file_properties_calculator.cc index a0636acbe..9c6d8caca 100644 --- a/mediapipe/calculators/image/image_file_properties_calculator.cc +++ b/mediapipe/calculators/image/image_file_properties_calculator.cc @@ -28,24 +28,24 @@ namespace { // sqrt(36^2 + 24^2). static const double SENSOR_DIAGONAL_35MM = std::sqrt(1872.0); -mediapipe::StatusOr ComputeFocalLengthInPixels(int image_width, - int image_height, - double focal_length_35mm, - double focal_length_mm) { +absl::StatusOr ComputeFocalLengthInPixels(int image_width, + int image_height, + double focal_length_35mm, + double focal_length_mm) { // TODO: Allow returning image file properties even when focal length // computation is not possible. if (image_width == 0 || image_height == 0) { - return mediapipe::InternalError( + return absl::InternalError( "Image dimensions should be non-zero to compute focal length in " "pixels."); } if (focal_length_mm == 0) { - return mediapipe::InternalError( + return absl::InternalError( "Focal length in mm should be non-zero to compute focal length in " "pixels."); } if (focal_length_35mm == 0) { - return mediapipe::InternalError( + return absl::InternalError( "Focal length in 35 mm should be non-zero to compute focal length in " "pixels."); } @@ -77,13 +77,13 @@ mediapipe::StatusOr ComputeFocalLengthInPixels(int image_width, return focal_length_pixels; } -mediapipe::StatusOr GetImageFileProperites( +absl::StatusOr GetImageFileProperites( const std::string& image_bytes) { easyexif::EXIFInfo result; int code = result.parseFrom(image_bytes); if (code) { - return mediapipe::InternalError("Error parsing EXIF, code: " + - std::to_string(code)); + return absl::InternalError("Error parsing EXIF, code: " + + std::to_string(code)); } ImageFileProperties properties; @@ -126,7 +126,7 @@ mediapipe::StatusOr GetImageFileProperites( // } class ImageFilePropertiesCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { if (cc->Inputs().NumEntries() != 0) { RET_CHECK(cc->Inputs().NumEntries() == 1); cc->Inputs().Index(0).Set(); @@ -142,10 +142,10 @@ class ImageFilePropertiesCalculator : public CalculatorBase { cc->OutputSidePackets().Index(0).Set<::mediapipe::ImageFileProperties>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); if (cc->InputSidePackets().NumEntries() == 1) { @@ -160,13 +160,13 @@ class ImageFilePropertiesCalculator : public CalculatorBase { MakePacket(properties_)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->Inputs().NumEntries() == 1) { if (cc->Inputs().Index(0).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const std::string& image_bytes = cc->Inputs().Index(0).Get(); ASSIGN_OR_RETURN(properties_, GetImageFileProperites(image_bytes)); @@ -184,7 +184,7 @@ class ImageFilePropertiesCalculator : public CalculatorBase { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/image/image_properties_calculator.cc b/mediapipe/calculators/image/image_properties_calculator.cc index 84e67c0cb..5fbd64012 100644 --- a/mediapipe/calculators/image/image_properties_calculator.cc +++ b/mediapipe/calculators/image/image_properties_calculator.cc @@ -15,9 +15,9 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gpu_buffer.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace { constexpr char kImageFrameTag[] = "IMAGE"; @@ -44,31 +44,31 @@ namespace mediapipe { // } class ImagePropertiesCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kImageFrameTag) ^ cc->Inputs().HasTag(kGpuBufferTag)); if (cc->Inputs().HasTag(kImageFrameTag)) { cc->Inputs().Tag(kImageFrameTag).Set(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { cc->Inputs().Tag(kGpuBufferTag).Set<::mediapipe::GpuBuffer>(); } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag("SIZE")) { cc->Outputs().Tag("SIZE").Set>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { int width; int height; @@ -78,7 +78,7 @@ class ImagePropertiesCalculator : public CalculatorBase { width = image.Width(); height = image.Height(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag) && !cc->Inputs().Tag(kGpuBufferTag).IsEmpty()) { const auto& image = @@ -86,13 +86,13 @@ class ImagePropertiesCalculator : public CalculatorBase { width = image.width(); height = image.height(); } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU cc->Outputs().Tag("SIZE").AddPacket( MakePacket>(width, height) .At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(ImagePropertiesCalculator); diff --git a/mediapipe/calculators/image/image_transformation_calculator.cc b/mediapipe/calculators/image/image_transformation_calculator.cc index 0c5ff8bdd..bb98f14e0 100644 --- a/mediapipe/calculators/image/image_transformation_calculator.cc +++ b/mediapipe/calculators/image/image_transformation_calculator.cc @@ -22,12 +22,12 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/scale_mode.pb.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_quad_renderer.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU #if defined(__ANDROID__) // The size of Java arrays is dynamic, which makes it difficult to @@ -42,9 +42,9 @@ typedef int DimensionsPacketType[2]; namespace mediapipe { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace { constexpr char kImageFrameTag[] = "IMAGE"; @@ -163,16 +163,16 @@ class ImageTransformationCalculator : public CalculatorBase { ImageTransformationCalculator() = default; ~ImageTransformationCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status RenderCpu(CalculatorContext* cc); - mediapipe::Status RenderGpu(CalculatorContext* cc); - mediapipe::Status GlSetup(); + absl::Status RenderCpu(CalculatorContext* cc); + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status GlSetup(); void ComputeOutputDimensions(int input_width, int input_height, int* output_width, int* output_height); @@ -189,17 +189,17 @@ class ImageTransformationCalculator : public CalculatorBase { bool flip_vertically_ = false; bool use_gpu_ = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU GlCalculatorHelper gpu_helper_; std::unique_ptr rgb_renderer_; std::unique_ptr yuv_renderer_; std::unique_ptr ext_rgb_renderer_; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(ImageTransformationCalculator); // static -mediapipe::Status ImageTransformationCalculator::GetContract( +absl::Status ImageTransformationCalculator::GetContract( CalculatorContract* cc) { // Only one input can be set, and the output type must match. RET_CHECK(cc->Inputs().HasTag(kImageFrameTag) ^ @@ -212,14 +212,14 @@ mediapipe::Status ImageTransformationCalculator::GetContract( cc->Inputs().Tag(kImageFrameTag).Set(); cc->Outputs().Tag(kImageFrameTag).Set(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { RET_CHECK(cc->Outputs().HasTag(kGpuBufferTag)); cc->Inputs().Tag(kGpuBufferTag).Set(); cc->Outputs().Tag(kGpuBufferTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag("ROTATION_DEGREES")) { cc->Inputs().Tag("ROTATION_DEGREES").Set(); @@ -249,15 +249,15 @@ mediapipe::Status ImageTransformationCalculator::GetContract( } if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageTransformationCalculator::Open(CalculatorContext* cc) { +absl::Status ImageTransformationCalculator::Open(CalculatorContext* cc) { // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(TimestampDiff(0)); @@ -303,19 +303,18 @@ mediapipe::Status ImageTransformationCalculator::Open(CalculatorContext* cc) { scale_mode_ = ParseScaleMode(options_.scale_mode(), DEFAULT_SCALE_MODE); if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU // Let the helper access the GL context information. MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); #else RET_CHECK_FAIL() << "GPU processing not enabled."; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageTransformationCalculator::Process( - CalculatorContext* cc) { +absl::Status ImageTransformationCalculator::Process(CalculatorContext* cc) { // Override values if specified so. if (cc->Inputs().HasTag("ROTATION_DEGREES") && !cc->Inputs().Tag("ROTATION_DEGREES").IsEmpty()) { @@ -332,25 +331,25 @@ mediapipe::Status ImageTransformationCalculator::Process( } if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().Tag(kGpuBufferTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } return gpu_helper_.RunInGlContext( - [this, cc]() -> mediapipe::Status { return RenderGpu(cc); }); -#endif // !MEDIAPIPE_DISABLE_GPU + [this, cc]() -> absl::Status { return RenderGpu(cc); }); +#endif // !MEDIAPIPE_DISABLE_GPU } else { if (cc->Inputs().Tag(kImageFrameTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } return RenderCpu(cc); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageTransformationCalculator::Close(CalculatorContext* cc) { +absl::Status ImageTransformationCalculator::Close(CalculatorContext* cc) { if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU QuadRenderer* rgb_renderer = rgb_renderer_.release(); QuadRenderer* yuv_renderer = yuv_renderer_.release(); QuadRenderer* ext_rgb_renderer = ext_rgb_renderer_.release(); @@ -368,14 +367,13 @@ mediapipe::Status ImageTransformationCalculator::Close(CalculatorContext* cc) { delete yuv_renderer; } }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageTransformationCalculator::RenderCpu( - CalculatorContext* cc) { +absl::Status ImageTransformationCalculator::RenderCpu(CalculatorContext* cc) { cv::Mat input_mat; mediapipe::ImageFormat::Format format; @@ -479,12 +477,11 @@ mediapipe::Status ImageTransformationCalculator::RenderCpu( .Tag(kImageFrameTag) .Add(output_frame.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageTransformationCalculator::RenderGpu( - CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status ImageTransformationCalculator::RenderGpu(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const auto& input = cc->Inputs().Tag(kGpuBufferTag).Get(); const int input_width = input.width(); const int input_height = input.height(); @@ -567,9 +564,9 @@ mediapipe::Status ImageTransformationCalculator::RenderGpu( auto output = dst.template GetFrame(); cc->Outputs().Tag(kGpuBufferTag).Add(output.release(), cc->InputTimestamp()); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } void ImageTransformationCalculator::ComputeOutputDimensions( diff --git a/mediapipe/calculators/image/luminance_calculator.cc b/mediapipe/calculators/image/luminance_calculator.cc index 503d2b22e..d5122c7a4 100644 --- a/mediapipe/calculators/image/luminance_calculator.cc +++ b/mediapipe/calculators/image/luminance_calculator.cc @@ -26,10 +26,9 @@ namespace mediapipe { // See GlSimpleCalculatorBase for inputs, outputs and input side packets. class LuminanceCalculator : public GlSimpleCalculator { public: - mediapipe::Status GlSetup() override; - mediapipe::Status GlRender(const GlTexture& src, - const GlTexture& dst) override; - mediapipe::Status GlTeardown() override; + absl::Status GlSetup() override; + absl::Status GlRender(const GlTexture& src, const GlTexture& dst) override; + absl::Status GlTeardown() override; private: GLuint program_ = 0; @@ -37,7 +36,7 @@ class LuminanceCalculator : public GlSimpleCalculator { }; REGISTER_CALCULATOR(LuminanceCalculator); -mediapipe::Status LuminanceCalculator::GlSetup() { +absl::Status LuminanceCalculator::GlSetup() { // Load vertex and fragment shaders const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, @@ -83,11 +82,11 @@ mediapipe::Status LuminanceCalculator::GlSetup() { (const GLchar**)&attr_name[0], attr_location, &program_); RET_CHECK(program_) << "Problem initializing the program."; frame_ = glGetUniformLocation(program_, "video_frame"); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LuminanceCalculator::GlRender(const GlTexture& src, - const GlTexture& dst) { +absl::Status LuminanceCalculator::GlRender(const GlTexture& src, + const GlTexture& dst) { static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -137,15 +136,15 @@ mediapipe::Status LuminanceCalculator::GlRender(const GlTexture& src, glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LuminanceCalculator::GlTeardown() { +absl::Status LuminanceCalculator::GlTeardown() { if (program_) { glDeleteProgram(program_); program_ = 0; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/mask_overlay_calculator.cc b/mediapipe/calculators/image/mask_overlay_calculator.cc index 9844f2ed5..5fbf9e4f4 100644 --- a/mediapipe/calculators/image/mask_overlay_calculator.cc +++ b/mediapipe/calculators/image/mask_overlay_calculator.cc @@ -52,14 +52,14 @@ class MaskOverlayCalculator : public CalculatorBase { MaskOverlayCalculator() {} ~MaskOverlayCalculator(); - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; - mediapipe::Status GlSetup( + absl::Status GlSetup( const MaskOverlayCalculatorOptions::MaskChannel mask_channel); - mediapipe::Status GlRender(const float mask_const); + absl::Status GlRender(const float mask_const); private: GlCalculatorHelper helper_; @@ -73,7 +73,7 @@ class MaskOverlayCalculator : public CalculatorBase { REGISTER_CALCULATOR(MaskOverlayCalculator); // static -mediapipe::Status MaskOverlayCalculator::GetContract(CalculatorContract* cc) { +absl::Status MaskOverlayCalculator::GetContract(CalculatorContract* cc) { MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); cc->Inputs().Get("VIDEO", 0).Set(); cc->Inputs().Get("VIDEO", 1).Set(); @@ -82,13 +82,13 @@ mediapipe::Status MaskOverlayCalculator::GetContract(CalculatorContract* cc) { else if (cc->Inputs().HasTag("CONST_MASK")) cc->Inputs().Tag("CONST_MASK").Set(); else - return mediapipe::Status(mediapipe::StatusCode::kNotFound, - "At least one mask input stream must be present."); + return absl::Status(absl::StatusCode::kNotFound, + "At least one mask input stream must be present."); cc->Outputs().Tag("OUTPUT").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MaskOverlayCalculator::Open(CalculatorContext* cc) { +absl::Status MaskOverlayCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); if (cc->Inputs().HasTag("MASK")) { use_mask_tex_ = true; @@ -96,8 +96,8 @@ mediapipe::Status MaskOverlayCalculator::Open(CalculatorContext* cc) { return helper_.Open(cc); } -mediapipe::Status MaskOverlayCalculator::Process(CalculatorContext* cc) { - return helper_.RunInGlContext([this, &cc]() -> mediapipe::Status { +absl::Status MaskOverlayCalculator::Process(CalculatorContext* cc) { + return helper_.RunInGlContext([this, &cc]() -> absl::Status { if (!initialized_) { const auto& options = cc->Options(); const auto mask_channel = options.mask_channel(); @@ -115,7 +115,7 @@ mediapipe::Status MaskOverlayCalculator::Process(CalculatorContext* cc) { if (mask_packet.IsEmpty()) { cc->Outputs().Tag("OUTPUT").AddPacket(input1_packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input0_buffer = cc->Inputs().Get("VIDEO", 0).Get(); @@ -172,11 +172,11 @@ mediapipe::Status MaskOverlayCalculator::Process(CalculatorContext* cc) { dst.Release(); cc->Outputs().Tag("OUTPUT").Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); }); } -mediapipe::Status MaskOverlayCalculator::GlSetup( +absl::Status MaskOverlayCalculator::GlSetup( const MaskOverlayCalculatorOptions::MaskChannel mask_channel) { // Load vertex and fragment shaders const GLint attr_location[NUM_ATTRIBUTES] = { @@ -247,10 +247,10 @@ mediapipe::Status MaskOverlayCalculator::GlSetup( unif_frame1_ = glGetUniformLocation(program_, "frame1"); unif_frame2_ = glGetUniformLocation(program_, "frame2"); unif_mask_ = glGetUniformLocation(program_, "mask"); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MaskOverlayCalculator::GlRender(const float mask_const) { +absl::Status MaskOverlayCalculator::GlRender(const float mask_const) { glUseProgram(program_); glVertexAttribPointer(ATTRIB_VERTEX, 2, GL_FLOAT, 0, 0, kBasicSquareVertices); glEnableVertexAttribArray(ATTRIB_VERTEX); @@ -266,7 +266,7 @@ mediapipe::Status MaskOverlayCalculator::GlRender(const float mask_const) { glUniform1f(unif_mask_, mask_const); glDrawArrays(GL_TRIANGLE_STRIP, 0, 4); - return mediapipe::OkStatus(); + return absl::OkStatus(); } MaskOverlayCalculator::~MaskOverlayCalculator() { diff --git a/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc b/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc index 6d30bc290..21bc587f3 100644 --- a/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc +++ b/mediapipe/calculators/image/opencv_encoded_image_to_image_frame_calculator.cc @@ -34,29 +34,29 @@ namespace mediapipe { // } class OpenCvEncodedImageToImageFrameCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: mediapipe::OpenCvEncodedImageToImageFrameCalculatorOptions options_; }; -mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::GetContract( +absl::Status OpenCvEncodedImageToImageFrameCalculator::GetContract( CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::Open( +absl::Status OpenCvEncodedImageToImageFrameCalculator::Open( CalculatorContext* cc) { options_ = cc->Options(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::Process( +absl::Status OpenCvEncodedImageToImageFrameCalculator::Process( CalculatorContext* cc) { const std::string& contents = cc->Inputs().Index(0).Get(); const std::vector contents_vector(contents.begin(), contents.end()); @@ -84,8 +84,9 @@ mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::Process( cv::cvtColor(decoded_mat, output_mat, cv::COLOR_BGR2RGB); break; case 4: - return mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) - << "4-channel image isn't supported yet"; + image_format = ImageFormat::SRGBA; + cv::cvtColor(decoded_mat, output_mat, cv::COLOR_BGR2RGBA); + break; default: return mediapipe::FailedPreconditionErrorBuilder(MEDIAPIPE_LOC) << "Unsupported number of channels: " << decoded_mat.channels(); @@ -95,7 +96,7 @@ mediapipe::Status OpenCvEncodedImageToImageFrameCalculator::Process( ImageFrame::kGlDefaultAlignmentBoundary); output_mat.copyTo(formats::MatView(output_frame.get())); cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(OpenCvEncodedImageToImageFrameCalculator); diff --git a/mediapipe/calculators/image/opencv_image_encoder_calculator.cc b/mediapipe/calculators/image/opencv_image_encoder_calculator.cc index 6f72346da..93ec9435f 100644 --- a/mediapipe/calculators/image/opencv_image_encoder_calculator.cc +++ b/mediapipe/calculators/image/opencv_image_encoder_calculator.cc @@ -38,29 +38,28 @@ namespace mediapipe { // } class OpenCvImageEncoderCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: int encoding_quality_; }; -mediapipe::Status OpenCvImageEncoderCalculator::GetContract( - CalculatorContract* cc) { +absl::Status OpenCvImageEncoderCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OpenCvImageEncoderCalculator::Open(CalculatorContext* cc) { +absl::Status OpenCvImageEncoderCalculator::Open(CalculatorContext* cc) { auto options = cc->Options(); encoding_quality_ = options.quality(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OpenCvImageEncoderCalculator::Process(CalculatorContext* cc) { +absl::Status OpenCvImageEncoderCalculator::Process(CalculatorContext* cc) { const ImageFrame& image_frame = cc->Inputs().Index(0).Get(); CHECK_EQ(1, image_frame.ByteDepth()); @@ -104,15 +103,14 @@ mediapipe::Status OpenCvImageEncoderCalculator::Process(CalculatorContext* cc) { << "Fail to encode the image to be jpeg format."; } - encoded_result->set_encoded_image(std::string(absl::string_view( - reinterpret_cast(&encode_buffer[0]), encode_buffer.size()))); + encoded_result->set_encoded_image(&encode_buffer[0], encode_buffer.size()); cc->Outputs().Index(0).Add(encoded_result.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OpenCvImageEncoderCalculator::Close(CalculatorContext* cc) { - return mediapipe::OkStatus(); +absl::Status OpenCvImageEncoderCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); } REGISTER_CALCULATOR(OpenCvImageEncoderCalculator); diff --git a/mediapipe/calculators/image/opencv_image_encoder_calculator.proto b/mediapipe/calculators/image/opencv_image_encoder_calculator.proto index 43172b319..0564fb270 100644 --- a/mediapipe/calculators/image/opencv_image_encoder_calculator.proto +++ b/mediapipe/calculators/image/opencv_image_encoder_calculator.proto @@ -29,11 +29,13 @@ message OpenCvImageEncoderCalculatorOptions { // TODO: Consider renaming it to EncodedImage. message OpenCvImageEncoderCalculatorResults { - // Encoded image - optional string encoded_image = 1; + // Pixel data encoded as JPEG. + optional bytes encoded_image = 1; - // Dimensions of the encoded image + // Height of the image data under #1 once decoded. optional int32 height = 2; + + // Width of the image data under #1 once decoded. optional int32 width = 3; enum ColorSpace { diff --git a/mediapipe/calculators/image/opencv_put_text_calculator.cc b/mediapipe/calculators/image/opencv_put_text_calculator.cc index e7769486f..82a4b3a53 100644 --- a/mediapipe/calculators/image/opencv_put_text_calculator.cc +++ b/mediapipe/calculators/image/opencv_put_text_calculator.cc @@ -32,17 +32,17 @@ namespace mediapipe { // TODO: Generalize the calculator for other text use cases. class OpenCvPutTextCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Process(CalculatorContext* cc) override; }; -mediapipe::Status OpenCvPutTextCalculator::GetContract(CalculatorContract* cc) { +absl::Status OpenCvPutTextCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OpenCvPutTextCalculator::Process(CalculatorContext* cc) { +absl::Status OpenCvPutTextCalculator::Process(CalculatorContext* cc) { const std::string& text_content = cc->Inputs().Index(0).Get(); cv::Mat mat = cv::Mat::zeros(640, 640, CV_8UC4); cv::putText(mat, text_content, cv::Point(15, 70), cv::FONT_HERSHEY_PLAIN, 3, @@ -51,7 +51,7 @@ mediapipe::Status OpenCvPutTextCalculator::Process(CalculatorContext* cc) { ImageFormat::SRGBA, mat.size().width, mat.size().height); mat.copyTo(formats::MatView(output_frame.get())); cc->Outputs().Index(0).Add(output_frame.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(OpenCvPutTextCalculator); diff --git a/mediapipe/calculators/image/recolor_calculator.cc b/mediapipe/calculators/image/recolor_calculator.cc index db0a46c7f..6a12025f6 100644 --- a/mediapipe/calculators/image/recolor_calculator.cc +++ b/mediapipe/calculators/image/recolor_calculator.cc @@ -24,11 +24,11 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/util/color.pb.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace { enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; @@ -84,17 +84,17 @@ class RecolorCalculator : public CalculatorBase { RecolorCalculator() = default; ~RecolorCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status LoadOptions(CalculatorContext* cc); - mediapipe::Status InitGpu(CalculatorContext* cc); - mediapipe::Status RenderGpu(CalculatorContext* cc); - mediapipe::Status RenderCpu(CalculatorContext* cc); + absl::Status LoadOptions(CalculatorContext* cc); + absl::Status InitGpu(CalculatorContext* cc); + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); void GlRender(); bool initialized_ = false; @@ -102,46 +102,46 @@ class RecolorCalculator : public CalculatorBase { mediapipe::RecolorCalculatorOptions::MaskChannel mask_channel_; bool use_gpu_ = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(RecolorCalculator); // static -mediapipe::Status RecolorCalculator::GetContract(CalculatorContract* cc) { +absl::Status RecolorCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); bool use_gpu = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { cc->Inputs().Tag(kGpuBufferTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kImageFrameTag)) { cc->Inputs().Tag(kImageFrameTag).Set(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kMaskGpuTag)) { cc->Inputs().Tag(kMaskGpuTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kMaskCpuTag)) { cc->Inputs().Tag(kMaskCpuTag).Set(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kGpuBufferTag)) { cc->Outputs().Tag(kGpuBufferTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kImageFrameTag)) { cc->Outputs().Tag(kImageFrameTag).Set(); } @@ -154,62 +154,62 @@ mediapipe::Status RecolorCalculator::GetContract(CalculatorContract* cc) { cc->Outputs().HasTag(kGpuBufferTag)); if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RecolorCalculator::Open(CalculatorContext* cc) { +absl::Status RecolorCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); if (cc->Inputs().HasTag(kGpuBufferTag)) { use_gpu_ = true; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } MP_RETURN_IF_ERROR(LoadOptions(cc)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RecolorCalculator::Process(CalculatorContext* cc) { +absl::Status RecolorCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, &cc]() -> mediapipe::Status { + gpu_helper_.RunInGlContext([this, &cc]() -> absl::Status { if (!initialized_) { MP_RETURN_IF_ERROR(InitGpu(cc)); initialized_ = true; } MP_RETURN_IF_ERROR(RenderGpu(cc)); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RecolorCalculator::Close(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status RecolorCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) { +absl::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kMaskCpuTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Get inputs and setup output. const auto& input_img = cc->Inputs().Tag(kImageFrameTag).Get(); @@ -265,14 +265,14 @@ mediapipe::Status RecolorCalculator::RenderCpu(CalculatorContext* cc) { .Tag(kImageFrameTag) .Add(output_img.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) { +absl::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kMaskGpuTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU // Get inputs and setup output. const Packet& input_packet = cc->Inputs().Tag(kGpuBufferTag).Value(); const Packet& mask_packet = cc->Inputs().Tag(kMaskGpuTag).Value(); @@ -311,13 +311,13 @@ mediapipe::Status RecolorCalculator::RenderGpu(CalculatorContext* cc) { img_tex.Release(); mask_tex.Release(); dst_tex.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } void RecolorCalculator::GlRender() { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -365,10 +365,10 @@ void RecolorCalculator::GlRender() { glBindVertexArray(0); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } -mediapipe::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) { +absl::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) { const auto& options = cc->Options(); mask_channel_ = options.mask_channel(); @@ -379,11 +379,11 @@ mediapipe::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) { color_.push_back(options.color().g()); color_.push_back(options.color().b()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RecolorCalculator::InitGpu(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status RecolorCalculator::InitGpu(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -452,9 +452,9 @@ mediapipe::Status RecolorCalculator::InitGpu(CalculatorContext* cc) { glUniform1i(glGetUniformLocation(program_, "mask"), 2); glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0] / 255.0, color_[1] / 255.0, color_[2] / 255.0); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/scale_image_calculator.cc b/mediapipe/calculators/image/scale_image_calculator.cc index 6d321b474..575268da5 100644 --- a/mediapipe/calculators/image/scale_image_calculator.cc +++ b/mediapipe/calculators/image/scale_image_calculator.cc @@ -44,7 +44,7 @@ namespace { // Given an upscaling algorithm, determine which OpenCV interpolation algorithm // to use. -mediapipe::Status FindInterpolationAlgorithm( +absl::Status FindInterpolationAlgorithm( ScaleImageCalculatorOptions::ScaleAlgorithm upscaling_algorithm, int* interpolation_algorithm) { switch (upscaling_algorithm) { @@ -70,7 +70,7 @@ mediapipe::Status FindInterpolationAlgorithm( RET_CHECK_FAIL() << absl::Substitute("Unknown upscaling algorithm: $0", upscaling_algorithm); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void CropImageFrame(const ImageFrame& original, int col_start, int row_start, @@ -147,7 +147,7 @@ class ScaleImageCalculator : public CalculatorBase { ScaleImageCalculator(); ~ScaleImageCalculator() override; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { ScaleImageCalculatorOptions options = cc->Options(); @@ -184,35 +184,35 @@ class ScaleImageCalculator : public CalculatorBase { if (cc->Inputs().HasTag("OVERRIDE_OPTIONS")) { cc->Inputs().Tag("OVERRIDE_OPTIONS").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // From Calculator. - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Initialize some data members from options_. This can be called either from // Open or Process depending on whether OVERRIDE_OPTIONS is used. - mediapipe::Status InitializeFromOptions(); + absl::Status InitializeFromOptions(); // Initialize crop and output parameters based on set member variable // values. This function will also send the header information on // the VIDEO_HEADER stream if it hasn't been done yet. - mediapipe::Status InitializeFrameInfo(CalculatorContext* cc); + absl::Status InitializeFrameInfo(CalculatorContext* cc); // Validate that input_format_ and output_format_ are supported image // formats. - mediapipe::Status ValidateImageFormats() const; + absl::Status ValidateImageFormats() const; // Validate that the image frame has the proper format and dimensions. // If the dimensions and format weren't initialized by the header, // then the first frame on which this function is called is used // to initialize. - mediapipe::Status ValidateImageFrame(CalculatorContext* cc, - const ImageFrame& image_frame); + absl::Status ValidateImageFrame(CalculatorContext* cc, + const ImageFrame& image_frame); // Validate that the YUV image has the proper dimensions. If the // dimensions weren't initialized by the header, then the first image // on which this function is called is used to initialize. - mediapipe::Status ValidateYUVImage(CalculatorContext* cc, - const YUVImage& yuv_image); + absl::Status ValidateYUVImage(CalculatorContext* cc, + const YUVImage& yuv_image); bool has_header_; // True if the input stream has a header. int input_width_; @@ -251,8 +251,7 @@ ScaleImageCalculator::ScaleImageCalculator() {} ScaleImageCalculator::~ScaleImageCalculator() {} -mediapipe::Status ScaleImageCalculator::InitializeFrameInfo( - CalculatorContext* cc) { +absl::Status ScaleImageCalculator::InitializeFrameInfo(CalculatorContext* cc) { MP_RETURN_IF_ERROR( scale_image::FindCropDimensions(input_width_, input_height_, // options_.min_aspect_ratio(), // @@ -299,10 +298,10 @@ mediapipe::Status ScaleImageCalculator::InitializeFrameInfo( .Add(header.release(), Timestamp::PreStream()); cc->Outputs().Tag("VIDEO_HEADER").Close(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ScaleImageCalculator::Open(CalculatorContext* cc) { +absl::Status ScaleImageCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); input_data_id_ = cc->Inputs().GetId("FRAMES", 0); @@ -339,7 +338,7 @@ mediapipe::Status ScaleImageCalculator::Open(CalculatorContext* cc) { // has a header. At this point in the code, the ScaleImageCalculator // config may be changed by the new options at PreStream, so the output // header can't be determined. - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "OVERRIDE_OPTIONS stream can't be used when the main input stream " "has a header."); } @@ -406,10 +405,10 @@ mediapipe::Status ScaleImageCalculator::Open(CalculatorContext* cc) { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ScaleImageCalculator::InitializeFromOptions() { +absl::Status ScaleImageCalculator::InitializeFromOptions() { if (options_.has_input_format()) { input_format_ = options_.input_format(); } else { @@ -423,10 +422,10 @@ mediapipe::Status ScaleImageCalculator::InitializeFromOptions() { downscaler_.reset(new ImageResizer(options_.post_sharpening_coefficient())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ScaleImageCalculator::ValidateImageFormats() const { +absl::Status ScaleImageCalculator::ValidateImageFormats() const { RET_CHECK_NE(input_format_, ImageFormat::UNKNOWN) << "The input image format was UNKNOWN."; RET_CHECK_NE(output_format_, ImageFormat::UNKNOWN) @@ -440,10 +439,10 @@ mediapipe::Status ScaleImageCalculator::ValidateImageFormats() const { input_format_ == ImageFormat::YCBCR420P) << "Conversion of the color space (except from " "YCbCr420P to SRGB) is not yet supported."; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ScaleImageCalculator::ValidateImageFrame( +absl::Status ScaleImageCalculator::ValidateImageFrame( CalculatorContext* cc, const ImageFrame& image_frame) { if (!has_header_) { if (input_width_ != image_frame.Width() || @@ -494,11 +493,11 @@ mediapipe::Status ScaleImageCalculator::ValidateImageFrame( image_frame_format_desc, " but expected ", input_format_desc)); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ScaleImageCalculator::ValidateYUVImage( - CalculatorContext* cc, const YUVImage& yuv_image) { +absl::Status ScaleImageCalculator::ValidateYUVImage(CalculatorContext* cc, + const YUVImage& yuv_image) { CHECK_EQ(input_format_, ImageFormat::YCBCR420P); if (!has_header_) { if (input_width_ != yuv_image.width() || @@ -528,14 +527,14 @@ mediapipe::Status ScaleImageCalculator::ValidateYUVImage( input_width_, "x", input_height_)); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ScaleImageCalculator::Process(CalculatorContext* cc) { +absl::Status ScaleImageCalculator::Process(CalculatorContext* cc) { if (cc->InputTimestamp() == Timestamp::PreStream()) { if (cc->Inputs().HasTag("OVERRIDE_OPTIONS")) { if (cc->Inputs().Tag("OVERRIDE_OPTIONS").IsEmpty()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "The OVERRIDE_OPTIONS input stream must be non-empty at PreStream " "time if used."); } @@ -549,7 +548,7 @@ mediapipe::Status ScaleImageCalculator::Process(CalculatorContext* cc) { input_video_header_ = cc->Inputs().Tag("VIDEO_HEADER").Get(); } if (cc->Inputs().Get(input_data_id_).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -603,7 +602,7 @@ mediapipe::Status ScaleImageCalculator::Process(CalculatorContext* cc) { cc->Outputs() .Get(output_data_id_) .Add(output_image.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } else { image_frame = &cc->Inputs().Get(input_data_id_).Get(); @@ -664,7 +663,7 @@ mediapipe::Status ScaleImageCalculator::Process(CalculatorContext* cc) { .Add(output_frame.release(), cc->InputTimestamp()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Rescale the image frame. @@ -698,7 +697,7 @@ mediapipe::Status ScaleImageCalculator::Process(CalculatorContext* cc) { cc->Outputs() .Get(output_data_id_) .Add(output_frame.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/scale_image_utils.cc b/mediapipe/calculators/image/scale_image_utils.cc index 20049af76..738e83da0 100644 --- a/mediapipe/calculators/image/scale_image_utils.cc +++ b/mediapipe/calculators/image/scale_image_utils.cc @@ -35,11 +35,11 @@ double ParseRational(const std::string& rational) { } } // namespace -mediapipe::Status FindCropDimensions(int input_width, int input_height, // - const std::string& min_aspect_ratio, // - const std::string& max_aspect_ratio, // - int* crop_width, int* crop_height, // - int* col_start, int* row_start) { +absl::Status FindCropDimensions(int input_width, int input_height, // + const std::string& min_aspect_ratio, // + const std::string& max_aspect_ratio, // + int* crop_width, int* crop_height, // + int* col_start, int* row_start) { CHECK(crop_width); CHECK(crop_height); CHECK(col_start); @@ -85,16 +85,16 @@ mediapipe::Status FindCropDimensions(int input_width, int input_height, // CHECK_LE(*crop_width, input_width); CHECK_LE(*crop_height, input_height); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FindOutputDimensions(int input_width, // - int input_height, // - int target_width, // - int target_height, // - bool preserve_aspect_ratio, // - int scale_to_multiple_of, // - int* output_width, int* output_height) { +absl::Status FindOutputDimensions(int input_width, // + int input_height, // + int target_width, // + int target_height, // + bool preserve_aspect_ratio, // + int scale_to_multiple_of, // + int* output_width, int* output_height) { CHECK(output_width); CHECK(output_height); @@ -122,7 +122,7 @@ mediapipe::Status FindOutputDimensions(int input_width, // *output_width = target_width; *output_height = target_height; - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (target_width > 0) { @@ -139,7 +139,7 @@ mediapipe::Status FindOutputDimensions(int input_width, // // was within the image, so use these dimensions. *output_width = try_width; *output_height = try_height; - return mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -157,7 +157,7 @@ mediapipe::Status FindOutputDimensions(int input_width, // // was within the image, so use these dimensions. *output_width = try_width; *output_height = try_height; - return mediapipe::OkStatus(); + return absl::OkStatus(); } } RET_CHECK_FAIL() diff --git a/mediapipe/calculators/image/scale_image_utils.h b/mediapipe/calculators/image/scale_image_utils.h index 86b014ca3..c2c0b8f7c 100644 --- a/mediapipe/calculators/image/scale_image_utils.h +++ b/mediapipe/calculators/image/scale_image_utils.h @@ -28,11 +28,11 @@ namespace scale_image { // is a centered, cropped portion of the image that falls within the min // and max aspect ratio. If either the min or max aspect ratio argument // is empty or has a 0 in the numerator or denominator then it is ignored. -mediapipe::Status FindCropDimensions(int input_width, int input_height, // - const std::string& min_aspect_ratio, // - const std::string& max_aspect_ratio, // - int* crop_width, int* crop_height, // - int* col_start, int* row_start); +absl::Status FindCropDimensions(int input_width, int input_height, // + const std::string& min_aspect_ratio, // + const std::string& max_aspect_ratio, // + int* crop_width, int* crop_height, // + int* col_start, int* row_start); // Given an input width and height, a target width and height, whether to // preserve the aspect ratio, and whether to round-down to the multiple of a @@ -43,12 +43,12 @@ mediapipe::Status FindCropDimensions(int input_width, int input_height, // // output_height will be reduced as necessary to preserve_aspect_ratio if the // option is specified. If preserving the aspect ratio is desired, you must set // scale_to_multiple_of to 2. -mediapipe::Status FindOutputDimensions(int input_width, int input_height, // - int target_width, - int target_height, // - bool preserve_aspect_ratio, // - int scale_to_multiple_of, // - int* output_width, int* output_height); +absl::Status FindOutputDimensions(int input_width, int input_height, // + int target_width, + int target_height, // + bool preserve_aspect_ratio, // + int scale_to_multiple_of, // + int* output_width, int* output_height); } // namespace scale_image } // namespace mediapipe diff --git a/mediapipe/calculators/image/set_alpha_calculator.cc b/mediapipe/calculators/image/set_alpha_calculator.cc index 683efce6b..08c150d21 100644 --- a/mediapipe/calculators/image/set_alpha_calculator.cc +++ b/mediapipe/calculators/image/set_alpha_calculator.cc @@ -25,11 +25,11 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -87,18 +87,18 @@ class SetAlphaCalculator : public CalculatorBase { SetAlphaCalculator() = default; ~SetAlphaCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // From Calculator. - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status RenderGpu(CalculatorContext* cc); - mediapipe::Status RenderCpu(CalculatorContext* cc); + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); - mediapipe::Status GlSetup(CalculatorContext* cc); + absl::Status GlSetup(CalculatorContext* cc); void GlRender(CalculatorContext* cc); mediapipe::SetAlphaCalculatorOptions options_; @@ -106,81 +106,81 @@ class SetAlphaCalculator : public CalculatorBase { bool use_gpu_ = false; bool gpu_initialized_ = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(SetAlphaCalculator); -mediapipe::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) { +absl::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) { CHECK_GE(cc->Inputs().NumEntries(), 1); bool use_gpu = false; if (cc->Inputs().HasTag(kInputFrameTag) && cc->Inputs().HasTag(kInputFrameTagGpu)) { - return mediapipe::InternalError("Cannot have multiple input images."); + return absl::InternalError("Cannot have multiple input images."); } if (cc->Inputs().HasTag(kInputFrameTagGpu) != cc->Outputs().HasTag(kOutputFrameTagGpu)) { - return mediapipe::InternalError("GPU output must have GPU input."); + return absl::InternalError("GPU output must have GPU input."); } // Input image to add/edit alpha channel. -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().Tag(kInputFrameTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputFrameTag)) { cc->Inputs().Tag(kInputFrameTag).Set(); } // Input alpha image mask (optional) -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputAlphaTagGpu)) { cc->Inputs().Tag(kInputAlphaTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kInputAlphaTag)) { cc->Inputs().Tag(kInputAlphaTag).Set(); } // RGBA output image. -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().Tag(kOutputFrameTagGpu).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kOutputFrameTag)) { cc->Outputs().Tag(kOutputFrameTag).Set(); } if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SetAlphaCalculator::Open(CalculatorContext* cc) { +absl::Status SetAlphaCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); if (cc->Inputs().HasTag(kInputFrameTagGpu) && cc->Outputs().HasTag(kOutputFrameTagGpu)) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU use_gpu_ = true; #else RET_CHECK_FAIL() << "GPU processing not enabled."; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } // Get global value from options (-1 if not set). @@ -193,48 +193,47 @@ mediapipe::Status SetAlphaCalculator::Open(CalculatorContext* cc) { RET_CHECK_FAIL() << "Must use either image mask or options alpha value."; if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); #endif } // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SetAlphaCalculator::Process(CalculatorContext* cc) { +absl::Status SetAlphaCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> mediapipe::Status { - if (!gpu_initialized_) { - MP_RETURN_IF_ERROR(GlSetup(cc)); - gpu_initialized_ = true; - } - MP_RETURN_IF_ERROR(RenderGpu(cc)); - return mediapipe::OkStatus(); - })); -#endif // !MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + if (!gpu_initialized_) { + MP_RETURN_IF_ERROR(GlSetup(cc)); + gpu_initialized_ = true; + } + MP_RETURN_IF_ERROR(RenderGpu(cc)); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(RenderCpu(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SetAlphaCalculator::Close(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status SetAlphaCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) { +absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kInputFrameTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Setup source image @@ -294,14 +293,14 @@ mediapipe::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) { .Tag(kOutputFrameTag) .Add(output_frame.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) { +absl::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) { if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU // Setup source texture. const auto& input_frame = cc->Inputs().Tag(kInputFrameTagGpu).Get(); @@ -354,13 +353,13 @@ mediapipe::Status SetAlphaCalculator::RenderGpu(CalculatorContext* cc) { // Cleanup input_texture.Release(); output_texture.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } void SetAlphaCalculator::GlRender(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -409,11 +408,11 @@ void SetAlphaCalculator::GlRender(CalculatorContext* cc) { glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } -mediapipe::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -466,9 +465,9 @@ mediapipe::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) { glUniform1i(glGetUniformLocation(program_, "alpha_mask"), 2); glUniform1f(glGetUniformLocation(program_, "alpha_value"), alpha_value_); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/sobel_edges_calculator.cc b/mediapipe/calculators/image/sobel_edges_calculator.cc index a9a8d637b..6154a246b 100644 --- a/mediapipe/calculators/image/sobel_edges_calculator.cc +++ b/mediapipe/calculators/image/sobel_edges_calculator.cc @@ -27,10 +27,9 @@ namespace mediapipe { // See GlSimpleCalculatorBase for inputs, outputs and input side packets. class SobelEdgesCalculator : public GlSimpleCalculator { public: - mediapipe::Status GlSetup() override; - mediapipe::Status GlRender(const GlTexture& src, - const GlTexture& dst) override; - mediapipe::Status GlTeardown() override; + absl::Status GlSetup() override; + absl::Status GlRender(const GlTexture& src, const GlTexture& dst) override; + absl::Status GlTeardown() override; private: GLuint program_ = 0; @@ -40,7 +39,7 @@ class SobelEdgesCalculator : public GlSimpleCalculator { }; REGISTER_CALCULATOR(SobelEdgesCalculator); -mediapipe::Status SobelEdgesCalculator::GlSetup() { +absl::Status SobelEdgesCalculator::GlSetup() { // Load vertex and fragment shaders const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, @@ -166,11 +165,11 @@ mediapipe::Status SobelEdgesCalculator::GlSetup() { frame_ = glGetUniformLocation(program_, "inputImage"); pixel_w_ = glGetUniformLocation(program_, "pixelW"); pixel_h_ = glGetUniformLocation(program_, "pixelH"); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SobelEdgesCalculator::GlRender(const GlTexture& src, - const GlTexture& dst) { +absl::Status SobelEdgesCalculator::GlRender(const GlTexture& src, + const GlTexture& dst) { static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -225,15 +224,15 @@ mediapipe::Status SobelEdgesCalculator::GlRender(const GlTexture& src, glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SobelEdgesCalculator::GlTeardown() { +absl::Status SobelEdgesCalculator::GlTeardown() { if (program_) { glDeleteProgram(program_); program_ = 0; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/internal/callback_packet_calculator.cc b/mediapipe/calculators/internal/callback_packet_calculator.cc index e78007fbb..cc153483e 100644 --- a/mediapipe/calculators/internal/callback_packet_calculator.cc +++ b/mediapipe/calculators/internal/callback_packet_calculator.cc @@ -50,7 +50,7 @@ void DumpPostStreamPacket(Packet* post_stream_packet, const Packet& packet) { // while that pointer is still alive. class CallbackPacketCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); switch (options.type()) { case CallbackPacketCalculatorOptions::VECTOR_PACKET: @@ -63,10 +63,10 @@ class CallbackPacketCalculator : public CalculatorBase { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Invalid type of callback to produce."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const auto& options = cc->Options(); void* ptr; if (sscanf(options.pointer().c_str(), "%p", &ptr) != 1) { @@ -90,11 +90,11 @@ class CallbackPacketCalculator : public CalculatorBase { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Invalid type to dump into."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index dc465d4cd..5a0631007 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -15,6 +15,12 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") +load( + "//mediapipe/framework/tool:mediapipe_graph.bzl", + "mediapipe_binary_graph", +) +load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") +load("//mediapipe/framework:encode_binary_proto.bzl", "encode_binary_proto") licenses(["notice"]) @@ -38,81 +44,128 @@ mediapipe_proto_library( ) cc_library( - name = "inference_calculator", + name = "inference_calculator_interface", srcs = ["inference_calculator.cc"], + hdrs = ["inference_calculator.h"], copts = select({ + # TODO: fix tensor.h not to require this, if possible "//mediapipe:apple": [ "-x objective-c++", "-fobjc-arc", # enable reference-counting ], "//conditions:default": [], }), - features = ["-layering_check"], # allow depending on inference_calculator_gpu_deps - linkopts = select({ + deps = [ + ":inference_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", + "//mediapipe/framework/tool:subgraph_expansion", + "//mediapipe/util/tflite:config", + "//mediapipe/util/tflite:tflite_model_loader", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +cc_library( + name = "inference_calculator_gl", + srcs = ["inference_calculator_gl.cc"], + tags = ["nomac"], # config problem with cpuinfo via TF + deps = [ + "inference_calculator_interface", + "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/util/tflite:tflite_gpu_runner", + "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", + "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", + ], + alwayslink = 1, +) + +cc_library( + name = "inference_calculator_metal", + srcs = ["inference_calculator_metal.cc"], + copts = [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + linkopts = [ + "-framework CoreVideo", + "-framework MetalKit", + ], + tags = ["ios"], + deps = [ + "inference_calculator_interface", + "//mediapipe/gpu:MPPMetalHelper", + "//mediapipe/gpu:MPPMetalUtil", + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/objc:mediapipe_framework_ios", + "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", + "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", + "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", + "@org_tensorflow//tensorflow/lite/delegates/gpu/metal:buffer_convert", + ], + alwayslink = 1, +) + +cc_library( + name = "inference_calculator_cpu", + srcs = [ + "inference_calculator_cpu.cc", + ], + copts = select({ + # TODO: fix tensor.h not to require this, if possible "//mediapipe:apple": [ - "-framework CoreVideo", - "-framework MetalKit", + "-x objective-c++", + "-fobjc-arc", # enable reference-counting ], "//conditions:default": [], }), - visibility = ["//visibility:public"], deps = [ - ":inference_calculator_cc_proto", + ":inference_calculator_interface", "@com_google_absl//absl/memory", - "//mediapipe/framework/api2:node", - "//mediapipe/framework:calculator_framework", - "//mediapipe/framework/formats:tensor", - "//mediapipe/util/tflite:tflite_model_loader", - "//mediapipe/util/tflite:config", - "@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", - "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", - "//mediapipe/framework/port:ret_check", ] + select({ - ":compute_shader_unavailable": [], - "//conditions:default": [":inference_calculator_gpu_deps"], - }) + select({ - "//conditions:default": [], - "//mediapipe:android": [ - "//mediapipe/util/android/file/base", - "@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate", - ], - }) + select({ "//conditions:default": [ "//mediapipe/util:cpu_util", ], + }) + select({ + "//conditions:default": [], + "//mediapipe:android": ["@org_tensorflow//tensorflow/lite/delegates/nnapi:nnapi_delegate"], }), alwayslink = 1, ) cc_library( - name = "inference_calculator_gpu_deps", - deps = selects.with_or({ - "//mediapipe:ios": [ - "//mediapipe/gpu:MPPMetalHelper", - "//mediapipe/gpu:MPPMetalUtil", - "//mediapipe/gpu:gpu_buffer", - "//mediapipe/objc:mediapipe_framework_ios", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", - "@org_tensorflow//tensorflow/lite/delegates/gpu/metal:buffer_convert", - "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", - "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate_internal", - ], - "//mediapipe:macos": [], - "//conditions:default": [ - "//mediapipe/util/tflite:tflite_gpu_runner", - "//mediapipe/gpu:gl_calculator_helper", - "//mediapipe/gpu:gpu_buffer", - "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", - "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", - "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", - ], + name = "inference_calculator_gl_if_compute_shader_available", + deps = select({ + ":compute_shader_unavailable": [], + "//conditions:default": [":inference_calculator_gl"], }), ) +cc_library( + name = "inference_calculator", + visibility = ["//visibility:public"], + deps = [ + ":inference_calculator_interface", + ":inference_calculator_cpu", + ] + select({ + "//conditions:default": [":inference_calculator_gl_if_compute_shader_available"], + "//mediapipe:ios": [":inference_calculator_metal"], + }), + alwayslink = 1, +) + mediapipe_proto_library( name = "tensor_converter_calculator_proto", srcs = ["tensor_converter_calculator.proto"], @@ -357,6 +410,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":tensors_to_classification_calculator_cc_proto", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "//mediapipe/framework/api2:node", @@ -427,6 +481,7 @@ cc_library( ":image_to_tensor_converter_opencv", ":image_to_tensor_utils", "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", @@ -452,6 +507,7 @@ cc_library( ], "//mediapipe:apple": [ ":image_to_tensor_converter_metal", + "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:gpu_buffer", ], @@ -499,21 +555,21 @@ cc_test( "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:tensor", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/tool:validate_type", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@org_tensorflow//tensorflow/lite:framework", ], ) @@ -529,7 +585,7 @@ cc_library( }), deps = [ ":image_to_tensor_utils", - "//mediapipe/framework:packet", + "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:statusor", ], @@ -550,9 +606,9 @@ cc_library( ":image_to_tensor_converter", ":image_to_tensor_utils", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:image_frame", - "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:image_opencv", "//mediapipe/framework/formats:tensor", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", @@ -579,7 +635,7 @@ cc_library( "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", "//mediapipe/gpu:gl_calculator_helper", - "//mediapipe/gpu:gpu_buffer", + "//mediapipe/framework/formats:image", "//mediapipe/gpu:gpu_buffer_format", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:types", @@ -612,7 +668,7 @@ cc_library( "//mediapipe/framework/port:statusor", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_simple_shaders", - "//mediapipe/gpu:gpu_buffer", + "//mediapipe/framework/formats:image", "//mediapipe/gpu:shader_util", ], }), @@ -663,7 +719,7 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", - "//mediapipe/gpu:gpu_buffer", + "//mediapipe/framework/formats:image", "//mediapipe/gpu:gpu_buffer_format", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:types", diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc index 775e0e70b..91eba2de5 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -22,6 +22,7 @@ #include "mediapipe/calculators/tensor/image_to_tensor_utils.h" #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" @@ -29,6 +30,7 @@ #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/statusor.h" #if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gpu_buffer.h" @@ -60,11 +62,18 @@ using GpuBuffer = mediapipe::GpuBuffer; // normalization, according to specified inputs and options. // // Inputs: -// IMAGE - ImageFrame [ImageFormat::SRGB/SRGBA] -// Image to extract from. +// IMAGE - Image[ImageFormat::SRGB / SRGBA, GpuBufferFormat::kBGRA32] or +// ImageFrame [ImageFormat::SRGB/SRGBA] (for backward compatibility +// with existing graphs that use IMAGE for ImageFrame input) // IMAGE_GPU - GpuBuffer [GpuBufferFormat::kBGRA32] // Image to extract from. -// (Either IMAGE or IMAGE_GPU has to be specified.) +// +// Note: +// - One and only one of IMAGE and IMAGE_GPU should be specified. +// - IMAGE input of type Image is processed on GPU if the data is already on +// GPU (i.e., Image::UsesGpu() returns true), or otherwise processed on CPU. +// - IMAGE input of type ImageFrame is always processed on CPU. +// - IMAGE_GPU input (of type GpuBuffer) is always processed on GPU. // // NORM_RECT - NormalizedRect @Optional // Describes region of image to extract. @@ -112,7 +121,8 @@ using GpuBuffer = mediapipe::GpuBuffer; // } class ImageToTensorCalculator : public Node { public: - static constexpr Input::Optional kInCpu{"IMAGE"}; + static constexpr Input< + OneOf>::Optional kIn{"IMAGE"}; static constexpr Input::Optional kInGpu{"IMAGE_GPU"}; static constexpr Input::Optional kInNormRect{ "NORM_RECT"}; @@ -121,10 +131,10 @@ class ImageToTensorCalculator : public Node { "LETTERBOX_PADDING"}; static constexpr Output>::Optional kOutMatrix{"MATRIX"}; - MEDIAPIPE_NODE_CONTRACT(kInCpu, kInGpu, kInNormRect, kOutTensors, + MEDIAPIPE_NODE_CONTRACT(kIn, kInGpu, kInNormRect, kOutTensors, kOutLetterboxPadding, kOutMatrix); - static ::mediapipe::Status UpdateContract(CalculatorContract* cc) { + static absl::Status UpdateContract(CalculatorContract* cc) { const auto& options = cc->Options(); @@ -138,69 +148,47 @@ class ImageToTensorCalculator : public Node { RET_CHECK_GT(options.output_tensor_height(), 0) << "Valid output tensor height is required."; - RET_CHECK(kInCpu(cc).IsConnected() ^ kInGpu(cc).IsConnected()) - << "One and only one of CPU or GPU input is expected."; + RET_CHECK(kIn(cc).IsConnected() ^ kInGpu(cc).IsConnected()) + << "One and only one of IMAGE and IMAGE_GPU input is expected."; - if (kInGpu(cc).IsConnected()) { #if MEDIAPIPE_DISABLE_GPU - return mediapipe::UnimplementedError("GPU processing is disabled"); -#else - -#if MEDIAPIPE_METAL_ENABLED - MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); -#else - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // MEDIAPIPE_METAL_ENABLED - -#endif // MEDIAPIPE_DISABLE_GPU + if (kInGpu(cc).IsConnected()) { + return absl::UnimplementedError( + "GPU processing is disabled in build flags"); } - return mediapipe::OkStatus(); +#else // !MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_METAL_ENABLED + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); +#else + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // MEDIAPIPE_METAL_ENABLED +#endif // MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) { + absl::Status Open(CalculatorContext* cc) { options_ = cc->Options(); output_width_ = options_.output_tensor_width(); output_height_ = options_.output_tensor_height(); range_min_ = options_.output_tensor_float_range().min(); range_max_ = options_.output_tensor_float_range().max(); - if (kInCpu(cc).IsConnected()) { - ASSIGN_OR_RETURN(converter_, CreateOpenCvConverter(cc, GetBorderMode())); - } else { -#if MEDIAPIPE_DISABLE_GPU - return mediapipe::UnimplementedError("GPU processing is disabled"); -#else - -#if MEDIAPIPE_METAL_ENABLED - ASSIGN_OR_RETURN(converter_, CreateMetalConverter(cc, GetBorderMode())); -#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - ASSIGN_OR_RETURN(converter_, - CreateImageToGlBufferTensorConverter( - cc, DoesInputStartAtBottom(), GetBorderMode())); -#else - ASSIGN_OR_RETURN(converter_, - CreateImageToGlTextureTensorConverter( - cc, DoesInputStartAtBottom(), GetBorderMode())); -#endif // MEDIAPIPE_METAL_ENABLED - -#endif // MEDIAPIPE_DISABLE_GPU - } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) { - const PacketBase& image_packet = - kInCpu(cc).IsConnected() ? kInCpu(cc).packet() : kInGpu(cc).packet(); - if (image_packet.IsEmpty()) { - // Timestamp bound update happens automatically. (See Open().) - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) { + if ((kIn(cc).IsConnected() && kIn(cc).IsEmpty()) || + (kInGpu(cc).IsConnected() && kInGpu(cc).IsEmpty())) { + // Timestamp bound update happens automatically. + return absl::OkStatus(); } absl::optional norm_rect; if (kInNormRect(cc).IsConnected()) { if (kInNormRect(cc).IsEmpty()) { // Timestamp bound update happens automatically. (See Open().) - return mediapipe::OkStatus(); + return absl::OkStatus(); } norm_rect = *kInNormRect(cc); if (norm_rect->width() == 0 && norm_rect->height() == 0) { @@ -211,11 +199,12 @@ class ImageToTensorCalculator : public Node { // NOTE: usage of sentinel rects should be avoided. DLOG(WARNING) << "Updating timestamp bound in response to a sentinel rect"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } } - const Size& size = converter_->GetImageSize(image_packet); + ASSIGN_OR_RETURN(auto image, GetInputImage(cc)); + const Size size{image->width(), image->height()}; RotatedRect roi = GetRoi(size.width, size.height, norm_rect); ASSIGN_OR_RETURN(auto padding, PadRoi(options_.output_tensor_width(), options_.output_tensor_height(), @@ -231,16 +220,19 @@ class ImageToTensorCalculator : public Node { kOutMatrix(cc).Send(std::move(matrix)); } - ASSIGN_OR_RETURN( - Tensor tensor, - converter_->Convert(image_packet, roi, {output_width_, output_height_}, - range_min_, range_max_)); + // Lazy initialization of the GPU or CPU converter. + MP_RETURN_IF_ERROR(InitConverterIfNecessary(cc, image->UsesGpu())); + + ASSIGN_OR_RETURN(Tensor tensor, + (image->UsesGpu() ? gpu_converter_ : cpu_converter_) + ->Convert(*image, roi, {output_width_, output_height_}, + range_min_, range_max_)); auto result = std::make_unique>(); result->push_back(std::move(tensor)); kOutTensors(cc).Send(std::move(result)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -261,7 +253,62 @@ class ImageToTensorCalculator : public Node { } } - std::unique_ptr converter_; + absl::StatusOr> GetInputImage( + CalculatorContext* cc) { + if (kIn(cc).IsConnected()) { + const auto& packet = kIn(cc).packet(); + return kIn(cc).Visit( + [&packet](const mediapipe::Image&) { + return SharedPtrWithPacket(packet); + }, + [&packet](const mediapipe::ImageFrame&) { + return std::make_shared( + std::const_pointer_cast( + SharedPtrWithPacket(packet))); + }); + } else { // if (kInGpu(cc).IsConnected()) +#if !MEDIAPIPE_DISABLE_GPU + const GpuBuffer& input = *kInGpu(cc); + // A shallow copy is okay since the resulting 'image' object is local in + // Process(), and thus never outlives 'input'. + return std::make_shared(input); +#else + return absl::UnimplementedError( + "GPU processing is disabled in build flags"); +#endif // !MEDIAPIPE_DISABLE_GPU + } + } + + absl::Status InitConverterIfNecessary(CalculatorContext* cc, bool use_gpu) { + // Lazy initialization of the GPU or CPU converter. + if (use_gpu) { + if (!gpu_converter_) { +#if !MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_METAL_ENABLED + ASSIGN_OR_RETURN(gpu_converter_, + CreateMetalConverter(cc, GetBorderMode())); +#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + ASSIGN_OR_RETURN(gpu_converter_, + CreateImageToGlBufferTensorConverter( + cc, DoesInputStartAtBottom(), GetBorderMode())); +#else + ASSIGN_OR_RETURN(gpu_converter_, + CreateImageToGlTextureTensorConverter( + cc, DoesInputStartAtBottom(), GetBorderMode())); +#endif // MEDIAPIPE_METAL_ENABLED +#endif // !MEDIAPIPE_DISABLE_GPU + } + } else { + if (!cpu_converter_) { + ASSIGN_OR_RETURN(cpu_converter_, + CreateOpenCvConverter(cc, GetBorderMode())); + } + } + return absl::OkStatus(); + } + + std::unique_ptr gpu_converter_; + std::unique_ptr cpu_converter_; mediapipe::ImageToTensorCalculatorOptions options_; int output_width_ = 0; int output_height_ = 0; diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index c11b61c51..233424720 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -22,11 +22,13 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/opencv_core_inc.h" @@ -54,10 +56,12 @@ cv::Mat GetRgba(absl::string_view path) { // Image to tensor test template. // No processing/assertions should be done after the function is invoked. -void RunTest(cv::Mat input, cv::Mat expected_result, float range_min, - float range_max, int tensor_width, int tensor_height, - bool keep_aspect, absl::optional border_mode, - const mediapipe::NormalizedRect& roi) { +void RunTestWithInputImagePacket(const Packet& input_image_packet, + cv::Mat expected_result, float range_min, + float range_max, int tensor_width, + int tensor_height, bool keep_aspect, + absl::optional border_mode, + const mediapipe::NormalizedRect& roi) { std::string border_mode_str; if (border_mode) { switch (*border_mode) { @@ -107,12 +111,8 @@ void RunTest(cv::Mat input, cv::Mat expected_result, float range_min, MP_ASSERT_OK(graph.Initialize(graph_config)); MP_ASSERT_OK(graph.StartRun({})); - ImageFrame input_image( - input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, - input.cols, input.rows, input.step, input.data, [](uint8*) {}); - MP_ASSERT_OK(graph.AddPacketToInputStream( - "input_image", - MakePacket(std::move(input_image)).At(Timestamp(0)))); + MP_ASSERT_OK(graph.AddPacketToInputStream("input_image", input_image_packet)); + MP_ASSERT_OK(graph.AddPacketToInputStream( "roi", MakePacket(std::move(roi)).At(Timestamp(0)))); @@ -133,8 +133,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, float range_min, const_cast(view.buffer())); cv::Mat result_rgb; auto transformation = - GetValueRangeTransformation(range_min, range_max, 0.0f, 255.0f) - .ValueOrDie(); + GetValueRangeTransformation(range_min, range_max, 0.0f, 255.0f).value(); tensor_mat.convertTo(result_rgb, CV_8UC3, transformation.scale, transformation.offset); @@ -152,6 +151,38 @@ void RunTest(cv::Mat input, cv::Mat expected_result, float range_min, MP_ASSERT_OK(graph.WaitUntilDone()); } +Packet MakeImageFramePacket(cv::Mat input) { + ImageFrame input_image( + input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, + input.cols, input.rows, input.step, input.data, [](uint8*) {}); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +Packet MakeImagePacket(cv::Mat input) { + mediapipe::Image input_image(std::make_shared( + input.channels() == 4 ? ImageFormat::SRGBA : ImageFormat::SRGB, + input.cols, input.rows, input.step, input.data, [](uint8*) {})); + return MakePacket(std::move(input_image)).At(Timestamp(0)); +} + +enum class InputType { kImageFrame, kImage }; + +const std::vector kInputTypesToTest = {InputType::kImageFrame, + InputType::kImage}; + +void RunTest(cv::Mat input, cv::Mat expected_result, float range_min, + float range_max, int tensor_width, int tensor_height, + bool keep_aspect, absl::optional border_mode, + const mediapipe::NormalizedRect& roi) { + for (auto input_type : kInputTypesToTest) { + RunTestWithInputImagePacket( + input_type == InputType::kImageFrame ? MakeImageFramePacket(input) + : MakeImagePacket(input), + expected_result, range_min, range_max, tensor_width, tensor_height, + keep_aspect, border_mode, roi); + } +} + TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) { mediapipe::NormalizedRect roi; roi.set_x_center(0.65f); diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter.h b/mediapipe/calculators/tensor/image_to_tensor_converter.h index ef4cac9d1..39fd1ee0d 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter.h +++ b/mediapipe/calculators/tensor/image_to_tensor_converter.h @@ -16,8 +16,8 @@ #define MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_CONVERTER_H_ #include "mediapipe/calculators/tensor/image_to_tensor_utils.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/statusor.h" namespace mediapipe { @@ -38,20 +38,17 @@ class ImageToTensorConverter { public: virtual ~ImageToTensorConverter() = default; - virtual Size GetImageSize(const Packet& image_packet) = 0; - // Converts image to tensor. - // @image_packet contains image to extract from. + // @image contains image to extract from. // @roi describes region of interest within the image to extract (absolute // values). // @output_dims dimensions of output tensor. // @range_min/max describes output tensor range image pixels should converted // to. - virtual mediapipe::StatusOr Convert(const Packet& image_packet, - const RotatedRect& roi, - const Size& output_dims, - float range_min, - float range_max) = 0; + virtual absl::StatusOr Convert(const mediapipe::Image& input, + const RotatedRect& roi, + const Size& output_dims, + float range_min, float range_max) = 0; }; } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc index b8633fc5d..c6c9a19f4 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.cc @@ -27,13 +27,13 @@ #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h" #include "mediapipe/calculators/tensor/image_to_tensor_utils.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "mediapipe/gpu/gl_calculator_helper.h" -#include "mediapipe/gpu/gpu_buffer.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/command_queue.h" @@ -54,7 +54,7 @@ class SubRectExtractorGl { public: // Extracts a region defined by @sub_rect, removes A channel, transforms input // pixels as alpha * x + beta and resizes result into destination. - mediapipe::Status ExtractSubRectToBuffer( + absl::Status ExtractSubRectToBuffer( const tflite::gpu::gl::GlTexture& texture, const tflite::gpu::HW& texture_size, const RotatedRect& sub_rect, bool flip_horizontaly, float alpha, float beta, @@ -62,7 +62,7 @@ class SubRectExtractorGl { tflite::gpu::gl::CommandQueue* command_queue, tflite::gpu::gl::GlBuffer* destination); - static mediapipe::StatusOr Create( + static absl::StatusOr Create( const mediapipe::GlContext& gl_context, bool input_starts_at_bottom, BorderMode border_mode); @@ -82,8 +82,8 @@ class SubRectExtractorGl { BorderMode border_mode_ = BorderMode::kReplicate; }; -mediapipe::Status SetMat4x4(const tflite::gpu::gl::GlProgram& program, - const std::string& name, float* data) { +absl::Status SetMat4x4(const tflite::gpu::gl::GlProgram& program, + const std::string& name, float* data) { GLint uniform_id; MP_RETURN_IF_ERROR(TFLITE_GPU_CALL_GL(glGetUniformLocation, &uniform_id, program.id(), name.c_str())); @@ -151,7 +151,7 @@ void main() { } )"; -mediapipe::Status SubRectExtractorGl::ExtractSubRectToBuffer( +absl::Status SubRectExtractorGl::ExtractSubRectToBuffer( const tflite::gpu::gl::GlTexture& texture, const tflite::gpu::HW& texture_size, const RotatedRect& texture_sub_rect, bool flip_horizontaly, float alpha, float beta, @@ -205,10 +205,10 @@ mediapipe::Status SubRectExtractorGl::ExtractSubRectToBuffer( glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::StatusOr SubRectExtractorGl::Create( +absl::StatusOr SubRectExtractorGl::Create( const mediapipe::GlContext& gl_context, bool input_starts_at_bottom, BorderMode border_mode) { bool use_custom_zero_border = border_mode == BorderMode::kZero && @@ -244,11 +244,11 @@ mediapipe::StatusOr SubRectExtractorGl::Create( class GlProcessor : public ImageToTensorConverter { public: - mediapipe::Status Init(CalculatorContext* cc, bool input_starts_at_bottom, - BorderMode border_mode) { + absl::Status Init(CalculatorContext* cc, bool input_starts_at_bottom, + BorderMode border_mode) { MP_RETURN_IF_ERROR(gl_helper_.Open(cc)); return gl_helper_.RunInGlContext([this, input_starts_at_bottom, - border_mode]() -> mediapipe::Status { + border_mode]() -> absl::Status { tflite::gpu::GpuInfo gpu_info; MP_RETURN_IF_ERROR(tflite::gpu::gl::RequestGpuInfo(&gpu_info)); RET_CHECK(gpu_info.IsApiOpenGl31OrAbove()) @@ -260,20 +260,14 @@ class GlProcessor : public ImageToTensorConverter { SubRectExtractorGl::Create(gl_helper_.GetGlContext(), input_starts_at_bottom, border_mode)); extractor_ = absl::make_unique(std::move(extractor)); - return mediapipe::OkStatus(); + return absl::OkStatus(); }); } - Size GetImageSize(const Packet& image_packet) override { - const auto& image = image_packet.Get(); - return {image.width(), image.height()}; - } - - mediapipe::StatusOr Convert(const Packet& image_packet, - const RotatedRect& roi, - const Size& output_dims, float range_min, - float range_max) override { - const auto& input = image_packet.Get(); + absl::StatusOr Convert(const mediapipe::Image& input, + const RotatedRect& roi, + const Size& output_dims, float range_min, + float range_max) override { if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { return InvalidArgumentError( absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ", @@ -284,40 +278,39 @@ class GlProcessor : public ImageToTensorConverter { Tensor tensor(Tensor::ElementType::kFloat32, {1, output_dims.height, output_dims.width, kNumChannels}); - MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext( - [this, &tensor, &input, &roi, &output_dims, range_min, - range_max]() -> mediapipe::Status { - constexpr int kRgbaNumChannels = 4; - auto source_texture = gl_helper_.CreateSourceTexture(input); - tflite::gpu::gl::GlTexture input_texture( - GL_TEXTURE_2D, source_texture.name(), GL_RGBA, - source_texture.width() * source_texture.height() * - kRgbaNumChannels * sizeof(uint8_t), - /*layer=*/0, - /*owned=*/false); + MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext([this, &tensor, &input, &roi, + &output_dims, range_min, + range_max]() -> absl::Status { + constexpr int kRgbaNumChannels = 4; + auto source_texture = gl_helper_.CreateSourceTexture(input); + tflite::gpu::gl::GlTexture input_texture( + GL_TEXTURE_2D, source_texture.name(), GL_RGBA, + source_texture.width() * source_texture.height() * kRgbaNumChannels * + sizeof(uint8_t), + /*layer=*/0, + /*owned=*/false); - constexpr float kInputImageRangeMin = 0.0f; - constexpr float kInputImageRangeMax = 1.0f; - ASSIGN_OR_RETURN(auto transform, - GetValueRangeTransformation(kInputImageRangeMin, - kInputImageRangeMax, - range_min, range_max)); + constexpr float kInputImageRangeMin = 0.0f; + constexpr float kInputImageRangeMax = 1.0f; + ASSIGN_OR_RETURN( + auto transform, + GetValueRangeTransformation(kInputImageRangeMin, kInputImageRangeMax, + range_min, range_max)); - auto buffer_view = tensor.GetOpenGlBufferWriteView(); - tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER, - buffer_view.name(), tensor.bytes(), - /*offset=*/0, - /*has_ownership=*/false); - MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer( - input_texture, - tflite::gpu::HW(source_texture.height(), source_texture.width()), - roi, - /*flip_horizontaly=*/false, transform.scale, transform.offset, - tflite::gpu::HW(output_dims.height, output_dims.width), - command_queue_.get(), &output)); + auto buffer_view = tensor.GetOpenGlBufferWriteView(); + tflite::gpu::gl::GlBuffer output(GL_SHADER_STORAGE_BUFFER, + buffer_view.name(), tensor.bytes(), + /*offset=*/0, + /*has_ownership=*/false); + MP_RETURN_IF_ERROR(extractor_->ExtractSubRectToBuffer( + input_texture, + tflite::gpu::HW(source_texture.height(), source_texture.width()), roi, + /*flip_horizontaly=*/false, transform.scale, transform.offset, + tflite::gpu::HW(output_dims.height, output_dims.width), + command_queue_.get(), &output)); - return mediapipe::OkStatus(); - })); + return absl::OkStatus(); + })); return tensor; } @@ -338,7 +331,7 @@ class GlProcessor : public ImageToTensorConverter { } // namespace -mediapipe::StatusOr> +absl::StatusOr> CreateImageToGlBufferTensorConverter(CalculatorContext* cc, bool input_starts_at_bottom, BorderMode border_mode) { diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h index da167b5c4..437b16b70 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h @@ -30,7 +30,7 @@ namespace mediapipe { // Creates image to tensor (represented as OpenGL buffer) converter. // NOTE: mediapipe::GlCalculatorHelper::UpdateContract invocation must precede // converter creation. -mediapipe::StatusOr> +absl::StatusOr> CreateImageToGlBufferTensorConverter(CalculatorContext* cc, bool input_starts_at_bottom, BorderMode border_mode); diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc index 3bd99ea77..26c31eaf5 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.cc @@ -27,6 +27,7 @@ #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils.h" #include "mediapipe/calculators/tensor/image_to_tensor_utils.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" @@ -34,7 +35,6 @@ #include "mediapipe/framework/port/statusor.h" #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" -#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" namespace mediapipe { @@ -47,11 +47,11 @@ constexpr int kNumAttributes = 2; class GlProcessor : public ImageToTensorConverter { public: - mediapipe::Status Init(CalculatorContext* cc, bool input_starts_at_bottom, - BorderMode border_mode) { + absl::Status Init(CalculatorContext* cc, bool input_starts_at_bottom, + BorderMode border_mode) { MP_RETURN_IF_ERROR(gl_helper_.Open(cc)); return gl_helper_.RunInGlContext([this, input_starts_at_bottom, - border_mode]() -> mediapipe::Status { + border_mode]() -> absl::Status { use_custom_zero_border_ = border_mode == BorderMode::kZero && !IsGlClampToBorderSupported(gl_helper_.GetGlContext()); @@ -164,20 +164,14 @@ class GlProcessor : public ImageToTensorConverter { glBindBuffer(GL_ARRAY_BUFFER, 0); - return mediapipe::OkStatus(); + return absl::OkStatus(); }); } - Size GetImageSize(const Packet& image_packet) override { - const auto& image = image_packet.Get(); - return {image.width(), image.height()}; - } - - mediapipe::StatusOr Convert(const Packet& image_packet, - const RotatedRect& roi, - const Size& output_dims, float range_min, - float range_max) override { - const auto& input = image_packet.Get(); + absl::StatusOr Convert(const mediapipe::Image& input, + const RotatedRect& roi, + const Size& output_dims, float range_min, + float range_max) override { if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { return InvalidArgumentError( absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ", @@ -189,9 +183,9 @@ class GlProcessor : public ImageToTensorConverter { Tensor::ElementType::kFloat32, Tensor::Shape{1, output_dims.height, output_dims.width, kNumChannels}); - MP_RETURN_IF_ERROR(gl_helper_.RunInGlContext( - [this, &tensor, &input, &roi, &output_dims, range_min, - range_max]() -> mediapipe::Status { + MP_RETURN_IF_ERROR( + gl_helper_.RunInGlContext([this, &tensor, &input, &roi, &output_dims, + range_min, range_max]() -> absl::Status { auto input_texture = gl_helper_.CreateSourceTexture(input); constexpr float kInputImageRangeMin = 0.0f; @@ -205,21 +199,18 @@ class GlProcessor : public ImageToTensorConverter { /*flip_horizontaly=*/false, transform.scale, transform.offset, output_dims, &tensor_view)); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); return tensor; } - mediapipe::Status ExtractSubRect(const mediapipe::GlTexture& texture, - const RotatedRect& sub_rect, - bool flip_horizontaly, float alpha, - float beta, const Size& output_dims, - Tensor::OpenGlTexture2dView* output) { + absl::Status ExtractSubRect(const mediapipe::GlTexture& texture, + const RotatedRect& sub_rect, + bool flip_horizontaly, float alpha, float beta, + const Size& output_dims, + Tensor::OpenGlTexture2dView* output) { std::array transform_mat; - GetRotatedSubRectToRectTransformMatrix(sub_rect, texture.width(), - texture.height(), flip_horizontaly, - &transform_mat); glDisable(GL_DEPTH_TEST); glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_); @@ -258,7 +249,24 @@ class GlProcessor : public ImageToTensorConverter { glUseProgram(program_); glUniform1f(alpha_id_, alpha); glUniform1f(beta_id_, beta); - glUniformMatrix4fv(matrix_id_, 1, GL_TRUE, transform_mat.data()); + + // If our context is ES2, then we must use GL_FALSE for our 'transpose' + // GLboolean in glUniformMatrix4fv, or else we'll get an INVALID_VALUE + // error. So in that case, we'll grab the transpose of our original matrix + // and send that instead. + const auto gl_context = mediapipe::GlContext::GetCurrent(); + LOG_IF(FATAL, !gl_context) << "GlContext is not bound to the thread."; + if (gl_context->GetGlVersion() == mediapipe::GlVersion::kGLES2) { + GetTransposedRotatedSubRectToRectTransformMatrix( + sub_rect, texture.width(), texture.height(), flip_horizontaly, + &transform_mat); + glUniformMatrix4fv(matrix_id_, 1, GL_FALSE, transform_mat.data()); + } else { + GetRotatedSubRectToRectTransformMatrix(sub_rect, texture.width(), + texture.height(), flip_horizontaly, + &transform_mat); + glUniformMatrix4fv(matrix_id_, 1, GL_TRUE, transform_mat.data()); + } // vao glBindVertexArray(vao_); @@ -292,7 +300,7 @@ class GlProcessor : public ImageToTensorConverter { glActiveTexture(GL_TEXTURE0); glBindTexture(GL_TEXTURE_2D, 0); - return mediapipe::OkStatus(); + return absl::OkStatus(); } ~GlProcessor() override { @@ -320,7 +328,7 @@ class GlProcessor : public ImageToTensorConverter { } // namespace -mediapipe::StatusOr> +absl::StatusOr> CreateImageToGlTextureTensorConverter(CalculatorContext* cc, bool input_starts_at_bottom, BorderMode border_mode) { diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h index 8802f7602..269abf141 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h @@ -30,7 +30,7 @@ namespace mediapipe { // Creates image to tensor (represented as OpenGL texture) converter. // NOTE: mediapipe::GlCalculatorHelper::UpdateContract invocation must precede // converter creation. -mediapipe::StatusOr> +absl::StatusOr> CreateImageToGlTextureTensorConverter(CalculatorContext* cc, bool input_starts_at_bottom, BorderMode border_mode); diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils_test.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils_test.cc index 4c8dc3d6d..9482cfc2a 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_gl_utils_test.cc @@ -14,7 +14,7 @@ namespace { TEST(ImageToTensorConverterGlUtilsTest, GlTexParameteriOverrider) { auto status_or_context = mediapipe::GlContext::Create(nullptr, false); MP_ASSERT_OK(status_or_context); - auto context = status_or_context.ValueOrDie(); + auto context = status_or_context.value(); std::vector min_filter_changes; context->Run([&min_filter_changes]() { diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc index 01546253f..565dd85b9 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.cc @@ -26,13 +26,13 @@ #include "mediapipe/calculators/tensor/image_to_tensor_converter.h" #include "mediapipe/calculators/tensor/image_to_tensor_utils.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "mediapipe/gpu/MPPMetalHelper.h" -#include "mediapipe/gpu/gpu_buffer.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/types.h" @@ -146,7 +146,7 @@ int GetBytesPerRaw(OutputFormat output_format, const tflite::gpu::HW& size) { class SubRectExtractorMetal { public: - static mediapipe::StatusOr> Make( + static absl::StatusOr> Make( id device, OutputFormat output_format, BorderMode border_mode) { id pipeline_state; @@ -174,12 +174,12 @@ class SubRectExtractorMetal { options:MTLResourceOptionCPUCacheModeDefault]; } - mediapipe::Status Execute(id input_texture, - const RotatedRect& sub_rect, bool flip_horizontaly, - float alpha, float beta, - const tflite::gpu::HW& destination_size, - id command_buffer, - id destination) { + absl::Status Execute(id input_texture, + const RotatedRect& sub_rect, bool flip_horizontaly, + float alpha, float beta, + const tflite::gpu::HW& destination_size, + id command_buffer, + id destination) { auto output_texture = MTLTextureWithBuffer(destination_size, destination); return InternalExecute(input_texture, sub_rect, flip_horizontaly, alpha, beta, destination_size, command_buffer, @@ -205,13 +205,12 @@ class SubRectExtractorMetal { return texture; } - mediapipe::Status InternalExecute(id input_texture, - const RotatedRect& sub_rect, - bool flip_horizontaly, float alpha, - float beta, - const tflite::gpu::HW& destination_size, - id command_buffer, - id output_texture) { + absl::Status InternalExecute(id input_texture, + const RotatedRect& sub_rect, + bool flip_horizontaly, float alpha, float beta, + const tflite::gpu::HW& destination_size, + id command_buffer, + id output_texture) { RET_CHECK(command_buffer != nil); RET_CHECK(output_texture != nil); @@ -254,10 +253,10 @@ class SubRectExtractorMetal { vertexCount:6]; [command_encoder endEncoding]; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status MakePipelineState( + static absl::Status MakePipelineState( id device, OutputFormat output_format, BorderMode border_mode, id* pipeline_state) { RET_CHECK(pipeline_state != nil); @@ -328,7 +327,7 @@ class SubRectExtractorMetal { RET_CHECK(error == nil) << "Couldn't create a pipeline state" << [[error localizedDescription] UTF8String]; - return mediapipe::OkStatus(); + return absl::OkStatus(); } id positions_buffer_; @@ -340,25 +339,19 @@ class SubRectExtractorMetal { class MetalProcessor : public ImageToTensorConverter { public: - mediapipe::Status Init(CalculatorContext* cc, BorderMode border_mode) { + absl::Status Init(CalculatorContext* cc, BorderMode border_mode) { metal_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(metal_helper_); ASSIGN_OR_RETURN(extractor_, SubRectExtractorMetal::Make( metal_helper_.mtlDevice, OutputFormat::kF32C4, border_mode)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - Size GetImageSize(const Packet& image_packet) override { - const auto& image = image_packet.Get(); - return {image.width(), image.height()}; - } - - mediapipe::StatusOr Convert(const Packet& image_packet, - const RotatedRect& roi, - const Size& output_dims, float range_min, - float range_max) override { - const auto& input = image_packet.Get(); + absl::StatusOr Convert(const mediapipe::Image& input, + const RotatedRect& roi, + const Size& output_dims, float range_min, + float range_max) override { if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { return InvalidArgumentError( absl::StrCat("Only BGRA/RGBA textures are supported, passed " @@ -367,7 +360,8 @@ class MetalProcessor : public ImageToTensorConverter { } @autoreleasepool { - id texture = [metal_helper_ metalTextureWithGpuBuffer:input]; + id texture = + [metal_helper_ metalTextureWithGpuBuffer:input.GetGpuBuffer()]; constexpr int kNumChannels = 4; Tensor tensor(Tensor::ElementType::kFloat32, @@ -400,8 +394,8 @@ class MetalProcessor : public ImageToTensorConverter { } // namespace -mediapipe::StatusOr> -CreateMetalConverter(CalculatorContext* cc, BorderMode border_mode) { +absl::StatusOr> CreateMetalConverter( + CalculatorContext* cc, BorderMode border_mode) { auto result = absl::make_unique(); MP_RETURN_IF_ERROR(result->Init(cc, border_mode)); diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.h b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.h index fe46c67b4..0fe5a87d0 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_metal.h +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_metal.h @@ -30,8 +30,8 @@ namespace mediapipe { // Creates Metal image-to-tensor converter. // NOTE: [MPPMetalHelper updateContract:...] invocation must precede // converter creation. -mediapipe::StatusOr> -CreateMetalConverter(CalculatorContext* cc, BorderMode border_mode); +absl::StatusOr> CreateMetalConverter( + CalculatorContext* cc, BorderMode border_mode); } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc index ed109d4ef..b8d1b0a8b 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.cc @@ -20,9 +20,9 @@ #include "mediapipe/calculators/tensor/image_to_tensor_converter.h" #include "mediapipe/calculators/tensor/image_to_tensor_utils.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_format.pb.h" -#include "mediapipe/framework/formats/image_frame.h" -#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/image_opencv.h" #include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/opencv_core_inc.h" @@ -46,21 +46,15 @@ class OpenCvProcessor : public ImageToTensorConverter { } } - Size GetImageSize(const Packet& image_packet) override { - const auto& image = image_packet.Get(); - return {image.Width(), image.Height()}; - } - - mediapipe::StatusOr Convert(const Packet& image_packet, - const RotatedRect& roi, - const Size& output_dims, float range_min, - float range_max) override { - const auto& input = image_packet.Get(); - if (input.Format() != mediapipe::ImageFormat::SRGB && - input.Format() != mediapipe::ImageFormat::SRGBA) { + absl::StatusOr Convert(const mediapipe::Image& input, + const RotatedRect& roi, + const Size& output_dims, float range_min, + float range_max) override { + if (input.image_format() != mediapipe::ImageFormat::SRGB && + input.image_format() != mediapipe::ImageFormat::SRGBA) { return InvalidArgumentError( absl::StrCat("Only RGBA/RGB formats are supported, passed format: ", - static_cast(input.Format()))); + static_cast(input.image_format()))); } cv::Mat src = mediapipe::formats::MatView(&input); @@ -118,8 +112,8 @@ class OpenCvProcessor : public ImageToTensorConverter { } // namespace -mediapipe::StatusOr> -CreateOpenCvConverter(CalculatorContext* cc, BorderMode border_mode) { +absl::StatusOr> CreateOpenCvConverter( + CalculatorContext* cc, BorderMode border_mode) { // Simply "return absl::make_unique()" failed to build on // macOS with bazel. return std::unique_ptr( diff --git a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h index a10bffaf1..3ccecc557 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h +++ b/mediapipe/calculators/tensor/image_to_tensor_converter_opencv.h @@ -24,8 +24,8 @@ namespace mediapipe { // Creates OpenCV image-to-tensor converter. -mediapipe::StatusOr> -CreateOpenCvConverter(CalculatorContext* cc, BorderMode border_mode); +absl::StatusOr> CreateOpenCvConverter( + CalculatorContext* cc, BorderMode border_mode); } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.cc b/mediapipe/calculators/tensor/image_to_tensor_utils.cc index dc5946760..6b3bf08cd 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.cc @@ -38,10 +38,10 @@ RotatedRect GetRoi(int input_width, int input_height, /*rotation =*/0}; } -mediapipe::StatusOr> PadRoi(int input_tensor_width, - int input_tensor_height, - bool keep_aspect_ratio, - RotatedRect* roi) { +absl::StatusOr> PadRoi(int input_tensor_width, + int input_tensor_height, + bool keep_aspect_ratio, + RotatedRect* roi) { if (!keep_aspect_ratio) { return std::array{0.0f, 0.0f, 0.0f, 0.0f}; } @@ -76,7 +76,7 @@ mediapipe::StatusOr> PadRoi(int input_tensor_width, horizontal_padding, vertical_padding}; } -mediapipe::StatusOr GetValueRangeTransformation( +absl::StatusOr GetValueRangeTransformation( float from_range_min, float from_range_max, float to_range_min, float to_range_max) { RET_CHECK_LT(from_range_min, from_range_max) @@ -173,4 +173,45 @@ void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect, matrix[15] = 1.0f; } +void GetTransposedRotatedSubRectToRectTransformMatrix( + const RotatedRect& sub_rect, int rect_width, int rect_height, + bool flip_horizontaly, std::array* matrix_ptr) { + std::array& matrix = *matrix_ptr; + // See comments in GetRotatedSubRectToRectTransformMatrix for detailed + // calculations. + const float a = sub_rect.width; + const float b = sub_rect.height; + const float flip = flip_horizontaly ? -1 : 1; + const float c = std::cos(sub_rect.rotation); + const float d = std::sin(sub_rect.rotation); + const float e = sub_rect.center_x; + const float f = sub_rect.center_y; + const float g = 1.0f / rect_width; + const float h = 1.0f / rect_height; + + // row 1 (indices 0,4,8,12 from non-transposed fcn) + matrix[0] = a * c * flip * g; + matrix[1] = a * d * flip * h; + matrix[2] = 0.0f; + matrix[3] = 0.0f; + + // row 2 (indices 1,5,9,13 from non-transposed fcn) + matrix[4] = -b * d * g; + matrix[5] = b * c * h; + matrix[6] = 0.0f; + matrix[7] = 0.0f; + + // row 3 (indices 2,6,10,14 from non-transposed fcn) + matrix[8] = 0.0f; + matrix[9] = 0.0f; + matrix[10] = a * g; + matrix[11] = 0.0f; + + // row 4 (indices 3,7,11,15 from non-transposed fcn) + matrix[12] = (-0.5f * a * c * flip + 0.5f * b * d + e) * g; + matrix[13] = (-0.5f * b * c - 0.5f * a * d * flip + f) * h; + matrix[14] = 0.0f; + matrix[15] = 1.0f; +} + } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils.h b/mediapipe/calculators/tensor/image_to_tensor_utils.h index 44ba28902..f913875e3 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils.h +++ b/mediapipe/calculators/tensor/image_to_tensor_utils.h @@ -37,10 +37,10 @@ RotatedRect GetRoi(int input_width, int input_height, // Pads ROI, so extraction happens correctly if aspect ratio is to be kept. // Returns letterbox padding applied. -mediapipe::StatusOr> PadRoi(int input_tensor_width, - int input_tensor_height, - bool keep_aspect_ratio, - RotatedRect* roi); +absl::StatusOr> PadRoi(int input_tensor_width, + int input_tensor_height, + bool keep_aspect_ratio, + RotatedRect* roi); // Represents a transformation of value which involves scaling and offsetting. // To apply transformation: @@ -55,7 +55,7 @@ struct ValueTransformation { // [from_range_min, from_range_max] into [to_range_min, to_range_max] range. // from_range_min must be less than from_range_max // to_range_min must be less than to_range_max -mediapipe::StatusOr GetValueRangeTransformation( +absl::StatusOr GetValueRangeTransformation( float from_range_min, float from_range_max, float to_range_min, float to_range_max); @@ -77,6 +77,24 @@ void GetRotatedSubRectToRectTransformMatrix(const RotatedRect& sub_rect, bool flip_horizontaly, std::array* matrix); +// Returns the transpose of the matrix found with +// "GetRotatedSubRectToRectTransformMatrix". That is to say, this populates a +// 4x4 "matrix" with col major order transformation matrix which maps (x, y) in +// range [0, 1] (describing points of @sub_rect) to (x', y') in range [0, 1]*** +// (describing points of a rect: [0, @rect_width] x [0, @rect_height] = RECT). +// +// *** (x', y') will go out of the range for points from @sub_rect +// which are not contained by RECT and it's expected behavior +// +// @sub_rect - rotated sub rect in absolute coordinates +// @rect_width - rect width +// @rect_height - rect height +// @flip_horizontaly - we need to flip the output buffer. +// @matrix - 4x4 matrix (array of 16 elements) to populate +void GetTransposedRotatedSubRectToRectTransformMatrix( + const RotatedRect& sub_rect, int rect_width, int rect_height, + bool flip_horizontaly, std::array* matrix); + } // namespace mediapipe #endif // MEDIAPIPE_CALCULATORS_TENSOR_IMAGE_TO_TENSOR_UTILS_H_ diff --git a/mediapipe/calculators/tensor/image_to_tensor_utils_test.cc b/mediapipe/calculators/tensor/image_to_tensor_utils_test.cc index e9baecc20..814b4c34f 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_utils_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_utils_test.cc @@ -70,7 +70,7 @@ TEST(PadRoi, NoPadding) { .rotation = 5}; auto status_or_value = PadRoi(10, 10, /*keep_aspect_ratio=*/false, &roi); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), + EXPECT_THAT(status_or_value.value(), ElementsAreArray({0.0f, 0.0f, 0.0f, 0.0f})); EXPECT_THAT(roi, EqRotatedRect(100, 200, 20, 10, 5)); } @@ -83,7 +83,7 @@ TEST(PadRoi, HorizontalPadding) { .rotation = 5}; auto status_or_value = PadRoi(10, 10, /*keep_aspect_ratio=*/true, &roi); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), + EXPECT_THAT(status_or_value.value(), ElementsAreArray({0.25f, 0.0f, 0.25f, 0.0f})); EXPECT_THAT(roi, EqRotatedRect(200, 200, 20, 10, 5)); } @@ -95,7 +95,7 @@ TEST(PadRoi, VerticalPadding) { auto status_or_value = PadRoi(10, 10, /*keep_aspect_ratio=*/true, &roi); MP_ASSERT_OK(status_or_value); EXPECT_THAT( - status_or_value.ValueOrDie(), + status_or_value.value(), ElementsAre(testing::FloatEq(0.0f), testing::FloatNear(expected_horizontal_padding, 1e-6), testing::FloatEq(0.0f), @@ -115,7 +115,7 @@ TEST(GetValueRangeTransformation, PixelToFloatZeroCenter) { /*from_range_min=*/0.0f, /*from_range_max=*/255.0f, /*to_range_min=*/-1.0f, /*to_range_max=*/1.0f); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), + EXPECT_THAT(status_or_value.value(), EqValueTransformation(/*scale=*/2 / 255.0f, /*offset=*/-1.0f)); } @@ -125,7 +125,7 @@ TEST(GetValueRangeTransformation, PixelToFloat) { /*from_range_min=*/0.0f, /*from_range_max=*/255.0f, /*to_range_min=*/0.0f, /*to_range_max=*/1.0f); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), + EXPECT_THAT(status_or_value.value(), EqValueTransformation(/*scale=*/1 / 255.0f, /*offset=*/0.0f)); } @@ -135,7 +135,7 @@ TEST(GetValueRangeTransformation, FloatToFloatNoOp) { /*from_range_min=*/0.0f, /*from_range_max=*/1.0f, /*to_range_min=*/0.0f, /*to_range_max=*/1.0f); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), + EXPECT_THAT(status_or_value.value(), EqValueTransformation(/*scale=*/1.0f, /*offset=*/0.0f)); } @@ -144,7 +144,7 @@ TEST(GetValueRangeTransformation, PixelToPixelNoOp) { /*from_range_min=*/0.0f, /*from_range_max=*/255.0f, /*to_range_min=*/0.0f, /*to_range_max=*/255.0f); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), + EXPECT_THAT(status_or_value.value(), EqValueTransformation(/*scale=*/1.0f, /*offset=*/0.0f)); } @@ -153,7 +153,7 @@ TEST(GetValueRangeTransformation, FloatToPixel) { /*from_range_min=*/0.0f, /*from_range_max=*/1.0f, /*to_range_min=*/0.0f, /*to_range_max=*/255.0f); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), + EXPECT_THAT(status_or_value.value(), EqValueTransformation(/*scale=*/255.0f, /*offset=*/0.0f)); } diff --git a/mediapipe/calculators/tensor/inference_calculator.cc b/mediapipe/calculators/tensor/inference_calculator.cc index 2c4e33fb3..89a02b713 100644 --- a/mediapipe/calculators/tensor/inference_calculator.cc +++ b/mediapipe/calculators/tensor/inference_calculator.cc @@ -12,822 +12,59 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "mediapipe/calculators/tensor/inference_calculator.h" + #include #include #include #include #include "absl/memory/memory.h" -#include "mediapipe/calculators/tensor/inference_calculator.pb.h" -#include "mediapipe/framework/api2/node.h" -#include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/formats/tensor.h" -#include "mediapipe/framework/port/ret_check.h" -#include "mediapipe/util/tflite/config.h" -#include "mediapipe/util/tflite/tflite_model_loader.h" - -#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) -#include "mediapipe/util/cpu_util.h" -#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ - -#include "tensorflow/lite/error_reporter.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model.h" - -#if defined(MEDIAPIPE_ANDROID) -#include "mediapipe/util/android/file/base/file.h" -#include "mediapipe/util/android/file/base/filesystem.h" -#include "mediapipe/util/android/file/base/helpers.h" -#endif // ANDROID - -#if MEDIAPIPE_TFLITE_GL_INFERENCE -#include "mediapipe/gpu/gl_calculator_helper.h" -#include "mediapipe/gpu/gpu_buffer.h" -#include "mediapipe/util/tflite/tflite_gpu_runner.h" -#include "tensorflow/lite/delegates/gpu/common/shape.h" -#include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE - -#if MEDIAPIPE_TFLITE_METAL_INFERENCE -#import -#import -#import - -#import "mediapipe/gpu/MPPMetalHelper.h" -#include "mediapipe/gpu/MPPMetalUtil.h" -#include "mediapipe/gpu/gpu_buffer.h" -#include "tensorflow/lite/delegates/gpu/common/shape.h" -#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h" -#include "tensorflow/lite/delegates/gpu/metal_delegate.h" -#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h" -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE - -#if !defined(MEDIAPIPE_EDGE_TPU) -#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" -#endif // !EDGETPU -#if defined(MEDIAPIPE_ANDROID) -#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" -#endif // ANDROID - -namespace { -// Commonly used to compute the number of blocks to launch in a kernel. -int NumGroups(const int size, const int group_size) { // NOLINT - return (size + group_size - 1) / group_size; -} - -// Round up n to next multiple of m. -template -T RoundUp(T n, T m) { - return ((n + m - T{1}) / m) * m; -} - -bool ShouldUseGpu(const mediapipe::InferenceCalculatorOptions& options) { - return ( - !options.has_delegate() || // Use GPU delegate if delegate not specified - (options.has_delegate() && options.delegate().has_gpu())); -} - -constexpr char kTensorsTag[] = "TENSORS"; - -#if defined(MEDIAPIPE_EDGE_TPU) -#include "edgetpu.h" - -// Creates and returns an Edge TPU interpreter to run the given edgetpu model. -std::unique_ptr BuildEdgeTpuInterpreter( - const tflite::FlatBufferModel& model, - tflite::ops::builtin::BuiltinOpResolver* resolver, - edgetpu::EdgeTpuContext* edgetpu_context) { - resolver->AddCustom(edgetpu::kCustomOp, edgetpu::RegisterCustomOp()); - std::unique_ptr interpreter; - if (tflite::InterpreterBuilder(model, *resolver)(&interpreter) != kTfLiteOk) { - std::cerr << "Failed to build edge TPU interpreter." << std::endl; - } - interpreter->SetExternalContext(kTfLiteEdgeTpuContext, edgetpu_context); - interpreter->SetNumThreads(1); - if (interpreter->AllocateTensors() != kTfLiteOk) { - std::cerr << "Failed to allocate edge TPU tensors." << std::endl; - } - return interpreter; -} -#endif // MEDIAPIPE_EDGE_TPU - -} // namespace +#include "absl/strings/string_view.h" +#include "mediapipe/framework/tool/subgraph_expansion.h" namespace mediapipe { namespace api2 { -#if MEDIAPIPE_TFLITE_METAL_INFERENCE -namespace { -tflite::gpu::BHWC BhwcFromTensorShape(const Tensor::Shape& shape) { - tflite::gpu::BHWC result; - result.b = shape.dims[0]; - switch (shape.dims.size()) { - case 1: - // result.b is already filled. - break; - case 2: - result.h = 1; - result.w = 1; - result.c = shape.dims[1]; - break; - case 3: - result.h = 1; - result.w = shape.dims[1]; - result.c = shape.dims[2]; - break; - case 4: - result.h = shape.dims[1]; - result.w = shape.dims[2]; - result.c = shape.dims[3]; - break; - default: - // Handles 0 and >4. - LOG(FATAL) - << "Dimensions size must be in range [1,4] for GPU inference, but " - << shape.dims.size() << " is provided"; - } - return result; -} -} // namespace -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE - -// Returns number of threads to configure XNNPACK delegate with. -// (Equal to user provided value if specified. Otherwise, it returns number of -// high cores (hard-coded to 1 for Emscripten without Threads extension)) -int GetXnnpackNumThreads(const mediapipe::InferenceCalculatorOptions& opts) { - static constexpr int kDefaultNumThreads = -1; - if (opts.has_delegate() && opts.delegate().has_xnnpack() && - opts.delegate().xnnpack().num_threads() != kDefaultNumThreads) { - return opts.delegate().xnnpack().num_threads(); - } -#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) - return InferHigherCoreIds().size(); -#else - return 1; -#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ -} - -// Calculator Header Section - -// Runs inference on the provided input Tensors and TFLite model. -// -// Creates an interpreter with given model and calls invoke(). -// Optionally run inference on CPU/GPU. -// -// This calculator can be used with TensorConverterCalculator to get the -// appropriate inputs. -// -// When the input tensors are on CPU, gpu inference is optional and can be -// specified in the calculator options. -// When the input tensors are on GPU, inference is GPU and output can be CPU or -// GPU. -// -// Input: -// TENSORS - Vector of Tensors -// -// Output: -// TENSORS - Vector of Tensors -// -// Input side packet: -// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver, -// instead of the builtin one. -// MODEL (optional) - Use to specify TfLite model -// (std::unique_ptr>) -// -// Example use: -// node { -// calculator: "InferenceCalculator" -// input_stream: "TENSORS:tensor_image" -// output_stream: "TENSORS:tensors" -// options: { -// [mediapipe.InferenceCalculatorOptions.ext] { -// model_path: "modelname.tflite" -// } -// } -// } -// -// or -// -// node { -// calculator: "InferenceCalculator" -// input_stream: "TENSORS:tensor_image" -// input_side_packet: "MODEL:model" -// output_stream: "TENSORS:tensors" -// options: { -// [mediapipe.InferenceCalculatorOptions.ext] { -// model_path: "modelname.tflite" -// delegate { gpu {} } -// } -// } -// } -// -// IMPORTANT Notes: -// Tensors are assumed to be ordered correctly (sequentially added to model). -// Input tensors are assumed to be of the correct size and already normalized. - -class InferenceCalculator : public Node { +class InferenceCalculatorSelectorImpl + : public SubgraphImpl { public: - using TfLiteDelegatePtr = - std::unique_ptr>; - - static constexpr Input> kInTensors{"TENSORS"}; - static constexpr SideInput::Optional - kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; - static constexpr SideInput::Optional kSideInModel{"MODEL"}; - static constexpr Output> kOutTensors{"TENSORS"}; - MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel, - kOutTensors); - static mediapipe::Status UpdateContract(CalculatorContract* cc); - - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; - - private: - mediapipe::Status ReadKernelsFromFile(); - mediapipe::Status WriteKernelsToFile(); - mediapipe::Status LoadModel(CalculatorContext* cc); - mediapipe::StatusOr GetModelAsPacket( - const CalculatorContext& cc); - mediapipe::Status LoadDelegate(CalculatorContext* cc); - mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc); - - mediapipe::Packet model_packet_; - std::unique_ptr interpreter_; - TfLiteDelegatePtr delegate_; - -#if MEDIAPIPE_TFLITE_GL_INFERENCE - mediapipe::GlCalculatorHelper gpu_helper_; - std::unique_ptr tflite_gpu_runner_; - bool allow_precision_loss_ = false; - mediapipe::InferenceCalculatorOptions::Delegate::Gpu::API - tflite_gpu_runner_api_; -#elif MEDIAPIPE_TFLITE_METAL_INFERENCE - MPPMetalHelper* gpu_helper_ = nullptr; - TFLBufferConvert* converter_to_BPHWC4_ = nil; - TFLBufferConvert* converter_from_BPHWC4_ = nil; -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE - -#if MEDIAPIPE_TFLITE_GPU_SUPPORTED - std::vector output_shapes_; - std::vector> gpu_buffers_in_; - std::vector> gpu_buffers_out_; -#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED - -#if defined(MEDIAPIPE_EDGE_TPU) - std::shared_ptr edgetpu_context_ = - edgetpu::EdgeTpuManager::GetSingleton()->OpenDevice(); -#endif - - bool use_advanced_gpu_api_ = false; - bool use_gpu_delegate_ = false; - - bool use_kernel_caching_ = false; - std::string cached_kernel_filename_; + absl::StatusOr GetConfig( + const CalculatorGraphConfig::Node& subgraph_node) { + const auto& options = + Subgraph::GetOptions<::mediapipe::InferenceCalculatorOptions>( + subgraph_node); + std::vector impls; + const bool should_use_gpu = + !options.has_delegate() || // Use GPU delegate if not specified + (options.has_delegate() && options.delegate().has_gpu()); + if (should_use_gpu) { + impls.emplace_back("Metal"); + impls.emplace_back("MlDrift"); + impls.emplace_back("Gl"); + } + impls.emplace_back("Cpu"); + for (const auto& suffix : impls) { + const auto impl = absl::StrCat("InferenceCalculator", suffix); + if (!mediapipe::CalculatorBaseRegistry::IsRegistered(impl)) continue; + CalculatorGraphConfig::Node impl_node = subgraph_node; + impl_node.set_calculator(impl); + return tool::MakeSingleNodeGraph(std::move(impl_node)); + } + return absl::UnimplementedError("no implementation available"); + } }; -MEDIAPIPE_REGISTER_NODE(InferenceCalculator); - -mediapipe::Status InferenceCalculator::UpdateContract(CalculatorContract* cc) { - const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); - RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) - << "Either model as side packet or model path in options is required."; - - if (ShouldUseGpu(options)) { -#if MEDIAPIPE_TFLITE_GL_INFERENCE - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#elif MEDIAPIPE_TFLITE_METAL_INFERENCE - MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); -#endif - } - return mediapipe::OkStatus(); -} - -mediapipe::Status InferenceCalculator::Open(CalculatorContext* cc) { -#if MEDIAPIPE_TFLITE_GL_INFERENCE || MEDIAPIPE_TFLITE_METAL_INFERENCE - const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); - if (ShouldUseGpu(options)) { -#if MEDIAPIPE_TFLITE_GL_INFERENCE - use_advanced_gpu_api_ = options.has_delegate() && - options.delegate().has_gpu() && - options.delegate().gpu().use_advanced_gpu_api(); - allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); - tflite_gpu_runner_api_ = options.delegate().gpu().api(); - use_kernel_caching_ = - use_advanced_gpu_api_ && options.delegate().gpu().use_kernel_caching(); -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE - use_gpu_delegate_ = !use_advanced_gpu_api_; - } -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE || MEDIAPIPE_TFLITE_METAL_INFERENCE - - if (use_kernel_caching_) { -#if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) - cached_kernel_filename_ = - "/sdcard/" + mediapipe::File::Basename(options.model_path()) + ".ker"; -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID - } - - // When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner - // for everything. - if (!use_advanced_gpu_api_) { - MP_RETURN_IF_ERROR(LoadModel(cc)); - } - - if (use_gpu_delegate_ || use_advanced_gpu_api_) { -#if MEDIAPIPE_TFLITE_GL_INFERENCE - MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { - return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) - : LoadDelegate(cc); - })); -#elif MEDIAPIPE_TFLITE_METAL_INFERENCE - gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; - RET_CHECK(gpu_helper_); - MP_RETURN_IF_ERROR(LoadDelegate(cc)); -#endif - } else { - MP_RETURN_IF_ERROR(LoadDelegate(cc)); - } - return mediapipe::OkStatus(); -} - -mediapipe::Status InferenceCalculator::Process(CalculatorContext* cc) { - if (kInTensors(cc).IsEmpty()) { - return mediapipe::OkStatus(); - } - const auto& input_tensors = *kInTensors(cc); - RET_CHECK(!input_tensors.empty()); - auto output_tensors = absl::make_unique>(); -#if MEDIAPIPE_TFLITE_METAL_INFERENCE - id command_buffer; - id compute_encoder; -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE - - if (use_gpu_delegate_ || use_advanced_gpu_api_) { -#if MEDIAPIPE_TFLITE_GL_INFERENCE - if (use_advanced_gpu_api_) { - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &input_tensors, &output_tensors]() -> ::mediapipe::Status { - for (int i = 0; i < input_tensors.size(); ++i) { - MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( - input_tensors[i].GetOpenGlBufferReadView().name(), i)); - } - output_tensors->reserve(output_shapes_.size()); - for (int i = 0; i < output_shapes_.size(); ++i) { - output_tensors->emplace_back(Tensor::ElementType::kFloat32, - output_shapes_[i]); - MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor( - output_tensors->back().GetOpenGlBufferWriteView().name(), i)); - } - return mediapipe::OkStatus(); - })); - } else { - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &input_tensors]() -> ::mediapipe::Status { - // Explicitly copy input. - for (int i = 0; i < input_tensors.size(); ++i) { - glBindBuffer(GL_COPY_READ_BUFFER, - input_tensors[i].GetOpenGlBufferReadView().name()); - glBindBuffer( - GL_COPY_WRITE_BUFFER, - gpu_buffers_in_[i]->GetOpenGlBufferWriteView().name()); - glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, - 0, input_tensors[i].bytes()); - } - return mediapipe::OkStatus(); - })); - } -#elif MEDIAPIPE_TFLITE_METAL_INFERENCE - command_buffer = [gpu_helper_ commandBuffer]; - command_buffer.label = @"InferenceCalculator"; - compute_encoder = [command_buffer computeCommandEncoder]; - // Explicit copy input with conversion float 32 bits to 16 bits. - for (int i = 0; i < input_tensors.size(); ++i) { - auto input_view = input_tensors[i].GetMtlBufferReadView(command_buffer); - // Reshape tensor. - tflite::gpu::BHWC shape = BhwcFromTensorShape(input_tensors[i].shape()); - auto gpu_buffer_view = - gpu_buffers_in_[i]->GetMtlBufferWriteView(command_buffer); - [converter_to_BPHWC4_ convertWithEncoder:compute_encoder - shape:shape - sourceBuffer:input_view.buffer() - convertedBuffer:gpu_buffer_view.buffer()]; - } -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE - } else { - // Read CPU input into tensors. - for (int i = 0; i < input_tensors.size(); ++i) { - const Tensor* input_tensor = &input_tensors[i]; - auto input_tensor_view = input_tensor->GetCpuReadView(); - auto input_tensor_buffer = input_tensor_view.buffer(); - float* local_tensor_buffer = interpreter_->typed_input_tensor(i); - std::memcpy(local_tensor_buffer, input_tensor_buffer, - input_tensor->bytes()); - } - } - - // Run inference. -#if MEDIAPIPE_TFLITE_GL_INFERENCE - if (use_advanced_gpu_api_) { - RET_CHECK(tflite_gpu_runner_->Invoke().ok()); - } else { - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); - } -#else -#if MEDIAPIPE_TFLITE_METAL_INFERENCE - if (use_gpu_delegate_) { - RET_CHECK( - TFLGpuDelegateSetCommandEncoder(delegate_.get(), compute_encoder)); - } -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE - RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE - - if (use_gpu_delegate_ || use_advanced_gpu_api_) { -#if MEDIAPIPE_TFLITE_GL_INFERENCE - if (use_gpu_delegate_) { - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &output_tensors]() -> ::mediapipe::Status { - output_tensors->reserve(output_shapes_.size()); - for (int i = 0; i < output_shapes_.size(); ++i) { - const auto& t = gpu_buffers_out_[i]; - output_tensors->emplace_back(Tensor::ElementType::kFloat32, - gpu_buffers_out_[i]->shape()); - auto read_view = t->GetOpenGlBufferReadView(); - glBindBuffer(GL_COPY_READ_BUFFER, read_view.name()); - auto write_view = - output_tensors->back().GetOpenGlBufferWriteView(); - glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name()); - glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, - 0, t->bytes()); - } - return mediapipe::OkStatus(); - })); - } - // Output tensors are already bound if use_advanced_gpu_api_ is true. -#elif MEDIAPIPE_TFLITE_METAL_INFERENCE - output_tensors->reserve(output_shapes_.size()); - for (int i = 0; i < output_shapes_.size(); ++i) { - output_tensors->emplace_back(Tensor::ElementType::kFloat32, - output_shapes_[i]); - // Reshape tensor. - tflite::gpu::BHWC shape = BhwcFromTensorShape(output_shapes_[i]); - auto read_view = - gpu_buffers_out_[i]->GetMtlBufferReadView(command_buffer); - auto write_view = - output_tensors->at(i).GetMtlBufferWriteView(command_buffer); - [converter_from_BPHWC4_ convertWithEncoder:compute_encoder - shape:shape - sourceBuffer:read_view.buffer() - convertedBuffer:write_view.buffer()]; - } - [compute_encoder endEncoding]; - [command_buffer commit]; -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE - } else { - // Output result tensors (CPU). - const auto& tensor_indexes = interpreter_->outputs(); - output_tensors->reserve(tensor_indexes.size()); - for (int i = 0; i < tensor_indexes.size(); ++i) { - TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); - output_tensors->emplace_back( - Tensor::ElementType::kFloat32, - Tensor::Shape{std::vector{ - tensor->dims->data, tensor->dims->data + tensor->dims->size}}); - auto cpu_view = output_tensors->back().GetCpuWriteView(); - std::memcpy(cpu_view.buffer(), tensor->data.f, - output_tensors->back().bytes()); - } - } - kOutTensors(cc).Send(std::move(output_tensors)); - return mediapipe::OkStatus(); -} - -mediapipe::Status InferenceCalculator::WriteKernelsToFile() { -#if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) - if (use_kernel_caching_) { - // Save kernel file. - auto kernel_cache = absl::make_unique>( - tflite_gpu_runner_->GetSerializedBinaryCache()); - std::string cache_str(kernel_cache->begin(), kernel_cache->end()); - MP_RETURN_IF_ERROR( - mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); - } -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID - return mediapipe::OkStatus(); -} - -mediapipe::Status InferenceCalculator::Close(CalculatorContext* cc) { - MP_RETURN_IF_ERROR(WriteKernelsToFile()); -#if MEDIAPIPE_TFLITE_GL_INFERENCE - if (use_gpu_delegate_) { - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { - gpu_buffers_in_.clear(); - gpu_buffers_out_.clear(); - return mediapipe::OkStatus(); - })); - } -#elif MEDIAPIPE_TFLITE_METAL_INFERENCE - converter_to_BPHWC4_ = nil; - converter_from_BPHWC4_ = nil; - gpu_buffers_in_.clear(); - gpu_buffers_out_.clear(); -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE - -#if defined(MEDIAPIPE_EDGE_TPU) - edgetpu_context_.reset(); -#endif - interpreter_ = nullptr; - delegate_ = nullptr; - return mediapipe::OkStatus(); -} - -mediapipe::Status InferenceCalculator::ReadKernelsFromFile() { -#if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) - if (use_kernel_caching_) { - // Load pre-compiled kernel file. - if (mediapipe::File::Exists(cached_kernel_filename_)) { - std::string cache_str; - MP_RETURN_IF_ERROR( - mediapipe::file::GetContents(cached_kernel_filename_, &cache_str)); - std::vector cache_vec(cache_str.begin(), cache_str.end()); - tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec)); - } - } -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID - return mediapipe::OkStatus(); -} - -mediapipe::Status InferenceCalculator::InitTFLiteGPURunner( +absl::StatusOr> InferenceCalculator::GetModelAsPacket( CalculatorContext* cc) { -#if MEDIAPIPE_TFLITE_GL_INFERENCE - ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); - const auto& model = *model_packet_.Get(); - tflite::ops::builtin::BuiltinOpResolver op_resolver = - kSideInCustomOpResolver(cc).GetOr( - tflite::ops::builtin::BuiltinOpResolver()); - - // Create runner - tflite::gpu::InferenceOptions options; - options.priority1 = allow_precision_loss_ - ? tflite::gpu::InferencePriority::MIN_LATENCY - : tflite::gpu::InferencePriority::MAX_PRECISION; - options.priority2 = tflite::gpu::InferencePriority::AUTO; - options.priority3 = tflite::gpu::InferencePriority::AUTO; - options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; - tflite_gpu_runner_ = std::make_unique(options); - switch (tflite_gpu_runner_api_) { - case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: { - tflite_gpu_runner_->ForceOpenGL(); - break; - } - case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENCL: { - tflite_gpu_runner_->ForceOpenCL(); - break; - } - case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: { - // Do not need to force any specific API. - break; - } - } - MP_RETURN_IF_ERROR( - tflite_gpu_runner_->InitializeWithModel(model, op_resolver)); - - // Create and bind OpenGL buffers for outputs. - // The buffers are created once and their ids are passed to calculator outputs - output_shapes_.resize(tflite_gpu_runner_->outputs_size()); - for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) { - output_shapes_[i] = {tflite_gpu_runner_->GetOutputShapes()[i].b, - tflite_gpu_runner_->GetOutputShapes()[i].h, - tflite_gpu_runner_->GetOutputShapes()[i].w, - tflite_gpu_runner_->GetOutputShapes()[i].c}; - } - - MP_RETURN_IF_ERROR(ReadKernelsFromFile()); - - MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE - - return mediapipe::OkStatus(); -} - -mediapipe::Status InferenceCalculator::LoadModel(CalculatorContext* cc) { - ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); - const auto& model = *model_packet_.Get(); - tflite::ops::builtin::BuiltinOpResolver op_resolver = - kSideInCustomOpResolver(cc).GetOr( - tflite::ops::builtin::BuiltinOpResolver()); - -#if defined(MEDIAPIPE_EDGE_TPU) - interpreter_ = - BuildEdgeTpuInterpreter(model, &op_resolver, edgetpu_context_.get()); -#else - tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); -#endif // MEDIAPIPE_EDGE_TPU - RET_CHECK(interpreter_); - -#if defined(__EMSCRIPTEN__) || defined(MEDIAPIPE_EDGE_TPU) - interpreter_->SetNumThreads(1); -#else - interpreter_->SetNumThreads( - cc->Options().cpu_num_thread()); -#endif // __EMSCRIPTEN__ - - RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); - // TODO: Support quantized tensors. - CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != - kTfLiteAffineQuantization); - - return mediapipe::OkStatus(); -} - -mediapipe::StatusOr InferenceCalculator::GetModelAsPacket( - const CalculatorContext& cc) { - const auto& options = cc.Options(); + const auto& options = cc->Options(); if (!options.model_path().empty()) { return TfLiteModelLoader::LoadFromPath(options.model_path()); } - if (cc.InputSidePackets().HasTag("MODEL")) { - return cc.InputSidePackets().Tag("MODEL"); - } - return mediapipe::Status( - mediapipe::StatusCode::kNotFound, - "Must specify TFLite model as path or loaded model."); -} - -mediapipe::Status InferenceCalculator::LoadDelegate(CalculatorContext* cc) { - const auto& calculator_opts = - cc->Options(); - if (calculator_opts.has_delegate() && - calculator_opts.delegate().has_tflite()) { - // Default tflite inference requeqsted - no need to modify graph. - return mediapipe::OkStatus(); - } - - if (!use_gpu_delegate_) { -#if defined(MEDIAPIPE_ANDROID) - const bool nnapi_requested = calculator_opts.has_delegate() - ? calculator_opts.delegate().has_nnapi() - : calculator_opts.use_nnapi(); - if (nnapi_requested) { - // Attempt to use NNAPI. - // If not supported, the default CPU delegate will be created and used. - interpreter_->SetAllowFp16PrecisionForFp32(1); - delegate_ = - TfLiteDelegatePtr(tflite::NnApiDelegate(), [](TfLiteDelegate*) { - // No need to free according to tflite::NnApiDelegate() - // documentation. - }); - RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), - kTfLiteOk); - return mediapipe::OkStatus(); - } -#endif // MEDIAPIPE_ANDROID - -#if defined(__EMSCRIPTEN__) - const bool xnnpack_requested = true; -#else - const bool xnnpack_requested = calculator_opts.has_delegate() && - calculator_opts.delegate().has_xnnpack(); -#endif // __EMSCRIPTEN__ - -#if !defined(MEDIAPIPE_EDGE_TPU) - if (xnnpack_requested) { - TfLiteXNNPackDelegateOptions xnnpack_opts{}; - xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts); - delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), - &TfLiteXNNPackDelegateDelete); - RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), - kTfLiteOk); - return mediapipe::OkStatus(); - } -#endif // !EDGETPU - - // Return, no need for GPU delegate below. - return mediapipe::OkStatus(); - } else { -#if MEDIAPIPE_TFLITE_GL_INFERENCE - // Configure and create the delegate. - TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); - options.compile_options.precision_loss_allowed = 1; - options.compile_options.preferred_gl_object_type = - TFLITE_GL_OBJECT_TYPE_FASTEST; - options.compile_options.dynamic_batch_enabled = 0; - options.compile_options.inline_parameters = 1; - delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options), - &TfLiteGpuDelegateDelete); - - // Get input image sizes. - const auto& input_indices = interpreter_->inputs(); - for (int i = 0; i < input_indices.size(); ++i) { - const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); - gpu_buffers_in_.emplace_back(absl::make_unique( - Tensor::ElementType::kFloat32, - Tensor::Shape{std::vector{ - tensor->dims->data, tensor->dims->data + tensor->dims->size}})); - RET_CHECK_EQ( - TfLiteGpuDelegateBindBufferToTensor( - delegate_.get(), - gpu_buffers_in_.back()->GetOpenGlBufferWriteView().name(), - interpreter_->inputs()[i]), - kTfLiteOk); - } - interpreter_->SetAllowBufferHandleOutput(true); - // Get output image sizes. - const auto& output_indices = interpreter_->outputs(); - output_shapes_.resize(output_indices.size()); - // Create and bind output buffers. - for (int i = 0; i < output_shapes_.size(); ++i) { - const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); - gpu_buffers_out_.emplace_back(absl::make_unique( - Tensor::ElementType::kFloat32, - Tensor::Shape{std::vector{ - tensor->dims->data, tensor->dims->data + tensor->dims->size}})); - RET_CHECK_EQ( - TfLiteGpuDelegateBindBufferToTensor( - delegate_.get(), - gpu_buffers_out_.back()->GetOpenGlBufferWriteView().name(), - output_indices[i]), - kTfLiteOk); - } - - // Must call this last. - RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), - kTfLiteOk); -#elif MEDIAPIPE_TFLITE_METAL_INFERENCE - // Configure and create the delegate. - TFLGpuDelegateOptions options; - options.allow_precision_loss = true; - options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait; - delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), - &TFLGpuDelegateDelete); - RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), - kTfLiteOk); - id device = gpu_helper_.mtlDevice; - - // Get input image sizes. - const auto& input_indices = interpreter_->inputs(); - for (int i = 0; i < input_indices.size(); ++i) { - const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); - // Create and bind input buffer. - std::vector dims{tensor->dims->data, - tensor->dims->data + tensor->dims->size}; - dims.back() = RoundUp(dims.back(), 4); - gpu_buffers_in_.emplace_back(absl::make_unique( - Tensor::ElementType::kFloat16, Tensor::Shape{dims})); - auto buffer_view = - gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice); - RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( - delegate_.get(), input_indices[i], buffer_view.buffer()), - true); - } - - interpreter_->SetAllowBufferHandleOutput(true); - // Get output image sizes. - const auto& output_indices = interpreter_->outputs(); - output_shapes_.resize(output_indices.size()); - for (int i = 0; i < output_shapes_.size(); ++i) { - const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); - RET_CHECK(tensor->dims->size <= 4); - // Create and bind output buffers. - // Channels are always padded to multiple of 4. - std::vector dims{tensor->dims->data, - tensor->dims->data + tensor->dims->size}; - output_shapes_[i] = {dims}; - dims.back() = RoundUp(dims.back(), 4); - gpu_buffers_out_.emplace_back(absl::make_unique( - Tensor::ElementType::kFloat16, Tensor::Shape{dims})); - RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( - delegate_.get(), output_indices[i], - gpu_buffers_out_[i] - ->GetMtlBufferWriteView(gpu_helper_.mtlDevice) - .buffer()), - true); - } - - // Create converter for GPU input. - converter_to_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device - isFloat16:true - convertToPBHWC4:true]; - if (converter_to_BPHWC4_ == nil) { - return mediapipe::InternalError( - "Error initializating input buffer converter"); - } - // Create converter for GPU output. - converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device - isFloat16:true - convertToPBHWC4:false]; - if (converter_from_BPHWC4_ == nil) { - return mediapipe::InternalError( - "Error initializating output buffer converter"); - } -#endif // MEDIAPIPE_TFLITE_GL_INFERENCE - } - - return mediapipe::OkStatus(); + if (!kSideInModel(cc).IsEmpty()) return kSideInModel(cc); + return absl::Status(mediapipe::StatusCode::kNotFound, + "Must specify TFLite model as path or loaded model."); } } // namespace api2 diff --git a/mediapipe/calculators/tensor/inference_calculator.h b/mediapipe/calculators/tensor/inference_calculator.h new file mode 100644 index 000000000..a746684ff --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator.h @@ -0,0 +1,136 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_H_ +#define MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_H_ + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/util/tflite/tflite_model_loader.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" + +namespace mediapipe { +namespace api2 { + +// Runs inference on the provided input Tensors and TFLite model. +// +// Creates an interpreter with given model and calls invoke(). +// Optionally run inference on CPU/GPU. +// +// This calculator can be used with TensorConverterCalculator to get the +// appropriate inputs. +// +// When the input tensors are on CPU, gpu inference is optional and can be +// specified in the calculator options. +// When the input tensors are on GPU, inference is GPU and output can be CPU or +// GPU. +// +// Input: +// TENSORS - Vector of Tensors +// +// Output: +// TENSORS - Vector of Tensors +// +// Input side packet: +// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver, +// instead of the builtin one. +// MODEL (optional) - Use to specify TfLite model +// (std::unique_ptr>) +// +// Example use: +// node { +// calculator: "InferenceCalculator" +// input_stream: "TENSORS:tensor_image" +// output_stream: "TENSORS:tensors" +// options: { +// [mediapipe.InferenceCalculatorOptions.ext] { +// model_path: "modelname.tflite" +// } +// } +// } +// +// or +// +// node { +// calculator: "InferenceCalculator" +// input_stream: "TENSORS:tensor_image" +// input_side_packet: "MODEL:model" +// output_stream: "TENSORS:tensors" +// options: { +// [mediapipe.InferenceCalculatorOptions.ext] { +// model_path: "modelname.tflite" +// delegate { gpu {} } +// } +// } +// } +// +// IMPORTANT Notes: +// Tensors are assumed to be ordered correctly (sequentially added to model). +// Input tensors are assumed to be of the correct size and already normalized. + +class InferenceCalculator : public NodeIntf { + public: + static constexpr Input> kInTensors{"TENSORS"}; + static constexpr SideInput::Optional + kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; + static constexpr SideInput::Optional kSideInModel{"MODEL"}; + static constexpr Output> kOutTensors{"TENSORS"}; + MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel, + kOutTensors); + + protected: + using TfLiteDelegatePtr = + std::unique_ptr>; + + absl::StatusOr> GetModelAsPacket( + CalculatorContext* cc); +}; + +struct InferenceCalculatorSelector : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculator"; +}; + +struct InferenceCalculatorGl : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculatorGl"; +}; + +struct InferenceCalculatorMlDrift : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculatorMlDrift"; +}; + +struct InferenceCalculatorMetal : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculatorMetal"; +}; + +struct InferenceCalculatorCpu : public InferenceCalculator { + static constexpr char kCalculatorName[] = "InferenceCalculatorCpu"; +}; + +} // namespace api2 +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TENSOR_INFERENCE_CALCULATOR_H_ diff --git a/mediapipe/calculators/tensor/inference_calculator.proto b/mediapipe/calculators/tensor/inference_calculator.proto index 07201f9d5..0efb61d4a 100644 --- a/mediapipe/calculators/tensor/inference_calculator.proto +++ b/mediapipe/calculators/tensor/inference_calculator.proto @@ -65,7 +65,8 @@ message InferenceCalculatorOptions { // Load pre-compiled serialized binary cache to accelerate init process. // Only available for OpenCL delegate on Android. - optional bool use_kernel_caching = 2 [default = false]; + // Kernel caching will only be enabled if this path is set. + optional string cached_kernel_path = 2; } // Android only. message Nnapi {} @@ -104,7 +105,11 @@ message InferenceCalculatorOptions { optional int32 cpu_num_thread = 4 [default = -1]; // TfLite delegate to run inference. - // NOTE: calculator is free to choose delegate if not specified explicitly. + // If not specified, TFLite GPU delegate is used by default (as if "gpu {}" + // is specified) unless GPU support is disabled in the build (i.e., with + // --define MEDIAPIPE_DISABLE_GPU=1), in which case regular TFLite on CPU is + // used (as if "tflite {}" is specified) except when building with emscripten + // where xnnpack is used. // NOTE: use_gpu/use_nnapi are ignored if specified. (Delegate takes // precedence over use_* deprecated options.) optional Delegate delegate = 5; diff --git a/mediapipe/calculators/tensor/inference_calculator_cpu.cc b/mediapipe/calculators/tensor/inference_calculator_cpu.cc new file mode 100644 index 000000000..d931b93fa --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_cpu.cc @@ -0,0 +1,205 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/inference_calculator.h" + +#if defined(MEDIAPIPE_ANDROID) +#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" +#endif // ANDROID + +#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) +#include "mediapipe/util/cpu_util.h" +#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ + +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" + +namespace mediapipe { +namespace api2 { + +namespace { + +// Returns number of threads to configure XNNPACK delegate with. +// (Equal to user provided value if specified. Otherwise, it returns number of +// high cores (hard-coded to 1 for Emscripten without Threads extension)) +int GetXnnpackNumThreads(const mediapipe::InferenceCalculatorOptions& opts) { + static constexpr int kDefaultNumThreads = -1; + if (opts.has_delegate() && opts.delegate().has_xnnpack() && + opts.delegate().xnnpack().num_threads() != kDefaultNumThreads) { + return opts.delegate().xnnpack().num_threads(); + } +#if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__) + return InferHigherCoreIds().size(); +#else + return 1; +#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ +} + +} // namespace + +class InferenceCalculatorCpuImpl + : public NodeImpl { + public: + static absl::Status UpdateContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status LoadModel(CalculatorContext* cc); + absl::Status LoadDelegate(CalculatorContext* cc); + + // TfLite requires us to keep the model alive as long as the interpreter is. + Packet model_packet_; + std::unique_ptr interpreter_; + TfLiteDelegatePtr delegate_; +}; + +absl::Status InferenceCalculatorCpuImpl::UpdateContract( + CalculatorContract* cc) { + const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); + RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) + << "Either model as side packet or model path in options is required."; + + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadModel(cc)); + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { + if (kInTensors(cc).IsEmpty()) { + return absl::OkStatus(); + } + const auto& input_tensors = *kInTensors(cc); + RET_CHECK(!input_tensors.empty()); + auto output_tensors = absl::make_unique>(); + + // Read CPU input into tensors. + for (int i = 0; i < input_tensors.size(); ++i) { + const Tensor* input_tensor = &input_tensors[i]; + auto input_tensor_view = input_tensor->GetCpuReadView(); + auto input_tensor_buffer = input_tensor_view.buffer(); + float* local_tensor_buffer = interpreter_->typed_input_tensor(i); + std::memcpy(local_tensor_buffer, input_tensor_buffer, + input_tensor->bytes()); + } + + // Run inference. + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + + // Output result tensors (CPU). + const auto& tensor_indexes = interpreter_->outputs(); + output_tensors->reserve(tensor_indexes.size()); + for (int i = 0; i < tensor_indexes.size(); ++i) { + TfLiteTensor* tensor = interpreter_->tensor(tensor_indexes[i]); + output_tensors->emplace_back( + Tensor::ElementType::kFloat32, + Tensor::Shape{std::vector{ + tensor->dims->data, tensor->dims->data + tensor->dims->size}}); + auto cpu_view = output_tensors->back().GetCpuWriteView(); + std::memcpy(cpu_view.buffer(), tensor->data.f, + output_tensors->back().bytes()); + } + kOutTensors(cc).Send(std::move(output_tensors)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::Close(CalculatorContext* cc) { + interpreter_ = nullptr; + delegate_ = nullptr; + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) { + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); + const auto& model = *model_packet_.Get(); + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolver()); + + tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); + RET_CHECK(interpreter_); + +#if defined(__EMSCRIPTEN__) + interpreter_->SetNumThreads(1); +#else + interpreter_->SetNumThreads( + cc->Options().cpu_num_thread()); +#endif // __EMSCRIPTEN__ + + RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + // TODO: Support quantized tensors. + CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != + kTfLiteAffineQuantization); + + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { + const auto& calculator_opts = + cc->Options(); + if (calculator_opts.has_delegate() && + calculator_opts.delegate().has_tflite()) { + // Default tflite inference requeqsted - no need to modify graph. + return absl::OkStatus(); + } + +#if defined(MEDIAPIPE_ANDROID) + const bool nnapi_requested = calculator_opts.has_delegate() + ? calculator_opts.delegate().has_nnapi() + : calculator_opts.use_nnapi(); + if (nnapi_requested) { + // Attempt to use NNAPI. + // If not supported, the default CPU delegate will be created and used. + interpreter_->SetAllowFp16PrecisionForFp32(1); + delegate_ = TfLiteDelegatePtr(tflite::NnApiDelegate(), [](TfLiteDelegate*) { + // No need to free according to tflite::NnApiDelegate() documentation. + }); + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), + kTfLiteOk); + return absl::OkStatus(); + } +#endif // MEDIAPIPE_ANDROID + +#if defined(__EMSCRIPTEN__) + const bool use_xnnpack = true; +#else + const bool use_xnnpack = calculator_opts.has_delegate() && + calculator_opts.delegate().has_xnnpack(); +#endif // defined(__EMSCRIPTEN__) + + if (use_xnnpack) { + TfLiteXNNPackDelegateOptions xnnpack_opts{}; + xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts); + delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), + &TfLiteXNNPackDelegateDelete); + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), + kTfLiteOk); + } + + return absl::OkStatus(); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc b/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc new file mode 100644 index 000000000..5fb7c974a --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_face_detection_test.cc @@ -0,0 +1,186 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/memory/memory.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/graph_test_base.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT +#include "mediapipe/framework/tool/subgraph_expansion.h" +#include "mediapipe/framework/tool/test_util.h" + +namespace mediapipe { +namespace api2 { +namespace { + +using mediapipe::Detection; +using mediapipe::InferenceCalculatorOptions_Delegate; +using testing::ElementsAre; +using testing::EqualsProto; +using testing::proto::Approximately; + +struct Param { + std::string name; // Appended to the test name. + std::string impl_suffix; // Expected InferenceCalculator backend. + InferenceCalculatorOptions_Delegate delegate; +}; + +const std::vector& GetParams() { + static auto all_params = [] { + static std::vector p; + p.push_back({"TfLite", "Cpu"}); + p.back().delegate.mutable_tflite(); +#if TARGET_OS_IPHONE && !TARGET_IPHONE_SIMULATOR + // Metal is not available on the iOS simulator. + p.push_back({"Metal", "Metal"}); + p.back().delegate.mutable_gpu(); +#endif // TARGET_IPHONE_SIMULATOR +#if __EMSCRIPTEN__ + p.push_back({"MlDrift", "MlDrift"}); + p.back().delegate.mutable_gpu(); +#endif // __EMSCRIPTEN__ +#if __ANDROID__ && 0 // Disabled for now since emulator can't go GLESv3 + p.push_back({"Gl", "Gl"}); + p.back().delegate.mutable_gpu(); + // This requires API level 27 + p.push_back({"NnApi", "Cpu"}); + p.back().delegate.mutable_nnapi(); +#endif // __ANDROID__ + p.push_back({"XnnPack", "Cpu"}); + p.back().delegate.mutable_xnnpack(); + return p; + }(); + return all_params; +} + +class InferenceCalculatorTest : public testing::TestWithParam { + protected: +#if __EMSCRIPTEN__ + // TODO: fix Tensor locking. + // The MlDrift backend currently fails in debug mode without this, + // because of Tensor locking issues. I am adding this temporarily since + // the calculator is already being used and it's better to have test + // coverage for it. Also, the issue doesn't apply to our Emscripten + // build in practice since it's single-threaded. + void SetUp(void) override { + absl::SetMutexDeadlockDetectionMode(absl::OnDeadlockCycle::kIgnore); + } +#endif // __EMSCRIPTEN__ + + void SetDelegateForParam(mediapipe::CalculatorGraphConfig_Node* node) { + *node->mutable_options() + ->MutableExtension(mediapipe::InferenceCalculatorOptions::ext) + ->mutable_delegate() = GetParam().delegate; + } +}; + +TEST_P(InferenceCalculatorTest, TestBackendSelection) { + CalculatorGraphConfig config; + auto node = config.add_node(); + node->set_calculator("InferenceCalculator"); + SetDelegateForParam(node); + MP_ASSERT_OK(tool::ExpandSubgraphs(&config)); + EXPECT_EQ(config.node(0).calculator(), + absl::StrCat("InferenceCalculator", GetParam().impl_suffix)); +} + +TEST_P(InferenceCalculatorTest, TestFaceDetection) { + CalculatorGraphConfig config; + ASSERT_TRUE(LoadTestGraph( + &config, file::JoinPath(GetTestRootDir(), + "mediapipe/calculators/tensor/" + "testdata/face_detection_test.binarypb"))); + + // Expand subgraphs to find any nested instances of InferenceCalculator. + MP_ASSERT_OK(tool::ExpandSubgraphs(&config)); + int found = 0; + for (auto& node : *config.mutable_node()) { + // The InferenceCalculator subgraph itself will have expanded to a specific + // implementation. Replace it. + // TODO: make it possible to exclude it from expansion above. + if (absl::StartsWith(node.calculator(), "InferenceCalculator")) { + ++found; + node.set_calculator("InferenceCalculator"); + SetDelegateForParam(&node); + } + } + ASSERT_EQ(found, 1); + + std::vector detection_packets; + tool::AddVectorSink("detections", &config, &detection_packets); + std::vector rendering_packets; + tool::AddVectorSink("rendering", &config, &rendering_packets); + + // Load test image. + std::unique_ptr input_image = LoadTestPng( + file::JoinPath(GetTestRootDir(), "mediapipe/objc/testdata/sergey.png")); + ASSERT_THAT(input_image, testing::NotNull()); + + std::unique_ptr expected_image = + LoadTestPng(file::JoinPath(GetTestRootDir(), + "mediapipe/calculators/tensor/" + "testdata/face_detection_expected.png")); + ASSERT_THAT(expected_image, testing::NotNull()); + + std::string binary; + Detection expected_detection; + MP_ASSERT_OK( + file::GetContents(file::JoinPath(GetTestRootDir(), + "mediapipe/calculators/tensor/" + "testdata/expected_detection.binarypb"), + &binary)); + expected_detection.ParseFromArray(binary.data(), binary.size()); + + // Prepare test inputs. + std::unordered_map> input_streams; + input_streams.insert(std::make_pair("image", std::move(input_image))); + std::string output_stream = "rendering"; + + // Test graph with relaxed color difference tolerance. + // Compare with CPU generated image. + Timestamp ts0 = Timestamp(0); + TestGraphConfig(config, input_streams, output_stream, expected_image, {}, ts0, + 2.0, 2.0, 1.0); + + ASSERT_EQ(detection_packets.size(), 1); + std::vector dets = + detection_packets[0].Get>(); +#if !defined(MEDIAPIPE_PROTO_LITE) + // Approximately is not available with lite protos (b/178137094). + EXPECT_THAT(dets, + ElementsAre(Approximately(EqualsProto(expected_detection)))); +#endif +} + +INSTANTIATE_TEST_SUITE_P(Implementation, InferenceCalculatorTest, + testing::ValuesIn(GetParams()), + [](const testing::TestParamInfo& info) { + return info.param.name; + }); + +} // namespace +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_gl.cc b/mediapipe/calculators/tensor/inference_calculator_gl.cc new file mode 100644 index 000000000..081b12d3c --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_gl.cc @@ -0,0 +1,368 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/inference_calculator.h" +#include "mediapipe/util/tflite/config.h" + +#if MEDIAPIPE_TFLITE_GL_INFERENCE +#include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/util/tflite/tflite_gpu_runner.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/gl_delegate.h" +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE + +#if defined(MEDIAPIPE_ANDROID) +#include "mediapipe/util/android/file/base/file.h" +#include "mediapipe/util/android/file/base/filesystem.h" +#include "mediapipe/util/android/file/base/helpers.h" +#endif // ANDROID + +namespace mediapipe { +namespace api2 { + +class InferenceCalculatorGlImpl + : public NodeImpl { + public: + static absl::Status UpdateContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status ReadKernelsFromFile(); + absl::Status WriteKernelsToFile(); + absl::Status LoadModel(CalculatorContext* cc); + absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status InitTFLiteGPURunner(CalculatorContext* cc); + + // TfLite requires us to keep the model alive as long as the interpreter is. + Packet model_packet_; + std::unique_ptr interpreter_; + TfLiteDelegatePtr delegate_; + +#if MEDIAPIPE_TFLITE_GL_INFERENCE + mediapipe::GlCalculatorHelper gpu_helper_; + std::unique_ptr tflite_gpu_runner_; + bool allow_precision_loss_ = false; + mediapipe::InferenceCalculatorOptions::Delegate::Gpu::API + tflite_gpu_runner_api_; +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE + +#if MEDIAPIPE_TFLITE_GPU_SUPPORTED + std::vector output_shapes_; + std::vector> gpu_buffers_in_; + std::vector> gpu_buffers_out_; +#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED + + bool use_advanced_gpu_api_ = false; + bool use_gpu_delegate_ = false; + + bool use_kernel_caching_ = false; + std::string cached_kernel_filename_; +}; + +absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) { + const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); + RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) + << "Either model as side packet or model path in options is required."; + + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { + const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); + use_advanced_gpu_api_ = options.has_delegate() && + options.delegate().has_gpu() && + options.delegate().gpu().use_advanced_gpu_api(); + allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); + tflite_gpu_runner_api_ = options.delegate().gpu().api(); + use_kernel_caching_ = use_advanced_gpu_api_ && + options.delegate().gpu().has_cached_kernel_path(); + use_gpu_delegate_ = !use_advanced_gpu_api_; + + if (use_kernel_caching_) { +#ifdef MEDIAPIPE_ANDROID + cached_kernel_filename_ = options.delegate().gpu().cached_kernel_path() + + mediapipe::File::Basename(options.model_path()) + + ".ker"; +#endif // MEDIAPIPE_ANDROID + } + + // When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner + // for everything. + if (!use_advanced_gpu_api_) { + MP_RETURN_IF_ERROR(LoadModel(cc)); + } + + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, + &cc]() -> ::mediapipe::Status { + return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) : LoadDelegate(cc); + })); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) { + if (kInTensors(cc).IsEmpty()) { + return absl::OkStatus(); + } + const auto& input_tensors = *kInTensors(cc); + RET_CHECK(!input_tensors.empty()); + auto output_tensors = absl::make_unique>(); + + if (use_advanced_gpu_api_) { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &input_tensors, &output_tensors]() -> ::mediapipe::Status { + for (int i = 0; i < input_tensors.size(); ++i) { + MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor( + input_tensors[i].GetOpenGlBufferReadView().name(), i)); + } + output_tensors->reserve(output_shapes_.size()); + for (int i = 0; i < output_shapes_.size(); ++i) { + output_tensors->emplace_back(Tensor::ElementType::kFloat32, + output_shapes_[i]); + MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor( + output_tensors->back().GetOpenGlBufferWriteView().name(), i)); + } + return absl::OkStatus(); + })); + } else { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &input_tensors]() -> ::mediapipe::Status { + // Explicitly copy input. + for (int i = 0; i < input_tensors.size(); ++i) { + glBindBuffer(GL_COPY_READ_BUFFER, + input_tensors[i].GetOpenGlBufferReadView().name()); + glBindBuffer(GL_COPY_WRITE_BUFFER, + gpu_buffers_in_[i]->GetOpenGlBufferWriteView().name()); + glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, + input_tensors[i].bytes()); + } + return absl::OkStatus(); + })); + } + + // Run inference. + if (use_advanced_gpu_api_) { + RET_CHECK(tflite_gpu_runner_->Invoke().ok()); + } else { + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + } + + if (use_gpu_delegate_) { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( + [this, &output_tensors]() -> ::mediapipe::Status { + output_tensors->reserve(output_shapes_.size()); + for (int i = 0; i < output_shapes_.size(); ++i) { + const auto& t = gpu_buffers_out_[i]; + output_tensors->emplace_back(Tensor::ElementType::kFloat32, + gpu_buffers_out_[i]->shape()); + auto read_view = t->GetOpenGlBufferReadView(); + glBindBuffer(GL_COPY_READ_BUFFER, read_view.name()); + auto write_view = output_tensors->back().GetOpenGlBufferWriteView(); + glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name()); + glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, + t->bytes()); + } + return absl::OkStatus(); + })); + } + // Output tensors are already bound if use_advanced_gpu_api_ is true. + + kOutTensors(cc).Send(std::move(output_tensors)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::WriteKernelsToFile() { +#ifdef MEDIAPIPE_ANDROID + if (use_kernel_caching_) { + // Save kernel file. + auto kernel_cache = absl::make_unique>( + tflite_gpu_runner_->GetSerializedBinaryCache()); + std::string cache_str(kernel_cache->begin(), kernel_cache->end()); + MP_RETURN_IF_ERROR( + mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); + } +#endif // MEDIAPIPE_ANDROID + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) { + MP_RETURN_IF_ERROR(WriteKernelsToFile()); + if (use_gpu_delegate_) { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { + gpu_buffers_in_.clear(); + gpu_buffers_out_.clear(); + return absl::OkStatus(); + })); + } + + interpreter_ = nullptr; + delegate_ = nullptr; + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::ReadKernelsFromFile() { +#ifdef MEDIAPIPE_ANDROID + if (use_kernel_caching_) { + // Load pre-compiled kernel file. + if (mediapipe::File::Exists(cached_kernel_filename_)) { + std::string cache_str; + MP_RETURN_IF_ERROR( + mediapipe::file::GetContents(cached_kernel_filename_, &cache_str)); + std::vector cache_vec(cache_str.begin(), cache_str.end()); + tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec)); + } + } +#endif // MEDIAPIPE_ANDROID + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( + CalculatorContext* cc) { + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); + const auto& model = *model_packet_.Get(); + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolver()); + + // Create runner + tflite::gpu::InferenceOptions options; + options.priority1 = allow_precision_loss_ + ? tflite::gpu::InferencePriority::MIN_LATENCY + : tflite::gpu::InferencePriority::MAX_PRECISION; + options.priority2 = tflite::gpu::InferencePriority::AUTO; + options.priority3 = tflite::gpu::InferencePriority::AUTO; + options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED; + tflite_gpu_runner_ = std::make_unique(options); + switch (tflite_gpu_runner_api_) { + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: { + tflite_gpu_runner_->ForceOpenGL(); + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENCL: { + tflite_gpu_runner_->ForceOpenCL(); + break; + } + case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: { + // Do not need to force any specific API. + break; + } + } + MP_RETURN_IF_ERROR( + tflite_gpu_runner_->InitializeWithModel(model, op_resolver)); + + // Create and bind OpenGL buffers for outputs. + // The buffers are created once and their ids are passed to calculator outputs + output_shapes_.resize(tflite_gpu_runner_->outputs_size()); + for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) { + output_shapes_[i] = {tflite_gpu_runner_->GetOutputShapes()[i].b, + tflite_gpu_runner_->GetOutputShapes()[i].h, + tflite_gpu_runner_->GetOutputShapes()[i].w, + tflite_gpu_runner_->GetOutputShapes()[i].c}; + } + + MP_RETURN_IF_ERROR(ReadKernelsFromFile()); + + MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); + + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); + const auto& model = *model_packet_.Get(); + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolver()); + + tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); + RET_CHECK(interpreter_); + +#if defined(__EMSCRIPTEN__) + interpreter_->SetNumThreads(1); +#else + interpreter_->SetNumThreads( + cc->Options().cpu_num_thread()); +#endif // __EMSCRIPTEN__ + + RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + // TODO: Support quantized tensors. + CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != + kTfLiteAffineQuantization); + + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) { + // Configure and create the delegate. + TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); + options.compile_options.precision_loss_allowed = 1; + options.compile_options.preferred_gl_object_type = + TFLITE_GL_OBJECT_TYPE_FASTEST; + options.compile_options.dynamic_batch_enabled = 0; + options.compile_options.inline_parameters = 1; + delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options), + &TfLiteGpuDelegateDelete); + + // Get input image sizes. + const auto& input_indices = interpreter_->inputs(); + for (int i = 0; i < input_indices.size(); ++i) { + const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + gpu_buffers_in_.emplace_back(absl::make_unique( + Tensor::ElementType::kFloat32, + Tensor::Shape{std::vector{ + tensor->dims->data, tensor->dims->data + tensor->dims->size}})); + RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( + delegate_.get(), + gpu_buffers_in_.back()->GetOpenGlBufferWriteView().name(), + interpreter_->inputs()[i]), + kTfLiteOk); + } + interpreter_->SetAllowBufferHandleOutput(true); + // Get output image sizes. + const auto& output_indices = interpreter_->outputs(); + output_shapes_.resize(output_indices.size()); + // Create and bind output buffers. + for (int i = 0; i < output_shapes_.size(); ++i) { + const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + gpu_buffers_out_.emplace_back(absl::make_unique( + Tensor::ElementType::kFloat32, + Tensor::Shape{std::vector{ + tensor->dims->data, tensor->dims->data + tensor->dims->size}})); + RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( + delegate_.get(), + gpu_buffers_out_.back()->GetOpenGlBufferWriteView().name(), + output_indices[i]), + kTfLiteOk); + } + + // Must call this last. + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), + kTfLiteOk); + + return absl::OkStatus(); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_metal.cc b/mediapipe/calculators/tensor/inference_calculator_metal.cc new file mode 100644 index 000000000..490189aec --- /dev/null +++ b/mediapipe/calculators/tensor/inference_calculator_metal.cc @@ -0,0 +1,293 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import +#import + +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "mediapipe/calculators/tensor/inference_calculator.h" +#import "mediapipe/gpu/MPPMetalHelper.h" +#include "mediapipe/gpu/MPPMetalUtil.h" +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/util/tflite/config.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h" +#include "tensorflow/lite/delegates/gpu/metal_delegate.h" +#include "tensorflow/lite/delegates/gpu/metal_delegate_internal.h" + +namespace { + +// Round up n to next multiple of m. +template +T RoundUp(T n, T m) { + return ((n + m - T{1}) / m) * m; +} + +} // namespace + +namespace mediapipe { +namespace api2 { + +#if MEDIAPIPE_TFLITE_METAL_INFERENCE +namespace { +tflite::gpu::BHWC BhwcFromTensorShape(const Tensor::Shape& shape) { + tflite::gpu::BHWC result; + result.b = shape.dims[0]; + switch (shape.dims.size()) { + case 1: + // result.b is already filled. + break; + case 2: + result.h = 1; + result.w = 1; + result.c = shape.dims[1]; + break; + case 3: + result.h = 1; + result.w = shape.dims[1]; + result.c = shape.dims[2]; + break; + case 4: + result.h = shape.dims[1]; + result.w = shape.dims[2]; + result.c = shape.dims[3]; + break; + default: + // Handles 0 and >4. + LOG(FATAL) + << "Dimensions size must be in range [1,4] for GPU inference, but " + << shape.dims.size() << " is provided"; + } + return result; +} +} // namespace +#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE + +class InferenceCalculatorMetalImpl + : public NodeImpl { + public: + static absl::Status UpdateContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status LoadModel(CalculatorContext* cc); + absl::Status LoadDelegate(CalculatorContext* cc); + + // TfLite requires us to keep the model alive as long as the interpreter is. + Packet model_packet_; + std::unique_ptr interpreter_; + TfLiteDelegatePtr delegate_; + +#if MEDIAPIPE_TFLITE_METAL_INFERENCE + MPPMetalHelper* gpu_helper_ = nullptr; + TFLBufferConvert* converter_to_BPHWC4_ = nil; + TFLBufferConvert* converter_from_BPHWC4_ = nil; +#endif // MEDIAPIPE_TFLITE_GL_INFERENCE + +#if MEDIAPIPE_TFLITE_GPU_SUPPORTED + std::vector output_shapes_; + std::vector> gpu_buffers_in_; + std::vector> gpu_buffers_out_; +#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED +}; + +absl::Status InferenceCalculatorMetalImpl::UpdateContract( + CalculatorContract* cc) { + const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); + RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) + << "Either model as side packet or model path in options is required."; + + MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) { + MP_RETURN_IF_ERROR(LoadModel(cc)); + + gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; + RET_CHECK(gpu_helper_); + MP_RETURN_IF_ERROR(LoadDelegate(cc)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { + if (kInTensors(cc).IsEmpty()) { + return absl::OkStatus(); + } + const auto& input_tensors = *kInTensors(cc); + RET_CHECK(!input_tensors.empty()); + auto output_tensors = absl::make_unique>(); + + id command_buffer; + + command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"InferenceCalculator"; + // Explicit copy input with conversion float 32 bits to 16 bits. + for (int i = 0; i < input_tensors.size(); ++i) { + auto input_view = input_tensors[i].GetMtlBufferReadView(command_buffer); + // Reshape tensor. + tflite::gpu::BHWC shape = BhwcFromTensorShape(input_tensors[i].shape()); + auto gpu_buffer_view = + gpu_buffers_in_[i]->GetMtlBufferWriteView(command_buffer); + id input_encoder = + [command_buffer computeCommandEncoder]; + [converter_to_BPHWC4_ convertWithEncoder:input_encoder + shape:shape + sourceBuffer:input_view.buffer() + convertedBuffer:gpu_buffer_view.buffer()]; + [input_encoder endEncoding]; + } + + // Run inference. + RET_CHECK(TFLGpuDelegateSetCommandBuffer(delegate_.get(), command_buffer)); + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + + output_tensors->reserve(output_shapes_.size()); + for (int i = 0; i < output_shapes_.size(); ++i) { + output_tensors->emplace_back(Tensor::ElementType::kFloat32, + output_shapes_[i]); + // Reshape tensor. + tflite::gpu::BHWC shape = BhwcFromTensorShape(output_shapes_[i]); + auto read_view = gpu_buffers_out_[i]->GetMtlBufferReadView(command_buffer); + auto write_view = + output_tensors->at(i).GetMtlBufferWriteView(command_buffer); + id output_encoder = + [command_buffer computeCommandEncoder]; + [converter_from_BPHWC4_ convertWithEncoder:output_encoder + shape:shape + sourceBuffer:read_view.buffer() + convertedBuffer:write_view.buffer()]; + [output_encoder endEncoding]; + } + [command_buffer commit]; + + kOutTensors(cc).Send(std::move(output_tensors)); + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::Close(CalculatorContext* cc) { + converter_to_BPHWC4_ = nil; + converter_from_BPHWC4_ = nil; + gpu_buffers_in_.clear(); + gpu_buffers_out_.clear(); + interpreter_ = nullptr; + delegate_ = nullptr; + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) { + ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); + const auto& model = *model_packet_.Get(); + tflite::ops::builtin::BuiltinOpResolver op_resolver = + kSideInCustomOpResolver(cc).GetOr( + tflite::ops::builtin::BuiltinOpResolver()); + + tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); + RET_CHECK(interpreter_); + + interpreter_->SetNumThreads( + cc->Options().cpu_num_thread()); + + RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); + // TODO: Support quantized tensors. + CHECK(interpreter_->tensor(interpreter_->inputs()[0])->quantization.type != + kTfLiteAffineQuantization); + + return absl::OkStatus(); +} + +absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { + const auto& calculator_opts = + cc->Options(); + + // Configure and create the delegate. + TFLGpuDelegateOptions options; + options.allow_precision_loss = true; + options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait; + delegate_ = + TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); + RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), + kTfLiteOk); + id device = gpu_helper_.mtlDevice; + + // Get input image sizes. + const auto& input_indices = interpreter_->inputs(); + for (int i = 0; i < input_indices.size(); ++i) { + const TfLiteTensor* tensor = interpreter_->tensor(input_indices[i]); + // Create and bind input buffer. + std::vector dims{tensor->dims->data, + tensor->dims->data + tensor->dims->size}; + dims.back() = RoundUp(dims.back(), 4); + gpu_buffers_in_.emplace_back(absl::make_unique( + Tensor::ElementType::kFloat16, Tensor::Shape{dims})); + auto buffer_view = + gpu_buffers_in_[i]->GetMtlBufferWriteView(gpu_helper_.mtlDevice); + RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( + delegate_.get(), input_indices[i], buffer_view.buffer()), + true); + } + + interpreter_->SetAllowBufferHandleOutput(true); + // Get output image sizes. + const auto& output_indices = interpreter_->outputs(); + output_shapes_.resize(output_indices.size()); + for (int i = 0; i < output_shapes_.size(); ++i) { + const TfLiteTensor* tensor = interpreter_->tensor(output_indices[i]); + RET_CHECK(tensor->dims->size <= 4); + // Create and bind output buffers. + // Channels are always padded to multiple of 4. + std::vector dims{tensor->dims->data, + tensor->dims->data + tensor->dims->size}; + output_shapes_[i] = {dims}; + dims.back() = RoundUp(dims.back(), 4); + gpu_buffers_out_.emplace_back(absl::make_unique( + Tensor::ElementType::kFloat16, Tensor::Shape{dims})); + RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( + delegate_.get(), output_indices[i], + gpu_buffers_out_[i] + ->GetMtlBufferWriteView(gpu_helper_.mtlDevice) + .buffer()), + true); + } + + // Create converter for GPU input. + converter_to_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device + isFloat16:true + convertToPBHWC4:true]; + if (converter_to_BPHWC4_ == nil) { + return mediapipe::InternalError( + "Error initializating input buffer converter"); + } + // Create converter for GPU output. + converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device + isFloat16:true + convertToPBHWC4:false]; + if (converter_from_BPHWC4_ == nil) { + return absl::InternalError("Error initializating output buffer converter"); + } + + return absl::OkStatus(); +} + +} // namespace api2 +} // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_test.cc b/mediapipe/calculators/tensor/inference_calculator_test.cc index 248d799e5..882a5e81e 100644 --- a/mediapipe/calculators/tensor/inference_calculator_test.cc +++ b/mediapipe/calculators/tensor/inference_calculator_test.cc @@ -111,9 +111,9 @@ TEST(InferenceCalculatorTest, SmokeTest) { // Test CPU inference only. DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( graph_proto, {{"$delegate", "delegate { tflite {} }"}})); - DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( - graph_proto, {{"$delegate", "delegate { xnnpack {} }"}})); - DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( + DoSmokeTest(absl::StrReplaceAll(graph_proto, + {{"$delegate", "delegate { xnnpack {} }"}})); + DoSmokeTest(absl::StrReplaceAll( graph_proto, {{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}})); } diff --git a/mediapipe/calculators/tensor/tensor_converter_calculator.cc b/mediapipe/calculators/tensor/tensor_converter_calculator.cc index 4da199c7a..82180fe52 100644 --- a/mediapipe/calculators/tensor/tensor_converter_calculator.cc +++ b/mediapipe/calculators/tensor/tensor_converter_calculator.cc @@ -100,21 +100,21 @@ namespace mediapipe { class TensorConverterCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status InitGpu(CalculatorContext* cc); - mediapipe::Status LoadOptions(CalculatorContext* cc); + absl::Status InitGpu(CalculatorContext* cc); + absl::Status LoadOptions(CalculatorContext* cc); template - mediapipe::Status NormalizeImage(const ImageFrame& image_frame, - bool flip_vertically, float* tensor_ptr); - mediapipe::Status CopyMatrixToTensor(const Matrix& matrix, float* tensor_ptr); - mediapipe::Status ProcessCPU(CalculatorContext* cc); - mediapipe::Status ProcessGPU(CalculatorContext* cc); + absl::Status NormalizeImage(const ImageFrame& image_frame, + bool flip_vertically, float* tensor_ptr); + absl::Status CopyMatrixToTensor(const Matrix& matrix, float* tensor_ptr); + absl::Status ProcessCPU(CalculatorContext* cc); + absl::Status ProcessGPU(CalculatorContext* cc); #if MEDIAPIPE_METAL_ENABLED MPPMetalHelper* gpu_helper_ = nullptr; @@ -139,8 +139,7 @@ class TensorConverterCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TensorConverterCalculator); -mediapipe::Status TensorConverterCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TensorConverterCalculator::GetContract(CalculatorContract* cc) { // Confirm only one of the input streams is present. RET_CHECK(static_cast(cc->Inputs().HasTag(kImageFrameTag)) + static_cast(cc->Inputs().HasTag(kGpuBufferTag)) + @@ -167,10 +166,10 @@ mediapipe::Status TensorConverterCalculator::GetContract( RET_CHECK(cc->Outputs().HasTag(kTensorsTag)); cc->Outputs().Tag(kTensorsTag).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorConverterCalculator::Open(CalculatorContext* cc) { +absl::Status TensorConverterCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); MP_RETURN_IF_ERROR(LoadOptions(cc)); @@ -187,13 +186,13 @@ mediapipe::Status TensorConverterCalculator::Open(CalculatorContext* cc) { } #endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorConverterCalculator::Process(CalculatorContext* cc) { +absl::Status TensorConverterCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { if (cc->Inputs().Tag(kGpuBufferTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Convert to GPU tensors type. MP_RETURN_IF_ERROR(ProcessGPU(cc)); @@ -201,10 +200,10 @@ mediapipe::Status TensorConverterCalculator::Process(CalculatorContext* cc) { // Convert to CPU tensors or Matrix type. MP_RETURN_IF_ERROR(ProcessCPU(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorConverterCalculator::Close(CalculatorContext* cc) { +absl::Status TensorConverterCalculator::Close(CalculatorContext* cc) { #if !MEDIAPIPE_DISABLE_GPU if (use_gpu_) { #if MEDIAPIPE_METAL_ENABLED @@ -221,14 +220,14 @@ mediapipe::Status TensorConverterCalculator::Close(CalculatorContext* cc) { #endif // MEDIAPIPE_METAL_ENABLED } #endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorConverterCalculator::ProcessCPU(CalculatorContext* cc) { +absl::Status TensorConverterCalculator::ProcessCPU(CalculatorContext* cc) { auto output_tensors = absl::make_unique>(); if (cc->Inputs().HasTag(kImageFrameTag)) { if (cc->Inputs().Tag(kImageFrameTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& image_frame = cc->Inputs().Tag(kImageFrameTag).Get(); @@ -257,12 +256,12 @@ mediapipe::Status TensorConverterCalculator::ProcessCPU(CalculatorContext* cc) { MP_RETURN_IF_ERROR(NormalizeImage(image_frame, flip_vertically_, cpu_view.buffer())); } else { - return mediapipe::InternalError( + return absl::InternalError( "Only byte-based (8 bit) and float (32 bit) images supported."); } } else if (cc->Inputs().HasTag(kMatrixTag)) { if (cc->Inputs().Tag(kMatrixTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& matrix = cc->Inputs().Tag(kMatrixTag).Get(); const int height = matrix.rows(); @@ -273,16 +272,16 @@ mediapipe::Status TensorConverterCalculator::ProcessCPU(CalculatorContext* cc) { MP_RETURN_IF_ERROR(CopyMatrixToTensor( matrix, output_tensors->back().GetCpuWriteView().buffer())); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } cc->Outputs() .Tag(kTensorsTag) .Add(output_tensors.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) { +absl::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) { #if !MEDIAPIPE_DISABLE_GPU if (!initialized_) { MP_RETURN_IF_ERROR(InitGpu(cc)); @@ -318,7 +317,7 @@ mediapipe::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) { [command_buffer commit]; #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &output_tensors, &input]() -> mediapipe::Status { + [this, &output_tensors, &input]() -> absl::Status { auto src = gpu_helper_.CreateSourceTexture(input); #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 // Convert GL texture into SSBO. @@ -361,7 +360,7 @@ mediapipe::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) { glBindTexture(GL_TEXTURE_2D, 0); #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 src.Release(); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); #endif // MEDIAPIPE_METAL_ENABLED cc->Outputs() @@ -371,10 +370,10 @@ mediapipe::Status TensorConverterCalculator::ProcessGPU(CalculatorContext* cc) { RET_CHECK_FAIL() << "GPU processing is not enabled."; #endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) { +absl::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) { #if !MEDIAPIPE_DISABLE_GPU // Get input image sizes. const auto& input = @@ -448,7 +447,7 @@ mediapipe::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) { &input, #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 &single_channel]() - -> mediapipe::Status { + -> absl::Status { #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 // Shader to convert GL Texture to Shader Storage Buffer Object (SSBO), // with normalization to either: [0,1] or [-1,1]. @@ -558,15 +557,14 @@ mediapipe::Status TensorConverterCalculator::InitGpu(CalculatorContext* cc) { glGenFramebuffers(1, &framebuffer_); #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - return mediapipe::OkStatus(); + return absl::OkStatus(); })); #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorConverterCalculator::LoadOptions( - CalculatorContext* cc) { +absl::Status TensorConverterCalculator::LoadOptions(CalculatorContext* cc) { // Get calculator options specified in the graph. const auto& options = cc->Options<::mediapipe::TensorConverterCalculatorOptions>(); @@ -604,11 +602,11 @@ mediapipe::Status TensorConverterCalculator::LoadOptions( CHECK_GE(max_num_channels_, 1); CHECK_LE(max_num_channels_, 4); CHECK_NE(max_num_channels_, 2); - return mediapipe::OkStatus(); + return absl::OkStatus(); } template -mediapipe::Status TensorConverterCalculator::NormalizeImage( +absl::Status TensorConverterCalculator::NormalizeImage( const ImageFrame& image_frame, bool flip_vertically, float* tensor_ptr) { const int height = image_frame.Height(); const int width = image_frame.Width(); @@ -652,11 +650,11 @@ mediapipe::Status TensorConverterCalculator::NormalizeImage( } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorConverterCalculator::CopyMatrixToTensor( - const Matrix& matrix, float* tensor_ptr) { +absl::Status TensorConverterCalculator::CopyMatrixToTensor(const Matrix& matrix, + float* tensor_ptr) { if (row_major_matrix_) { auto matrix_map = Eigen::Map(tensor_ptr, matrix.rows(), matrix.cols()); @@ -667,7 +665,7 @@ mediapipe::Status TensorConverterCalculator::CopyMatrixToTensor( matrix_map = matrix; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc index e076e2451..c3b91de71 100644 --- a/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_classification_calculator.cc @@ -16,6 +16,7 @@ #include #include +#include "absl/container/node_hash_map.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h" @@ -66,20 +67,19 @@ class TensorsToClassificationCalculator : public Node { "CLASSIFICATIONS"}; MEDIAPIPE_NODE_CONTRACT(kInTensors, kOutClassificationList); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: ::mediapipe::TensorsToClassificationCalculatorOptions options_; int top_k_ = 0; - std::unordered_map label_map_; + absl::node_hash_map label_map_; bool label_map_loaded_ = false; }; MEDIAPIPE_REGISTER_NODE(TensorsToClassificationCalculator); -mediapipe::Status TensorsToClassificationCalculator::Open( - CalculatorContext* cc) { +absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) { options_ = cc->Options<::mediapipe::TensorsToClassificationCalculatorOptions>(); @@ -100,11 +100,10 @@ mediapipe::Status TensorsToClassificationCalculator::Open( label_map_loaded_ = true; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToClassificationCalculator::Process( - CalculatorContext* cc) { +absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) { const auto& input_tensors = *kInTensors(cc); RET_CHECK_EQ(input_tensors.size(), 1); @@ -168,12 +167,11 @@ mediapipe::Status TensorsToClassificationCalculator::Process( top_k_, raw_classification_list->size() - top_k_); } kOutClassificationList(cc).Send(std::move(classification_list)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToClassificationCalculator::Close( - CalculatorContext* cc) { - return mediapipe::OkStatus(); +absl::Status TensorsToClassificationCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); } } // namespace api2 diff --git a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc index 52c2adb14..1a27cafce 100644 --- a/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_detections_calculator.cc @@ -134,26 +134,27 @@ class TensorsToDetectionsCalculator : public Node { "ANCHORS"}; static constexpr Output> kOutDetections{"DETECTIONS"}; MEDIAPIPE_NODE_CONTRACT(kInTensors, kInAnchors, kOutDetections); - static mediapipe::Status UpdateContract(CalculatorContract* cc); + static absl::Status UpdateContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status ProcessCPU(CalculatorContext* cc, - std::vector* output_detections); - mediapipe::Status ProcessGPU(CalculatorContext* cc, - std::vector* output_detections); + absl::Status ProcessCPU(CalculatorContext* cc, + std::vector* output_detections); + absl::Status ProcessGPU(CalculatorContext* cc, + std::vector* output_detections); - mediapipe::Status LoadOptions(CalculatorContext* cc); - mediapipe::Status GpuInit(CalculatorContext* cc); - mediapipe::Status DecodeBoxes(const float* raw_boxes, - const std::vector& anchors, - std::vector* boxes); - mediapipe::Status ConvertToDetections( - const float* detection_boxes, const float* detection_scores, - const int* detection_classes, std::vector* output_detections); + absl::Status LoadOptions(CalculatorContext* cc); + absl::Status GpuInit(CalculatorContext* cc); + absl::Status DecodeBoxes(const float* raw_boxes, + const std::vector& anchors, + std::vector* boxes); + absl::Status ConvertToDetections(const float* detection_boxes, + const float* detection_scores, + const int* detection_classes, + std::vector* output_detections); Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, int class_id, bool flip_vertically); @@ -179,12 +180,13 @@ class TensorsToDetectionsCalculator : public Node { std::unique_ptr decoded_boxes_buffer_; std::unique_ptr scored_boxes_buffer_; + bool gpu_inited_ = false; bool gpu_input_ = false; bool anchors_init_ = false; }; MEDIAPIPE_REGISTER_NODE(TensorsToDetectionsCalculator); -mediapipe::Status TensorsToDetectionsCalculator::UpdateContract( +absl::Status TensorsToDetectionsCalculator::UpdateContract( CalculatorContract* cc) { if (CanUseGpu()) { #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE @@ -194,10 +196,10 @@ mediapipe::Status TensorsToDetectionsCalculator::UpdateContract( #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) { +absl::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) { MP_RETURN_IF_ERROR(LoadOptions(cc)); if (CanUseGpu()) { @@ -207,14 +209,12 @@ mediapipe::Status TensorsToDetectionsCalculator::Open(CalculatorContext* cc) { gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(gpu_helper_); #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - MP_RETURN_IF_ERROR(GpuInit(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToDetectionsCalculator::Process( - CalculatorContext* cc) { +absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) { auto output_detections = absl::make_unique>(); bool gpu_processing = false; if (CanUseGpu()) { @@ -229,16 +229,20 @@ mediapipe::Status TensorsToDetectionsCalculator::Process( } if (gpu_processing) { + if (!gpu_inited_) { + MP_RETURN_IF_ERROR(GpuInit(cc)); + gpu_inited_ = true; + } MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get())); } else { MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get())); } kOutDetections(cc).Send(std::move(output_detections)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU( +absl::Status TensorsToDetectionsCalculator::ProcessCPU( CalculatorContext* cc, std::vector* output_detections) { const auto& input_tensors = *kInTensors(cc); @@ -275,7 +279,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU( } else if (!kInAnchors(cc).IsEmpty()) { anchors_ = *kInAnchors(cc); } else { - return mediapipe::UnavailableError("No anchor data available."); + return absl::UnavailableError("No anchor data available."); } anchors_init_ = true; } @@ -362,10 +366,10 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessCPU( detection_classes.data(), output_detections)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU( +absl::Status TensorsToDetectionsCalculator::ProcessGPU( CalculatorContext* cc, std::vector* output_detections) { const auto& input_tensors = *kInTensors(cc); RET_CHECK_GE(input_tensors.size(), 2); @@ -373,7 +377,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, &input_tensors, &cc, &output_detections]() - -> mediapipe::Status { + -> absl::Status { if (!anchors_init_) { if (input_tensors.size() == kNumInputTensorsWithAnchors) { auto read_view = input_tensors[2].GetOpenGlBufferReadView(); @@ -388,7 +392,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU( auto raw_anchors = anchors_view.buffer(); ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors); } else { - return mediapipe::UnavailableError("No anchor data available."); + return absl::UnavailableError("No anchor data available."); } anchors_init_ = true; } @@ -414,7 +418,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU( glUseProgram(score_program_); glDispatchCompute(num_boxes_, 1, 1); } - return mediapipe::OkStatus(); + return absl::OkStatus(); })); // TODO: b/138851969. Is it possible to output a float vector @@ -459,7 +463,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU( ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors_view.buffer()); } else { - return mediapipe::UnavailableError("No anchor data available."); + return absl::UnavailableError("No anchor data available."); } anchors_init_ = true; } @@ -520,10 +524,10 @@ mediapipe::Status TensorsToDetectionsCalculator::ProcessGPU( #else LOG(ERROR) << "GPU input on non-Android not supported yet."; #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToDetectionsCalculator::Close(CalculatorContext* cc) { +absl::Status TensorsToDetectionsCalculator::Close(CalculatorContext* cc) { #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE gpu_helper_.RunInGlContext([this] { decoded_boxes_buffer_ = nullptr; @@ -540,11 +544,10 @@ mediapipe::Status TensorsToDetectionsCalculator::Close(CalculatorContext* cc) { score_program_ = nil; #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToDetectionsCalculator::LoadOptions( - CalculatorContext* cc) { +absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = cc->Options<::mediapipe::TensorsToDetectionsCalculatorOptions>(); RET_CHECK(options_.has_num_classes()); @@ -567,10 +570,10 @@ mediapipe::Status TensorsToDetectionsCalculator::LoadOptions( ignore_classes_.insert(options_.ignore_classes(i)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToDetectionsCalculator::DecodeBoxes( +absl::Status TensorsToDetectionsCalculator::DecodeBoxes( const float* raw_boxes, const std::vector& anchors, std::vector* boxes) { for (int i = 0; i < num_boxes_; ++i) { @@ -631,10 +634,10 @@ mediapipe::Status TensorsToDetectionsCalculator::DecodeBoxes( } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToDetectionsCalculator::ConvertToDetections( +absl::Status TensorsToDetectionsCalculator::ConvertToDetections( const float* detection_boxes, const float* detection_scores, const int* detection_classes, std::vector* output_detections) { for (int i = 0; i < num_boxes_; ++i) { @@ -671,7 +674,7 @@ mediapipe::Status TensorsToDetectionsCalculator::ConvertToDetections( } output_detections->emplace_back(detection); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } Detection TensorsToDetectionsCalculator::ConvertToDetection( @@ -694,10 +697,9 @@ Detection TensorsToDetectionsCalculator::ConvertToDetection( return detection; } -mediapipe::Status TensorsToDetectionsCalculator::GpuInit( - CalculatorContext* cc) { +absl::Status TensorsToDetectionsCalculator::GpuInit(CalculatorContext* cc) { #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> mediapipe::Status { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { // A shader to decode detection boxes. const std::string decode_src = absl::Substitute( R"( #version 310 es @@ -801,7 +803,14 @@ void main() { glCompileShader(shader); GLint compiled = GL_FALSE; glGetShaderiv(shader, GL_COMPILE_STATUS, &compiled); - RET_CHECK(compiled == GL_TRUE); + RET_CHECK(compiled == GL_TRUE) << "Shader compilation error: " << [shader] { + GLint length; + glGetShaderiv(shader, GL_INFO_LOG_LENGTH, &length); + std::string str; + str.reserve(length); + glGetShaderInfoLog(shader, length, nullptr, str.data()); + return str; + }(); decode_program_ = glCreateProgram(); glAttachShader(decode_program_, shader); glDeleteShader(shader); @@ -910,7 +919,7 @@ void main() { scored_boxes_buffer_ = absl::make_unique( Tensor::ElementType::kFloat32, Tensor::Shape{1, num_boxes_ * 2}); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); #elif MEDIAPIPE_METAL_ENABLED @@ -1128,7 +1137,7 @@ kernel void scoreKernel( #endif // !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace api2 diff --git a/mediapipe/calculators/tensor/tensors_to_floats_calculator.cc b/mediapipe/calculators/tensor/tensors_to_floats_calculator.cc index a95a9da8d..5ec3b4dea 100644 --- a/mediapipe/calculators/tensor/tensors_to_floats_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_floats_calculator.cc @@ -53,28 +53,27 @@ class TensorsToFloatsCalculator : public Node { MEDIAPIPE_NODE_INTERFACE(TensorsToFloatsCalculator, kInTensors, kOutFloat, kOutFloats); - static mediapipe::Status UpdateContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) final; - mediapipe::Status Process(CalculatorContext* cc) final; + static absl::Status UpdateContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) final; + absl::Status Process(CalculatorContext* cc) final; private: ::mediapipe::TensorsToFloatsCalculatorOptions options_; }; MEDIAPIPE_REGISTER_NODE(TensorsToFloatsCalculator); -mediapipe::Status TensorsToFloatsCalculator::UpdateContract( - CalculatorContract* cc) { +absl::Status TensorsToFloatsCalculator::UpdateContract(CalculatorContract* cc) { // Only exactly a single output allowed. RET_CHECK(kOutFloat(cc).IsConnected() ^ kOutFloats(cc).IsConnected()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToFloatsCalculator::Open(CalculatorContext* cc) { +absl::Status TensorsToFloatsCalculator::Open(CalculatorContext* cc) { options_ = cc->Options<::mediapipe::TensorsToFloatsCalculatorOptions>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToFloatsCalculator::Process(CalculatorContext* cc) { +absl::Status TensorsToFloatsCalculator::Process(CalculatorContext* cc) { const auto& input_tensors = *kInTensors(cc); RET_CHECK(!input_tensors.empty()); // TODO: Add option to specify which tensor to take from. @@ -101,7 +100,7 @@ mediapipe::Status TensorsToFloatsCalculator::Process(CalculatorContext* cc) { } else { kOutFloats(cc).Send(std::move(output_floats)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc index ca69d1344..8e4066bee 100644 --- a/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc +++ b/mediapipe/calculators/tensor/tensors_to_landmarks_calculator.cc @@ -100,23 +100,23 @@ class TensorsToLandmarksCalculator : public Node { MEDIAPIPE_NODE_CONTRACT(kInTensors, kFlipHorizontally, kFlipVertically, kOutLandmarkList, kOutNormalizedLandmarkList); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: - mediapipe::Status LoadOptions(CalculatorContext* cc); + absl::Status LoadOptions(CalculatorContext* cc); int num_landmarks_ = 0; ::mediapipe::TensorsToLandmarksCalculatorOptions options_; }; MEDIAPIPE_REGISTER_NODE(TensorsToLandmarksCalculator); -mediapipe::Status TensorsToLandmarksCalculator::Open(CalculatorContext* cc) { +absl::Status TensorsToLandmarksCalculator::Open(CalculatorContext* cc) { MP_RETURN_IF_ERROR(LoadOptions(cc)); if (kOutNormalizedLandmarkList(cc).IsConnected()) { RET_CHECK(options_.has_input_image_height() && options_.has_input_image_width()) - << "Must provide input with/height for getting normalized landmarks."; + << "Must provide input width/height for getting normalized landmarks."; } if (kOutLandmarkList(cc).IsConnected() && (options_.flip_horizontally() || options_.flip_vertically() || @@ -124,15 +124,15 @@ mediapipe::Status TensorsToLandmarksCalculator::Open(CalculatorContext* cc) { kFlipVertically(cc).IsConnected())) { RET_CHECK(options_.has_input_image_height() && options_.has_input_image_width()) - << "Must provide input with/height for using flipping when outputing " + << "Must provide input width/height for using flipping when outputing " "landmarks in absolute coordinates."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) { +absl::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) { if (kInTensors(cc).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } bool flip_horizontally = kFlipHorizontally(cc).GetOr(options_.flip_horizontally()); @@ -204,17 +204,16 @@ mediapipe::Status TensorsToLandmarksCalculator::Process(CalculatorContext* cc) { kOutLandmarkList(cc).Send(std::move(output_landmarks)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToLandmarksCalculator::LoadOptions( - CalculatorContext* cc) { +absl::Status TensorsToLandmarksCalculator::LoadOptions(CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = cc->Options<::mediapipe::TensorsToLandmarksCalculatorOptions>(); RET_CHECK(options_.has_num_landmarks()); num_landmarks_ = options_.num_landmarks(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/testdata/expected_detection.pbtxt b/mediapipe/calculators/tensor/testdata/expected_detection.pbtxt new file mode 100644 index 000000000..f1739ce86 --- /dev/null +++ b/mediapipe/calculators/tensor/testdata/expected_detection.pbtxt @@ -0,0 +1,35 @@ +label_id: 0 +score: 0.92843366 +location_data { + format: RELATIVE_BOUNDING_BOX + relative_bounding_box { + xmin: 0.21061149 + ymin: 0.29150677 + width: 0.5657704 + height: 0.5657307 + } + relative_keypoints { + x: 0.37730268 + y: 0.44038114 + } + relative_keypoints { + x: 0.6250565 + y: 0.44425336 + } + relative_keypoints { + x: 0.50687385 + y: 0.5767085 + } + relative_keypoints { + x: 0.50173956 + y: 0.6991459 + } + relative_keypoints { + x: 0.2383742 + y: 0.49879026 + } + relative_keypoints { + x: 0.7404449 + y: 0.50361776 + } +} diff --git a/mediapipe/calculators/tensor/testdata/face_detection_expected.png b/mediapipe/calculators/tensor/testdata/face_detection_expected.png new file mode 100644 index 000000000..df38abf70 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/face_detection_expected.png differ diff --git a/mediapipe/calculators/tensor/testdata/face_detection_test.pbtxt b/mediapipe/calculators/tensor/testdata/face_detection_test.pbtxt new file mode 100644 index 000000000..b0e00346c --- /dev/null +++ b/mediapipe/calculators/tensor/testdata/face_detection_test.pbtxt @@ -0,0 +1,31 @@ +input_stream: "image" +output_stream: "rendering" +output_stream: "detections" + +# Subgraph that detects faces. +node { + calculator: "FaceDetectionFrontCpu" + input_stream: "IMAGE:image" + output_stream: "DETECTIONS:detections" +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:detections" + output_stream: "RENDER_DATA:render_data" + options: { + [mediapipe.DetectionsToRenderDataCalculatorOptions.ext] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "IMAGE:image" + input_stream: "render_data" + output_stream: "IMAGE:rendering" +} diff --git a/mediapipe/calculators/tensorflow/BUILD b/mediapipe/calculators/tensorflow/BUILD index 4b6d244e3..4c2b90a59 100644 --- a/mediapipe/calculators/tensorflow/BUILD +++ b/mediapipe/calculators/tensorflow/BUILD @@ -106,6 +106,13 @@ proto_library( deps = ["//mediapipe/framework:calculator_proto"], ) +proto_library( + name = "vector_string_to_tensor_calculator_options_proto", + srcs = ["vector_string_to_tensor_calculator_options.proto"], + visibility = ["//visibility:public"], + deps = ["//mediapipe/framework:calculator_proto"], +) + mediapipe_cc_proto_library( name = "graph_tensors_packet_generator_cc_proto", srcs = ["graph_tensors_packet_generator.proto"], @@ -281,6 +288,14 @@ mediapipe_cc_proto_library( deps = [":vector_float_to_tensor_calculator_options_proto"], ) +mediapipe_cc_proto_library( + name = "vector_string_to_tensor_calculator_options_cc_proto", + srcs = ["vector_string_to_tensor_calculator_options.proto"], + cc_deps = ["//mediapipe/framework:calculator_cc_proto"], + visibility = ["//visibility:public"], + deps = [":vector_string_to_tensor_calculator_options_proto"], +) + cc_library( name = "graph_tensors_packet_generator", srcs = ["graph_tensors_packet_generator.cc"], @@ -728,6 +743,20 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "vector_string_to_tensor_calculator", + srcs = ["vector_string_to_tensor_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + ":vector_string_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + "@org_tensorflow//tensorflow/core:framework", + ], + alwayslink = 1, +) + cc_library( name = "unpack_yt8m_sequence_example_calculator", srcs = ["unpack_yt8m_sequence_example_calculator.cc"], @@ -858,11 +887,12 @@ cc_test( ":tensorflow_inference_calculator", ":tensorflow_session", ":tensorflow_session_from_frozen_graph_calculator", - "//mediapipe/calculators/tensorflow:tensorflow_session_from_frozen_graph_calculator_cc_proto", + ":tensorflow_session_from_frozen_graph_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework:packet", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", @@ -893,6 +923,7 @@ cc_test( "//mediapipe/framework:packet", "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", @@ -923,6 +954,7 @@ cc_test( "//mediapipe/framework:packet", "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:tag_map_helper", @@ -949,6 +981,7 @@ cc_test( "//mediapipe/framework:calculator_runner", "//mediapipe/framework:packet", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/tool:tag_map_helper", @@ -1075,6 +1108,21 @@ cc_test( ], ) +cc_test( + name = "vector_string_to_tensor_calculator_test", + srcs = ["vector_string_to_tensor_calculator_test.cc"], + deps = [ + ":vector_string_to_tensor_calculator", + ":vector_string_to_tensor_calculator_options_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/strings", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:protos_all_cc", + ], +) + test_suite( name = "ios", tags = ["ios"], @@ -1097,13 +1145,13 @@ cc_test( ":tensorflow_session_from_frozen_graph_generator", ":tensorflow_session_from_frozen_graph_generator_cc_proto", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", + "//mediapipe/framework/port:integral_types", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/tool:sink", "//mediapipe/framework/tool:validate_type", "//mediapipe/framework/port:gtest_main", - "//mediapipe/framework/port:status", - "//mediapipe/framework/port:ret_check", ] + select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:testlib", diff --git a/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.cc b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.cc index 5d449f037..310d412bf 100644 --- a/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.cc +++ b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator.cc @@ -33,7 +33,7 @@ namespace tf = ::tensorflow; class GraphTensorsPacketGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { RET_CHECK(extendable_options.HasExtension( @@ -45,10 +45,10 @@ class GraphTensorsPacketGenerator : public PacketGenerator { /* "A map of tensor tags and tensors" */); RET_CHECK_EQ(options.tensor_tag_size(), options.tensor_num_nodes_size()); RET_CHECK_GT(options.tensor_tag_size(), 0); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( + static absl::Status Generate( const PacketGeneratorOptions& packet_generator_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { const GraphTensorsPacketGeneratorOptions& options = @@ -65,7 +65,7 @@ class GraphTensorsPacketGenerator : public PacketGenerator { (*tensor_map)[tensor_tag].flat().setZero(); } output_side_packets->Index(0) = AdoptAsUniquePtr(tensor_map.release()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(GraphTensorsPacketGenerator); diff --git a/mediapipe/calculators/tensorflow/graph_tensors_packet_generator_test.cc b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator_test.cc index 829994e3c..ef77fb918 100644 --- a/mediapipe/calculators/tensorflow/graph_tensors_packet_generator_test.cc +++ b/mediapipe/calculators/tensorflow/graph_tensors_packet_generator_test.cc @@ -72,7 +72,7 @@ TEST_F(GraphTensorsPacketGeneratorTest, VerifyTensorSizeShapeAndValue) { PacketSet inputs({}); PacketSet outputs(1); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "GraphTensorsPacketGenerator", extendable_options_, inputs, &outputs); MP_EXPECT_OK(run_status) << run_status.message(); VerifyTensorMap(&outputs); diff --git a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc index e5c0601e5..0db193bcc 100644 --- a/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/image_frame_to_tensor_calculator.cc @@ -78,18 +78,17 @@ std::unique_ptr ImageFrameToNormalizedTensor( // } class ImageFrameToTensorCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: ImageFrameToTensorCalculatorOptions options_; }; REGISTER_CALCULATOR(ImageFrameToTensorCalculator); -mediapipe::Status ImageFrameToTensorCalculator::GetContract( - CalculatorContract* cc) { +absl::Status ImageFrameToTensorCalculator::GetContract(CalculatorContract* cc) { // Start with only one input packet. RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Only one input stream is supported."; @@ -101,18 +100,18 @@ mediapipe::Status ImageFrameToTensorCalculator::GetContract( cc->Outputs().Index(0).Set( // Output TensorFlow Tensor. ); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageFrameToTensorCalculator::Open(CalculatorContext* cc) { +absl::Status ImageFrameToTensorCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ImageFrameToTensorCalculator::Process(CalculatorContext* cc) { +absl::Status ImageFrameToTensorCalculator::Process(CalculatorContext* cc) { const Packet& input_item = cc->Inputs().Index(0).Value(); RET_CHECK(!input_item.IsEmpty()) << "Input cannot be empty."; @@ -146,7 +145,7 @@ mediapipe::Status ImageFrameToTensorCalculator::Process(CalculatorContext* cc) { } else if (bytes_per_pixel == 4) { data_type = tf::DT_FLOAT; } else { - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Unsupported image format (", bytes_per_pixel, " bytes per pixel)")); } @@ -173,7 +172,7 @@ mediapipe::Status ImageFrameToTensorCalculator::Process(CalculatorContext* cc) { } cc->Outputs().Index(0).Add(tensor.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.cc b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.cc index 7332439ef..a07b95ccc 100644 --- a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.cc +++ b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator.cc @@ -84,18 +84,18 @@ namespace tf = tensorflow; class LappedTensorBufferCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: // Adds a batch dimension to the input tensor if specified in the // calculator options. - mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor); + absl::Status AddBatchDimension(tf::Tensor* input_tensor); // Sends the current buffer downstream. - mediapipe::Status ProcessBuffer(CalculatorContext* cc); + absl::Status ProcessBuffer(CalculatorContext* cc); int steps_until_output_; int buffer_size_; @@ -110,8 +110,7 @@ class LappedTensorBufferCalculator : public CalculatorBase { REGISTER_CALCULATOR(LappedTensorBufferCalculator); -mediapipe::Status LappedTensorBufferCalculator::GetContract( - CalculatorContract* cc) { +absl::Status LappedTensorBufferCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Only one input stream is supported."; cc->Inputs().Index(0).Set( @@ -141,10 +140,10 @@ mediapipe::Status LappedTensorBufferCalculator::GetContract( if (cc->Outputs().NumEntries() > 1) { cc->Outputs().Index(1).Set>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) { +absl::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); if (cc->InputSidePackets().HasTag(kCalculatorOptions)) { options_ = cc->InputSidePackets() @@ -176,10 +175,10 @@ mediapipe::Status LappedTensorBufferCalculator::Open(CalculatorContext* cc) { buffer_ = absl::make_unique>(buffer_size_); steps_until_output_ = buffer_size_ - options_.padding(); initialized_ = false; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LappedTensorBufferCalculator::Process(CalculatorContext* cc) { +absl::Status LappedTensorBufferCalculator::Process(CalculatorContext* cc) { // These are cheap, shallow copies. tensorflow::Tensor input_tensor( cc->Inputs().Index(0).Get()); @@ -201,12 +200,12 @@ mediapipe::Status LappedTensorBufferCalculator::Process(CalculatorContext* cc) { MP_RETURN_IF_ERROR(ProcessBuffer(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LappedTensorBufferCalculator::Close(CalculatorContext* cc) { +absl::Status LappedTensorBufferCalculator::Close(CalculatorContext* cc) { if (!initialized_ || options_.padding() == 0) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } int last_frame = buffer_size_ - steps_until_output_ - 1; const auto& pad_frame = buffer_->Get(last_frame); @@ -216,12 +215,12 @@ mediapipe::Status LappedTensorBufferCalculator::Close(CalculatorContext* cc) { } MP_RETURN_IF_ERROR(ProcessBuffer(cc)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Adds a batch dimension to the input tensor if specified in the calculator // options. -mediapipe::Status LappedTensorBufferCalculator::AddBatchDimension( +absl::Status LappedTensorBufferCalculator::AddBatchDimension( tf::Tensor* input_tensor) { if (options_.add_batch_dim_to_tensors()) { tf::TensorShape new_shape(input_tensor->shape()); @@ -230,11 +229,11 @@ mediapipe::Status LappedTensorBufferCalculator::AddBatchDimension( << "Could not add 0th dimension to tensor without changing its shape." << " Current shape: " << input_tensor->shape().DebugString(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Process buffer -mediapipe::Status LappedTensorBufferCalculator::ProcessBuffer( +absl::Status LappedTensorBufferCalculator::ProcessBuffer( CalculatorContext* cc) { auto concatenated = ::absl::make_unique(); const tf::Status concat_status = tf::tensor::Concat( @@ -255,7 +254,7 @@ mediapipe::Status LappedTensorBufferCalculator::ProcessBuffer( timestamp_buffer_->Get(timestamp_offset_)); } steps_until_output_ = buffer_size_ - overlap_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator_test.cc b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator_test.cc index f34d81ccb..e0e3000d2 100644 --- a/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/lapped_tensor_buffer_calculator_test.cc @@ -153,6 +153,35 @@ TEST_F(LappedTensorBufferCalculatorTest, OneToThreeSkip) { } } +TEST_F(LappedTensorBufferCalculatorTest, OneToThreeNegativeOverlap) { + int buffer_size = 3; + int overlap = -1; + bool add_dim = false; + SetUpCalculator(buffer_size, overlap, add_dim, 0, 0, false); + int num_timesteps = 7; + for (int i = 0; i < num_timesteps; ++i) { + auto input = ::absl::make_unique( + tensorflow::DT_FLOAT, tensorflow::TensorShape({1})); + input->tensor()(0) = i; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(i))); + } + ASSERT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + ASSERT_EQ(2, output_packets.size()); + // The outputs in packet one should be {0, 1, 2}, and in packet two {4, 5, 6} + for (int i = 0; i < 3; ++i) { + float value_0 = output_packets[0].Get().tensor()(i); + ASSERT_NEAR(value_0, i, 0.0001); + } + for (int i = 0; i < 3; ++i) { + float value_1 = output_packets[1].Get().tensor()(i); + ASSERT_NEAR(value_1, 4 + i, 0.0001); + } +} + TEST_F(LappedTensorBufferCalculatorTest, OneToThreeBatch) { int buffer_size = 3; int overlap = 2; diff --git a/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc index 20e2883b4..32a0eb70b 100644 --- a/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/matrix_to_tensor_calculator.cc @@ -26,19 +26,19 @@ namespace mediapipe { namespace { -mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, - TimeSeriesHeader* header) { +absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, + TimeSeriesHeader* header) { CHECK(header); if (header_packet.IsEmpty()) { - return mediapipe::UnknownError("No header found."); + return absl::UnknownError("No header found."); } if (!header_packet.ValidateAsType().ok()) { - return mediapipe::UnknownError("Packet does not contain TimeSeriesHeader."); + return absl::UnknownError("Packet does not contain TimeSeriesHeader."); } *header = header_packet.Get(); if (header->has_sample_rate() && header->sample_rate() >= 0 && header->has_num_channels() && header->num_channels() >= 0) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { std::string error_message = "TimeSeriesHeader is missing necessary fields: " @@ -47,7 +47,7 @@ mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, absl::StrAppend(&error_message, "Got header:\n", header->ShortDebugString()); #endif - return mediapipe::InvalidArgumentError(error_message); + return absl::InvalidArgumentError(error_message); } } } // namespace @@ -77,18 +77,17 @@ typedef Eigen::Matrix // } class MatrixToTensorCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: MatrixToTensorCalculatorOptions options_; }; REGISTER_CALCULATOR(MatrixToTensorCalculator); -mediapipe::Status MatrixToTensorCalculator::GetContract( - CalculatorContract* cc) { +absl::Status MatrixToTensorCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Only one input stream is supported."; cc->Inputs().Index(0).Set( @@ -101,15 +100,15 @@ mediapipe::Status MatrixToTensorCalculator::GetContract( // TimeSeriesHeader as the input (or no header if the input has no // header). ); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MatrixToTensorCalculator::Open(CalculatorContext* cc) { +absl::Status MatrixToTensorCalculator::Open(CalculatorContext* cc) { // If the input is part of a time series, then preserve the header so that // downstream consumers can access the sample rate if needed. options_ = cc->Options(); auto input_header = ::absl::make_unique(); - const mediapipe::Status header_status = FillTimeSeriesHeaderIfValid( + const absl::Status header_status = FillTimeSeriesHeaderIfValid( cc->Inputs().Index(0).Header(), input_header.get()); if (header_status.ok()) { cc->Outputs().Index(0).SetHeader(Adopt(input_header.release())); @@ -118,10 +117,10 @@ mediapipe::Status MatrixToTensorCalculator::Open(CalculatorContext* cc) { // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MatrixToTensorCalculator::Process(CalculatorContext* cc) { +absl::Status MatrixToTensorCalculator::Process(CalculatorContext* cc) { const Matrix& matrix = cc->Inputs().Index(0).Get(); tf::TensorShape tensor_shape; if (options_.transpose()) { @@ -150,7 +149,7 @@ mediapipe::Status MatrixToTensorCalculator::Process(CalculatorContext* cc) { << " Current shape: " << tensor->shape().DebugString(); } cc->Outputs().Index(0).Add(tensor.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc index f108d8c40..a8abe10d9 100644 --- a/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tensorflow/object_detection_tensors_to_detections_calculator.cc @@ -93,7 +93,7 @@ class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase { public: ObjectDetectionTensorsToDetectionsCalculator() = default; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kBoxes).Set(); cc->Inputs().Tag(kScores).Set(); @@ -126,10 +126,10 @@ class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase { .Tag(kLabelMap) .Set>>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { if (cc->InputSidePackets().HasTag(kLabelMap)) { label_map_ = GetFromUniquePtr>( cc->InputSidePackets().Tag(kLabelMap)); @@ -141,10 +141,10 @@ class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase { tensor_dim_to_squeeze_field.begin(), tensor_dim_to_squeeze_field.end()); std::sort(tensor_dims_to_squeeze_.rbegin(), tensor_dims_to_squeeze_.rend()); cc->SetOffset(0); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const auto& options = cc->Options(); @@ -205,15 +205,15 @@ class ObjectDetectionTensorsToDetectionsCalculator : public CalculatorBase { .Tag(kDetections) .Add(output_detections.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: std::map* label_map_; std::vector tensor_dims_to_squeeze_; - mediapipe::StatusOr MaybeSqueezeDims( - const std::string& tensor_tag, const tf::Tensor& input_tensor) { + absl::StatusOr MaybeSqueezeDims(const std::string& tensor_tag, + const tf::Tensor& input_tensor) { if (tensor_dims_to_squeeze_.empty()) { return input_tensor; } diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index aece59e5a..fdf43dcfd 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -94,7 +94,7 @@ uint8 ConvertFloatToByte(const float float_value) { class PackMediaSequenceCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag)); cc->InputSidePackets().Tag(kSequenceExampleTag).Set(); @@ -167,10 +167,10 @@ class PackMediaSequenceCalculator : public CalculatorBase { .Tag(kSequenceExampleTag) .Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { sequence_ = ::absl::make_unique( cc->InputSidePackets() .Tag(kSequenceExampleTag) @@ -248,10 +248,10 @@ class PackMediaSequenceCalculator : public CalculatorBase { .Tag(kSequenceExampleTag) .SetNextTimestampBound(Timestamp::Max()); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status VerifySequence() { + absl::Status VerifySequence() { std::string error_msg = "Missing features - "; bool all_present = true; for (const auto& iter : features_present_) { @@ -261,13 +261,13 @@ class PackMediaSequenceCalculator : public CalculatorBase { } } if (all_present) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else { return ::mediapipe::NotFoundErrorBuilder(MEDIAPIPE_LOC) << error_msg; } } - ::mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { auto& options = cc->Options(); if (options.reconcile_metadata()) { RET_CHECK_OK(mpms::ReconcileMetadata( @@ -276,7 +276,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { } if (options.output_only_if_all_present()) { - ::mediapipe::Status status = VerifySequence(); + absl::Status status = VerifySequence(); if (!status.ok()) { cc->GetCounter(status.ToString())->Increment(); return status; @@ -295,10 +295,10 @@ class PackMediaSequenceCalculator : public CalculatorBase { } sequence_.reset(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { int image_height = -1; int image_width = -1; // Because the tag order may vary, we need to loop through tags to get @@ -489,12 +489,12 @@ class PackMediaSequenceCalculator : public CalculatorBase { sequence_.get()); already_has_mask = true; } else { - return ::mediapipe::UnimplementedError( + return absl::UnimplementedError( "Global detections and empty detections are not supported."); } } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } std::unique_ptr sequence_; diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index 19d03ecde..09c5a0f24 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -71,9 +71,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); - std::string test_image_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_image; - encoded_image.set_encoded_image(test_image_string); + encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); encoded_image.set_height(1); @@ -101,7 +100,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoImages) { ASSERT_EQ(num_images, mpms::GetImageEncodedSize(output_sequence)); for (int i = 0; i < num_images; ++i) { ASSERT_EQ(i, mpms::GetImageTimestampAt(output_sequence, i)); - ASSERT_EQ(test_image_string, mpms::GetImageEncodedAt(output_sequence, i)); + ASSERT_EQ(encoded_image.encoded_image(), + mpms::GetImageEncodedAt(output_sequence, i)); } } @@ -114,9 +114,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); - std::string test_image_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_image; - encoded_image.set_encoded_image(test_image_string); + encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); encoded_image.set_height(1); @@ -145,7 +144,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoPrefixedImages) { ASSERT_EQ(num_images, mpms::GetImageEncodedSize(prefix, output_sequence)); for (int i = 0; i < num_images; ++i) { ASSERT_EQ(i, mpms::GetImageTimestampAt(prefix, output_sequence, i)); - ASSERT_EQ(test_image_string, + ASSERT_EQ(encoded_image.encoded_image(), mpms::GetImageEncodedAt(prefix, output_sequence, i)); } } @@ -239,9 +238,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksAdditionalContext) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); - std::string test_image_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_image; - encoded_image.set_encoded_image(test_image_string); + encoded_image.set_encoded_image(bytes.data(), bytes.size()); auto image_ptr = ::absl::make_unique(encoded_image); runner_->MutableInputs()->Tag("IMAGE").packets.push_back( @@ -434,7 +432,7 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithoutImageDims) { Adopt(input_sequence.release()); auto status = runner_->Run(); - EXPECT_EQ(mediapipe::StatusCode::kInvalidArgument, status.code()); + EXPECT_EQ(absl::StatusCode::kInvalidArgument, status.code()); } TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { @@ -480,9 +478,8 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) { cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); - std::string test_image_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_image; - encoded_image.set_encoded_image(test_image_string); + encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(width); encoded_image.set_height(height); @@ -694,7 +691,7 @@ TEST_F(PackMediaSequenceCalculatorTest, MissingStreamNotOK) { runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") = Adopt(input_sequence.release()); - ::mediapipe::Status status = runner_->Run(); + absl::Status status = runner_->Run(); EXPECT_FALSE(status.ok()); } @@ -794,9 +791,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestReconcilingAnnotations) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); - std::string test_image_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_image; - encoded_image.set_encoded_image(test_image_string); + encoded_image.set_encoded_image(bytes.data(), bytes.size()); encoded_image.set_width(2); encoded_image.set_height(1); @@ -846,9 +842,8 @@ TEST_F(PackMediaSequenceCalculatorTest, TestOverwritingAndReconciling) { cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); std::vector bytes; ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80})); - std::string test_image_string(bytes.begin(), bytes.end()); OpenCvImageEncoderCalculatorResults encoded_image; - encoded_image.set_encoded_image(test_image_string); + encoded_image.set_encoded_image(bytes.data(), bytes.size()); int height = 2; int width = 2; encoded_image.set_width(width); diff --git a/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc b/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc index 64a2da016..da85bed94 100644 --- a/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc +++ b/mediapipe/calculators/tensorflow/string_to_sequence_example_calculator.cc @@ -44,15 +44,15 @@ constexpr char kSequenceExample[] = "SEQUENCE_EXAMPLE"; class StringToSequenceExampleCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(StringToSequenceExampleCalculator); -mediapipe::Status StringToSequenceExampleCalculator::GetContract( +absl::Status StringToSequenceExampleCalculator::GetContract( CalculatorContract* cc) { if (cc->InputSidePackets().HasTag(kString)) { cc->InputSidePackets().Tag(kString).Set(); @@ -62,11 +62,10 @@ mediapipe::Status StringToSequenceExampleCalculator::GetContract( cc->InputSidePackets().Tag(kSequenceExample).Set(); cc->OutputSidePackets().Tag(kString).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status StringToSequenceExampleCalculator::Open( - CalculatorContext* cc) { +absl::Status StringToSequenceExampleCalculator::Open(CalculatorContext* cc) { if (cc->InputSidePackets().HasTag(kString)) { auto string_value = cc->InputSidePackets().Tag(kString).Get(); auto example = absl::make_unique(); @@ -75,16 +74,14 @@ mediapipe::Status StringToSequenceExampleCalculator::Open( .Tag(kSequenceExample) .Set(mediapipe::Adopt(example.release())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status StringToSequenceExampleCalculator::Process( - CalculatorContext* cc) { - return mediapipe::OkStatus(); +absl::Status StringToSequenceExampleCalculator::Process(CalculatorContext* cc) { + return absl::OkStatus(); } -mediapipe::Status StringToSequenceExampleCalculator::Close( - CalculatorContext* cc) { +absl::Status StringToSequenceExampleCalculator::Close(CalculatorContext* cc) { if (cc->InputSidePackets().HasTag(kSequenceExample)) { const auto& example = cc->InputSidePackets().Tag(kSequenceExample).Get(); @@ -93,7 +90,7 @@ mediapipe::Status StringToSequenceExampleCalculator::Close( cc->OutputSidePackets().Tag(kString).Set( mediapipe::Adopt(string_value.release())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc index a3acc49f3..cbf494245 100644 --- a/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_squeeze_dimensions_calculator.cc @@ -27,7 +27,7 @@ namespace tf = ::tensorflow; // containing identical data (example output dimensions [1024, 5]). class TensorSqueezeDimensionsCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Need one input"; cc->Inputs().Index(0).Set( // Input Tensor @@ -36,10 +36,10 @@ class TensorSqueezeDimensionsCalculator : public CalculatorBase { cc->Outputs().Index(0).Set( // Output Tensor Reduced Dimensions ); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { options_ = cc->Options(); RET_CHECK(options_.squeeze_all_single_dims() ^ (options_.dim_size() > 0)) << "Must specify dimensions to remove, or set squeeze_all_single_dims, " @@ -52,10 +52,10 @@ class TensorSqueezeDimensionsCalculator : public CalculatorBase { remove_dims_initialized_ = true; } cc->SetOffset(0); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const tf::Tensor& input_tensor = cc->Inputs().Index(0).Get(); tf::TensorShape tensor_shape = input_tensor.shape(); if (!remove_dims_initialized_) { @@ -78,11 +78,11 @@ class TensorSqueezeDimensionsCalculator : public CalculatorBase { std::unique_ptr output_tensor(new tf::Tensor); RET_CHECK(output_tensor->CopyFrom(input_tensor, tensor_shape)); cc->Outputs().Index(0).Add(output_tensor.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Close(CalculatorContext* cc) override { + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc index 035acb564..d72c75923 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_image_frame_calculator.cc @@ -45,10 +45,10 @@ constexpr char kTensor[] = "TENSOR"; // Possible extensions: support other input ranges, maybe 4D tensors. class TensorToImageFrameCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: float scale_factor_; @@ -56,8 +56,7 @@ class TensorToImageFrameCalculator : public CalculatorBase { REGISTER_CALCULATOR(TensorToImageFrameCalculator); -mediapipe::Status TensorToImageFrameCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TensorToImageFrameCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) << "Only one input stream is supported."; RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) @@ -70,17 +69,17 @@ mediapipe::Status TensorToImageFrameCalculator::GetContract( cc->Outputs().Tag(kImage).Set( // Output ImageFrame. ); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) { +absl::Status TensorToImageFrameCalculator::Open(CalculatorContext* cc) { scale_factor_ = cc->Options().scale_factor(); cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) { +absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) { const tf::Tensor& input_tensor = cc->Inputs().Tag(kTensor).Get(); int32 depth = 1; if (input_tensor.dims() != 2) { // Depth is 1 for 2D tensors. @@ -113,11 +112,11 @@ mediapipe::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) { ImageFormat::GRAY8, input_tensor.dim_size(1), input_tensor.dim_size(0), input_tensor.dim_size(1), buffer.release()); } else { - return mediapipe::InvalidArgumentError("Unrecognized image depth."); + return absl::InvalidArgumentError("Unrecognized image depth."); } cc->Outputs().Tag(kImage).Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc index 270f00982..d52de7404 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_matrix_calculator.cc @@ -34,19 +34,19 @@ constexpr char kMatrix[] = "MATRIX"; constexpr char kTensor[] = "TENSOR"; constexpr char kReference[] = "REFERENCE"; -mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, - TimeSeriesHeader* header) { +absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, + TimeSeriesHeader* header) { CHECK(header); if (header_packet.IsEmpty()) { - return mediapipe::UnknownError("No header found."); + return absl::UnknownError("No header found."); } if (!header_packet.ValidateAsType().ok()) { - return mediapipe::UnknownError("Packet does not contain TimeSeriesHeader."); + return absl::UnknownError("Packet does not contain TimeSeriesHeader."); } *header = header_packet.Get(); if (header->has_sample_rate() && header->sample_rate() >= 0 && header->has_num_channels() && header->num_channels() >= 0) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { std::string error_message = "TimeSeriesHeader is missing necessary fields: " @@ -55,7 +55,7 @@ mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, absl::StrAppend(&error_message, "Got header:\n", header->ShortDebugString()); #endif - return mediapipe::InvalidArgumentError(error_message); + return absl::InvalidArgumentError(error_message); } } @@ -109,18 +109,17 @@ mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, // } class TensorToMatrixCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; // Store header information so that we can verify the inputs in process(). TimeSeriesHeader header_; }; REGISTER_CALCULATOR(TensorToMatrixCalculator); -mediapipe::Status TensorToMatrixCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TensorToMatrixCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_LE(cc->Inputs().NumEntries(), 2) << "Only one or two input streams are supported."; RET_CHECK_GT(cc->Inputs().NumEntries(), 0) @@ -146,12 +145,12 @@ mediapipe::Status TensorToMatrixCalculator::GetContract( cc->Outputs().Tag(kMatrix).Set( // Output Matrix. ); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) { +absl::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) { auto input_header = absl::make_unique(); - mediapipe::Status header_status; + absl::Status header_status; if (cc->Inputs().HasTag(kReference)) { header_status = FillTimeSeriesHeaderIfValid( cc->Inputs().Tag(kReference).Header(), input_header.get()); @@ -183,10 +182,10 @@ mediapipe::Status TensorToMatrixCalculator::Open(CalculatorContext* cc) { cc->Outputs().Tag(kMatrix).SetHeader(Adopt(input_header.release())); } cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) { +absl::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) { // Daredevil requested CHECK for noisy failures rather than quieter RET_CHECK // failures. These are absolute conditions of the graph for the graph to be // valid, and if it is violated by any input anywhere, the graph will be @@ -220,7 +219,7 @@ mediapipe::Status TensorToMatrixCalculator::Process(CalculatorContext* cc) { *output = Eigen::MatrixXf::Map(input_tensor.flat().data(), length, width); cc->Outputs().Tag(kMatrix).Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc index e50df9276..cd807b87b 100644 --- a/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensor_to_vector_float_calculator.cc @@ -28,17 +28,17 @@ namespace tf = ::tensorflow; class TensorToVectorFloatCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: TensorToVectorFloatCalculatorOptions options_; }; REGISTER_CALCULATOR(TensorToVectorFloatCalculator); -mediapipe::Status TensorToVectorFloatCalculator::GetContract( +absl::Status TensorToVectorFloatCalculator::GetContract( CalculatorContract* cc) { // Start with only one input packet. RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) @@ -58,16 +58,22 @@ mediapipe::Status TensorToVectorFloatCalculator::GetContract( // Output vector. ); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorToVectorFloatCalculator::Open(CalculatorContext* cc) { +absl::Status TensorToVectorFloatCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); - return mediapipe::OkStatus(); + + // Inform mediapipe that this calculator produces an output at time t for + // each input received at time t (i.e. this calculator does not buffer + // inputs). This enables mediapipe to propagate time of arrival estimates in + // mediapipe graphs through this calculator. + cc->SetOffset(/*offset=*/0); + + return absl::OkStatus(); } -mediapipe::Status TensorToVectorFloatCalculator::Process( - CalculatorContext* cc) { +absl::Status TensorToVectorFloatCalculator::Process(CalculatorContext* cc) { const tf::Tensor& input_tensor = cc->Inputs().Index(0).Value().Get(); RET_CHECK(tf::DT_FLOAT == input_tensor.dtype()) @@ -103,7 +109,7 @@ mediapipe::Status TensorToVectorFloatCalculator::Process( cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc index eb9891a37..d78a53053 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator.cc @@ -98,546 +98,388 @@ class InferenceState { // This calculator performs inference on a trained TensorFlow model. // -// Additional documentation and examples at -// go/mediapipe/tensorflow_in_mediapipe. -// -// TensorFlow Sessions can be created from checkpoint paths, frozen models, or -// the SavedModel system (go/saved-model). See the TensorFlowSessionFrom* -// packet generators for details. Each of these methods defines a mapping -// between MediaPipe streams and TensorFlow tensors. All of this information is -// passed in as an input_side_packet. -// -// The input and output streams are TensorFlow tensors labeled by tags. The tags -// for the streams are matched to feeds and fetchs in a TensorFlow session using -// a named_signature.generic_signature in the ModelManifest. The -// generic_signature is used as key-value pairs between the MediaPipe tag and -// the TensorFlow tensor. The signature_name in the options proto determines -// which named_signature is used. The keys in the generic_signature must be -// valid MediaPipe tags ([A-Z0-9_]*, no lowercase or special characters). All of -// the tensors corresponding to tags in the signature for input_streams are fed -// to the model and for output_streams the tensors are fetched from the model. -// -// Other calculators are used to convert data to and from tensors, this op only -// handles the TensorFlow session and batching. Batching occurs by concatenating -// input tensors along the 0th dimension across timestamps. If the 0th dimension -// is not a batch dimension, this calculator will add a 0th dimension by -// default. Setting add_batch_dim_to_tensors to false disables the dimension -// addition. Once batch_size inputs have been provided, the batch will be run -// and the output tensors sent out on the output streams with timestamps -// corresponding to the input stream packets. Setting the batch_size to 1 -// completely disables batching, but is indepdent of add_batch_dim_to_tensors. -// -// The TensorFlowInferenceCalculator also support feeding states recurrently for -// RNNs and LSTMs. Simply set the recurrent_tag_pair options to define the -// recurrent tensors. Initializing the recurrent state can be handled by the -// GraphTensorsPacketGenerator. -// -// The calculator updates two Counters to report timing information: -// ---TotalTimeUsecs = Total time spent running inference (in usecs), -// ---TotalProcessedTimestamps = # of instances processed -// (approximately batches processed * batch_size), -// where is replaced with CalculatorGraphConfig::Node::name() if it -// exists, or with TensorFlowInferenceCalculator if the name is not set. The -// name must be set for timing information to be instance-specific in graphs -// with multiple TensorFlowInferenceCalculators. -// -// Example config: -// packet_generator { -// packet_generator: "TensorFlowSessionFromSavedModelGenerator" -// output_side_packet: "tensorflow_session" -// options { -// [mediapipe.TensorFlowSessionFromSavedModelGeneratorOptions.ext]: { -// saved_model_path: "/path/to/saved/model" -// signature_name: "mediapipe" -// } -// } -// } -// node { -// calculator: "TensorFlowInferenceCalculator" -// input_stream: "IMAGES:image_tensors_keyed_in_signature_by_tag" -// input_stream: "AUDIO:audio_tensors_keyed_in_signature_by_tag" -// output_stream: "LABELS:softmax_tensor_keyed_in_signature_by_tag" -// input_side_packet: "SESSION:tensorflow_session" -// } -// -// Where the input and output streams are treated as Packet and -// the mediapipe_signature has tensor bindings between "IMAGES", "AUDIO", and -// "LABELS" and their respective tensors exported to /path/to/bundle. For an -// example of how this model was exported, see -// tensorflow_inference_test_graph_generator.py -// -// It is possible to use a GraphDef proto that was not exported by exporter (i.e -// without MetaGraph with bindings). Such GraphDef could contain all of its -// parameters in-lined (for example, it can be the output of freeze_graph.py). -// To instantiate a TensorFlow model from a GraphDef file, replace the -// packet_factory above with TensorFlowSessionFromFrozenGraphGenerator: -// -// packet_generator { -// packet_generator: "TensorFlowSessionFromFrozenGraphGenerator" -// output_side_packet: "SESSION:tensorflow_session" -// options { -// [mediapipe.TensorFlowSessionFromFrozenGraphGeneratorOptions.ext]: { -// graph_proto_path: "[PATH]" -// tag_to_tensor_names { -// key: "JPG_STRING" -// value: "input:0" -// } -// tag_to_tensor_names { -// key: "SOFTMAX" -// value: "softmax:0" -// } -// } -// } -// } -// -// It is also possible to use a GraphDef proto and checkpoint file that have not -// been frozen. This can be used to load graphs directly as they have been -// written from training. However, it is more brittle and you are encouraged to -// use a one of the more perminent formats described above. To instantiate a -// TensorFlow model from a GraphDef file and checkpoint, replace the -// packet_factory above with TensorFlowSessionFromModelCheckpointGenerator: -// -// packet_generator { -// packet_generator: "TensorFlowSessionFromModelCheckpointGenerator" -// output_side_packet: "SESSION:tensorflow_session" -// options { -// [mediapipe.TensorFlowSessionFromModelCheckpointGeneratorOptions.ext]: { -// graph_proto_path: "[PATH]" -// model_options { -// checkpoint_path: "[PATH2]" -// } -// tag_to_tensor_names { -// key: "JPG_STRING" -// value: "input:0" -// } -// tag_to_tensor_names { -// key: "SOFTMAX" -// value: "softmax:0" -// } -// } -// } -// } -class TensorFlowInferenceCalculator : public CalculatorBase { - public: - // Counters for recording timing information. The actual names have the value - // of CalculatorGraphConfig::Node::name() prepended. - static constexpr char kTotalUsecsCounterSuffix[] = "TotalTimeUsecs"; - static constexpr char kTotalProcessedTimestampsCounterSuffix[] = - "TotalProcessedTimestamps"; - static constexpr char kTotalSessionRunsTimeUsecsCounterSuffix[] = - "TotalSessionRunsTimeUsecs"; - static constexpr char kTotalNumSessionRunsCounterSuffix[] = - "TotalNumSessionRuns"; +// A mediapipe::TensorFlowSession with a model loaded and ready for use. +// For this calculator it must include a tag_to_tensor_map. +cc->InputSidePackets().Tag("SESSION").Set(); +if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) { + cc->InputSidePackets() + .Tag("RECURRENT_INIT_TENSORS") + .Set>>(); +} +return absl::OkStatus(); +} - TensorFlowInferenceCalculator() : session_(nullptr) { - clock_ = std::unique_ptr( - mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); +std::unique_ptr CreateInferenceState(CalculatorContext* cc) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + std::unique_ptr inference_state = + absl::make_unique(); + if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") && + !cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) { + std::map* init_tensor_map; + init_tensor_map = GetFromUniquePtr>( + cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS")); + for (const auto& p : *init_tensor_map) { + inference_state->input_tensor_batches_[p.first].emplace_back(p.second); + } + } + return inference_state; +} + +absl::Status Open(CalculatorContext* cc) override { + options_ = cc->Options(); + + RET_CHECK(cc->InputSidePackets().HasTag("SESSION")); + session_ = cc->InputSidePackets() + .Tag("SESSION") + .Get() + .session.get(); + tag_to_tensor_map_ = cc->InputSidePackets() + .Tag("SESSION") + .Get() + .tag_to_tensor_map; + + // Validate and store the recurrent tags + RET_CHECK(options_.has_batch_size()); + RET_CHECK(options_.batch_size() == 1 || options_.recurrent_tag_pair().empty()) + << "To use recurrent_tag_pairs, batch_size must be 1."; + for (const auto& tag_pair : options_.recurrent_tag_pair()) { + const std::vector tags = absl::StrSplit(tag_pair, ':'); + RET_CHECK_EQ(tags.size(), 2) + << "recurrent_tag_pair must be a colon " + "separated std::string with two components: " + << tag_pair; + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0])) + << "Can't find tag '" << tags[0] << "' in signature " + << options_.signature_name(); + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1])) + << "Can't find tag '" << tags[1] << "' in signature " + << options_.signature_name(); + recurrent_feed_tags_.insert(tags[0]); + recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0]; } - static mediapipe::Status GetContract(CalculatorContract* cc) { - const auto& options = cc->Options(); - RET_CHECK(!cc->Inputs().GetTags().empty()); - for (const std::string& tag : cc->Inputs().GetTags()) { - // The tensorflow::Tensor with the tag equal to the graph node. May - // have a TimeSeriesHeader if all present TimeSeriesHeaders match. - if (!options.batched_input()) { - cc->Inputs().Tag(tag).Set(); + // Check that all tags are present in this signature bound to tensors. + for (const std::string& tag : cc->Inputs().GetTags()) { + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) + << "Can't find tag '" << tag << "' in signature " + << options_.signature_name(); + } + for (const std::string& tag : cc->Outputs().GetTags()) { + RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) + << "Can't find tag '" << tag << "' in signature " + << options_.signature_name(); + } + + { + absl::WriterMutexLock l(&mutex_); + inference_state_ = std::unique_ptr(); + } + + if (options_.batch_size() == 1 || options_.batched_input()) { + cc->SetOffset(0); + } + + return absl::OkStatus(); +} + +// Adds a batch dimension to the input tensor if specified in the calculator +// options. +absl::Status AddBatchDimension(tf::Tensor* input_tensor) { + if (options_.add_batch_dim_to_tensors()) { + tf::TensorShape new_shape(input_tensor->shape()); + new_shape.InsertDim(0, 1); + RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape)) + << "Could not add 0th dimension to tensor without changing its shape." + << " Current shape: " << input_tensor->shape().DebugString(); + } + return absl::OkStatus(); +} + +absl::Status AggregateTensorPacket( + const std::string& tag_name, const Packet& packet, + std::map>* + input_tensors_by_tag_by_timestamp, + InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + tf::Tensor input_tensor(packet.Get()); + RET_CHECK_OK(AddBatchDimension(&input_tensor)); + if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) { + // If we receive an input on a recurrent tag, override the state. + // It's OK to override the global state because there is just one + // input stream allowed for recurrent tensors. + inference_state_->input_tensor_batches_[tag_name].clear(); + } + (*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert( + std::make_pair(tag_name, input_tensor)); + return absl::OkStatus(); +} + +// Removes the batch dimension of the output tensor if specified in the +// calculator options. +absl::Status RemoveBatchDimension(tf::Tensor* output_tensor) { + if (options_.add_batch_dim_to_tensors()) { + tf::TensorShape new_shape(output_tensor->shape()); + new_shape.RemoveDim(0); + RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape)) + << "Could not remove 0th dimension from tensor without changing its " + << "shape. Current shape: " << output_tensor->shape().DebugString() + << " (The expected first dimension is 1 for a batch element.)"; + } + return absl::OkStatus(); +} + +absl::Status Process(CalculatorContext* cc) override { + std::unique_ptr inference_state_to_process; + { + absl::WriterMutexLock l(&mutex_); + if (inference_state_ == nullptr) { + inference_state_ = CreateInferenceState(cc); + } + std::map> + input_tensors_by_tag_by_timestamp; + for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) { + if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) { + // Recurrent tensors can be empty. + if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) { + if (options_.skip_on_missing_features()) { + return absl::OkStatus(); + } else { + return absl::InvalidArgumentError(absl::StrCat( + "Tag ", tag_as_node_name, + " not present at timestamp: ", cc->InputTimestamp().Value())); + } + } + } else if (options_.batched_input()) { + const auto& tensor_packets = + cc->Inputs().Tag(tag_as_node_name).Get>(); + if (tensor_packets.size() > options_.batch_size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Batch for tag ", tag_as_node_name, + " has more packets than batch capacity. batch_size: ", + options_.batch_size(), " packets: ", tensor_packets.size())); + } + for (const auto& packet : tensor_packets) { + RET_CHECK_OK(AggregateTensorPacket(tag_as_node_name, packet, + &input_tensors_by_tag_by_timestamp, + inference_state_.get())); + } } else { - cc->Inputs().Tag(tag).Set>(); + RET_CHECK_OK(AggregateTensorPacket( + tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(), + &input_tensors_by_tag_by_timestamp, inference_state_.get())); } } - RET_CHECK(!cc->Outputs().GetTags().empty()); - for (const std::string& tag : cc->Outputs().GetTags()) { - // The tensorflow::Tensor with tag equal to the graph node to - // output. Any TimeSeriesHeader from the inputs will be forwarded - // with channels set to 0. - cc->Outputs().Tag(tag).Set(); - } - // A mediapipe::TensorFlowSession with a model loaded and ready for use. - // For this calculator it must include a tag_to_tensor_map. - cc->InputSidePackets().Tag("SESSION").Set(); - if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS")) { - cc->InputSidePackets() - .Tag("RECURRENT_INIT_TENSORS") - .Set>>(); - } - return mediapipe::OkStatus(); - } - - std::unique_ptr CreateInferenceState(CalculatorContext* cc) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { - std::unique_ptr inference_state = - absl::make_unique(); - if (cc->InputSidePackets().HasTag("RECURRENT_INIT_TENSORS") && - !cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS").IsEmpty()) { - std::map* init_tensor_map; - init_tensor_map = GetFromUniquePtr>( - cc->InputSidePackets().Tag("RECURRENT_INIT_TENSORS")); - for (const auto& p : *init_tensor_map) { - inference_state->input_tensor_batches_[p.first].emplace_back(p.second); + for (const auto& timestamp_and_input_tensors_by_tag : + input_tensors_by_tag_by_timestamp) { + inference_state_->batch_timestamps_.emplace_back( + timestamp_and_input_tensors_by_tag.first); + for (const auto& input_tensor_and_tag : + timestamp_and_input_tensors_by_tag.second) { + inference_state_->input_tensor_batches_[input_tensor_and_tag.first] + .emplace_back(input_tensor_and_tag.second); } } - return inference_state; - } - - mediapipe::Status Open(CalculatorContext* cc) override { - options_ = cc->Options(); - - RET_CHECK(cc->InputSidePackets().HasTag("SESSION")); - session_ = cc->InputSidePackets() - .Tag("SESSION") - .Get() - .session.get(); - tag_to_tensor_map_ = cc->InputSidePackets() - .Tag("SESSION") - .Get() - .tag_to_tensor_map; - - // Validate and store the recurrent tags - RET_CHECK(options_.has_batch_size()); - RET_CHECK(options_.batch_size() == 1 || - options_.recurrent_tag_pair().empty()) - << "To use recurrent_tag_pairs, batch_size must be 1."; - for (const auto& tag_pair : options_.recurrent_tag_pair()) { - const std::vector tags = absl::StrSplit(tag_pair, ':'); - RET_CHECK_EQ(tags.size(), 2) - << "recurrent_tag_pair must be a colon " - "separated std::string with two components: " - << tag_pair; - RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0])) - << "Can't find tag '" << tags[0] << "' in signature " - << options_.signature_name(); - RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1])) - << "Can't find tag '" << tags[1] << "' in signature " - << options_.signature_name(); - recurrent_feed_tags_.insert(tags[0]); - recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0]; - } - - // Check that all tags are present in this signature bound to tensors. - for (const std::string& tag : cc->Inputs().GetTags()) { - RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) - << "Can't find tag '" << tag << "' in signature " - << options_.signature_name(); - } - for (const std::string& tag : cc->Outputs().GetTags()) { - RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) - << "Can't find tag '" << tag << "' in signature " - << options_.signature_name(); - } - - { - absl::WriterMutexLock l(&mutex_); + if (inference_state_->batch_timestamps_.size() == options_.batch_size() || + options_.batched_input()) { + inference_state_to_process = std::move(inference_state_); inference_state_ = std::unique_ptr(); } - - if (options_.batch_size() == 1 || options_.batched_input()) { - cc->SetOffset(0); - } - - return mediapipe::OkStatus(); } - // Adds a batch dimension to the input tensor if specified in the calculator - // options. - mediapipe::Status AddBatchDimension(tf::Tensor* input_tensor) { - if (options_.add_batch_dim_to_tensors()) { - tf::TensorShape new_shape(input_tensor->shape()); - new_shape.InsertDim(0, 1); - RET_CHECK(input_tensor->CopyFrom(*input_tensor, new_shape)) - << "Could not add 0th dimension to tensor without changing its shape." - << " Current shape: " << input_tensor->shape().DebugString(); - } - return mediapipe::OkStatus(); + if (inference_state_to_process) { + MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process))); } - mediapipe::Status AggregateTensorPacket( - const std::string& tag_name, const Packet& packet, - std::map>* - input_tensors_by_tag_by_timestamp, - InferenceState* inference_state) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mutex_) { - tf::Tensor input_tensor(packet.Get()); - RET_CHECK_OK(AddBatchDimension(&input_tensor)); - if (mediapipe::ContainsKey(recurrent_feed_tags_, tag_name)) { - // If we receive an input on a recurrent tag, override the state. - // It's OK to override the global state because there is just one - // input stream allowed for recurrent tensors. - inference_state_->input_tensor_batches_[tag_name].clear(); - } - (*input_tensors_by_tag_by_timestamp)[packet.Timestamp()].insert( - std::make_pair(tag_name, input_tensor)); - return mediapipe::OkStatus(); - } - - // Removes the batch dimension of the output tensor if specified in the - // calculator options. - mediapipe::Status RemoveBatchDimension(tf::Tensor* output_tensor) { - if (options_.add_batch_dim_to_tensors()) { - tf::TensorShape new_shape(output_tensor->shape()); - new_shape.RemoveDim(0); - RET_CHECK(output_tensor->CopyFrom(*output_tensor, new_shape)) - << "Could not remove 0th dimension from tensor without changing its " - << "shape. Current shape: " << output_tensor->shape().DebugString() - << " (The expected first dimension is 1 for a batch element.)"; - } - return mediapipe::OkStatus(); - } - - mediapipe::Status Process(CalculatorContext* cc) override { - std::unique_ptr inference_state_to_process; - { - absl::WriterMutexLock l(&mutex_); - if (inference_state_ == nullptr) { - inference_state_ = CreateInferenceState(cc); - } - std::map> - input_tensors_by_tag_by_timestamp; - for (const std::string& tag_as_node_name : cc->Inputs().GetTags()) { - if (cc->Inputs().Tag(tag_as_node_name).IsEmpty()) { - // Recurrent tensors can be empty. - if (!mediapipe::ContainsKey(recurrent_feed_tags_, tag_as_node_name)) { - if (options_.skip_on_missing_features()) { - return mediapipe::OkStatus(); - } else { - return mediapipe::InvalidArgumentError(absl::StrCat( - "Tag ", tag_as_node_name, - " not present at timestamp: ", cc->InputTimestamp().Value())); - } - } - } else if (options_.batched_input()) { - const auto& tensor_packets = - cc->Inputs().Tag(tag_as_node_name).Get>(); - if (tensor_packets.size() > options_.batch_size()) { - return mediapipe::InvalidArgumentError(absl::StrCat( - "Batch for tag ", tag_as_node_name, - " has more packets than batch capacity. batch_size: ", - options_.batch_size(), " packets: ", tensor_packets.size())); - } - for (const auto& packet : tensor_packets) { - RET_CHECK_OK(AggregateTensorPacket( - tag_as_node_name, packet, &input_tensors_by_tag_by_timestamp, - inference_state_.get())); - } - } else { - RET_CHECK_OK(AggregateTensorPacket( - tag_as_node_name, cc->Inputs().Tag(tag_as_node_name).Value(), - &input_tensors_by_tag_by_timestamp, inference_state_.get())); - } - } - for (const auto& timestamp_and_input_tensors_by_tag : - input_tensors_by_tag_by_timestamp) { - inference_state_->batch_timestamps_.emplace_back( - timestamp_and_input_tensors_by_tag.first); - for (const auto& input_tensor_and_tag : - timestamp_and_input_tensors_by_tag.second) { - inference_state_->input_tensor_batches_[input_tensor_and_tag.first] - .emplace_back(input_tensor_and_tag.second); - } - } - if (inference_state_->batch_timestamps_.size() == options_.batch_size() || - options_.batched_input()) { - inference_state_to_process = std::move(inference_state_); - inference_state_ = std::unique_ptr(); - } - } - - if (inference_state_to_process) { - MP_RETURN_IF_ERROR( - OutputBatch(cc, std::move(inference_state_to_process))); - } - - return mediapipe::OkStatus(); - } - - mediapipe::Status Close(CalculatorContext* cc) override { - std::unique_ptr inference_state_to_process = nullptr; - { - absl::WriterMutexLock l(&mutex_); - if (cc->GraphStatus().ok() && inference_state_ != nullptr && - !inference_state_->batch_timestamps_.empty()) { - inference_state_to_process = std::move(inference_state_); - inference_state_ = std::unique_ptr(); - } - } - if (inference_state_to_process) { - MP_RETURN_IF_ERROR( - OutputBatch(cc, std::move(inference_state_to_process))); - } - return mediapipe::OkStatus(); - } - - // When a batch of input tensors is ready to be run, runs TensorFlow and - // outputs the output tensors. The output tensors have timestamps matching - // the input tensor that formed that batch element. Any requested - // batch_dimension is added and removed. This code takes advantage of the fact - // that copying a tensor shares the same reference-counted, heap allocated - // memory buffer. Therefore, copies are cheap and should not cause the memory - // buffer to fall out of scope. In contrast, concat is only used where - // necessary. - mediapipe::Status OutputBatch( - CalculatorContext* cc, std::unique_ptr inference_state) { - const int64 start_time = absl::ToUnixMicros(clock_->TimeNow()); - std::vector> input_tensors; - - for (auto& keyed_tensors : inference_state->input_tensor_batches_) { - if (options_.batch_size() == 1) { - // Short circuit to avoid the cost of deep copying tensors in concat. - if (!keyed_tensors.second.empty()) { - input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], - keyed_tensors.second[0]); - } else { - // The input buffer can be empty for recurrent tensors. - RET_CHECK( - mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first)) - << "A non-recurrent tensor does not have an input: " - << keyed_tensors.first; - } - } else { - // Pad by replicating the first tens or, then ignore the values. - keyed_tensors.second.resize(options_.batch_size()); - std::fill(keyed_tensors.second.begin() + - inference_state->batch_timestamps_.size(), - keyed_tensors.second.end(), keyed_tensors.second[0]); - tf::Tensor concated; - const tf::Status concat_status = - tf::tensor::Concat(keyed_tensors.second, &concated); - CHECK(concat_status.ok()) << concat_status.ToString(); - input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], - concated); - } - } - inference_state->input_tensor_batches_.clear(); - std::vector output_tensor_names; - std::vector output_name_in_signature; - for (const std::string& tag : cc->Outputs().GetTags()) { - output_tensor_names.emplace_back(tag_to_tensor_map_[tag]); - output_name_in_signature.emplace_back(tag); - } - for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { - // Ensure that we always fetch the recurrent state tensors. - if (std::find(output_name_in_signature.begin(), - output_name_in_signature.end(), - tag_pair.first) == output_name_in_signature.end()) { - output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]); - output_name_in_signature.emplace_back(tag_pair.first); - } - } - std::vector outputs; - - SimpleSemaphore* session_run_throttle = nullptr; - if (options_.max_concurrent_session_runs() > 0) { - session_run_throttle = - get_session_run_throttle(options_.max_concurrent_session_runs()); - session_run_throttle->Acquire(1); - } - const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow()); - tf::Status tf_status; - { -#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__) - tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName())); -#endif - tf_status = session_->Run(input_tensors, output_tensor_names, - {} /* target_node_names */, &outputs); - } - - if (session_run_throttle != nullptr) { - session_run_throttle->Release(1); - } - - // RET_CHECK on the tf::Status object itself in order to print an - // informative error message. - RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); - - const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow()); - cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix) - ->IncrementBy(run_end_time - run_start_time); - cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment(); - - // Feed back the recurrent state. - for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { - int pos = std::find(output_name_in_signature.begin(), - output_name_in_signature.end(), tag_pair.first) - - output_name_in_signature.begin(); - inference_state->input_tensor_batches_[tag_pair.second].emplace_back( - outputs[pos]); - } + return absl::OkStatus(); +} +absl::Status Close(CalculatorContext* cc) override { + std::unique_ptr inference_state_to_process = nullptr; + { absl::WriterMutexLock l(&mutex_); - // Set that we want to split on each index of the 0th dimension. - std::vector split_vector(options_.batch_size(), 1); - for (int i = 0; i < output_tensor_names.size(); ++i) { - if (options_.batch_size() == 1) { - if (cc->Outputs().HasTag(output_name_in_signature[i])) { - tf::Tensor output_tensor(outputs[i]); - RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); - cc->Outputs() - .Tag(output_name_in_signature[i]) - .Add(new tf::Tensor(output_tensor), - inference_state->batch_timestamps_[0]); - } + if (cc->GraphStatus().ok() && inference_state_ != nullptr && + !inference_state_->batch_timestamps_.empty()) { + inference_state_to_process = std::move(inference_state_); + inference_state_ = std::unique_ptr(); + } + } + if (inference_state_to_process) { + MP_RETURN_IF_ERROR(OutputBatch(cc, std::move(inference_state_to_process))); + } + return absl::OkStatus(); +} + +// When a batch of input tensors is ready to be run, runs TensorFlow and +// outputs the output tensors. The output tensors have timestamps matching +// the input tensor that formed that batch element. Any requested +// batch_dimension is added and removed. This code takes advantage of the fact +// that copying a tensor shares the same reference-counted, heap allocated +// memory buffer. Therefore, copies are cheap and should not cause the memory +// buffer to fall out of scope. In contrast, concat is only used where +// necessary. +absl::Status OutputBatch(CalculatorContext* cc, + std::unique_ptr inference_state) { + const int64 start_time = absl::ToUnixMicros(clock_->TimeNow()); + std::vector> input_tensors; + + for (auto& keyed_tensors : inference_state->input_tensor_batches_) { + if (options_.batch_size() == 1) { + // Short circuit to avoid the cost of deep copying tensors in concat. + if (!keyed_tensors.second.empty()) { + input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], + keyed_tensors.second[0]); } else { - std::vector split_tensors; - const tf::Status split_status = - tf::tensor::Split(outputs[i], split_vector, &split_tensors); - CHECK(split_status.ok()) << split_status.ToString(); - // Loop over timestamps so that we don't copy the padding. - for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) { - tf::Tensor output_tensor(split_tensors[j]); - RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); - cc->Outputs() - .Tag(output_name_in_signature[i]) - .Add(new tf::Tensor(output_tensor), - inference_state->batch_timestamps_[j]); - } + // The input buffer can be empty for recurrent tensors. + RET_CHECK( + mediapipe::ContainsKey(recurrent_feed_tags_, keyed_tensors.first)) + << "A non-recurrent tensor does not have an input: " + << keyed_tensors.first; + } + } else { + // Pad by replicating the first tens or, then ignore the values. + keyed_tensors.second.resize(options_.batch_size()); + std::fill(keyed_tensors.second.begin() + + inference_state->batch_timestamps_.size(), + keyed_tensors.second.end(), keyed_tensors.second[0]); + tf::Tensor concated; + const tf::Status concat_status = + tf::tensor::Concat(keyed_tensors.second, &concated); + CHECK(concat_status.ok()) << concat_status.ToString(); + input_tensors.emplace_back(tag_to_tensor_map_[keyed_tensors.first], + concated); + } + } + inference_state->input_tensor_batches_.clear(); + std::vector output_tensor_names; + std::vector output_name_in_signature; + for (const std::string& tag : cc->Outputs().GetTags()) { + output_tensor_names.emplace_back(tag_to_tensor_map_[tag]); + output_name_in_signature.emplace_back(tag); + } + for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { + // Ensure that we always fetch the recurrent state tensors. + if (std::find(output_name_in_signature.begin(), + output_name_in_signature.end(), + tag_pair.first) == output_name_in_signature.end()) { + output_tensor_names.emplace_back(tag_to_tensor_map_[tag_pair.first]); + output_name_in_signature.emplace_back(tag_pair.first); + } + } + std::vector outputs; + + SimpleSemaphore* session_run_throttle = nullptr; + if (options_.max_concurrent_session_runs() > 0) { + session_run_throttle = + get_session_run_throttle(options_.max_concurrent_session_runs()); + session_run_throttle->Acquire(1); + } + const int64 run_start_time = absl::ToUnixMicros(clock_->TimeNow()); + tf::Status tf_status; + { +#if !defined(MEDIAPIPE_MOBILE) && !defined(__APPLE__) + tensorflow::profiler::TraceMe trace(absl::string_view(cc->NodeName())); +#endif + tf_status = session_->Run(input_tensors, output_tensor_names, + {} /* target_node_names */, &outputs); + } + + if (session_run_throttle != nullptr) { + session_run_throttle->Release(1); + } + + // RET_CHECK on the tf::Status object itself in order to print an + // informative error message. + RET_CHECK(tf_status.ok()) << "Run failed: " << tf_status.ToString(); + + const int64 run_end_time = absl::ToUnixMicros(clock_->TimeNow()); + cc->GetCounter(kTotalSessionRunsTimeUsecsCounterSuffix) + ->IncrementBy(run_end_time - run_start_time); + cc->GetCounter(kTotalNumSessionRunsCounterSuffix)->Increment(); + + // Feed back the recurrent state. + for (const auto& tag_pair : recurrent_fetch_tags_to_feed_tags_) { + int pos = std::find(output_name_in_signature.begin(), + output_name_in_signature.end(), tag_pair.first) - + output_name_in_signature.begin(); + inference_state->input_tensor_batches_[tag_pair.second].emplace_back( + outputs[pos]); + } + + absl::WriterMutexLock l(&mutex_); + // Set that we want to split on each index of the 0th dimension. + std::vector split_vector(options_.batch_size(), 1); + for (int i = 0; i < output_tensor_names.size(); ++i) { + if (options_.batch_size() == 1) { + if (cc->Outputs().HasTag(output_name_in_signature[i])) { + tf::Tensor output_tensor(outputs[i]); + RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); + cc->Outputs() + .Tag(output_name_in_signature[i]) + .Add(new tf::Tensor(output_tensor), + inference_state->batch_timestamps_[0]); + } + } else { + std::vector split_tensors; + const tf::Status split_status = + tf::tensor::Split(outputs[i], split_vector, &split_tensors); + CHECK(split_status.ok()) << split_status.ToString(); + // Loop over timestamps so that we don't copy the padding. + for (int j = 0; j < inference_state->batch_timestamps_.size(); ++j) { + tf::Tensor output_tensor(split_tensors[j]); + RET_CHECK_OK(RemoveBatchDimension(&output_tensor)); + cc->Outputs() + .Tag(output_name_in_signature[i]) + .Add(new tf::Tensor(output_tensor), + inference_state->batch_timestamps_[j]); } } - - // Get end time and report. - const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); - cc->GetCounter(kTotalUsecsCounterSuffix) - ->IncrementBy(end_time - start_time); - cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) - ->IncrementBy(inference_state->batch_timestamps_.size()); - - // Make sure we hold on to the recursive state. - if (!options_.recurrent_tag_pair().empty()) { - inference_state_ = std::move(inference_state); - inference_state_->batch_timestamps_.clear(); - } - - return mediapipe::OkStatus(); } - private: - // The Session object is provided by a packet factory and is owned by the - // MediaPipe framework. Individual calls are thread-safe, but session state - // may be shared across threads. - tf::Session* session_; + // Get end time and report. + const int64 end_time = absl::ToUnixMicros(clock_->TimeNow()); + cc->GetCounter(kTotalUsecsCounterSuffix)->IncrementBy(end_time - start_time); + cc->GetCounter(kTotalProcessedTimestampsCounterSuffix) + ->IncrementBy(inference_state->batch_timestamps_.size()); - // A mapping between stream tags and the tensor names they are bound to. - std::map tag_to_tensor_map_; - - absl::Mutex mutex_; - std::unique_ptr inference_state_ ABSL_GUARDED_BY(mutex_); - - // The options for the calculator. - TensorFlowInferenceCalculatorOptions options_; - - // Store the feed and fetch tags for feed/fetch recurrent networks. - std::set recurrent_feed_tags_; - std::map recurrent_fetch_tags_to_feed_tags_; - - // Clock used to measure the computation time in OutputBatch(). - std::unique_ptr clock_; - - // The static singleton semaphore to throttle concurrent session runs. - static SimpleSemaphore* get_session_run_throttle( - int32 max_concurrent_session_runs) { - static SimpleSemaphore* session_run_throttle = - new SimpleSemaphore(max_concurrent_session_runs); - return session_run_throttle; + // Make sure we hold on to the recursive state. + if (!options_.recurrent_tag_pair().empty()) { + inference_state_ = std::move(inference_state); + inference_state_->batch_timestamps_.clear(); } -}; + + return absl::OkStatus(); +} + +private: +// The Session object is provided by a packet factory and is owned by the +// MediaPipe framework. Individual calls are thread-safe, but session state may +// be shared across threads. +tf::Session* session_; + +// A mapping between stream tags and the tensor names they are bound to. +std::map tag_to_tensor_map_; + +absl::Mutex mutex_; +std::unique_ptr inference_state_ ABSL_GUARDED_BY(mutex_); + +// The options for the calculator. +TensorFlowInferenceCalculatorOptions options_; + +// Store the feed and fetch tags for feed/fetch recurrent networks. +std::set recurrent_feed_tags_; +std::map recurrent_fetch_tags_to_feed_tags_; + +// Clock used to measure the computation time in OutputBatch(). +std::unique_ptr clock_; + +// The static singleton semaphore to throttle concurrent session runs. +static SimpleSemaphore* get_session_run_throttle( + int32 max_concurrent_session_runs) { + static SimpleSemaphore* session_run_throttle = + new SimpleSemaphore(max_concurrent_session_runs); + return session_run_throttle; +} +} +; REGISTER_CALCULATOR(TensorFlowInferenceCalculator); constexpr char TensorFlowInferenceCalculator::kTotalUsecsCounterSuffix[]; diff --git a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc index 557d46ff8..20e80bf33 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_inference_calculator_test.cc @@ -21,6 +21,7 @@ #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc index 2650447ca..2c1d169bc 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator.cc @@ -59,7 +59,7 @@ void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); bool has_exactly_one_model = @@ -89,10 +89,10 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { // a map from tags to tensor names. ); RET_CHECK_GT(options.tag_to_tensor_names().size(), 0); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { auto clock = std::unique_ptr( mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); const uint64 start_time = absl::ToUnixMicros(clock->TimeNow()); @@ -151,11 +151,11 @@ class TensorFlowSessionFromFrozenGraphCalculator : public CalculatorBase { const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); LOG(INFO) << "Loaded frozen model in: " << end_time - start_time << " microseconds."; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; REGISTER_CALCULATOR(TensorFlowSessionFromFrozenGraphCalculator); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc index 097f5534b..8d3d3fdff 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_calculator_test.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -153,7 +154,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphCalculatorTest, StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("multiplied_tensor"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.AddPacketToInputStream( diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc index b2dc3a8d5..9f5b9e06b 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator.cc @@ -55,7 +55,7 @@ void SetPreferredDevice(tf::GraphDef* graph_def, absl::string_view device_id) { class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { RET_CHECK(extendable_options.HasExtension( @@ -87,10 +87,10 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { // a map from tags to tensor names. ); RET_CHECK_GT(options.tag_to_tensor_names().size(), 0); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( + static absl::Status Generate( const PacketGeneratorOptions& packet_generator_options, const PacketSet& input_side_packets, PacketSet* output_side_packets) { auto clock = std::unique_ptr( @@ -151,7 +151,7 @@ class TensorFlowSessionFromFrozenGraphGenerator : public PacketGenerator { const uint64 end_time = absl::ToUnixMicros(clock->TimeNow()); LOG(INFO) << "Loaded frozen model in: " << end_time - start_time << " microseconds."; - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(TensorFlowSessionFromFrozenGraphGenerator); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc index 793f58163..c7f06bbc4 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_frozen_graph_generator_test.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_generator.pb.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -101,10 +102,10 @@ class TensorFlowSessionFromFrozenGraphGeneratorTest : public ::testing::Test { TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, CreatesPacketWithGraphAndBindings) { - PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet input_side_packets(tool::CreateTagMap({}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -149,7 +150,7 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("multiplied_tensor"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.AddPacketToInputStream( @@ -171,16 +172,16 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, CreatesPacketWithGraphAndBindingsFromInputSidePacket) { PacketSet input_side_packets( - tool::CreateTagMap({"STRING_MODEL:model"}).ValueOrDie()); + tool::CreateTagMap({"STRING_MODEL:model"}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); generator_options_->clear_graph_proto_path(); input_side_packets.Tag("STRING_MODEL") = Adopt(new std::string(serialized_graph_contents)); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -191,13 +192,13 @@ TEST_F( TensorFlowSessionFromFrozenGraphGeneratorTest, CreatesPacketWithGraphAndBindingsFromInputSidePacketStringModelFilePath) { PacketSet input_side_packets( - tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).ValueOrDie()); + tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); generator_options_->clear_graph_proto_path(); input_side_packets.Tag("STRING_MODEL_FILE_PATH") = Adopt(new std::string(GetGraphDefPath())); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -207,15 +208,15 @@ TEST_F( TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, CheckFailureForOptionsAndInputsProvideGraphDefProto) { PacketSet input_side_packets( - tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).ValueOrDie()); + tool::CreateTagMap({"STRING_MODEL_FILE_PATH:model_path"}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); input_side_packets.Tag("STRING_MODEL_FILE_PATH") = Adopt(new std::string(GetGraphDefPath())); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); - EXPECT_EQ(run_status.code(), mediapipe::StatusCode::kInternal); + EXPECT_EQ(run_status.code(), absl::StatusCode::kInternal); EXPECT_THAT( run_status.message(), ::testing::HasSubstr("Must have exactly one of graph_proto_path")); @@ -226,9 +227,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, PacketSet input_side_packets( tool::CreateTagMap( {"STRING_MODEL_FILE_PATH:model_path", "STRING_MODEL:model"}) - .ValueOrDie()); + .value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); @@ -237,10 +238,10 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, input_side_packets.Tag("STRING_MODEL_FILE_PATH") = Adopt(new std::string(GetGraphDefPath())); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); - EXPECT_EQ(run_status.code(), mediapipe::StatusCode::kInternal); + EXPECT_EQ(run_status.code(), absl::StatusCode::kInternal); EXPECT_THAT( run_status.message(), ::testing::HasSubstr("Must have exactly one of graph_proto_path")); @@ -251,9 +252,9 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, PacketSet input_side_packets( tool::CreateTagMap( {"STRING_MODEL_FILE_PATH:model_path", "STRING_MODEL:model"}) - .ValueOrDie()); + .value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); std::string serialized_graph_contents; MP_EXPECT_OK(mediapipe::file::GetContents(GetGraphDefPath(), &serialized_graph_contents)); @@ -263,10 +264,10 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, Adopt(new std::string(GetGraphDefPath())); generator_options_->clear_graph_proto_path(); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); - EXPECT_EQ(run_status.code(), mediapipe::StatusCode::kInternal); + EXPECT_EQ(run_status.code(), absl::StatusCode::kInternal); EXPECT_THAT( run_status.message(), ::testing::HasSubstr("Must have exactly one of graph_proto_path")); @@ -274,11 +275,11 @@ TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, TEST_F(TensorFlowSessionFromFrozenGraphGeneratorTest, CheckInitializationOpName) { - PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet input_side_packets(tool::CreateTagMap({}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); + tool::CreateTagMap({"SESSION:session"}).value()); generator_options_->add_initialization_op_names("multiplied:0"); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromFrozenGraphGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc index 5852f5655..6aedb138f 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.cc @@ -35,9 +35,9 @@ static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; // Given the path to a directory containing multiple tensorflow saved models // in subdirectories, replaces path with the alphabetically last subdirectory. -mediapipe::Status GetLatestDirectory(std::string* path) { +absl::Status GetLatestDirectory(std::string* path) { #if defined(__ANDROID__) - return mediapipe::UnimplementedError( + return absl::UnimplementedError( "GetLatestDirectory is not implemented on Android"); #else std::vector saved_models; @@ -47,7 +47,7 @@ mediapipe::Status GetLatestDirectory(std::string* path) { << "No exported bundles found in " << path; ::std::sort(saved_models.begin(), saved_models.end()); *path = std::string(file::Dirname(saved_models.back())); - return mediapipe::OkStatus(); + return absl::OkStatus(); #endif } @@ -75,10 +75,10 @@ const std::string MaybeConvertSignatureToTag( } // namespace // TensorFlowSessionFromSavedModelCalculator is a MediaPipe packet calculator -// that loads a trained TensorFlow model exported via SavedModel's exporter (see -// go/savedmodel) and returns a Packet containing a unique_ptr to a -// mediapipe::TensorFlowSession, which in turn contains a TensorFlow Session -// ready for execution and a map between tags and tensor names. +// that loads a trained TensorFlow model exported via SavedModel's exporter and +// returns a Packet containing a unique_ptr to a mediapipe::TensorFlowSession, +// which in turn contains a TensorFlow Session ready for execution and a map +// between tags and tensor names. // // Example usage: // node { @@ -93,7 +93,7 @@ const std::string MaybeConvertSignatureToTag( // } class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); const bool has_exactly_one_model = @@ -108,10 +108,10 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { } // A TensorFlow model loaded and ready for use along with tensor cc->OutputSidePackets().Tag("SESSION").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const auto& options = cc->Options(); std::string path = cc->InputSidePackets().HasTag(kStringSavedModelPath) @@ -140,8 +140,8 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { ::tensorflow::Status status = tensorflow::LoadSavedModel( session_options, run_options, path, tags_set, saved_model.get()); if (!status.ok()) { - return mediapipe::Status( - static_cast(status.code()), status.ToString()); + return absl::Status(static_cast(status.code()), + status.ToString()); } auto session = absl::make_unique(); @@ -160,11 +160,11 @@ class TensorFlowSessionFromSavedModelCalculator : public CalculatorBase { } cc->OutputSidePackets().Tag("SESSION").Set(Adopt(session.release())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto index a8839ef52..927d3b51f 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator.proto @@ -28,7 +28,7 @@ message TensorFlowSessionFromSavedModelCalculatorOptions { // SavedModels, include a flag to load the most recent model. // Path to a directory containing a trained TensorFlow model as prepared - // by SavedModel (go/saved-model). + // by SavedModel. optional string saved_model_path = 1; // The name of the generic signature to load into the mapping from tags to // tensor names. diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc index 516a50d8e..912d71600 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_calculator_test.cc @@ -20,6 +20,7 @@ #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" @@ -164,7 +165,7 @@ TEST_F(TensorFlowSessionFromSavedModelCalculatorTest, StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("multiplied_tensor"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.AddPacketToInputStream( diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc index 7e5fb289e..6489b0267 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.cc @@ -37,9 +37,9 @@ static constexpr char kStringSavedModelPath[] = "STRING_SAVED_MODEL_PATH"; // Given the path to a directory containing multiple tensorflow saved models // in subdirectories, replaces path with the alphabetically last subdirectory. -mediapipe::Status GetLatestDirectory(std::string* path) { +absl::Status GetLatestDirectory(std::string* path) { #if defined(__ANDROID__) - return mediapipe::UnimplementedError( + return absl::UnimplementedError( "GetLatestDirectory is not implemented on Android"); #else std::vector saved_models; @@ -49,7 +49,7 @@ mediapipe::Status GetLatestDirectory(std::string* path) { << "No exported bundles found in " << path; ::std::sort(saved_models.begin(), saved_models.end()); *path = std::string(file::Dirname(saved_models.back())); - return mediapipe::OkStatus(); + return absl::OkStatus(); #endif } @@ -77,13 +77,13 @@ const std::string MaybeConvertSignatureToTag( } // namespace // TensorFlowSessionFromSavedModelGenerator is a MediaPipe packet generator -// that loads a trained TensorFlow model exported via SavedModel's exporter (see -// go/savedmodel) and returns a Packet containing a unique_ptr to a -// mediapipe::TensorFlowSession, which in turn contains a TensorFlow Session -// ready for execution and a map between tags and tensor names. +// that loads a trained TensorFlow model exported via SavedModel's exporter and +// returns a Packet containing a unique_ptr to a mediapipe::TensorFlowSession, +// which in turn contains a TensorFlow Session ready for execution and a map +// between tags and tensor names. class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { const TensorFlowSessionFromSavedModelGeneratorOptions& options = @@ -101,12 +101,12 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { } // A TensorFlow model loaded and ready for use along with tensor output_side_packets->Tag("SESSION").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( - const PacketGeneratorOptions& extendable_options, - const PacketSet& input_side_packets, PacketSet* output_side_packets) { + static absl::Status Generate(const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, + PacketSet* output_side_packets) { const TensorFlowSessionFromSavedModelGeneratorOptions& options = extendable_options.GetExtension( TensorFlowSessionFromSavedModelGeneratorOptions::ext); @@ -135,8 +135,8 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { ::tensorflow::Status status = tensorflow::LoadSavedModel( session_options, run_options, path, tags_set, saved_model.get()); if (!status.ok()) { - return mediapipe::Status( - static_cast(status.code()), status.ToString()); + return absl::Status(static_cast(status.code()), + status.ToString()); } auto session = absl::make_unique(); session->session = std::move(saved_model->session); @@ -154,7 +154,7 @@ class TensorFlowSessionFromSavedModelGenerator : public PacketGenerator { } output_side_packets->Tag("SESSION") = Adopt(session.release()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(TensorFlowSessionFromSavedModelGenerator); diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto index 88ce93435..d24a1cd73 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator.proto @@ -28,7 +28,7 @@ message TensorFlowSessionFromSavedModelGeneratorOptions { // SavedModels, include a flag to load the most recent model. // Path to a directory containing a trained TensorFlow model as prepared - // by SavedModel (go/saved-model). + // by SavedModel. optional string saved_model_path = 1; // The name of the generic signature to load into the mapping from tags to // tensor names. diff --git a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc index ffe9d1fc5..92d0d5de4 100644 --- a/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc +++ b/mediapipe/calculators/tensorflow/tensorflow_session_from_saved_model_generator_test.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_generator.pb.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" @@ -66,10 +67,10 @@ class TensorFlowSessionFromSavedModelGeneratorTest : public ::testing::Test { TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, CreatesPacketWithGraphAndBindings) { - PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet input_side_packets(tool::CreateTagMap({}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromSavedModelGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -105,13 +106,12 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, CreateSessionFromSidePacket) { generator_options_->clear_saved_model_path(); PacketSet input_side_packets( - tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}) - .ValueOrDie()); + tool::CreateTagMap({"STRING_SAVED_MODEL_PATH:saved_model_dir"}).value()); input_side_packets.Tag("STRING_SAVED_MODEL_PATH") = Adopt(new std::string(GetSavedModelDir())); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromSavedModelGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -159,7 +159,7 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("multiplied_tensor"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.AddPacketToInputStream( @@ -184,10 +184,10 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, std::string(file::SplitPath(GetSavedModelDir()).first)); generator_options_->set_load_latest_model(true); - PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet input_side_packets(tool::CreateTagMap({}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromSavedModelGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); @@ -205,10 +205,10 @@ TEST_F(TensorFlowSessionFromSavedModelGeneratorTest, generator_options_->mutable_session_config()->mutable_device_count()->insert( {"CPU", 10}); - PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet input_side_packets(tool::CreateTagMap({}).value()); PacketSet output_side_packets( - tool::CreateTagMap({"SESSION:session"}).ValueOrDie()); - mediapipe::Status run_status = tool::RunGenerateAndValidateTypes( + tool::CreateTagMap({"SESSION:session"}).value()); + absl::Status run_status = tool::RunGenerateAndValidateTypes( "TensorFlowSessionFromSavedModelGenerator", extendable_options_, input_side_packets, &output_side_packets); MP_EXPECT_OK(run_status) << run_status.message(); diff --git a/mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc b/mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc index 46c03a1be..28271f3a7 100644 --- a/mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc +++ b/mediapipe/calculators/tensorflow/tfrecord_reader_calculator.cc @@ -49,14 +49,13 @@ const char kSequenceExampleTag[] = "SEQUENCE_EXAMPLE"; // } class TFRecordReaderCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; -mediapipe::Status TFRecordReaderCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TFRecordReaderCalculator::GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag(kTFRecordPath).Set(); if (cc->InputSidePackets().HasTag(kRecordIndex)) { cc->InputSidePackets().Tag(kRecordIndex).Set(); @@ -73,10 +72,10 @@ mediapipe::Status TFRecordReaderCalculator::GetContract( .Tag(kSequenceExampleTag) .Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TFRecordReaderCalculator::Open(CalculatorContext* cc) { +absl::Status TFRecordReaderCalculator::Open(CalculatorContext* cc) { std::unique_ptr file; auto tf_status = tensorflow::Env::Default()->NewRandomAccessFile( cc->InputSidePackets().Tag(kTFRecordPath).Get(), &file); @@ -114,11 +113,11 @@ mediapipe::Status TFRecordReaderCalculator::Open(CalculatorContext* cc) { ++current_idx; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TFRecordReaderCalculator::Process(CalculatorContext* cc) { - return mediapipe::OkStatus(); +absl::Status TFRecordReaderCalculator::Process(CalculatorContext* cc) { + return absl::OkStatus(); } REGISTER_CALCULATOR(TFRecordReaderCalculator); diff --git a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc index 1c4fc9218..1f4cda359 100644 --- a/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/unpack_media_sequence_calculator.cc @@ -118,7 +118,7 @@ namespace mpms = mediapipe::mediasequence; // } class UnpackMediaSequenceCalculator : public CalculatorBase { public: - static ::mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); RET_CHECK(cc->InputSidePackets().HasTag(kSequenceExampleTag)); cc->InputSidePackets().Tag(kSequenceExampleTag).Set(); @@ -183,10 +183,10 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { cc->Outputs().Tag(tag).Set>(); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { // Copy the packet to copy the otherwise inaccessible shared ptr. example_packet_holder_ = cc->InputSidePackets().Tag(kSequenceExampleTag); sequence_ = &example_packet_holder_.Get(); @@ -335,10 +335,10 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { .Set(MakePacket(mpms::GetImageFrameRate(sequence))); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (timestamps_.empty()) { // This occurs when we only have metadata to unpack. LOG(INFO) << "only unpacking metadata because there are no timestamps."; @@ -435,7 +435,7 @@ class UnpackMediaSequenceCalculator : public CalculatorBase { ++current_timestamp_index_; if (current_timestamp_index_ < timestamps_[last_timestamp_key_].size()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else { return tool::StatusStop(); } diff --git a/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc b/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc index d03d2c0e0..efb3037f8 100644 --- a/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc +++ b/mediapipe/calculators/tensorflow/unpack_yt8m_sequence_example_calculator.cc @@ -64,7 +64,7 @@ std::string GetQuantizedFeature( // } class UnpackYt8mSequenceExampleCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets() .Tag(kYt8mSequenceExample) .Set(); @@ -84,10 +84,10 @@ class UnpackYt8mSequenceExampleCalculator : public CalculatorBase { if (cc->OutputSidePackets().HasTag(kSegmentSize)) { cc->OutputSidePackets().Tag(kSegmentSize).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const tensorflow::SequenceExample& sequence_example = cc->InputSidePackets() .Tag(kYt8mSequenceExample) @@ -108,7 +108,7 @@ class UnpackYt8mSequenceExampleCalculator : public CalculatorBase { .feature_size(); if (rgb_feature_list_length != audio_feature_list_length) { - return mediapipe::FailedPreconditionError(absl::StrCat( + return absl::FailedPreconditionError(absl::StrCat( "Data corruption: the length of audio features and rgb features are " "not equal. Please check the sequence example that contains yt8m " "id: ", @@ -151,10 +151,10 @@ class UnpackYt8mSequenceExampleCalculator : public CalculatorBase { } LOG(INFO) << "Reading the sequence example that contains yt8m id: " << yt8m_id << ". Feature list length: " << feature_list_length_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (current_index_ >= feature_list_length_) { return mediapipe::tool::StatusStop(); } @@ -179,7 +179,7 @@ class UnpackYt8mSequenceExampleCalculator : public CalculatorBase { GetQuantizedFeature(sequence_example, kAudio, current_index_)) .At(timestamp)); ++current_index_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc index d75348daa..96208b3e5 100644 --- a/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/vector_float_to_tensor_calculator.cc @@ -44,17 +44,17 @@ namespace tf = ::tensorflow; // } class VectorFloatToTensorCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: VectorFloatToTensorCalculatorOptions options_; }; REGISTER_CALCULATOR(VectorFloatToTensorCalculator); -mediapipe::Status VectorFloatToTensorCalculator::GetContract( +absl::Status VectorFloatToTensorCalculator::GetContract( CalculatorContract* cc) { const auto& options = cc->Options(); // Start with only one input packet. @@ -75,16 +75,16 @@ mediapipe::Status VectorFloatToTensorCalculator::GetContract( cc->Outputs().Index(0).Set( // Output stream with data as tf::Tensor and the same TimeSeriesHeader. ); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status VectorFloatToTensorCalculator::Open(CalculatorContext* cc) { +absl::Status VectorFloatToTensorCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); - return mediapipe::OkStatus(); + cc->SetOffset(0); + return absl::OkStatus(); } -mediapipe::Status VectorFloatToTensorCalculator::Process( - CalculatorContext* cc) { +absl::Status VectorFloatToTensorCalculator::Process(CalculatorContext* cc) { tf::TensorShape tensor_shape; if (options_.input_size() == INPUT_2D) { const std::vector>& input = @@ -127,7 +127,7 @@ mediapipe::Status VectorFloatToTensorCalculator::Process( } else { LOG(FATAL) << "input size not supported"; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc index c05bccd70..f5bf7661e 100644 --- a/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensorflow/vector_int_to_tensor_calculator.cc @@ -62,18 +62,17 @@ void AssignMatrixValue(int r, int c, int value, tf::Tensor* output_tensor) { // } class VectorIntToTensorCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: VectorIntToTensorCalculatorOptions options_; }; REGISTER_CALCULATOR(VectorIntToTensorCalculator); -mediapipe::Status VectorIntToTensorCalculator::GetContract( - CalculatorContract* cc) { +absl::Status VectorIntToTensorCalculator::GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); // Start with only one input packet. RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) @@ -92,19 +91,19 @@ mediapipe::Status VectorIntToTensorCalculator::GetContract( RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) << "Only one output stream is supported."; cc->Outputs().Tag(kTensorOut).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status VectorIntToTensorCalculator::Open(CalculatorContext* cc) { +absl::Status VectorIntToTensorCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); RET_CHECK(options_.tensor_data_type() == tf::DT_UINT8 || options_.tensor_data_type() == tf::DT_INT32 || options_.tensor_data_type() == tf::DT_INT64) << "Output tensor data type is not supported."; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) { +absl::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) { tf::TensorShape tensor_shape; if (options_.input_size() == INPUT_2D) { const std::vector>& input = @@ -196,7 +195,7 @@ mediapipe::Status VectorIntToTensorCalculator::Process(CalculatorContext* cc) { } else { LOG(FATAL) << "input size not supported"; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator.cc b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator.cc new file mode 100644 index 000000000..0e579009b --- /dev/null +++ b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator.cc @@ -0,0 +1,137 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Converts vector (or vector>) to 1D (or 2D) +// tf::Tensor. + +#include "mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" + +namespace mediapipe { + +namespace { +auto& INPUT_1D = VectorStringToTensorCalculatorOptions::INPUT_1D; +auto& INPUT_2D = VectorStringToTensorCalculatorOptions::INPUT_2D; +} // namespace + +namespace tf = ::tensorflow; + +// The calculator expects one input (a packet containing a vector +// or vector>) and generates one output (a packet containing +// a tf::Tensor containing the same data). The output tensor will be either 1D +// or 2D with dimensions corresponding to the input vector std::string. It will +// hold DT_STRING values. +// +// Example config: +// node { +// calculator: "VectorStringToTensorCalculator" +// input_stream: "vector_string_features" +// output_stream: "tensor_features" +// } +class VectorStringToTensorCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc); + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + VectorStringToTensorCalculatorOptions options_; +}; +REGISTER_CALCULATOR(VectorStringToTensorCalculator); + +absl::Status VectorStringToTensorCalculator::GetContract( + CalculatorContract* cc) { + const auto& options = cc->Options(); + // Start with only one input packet. + RET_CHECK_EQ(cc->Inputs().NumEntries(), 1) + << "Only one input stream is supported."; + if (options.input_size() == INPUT_2D) { + cc->Inputs().Index(0).Set>>( + /* "Input vector>." */); + } else if (options.input_size() == INPUT_1D) { + cc->Inputs().Index(0).Set>( + // Input vector. + ); + } else { + LOG(FATAL) << "input size not supported"; + } + RET_CHECK_EQ(cc->Outputs().NumEntries(), 1) + << "Only one output stream is supported."; + cc->Outputs().Index(0).Set( + // Output stream with data as tf::Tensor and the same TimeSeriesHeader. + ); + return absl::OkStatus(); +} + +absl::Status VectorStringToTensorCalculator::Open(CalculatorContext* cc) { + options_ = cc->Options(); + cc->SetOffset(0); + return absl::OkStatus(); +} + +absl::Status VectorStringToTensorCalculator::Process(CalculatorContext* cc) { + tf::TensorShape tensor_shape; + if (options_.input_size() == INPUT_2D) { + const std::vector>& input = + cc->Inputs() + .Index(0) + .Value() + .Get>>(); + + const int32 rows = input.size(); + RET_CHECK_GE(rows, 1); + const int32 cols = input[0].size(); + RET_CHECK_GE(cols, 1); + for (int i = 1; i < rows; ++i) { + RET_CHECK_EQ(input[i].size(), cols); + } + if (options_.transpose()) { + tensor_shape = tf::TensorShape({cols, rows}); + } else { + tensor_shape = tf::TensorShape({rows, cols}); + } + auto output = ::absl::make_unique(tf::DT_STRING, tensor_shape); + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + if (options_.transpose()) { + output->tensor()(c, r) = input[r][c]; + } else { + output->tensor()(r, c) = input[r][c]; + } + } + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } else if (options_.input_size() == INPUT_1D) { + const std::vector& input = + cc->Inputs().Index(0).Value().Get>(); + RET_CHECK_GE(input.size(), 1); + const int32 length = input.size(); + tensor_shape = tf::TensorShape({length}); + auto output = ::absl::make_unique(tf::DT_STRING, tensor_shape); + for (int i = 0; i < length; ++i) { + output->tensor()(i) = input.at(i); + } + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + } else { + LOG(FATAL) << "input size not supported"; + } + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.proto b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.proto new file mode 100644 index 000000000..908d98dff --- /dev/null +++ b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.proto @@ -0,0 +1,40 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message VectorStringToTensorCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional VectorStringToTensorCalculatorOptions ext = 357221188; + } + enum InputSize { + UNKNOWN = 0; + INPUT_1D = 1; + INPUT_2D = 2; + } + + // If input_size is INPUT_2D, unpack a vector> to a + // 2d tensor (matrix). If INPUT_1D, + // convert a vector into a 1d tensor (vector). + optional InputSize input_size = 1 [default = INPUT_1D]; + + // If true, the output tensor is transposed. + // Otherwise, the output tensor is not transposed. + // It will be ignored if input_size is INPUT_1D. + optional bool transpose = 2 [default = false]; +} diff --git a/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_test.cc b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_test.cc new file mode 100644 index 000000000..5921bd1b0 --- /dev/null +++ b/mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_test.cc @@ -0,0 +1,120 @@ +// Copyright 2018 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/str_cat.h" +#include "mediapipe/calculators/tensorflow/vector_string_to_tensor_calculator_options.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mediapipe { + +namespace { + +namespace tf = ::tensorflow; + +class VectorStringToTensorCalculatorTest : public ::testing::Test { + protected: + void SetUpRunner( + const VectorStringToTensorCalculatorOptions::InputSize input_size, + const bool transpose) { + CalculatorGraphConfig::Node config; + config.set_calculator("VectorStringToTensorCalculator"); + config.add_input_stream("input_string"); + config.add_output_stream("output_tensor"); + auto options = config.mutable_options()->MutableExtension( + VectorStringToTensorCalculatorOptions::ext); + options->set_input_size(input_size); + options->set_transpose(transpose); + runner_ = ::absl::make_unique(config); + } + + void TestConvertFromVectoVectorString(const bool transpose) { + SetUpRunner(VectorStringToTensorCalculatorOptions::INPUT_2D, transpose); + auto input = ::absl::make_unique>>( + 2, std::vector(2)); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + input->at(i).at(j) = absl::StrCat(i, j); + } + } + + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const tf::Tensor& output_tensor = output_packets[0].Get(); + + EXPECT_EQ(2, output_tensor.dims()); + EXPECT_EQ(tf::DT_STRING, output_tensor.dtype()); + const auto matrix = output_tensor.matrix(); + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + if (!transpose) { + EXPECT_EQ(absl::StrCat(i, j), matrix(i, j)); + } else { + EXPECT_EQ(absl::StrCat(j, i), matrix(i, j)); + } + } + } + } + + std::unique_ptr runner_; +}; + +TEST_F(VectorStringToTensorCalculatorTest, ConvertsFromVectorString) { + SetUpRunner(VectorStringToTensorCalculatorOptions::INPUT_1D, false); + auto input = ::absl::make_unique>(5); + for (int i = 0; i < 5; ++i) { + input->at(i) = absl::StrCat(i); + } + const int64 time = 1234; + runner_->MutableInputs()->Index(0).packets.push_back( + Adopt(input.release()).At(Timestamp(time))); + + EXPECT_TRUE(runner_->Run().ok()); + + const std::vector& output_packets = + runner_->Outputs().Index(0).packets; + EXPECT_EQ(1, output_packets.size()); + EXPECT_EQ(time, output_packets[0].Timestamp().Value()); + const tf::Tensor& output_tensor = output_packets[0].Get(); + + EXPECT_EQ(1, output_tensor.dims()); + EXPECT_EQ(tf::DT_STRING, output_tensor.dtype()); + const auto vec = output_tensor.vec(); + + for (int i = 0; i < 5; ++i) { + EXPECT_EQ(absl::StrCat(i), vec(i)); + } +} + +TEST_F(VectorStringToTensorCalculatorTest, ConvertsFromVectorVectorString) { + for (bool transpose : {false, true}) { + TestConvertFromVectoVectorString(transpose); + } +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 18138b8d7..6bc4636b1 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -143,16 +143,15 @@ cc_test( data = [":anchor_golden_files"], deps = [ ":ssd_anchors_calculator", - ":ssd_anchors_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats/object_detection:anchor_cc_proto", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:parse_text_proto", - "//mediapipe/framework/tool:validate_type", ], ) @@ -460,7 +459,10 @@ cc_library( # bazel test //mediapipe/calculators/tflite:tflite_inference_calculator_test --copt=-DTFLITE_GPU_EXTRA_GLES_DEPS --copt=-DMESA_EGL_NO_X11_HEADERS --copt=-DEGL_NO_X11 --config=grte_v5 --test_strategy=local cc_test( name = "tflite_inference_calculator_test", - srcs = ["tflite_inference_calculator_test.cc"], + srcs = [ + "tflite_inference_calculator_test.cc", + "tflite_inference_calculator_test_common.h", + ], data = ["testdata/add.bin"], linkstatic = 1, deps = [ @@ -480,7 +482,9 @@ cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:type_to_tflitetype", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels/internal:tensor", ], ) diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator.cc b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc index 07a91ecd8..f618b2f6a 100644 --- a/mediapipe/calculators/tflite/ssd_anchors_calculator.cc +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator.cc @@ -71,12 +71,12 @@ float CalculateScale(float min_scale, float max_scale, int stride_index, // } class SsdAnchorsCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->OutputSidePackets().Index(0).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); const SsdAnchorsCalculatorOptions& options = @@ -85,24 +85,24 @@ class SsdAnchorsCalculator : public CalculatorBase { auto anchors = absl::make_unique>(); MP_RETURN_IF_ERROR(GenerateAnchors(anchors.get(), options)); cc->OutputSidePackets().Index(0).Set(Adopt(anchors.release())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } private: - static mediapipe::Status GenerateAnchors( + static absl::Status GenerateAnchors( std::vector* anchors, const SsdAnchorsCalculatorOptions& options); }; REGISTER_CALCULATOR(SsdAnchorsCalculator); -mediapipe::Status SsdAnchorsCalculator::GenerateAnchors( +absl::Status SsdAnchorsCalculator::GenerateAnchors( std::vector* anchors, const SsdAnchorsCalculatorOptions& options) { // Verify the options. if (!options.feature_map_height_size() && !options.strides_size()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Both feature map shape and strides are missing. Must provide either " "one."); } @@ -206,7 +206,7 @@ mediapipe::Status SsdAnchorsCalculator::GenerateAnchors( } layer_id = last_same_stride_layer; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc b/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc index 7a5b555db..906eeed21 100644 --- a/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc +++ b/mediapipe/calculators/tflite/ssd_anchors_calculator_test.cc @@ -16,6 +16,7 @@ #include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/object_detection/anchor.pb.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" diff --git a/mediapipe/calculators/tflite/testdata/README.md b/mediapipe/calculators/tflite/testdata/README.md new file mode 100644 index 000000000..c0efdcf07 --- /dev/null +++ b/mediapipe/calculators/tflite/testdata/README.md @@ -0,0 +1,2 @@ +The model files add.bin, add_quantized.bin +(and corresponding metatada json files) come from tensorflow/lite/testdata/ diff --git a/mediapipe/calculators/tflite/testdata/add_quantized.bin b/mediapipe/calculators/tflite/testdata/add_quantized.bin new file mode 100644 index 000000000..07d48b93e Binary files /dev/null and b/mediapipe/calculators/tflite/testdata/add_quantized.bin differ diff --git a/mediapipe/calculators/tflite/testdata/add_quantized.json b/mediapipe/calculators/tflite/testdata/add_quantized.json new file mode 100644 index 000000000..f70ed8143 --- /dev/null +++ b/mediapipe/calculators/tflite/testdata/add_quantized.json @@ -0,0 +1,123 @@ +{ + version: 3, + operator_codes: [ + { + } + ], + subgraphs: [ + { + tensors: [ + { + shape: [ + 1, + 8, + 8, + 3 + ], + name: "add", + quantization: { + min: [ + 0.0 + ], + max: [ + 1.0 + ], + scale: [ + 0.003922 + ], + zero_point: [ + 0 + ] + } + }, + { + shape: [ + 1, + 8, + 8, + 3 + ], + type: "UINT8", + name: "input", + quantization: { + min: [ + 0.0 + ], + max: [ + 1.0 + ], + scale: [ + 0.003922 + ], + zero_point: [ + 0 + ] + } + }, + { + shape: [ + 1, + 8, + 8, + 3 + ], + type: "UINT8", + name: "output", + quantization: { + min: [ + 0.0 + ], + max: [ + 1.0 + ], + scale: [ + 0.003922 + ], + zero_point: [ + 0 + ] + } + } + ], + inputs: [ + 1 + ], + outputs: [ + 2 + ], + operators: [ + { + inputs: [ + 1, + 1 + ], + outputs: [ + 0 + ], + builtin_options_type: "AddOptions", + builtin_options: { + } + }, + { + inputs: [ + 0, + 1 + ], + outputs: [ + 2 + ], + builtin_options_type: "AddOptions", + builtin_options: { + } + } + ] + } + ], + buffers: [ + { + data: [ + + ] + } + ] +} diff --git a/mediapipe/calculators/tflite/tflite_converter_calculator.cc b/mediapipe/calculators/tflite/tflite_converter_calculator.cc index ccb7d3744..d9dfd1526 100644 --- a/mediapipe/calculators/tflite/tflite_converter_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_converter_calculator.cc @@ -25,9 +25,9 @@ #include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/interpreter.h" -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gpu_buffer.h" -#endif // MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU #if MEDIAPIPE_TFLITE_GL_INFERENCE #include "mediapipe/gpu/gl_calculator_helper.h" @@ -134,21 +134,21 @@ struct GPUData { // class TfLiteConverterCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status InitGpu(CalculatorContext* cc); - mediapipe::Status LoadOptions(CalculatorContext* cc); + absl::Status InitGpu(CalculatorContext* cc); + absl::Status LoadOptions(CalculatorContext* cc); template - mediapipe::Status NormalizeImage(const ImageFrame& image_frame, - bool flip_vertically, float* tensor_ptr); - mediapipe::Status CopyMatrixToTensor(const Matrix& matrix, float* tensor_ptr); - mediapipe::Status ProcessCPU(CalculatorContext* cc); - mediapipe::Status ProcessGPU(CalculatorContext* cc); + absl::Status NormalizeImage(const ImageFrame& image_frame, + bool flip_vertically, float* tensor_ptr); + absl::Status CopyMatrixToTensor(const Matrix& matrix, float* tensor_ptr); + absl::Status ProcessCPU(CalculatorContext* cc); + absl::Status ProcessGPU(CalculatorContext* cc); std::unique_ptr interpreter_ = nullptr; @@ -182,8 +182,7 @@ bool ShouldUseGpu(CC* cc) { } } // namespace -mediapipe::Status TfLiteConverterCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TfLiteConverterCalculator::GetContract(CalculatorContract* cc) { // Confirm only one of the input streams is present. RET_CHECK(cc->Inputs().HasTag(kImageFrameTag) ^ cc->Inputs().HasTag(kGpuBufferTag) ^ @@ -199,11 +198,11 @@ mediapipe::Status TfLiteConverterCalculator::GetContract( if (cc->Inputs().HasTag(kMatrixTag)) { cc->Inputs().Tag(kMatrixTag).Set(); } -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { cc->Inputs().Tag(kGpuBufferTag).Set(); } -#endif // MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kTensorsTag)) { cc->Outputs().Tag(kTensorsTag).Set>(); @@ -223,10 +222,10 @@ mediapipe::Status TfLiteConverterCalculator::GetContract( // Assign this calculator's default InputStreamHandler. cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteConverterCalculator::Open(CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); MP_RETURN_IF_ERROR(LoadOptions(cc)); @@ -251,13 +250,13 @@ mediapipe::Status TfLiteConverterCalculator::Open(CalculatorContext* cc) { interpreter_->SetInputs({0}); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) { if (use_gpu_) { if (cc->Inputs().Tag(kGpuBufferTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (!initialized_) { MP_RETURN_IF_ERROR(InitGpu(cc)); @@ -269,23 +268,23 @@ mediapipe::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) { // Convert to CPU tensors or Matrix type. MP_RETURN_IF_ERROR(ProcessCPU(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { interpreter_.reset(); #if MEDIAPIPE_TFLITE_GL_INFERENCE gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); }); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE gpu_data_out_.reset(); #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteConverterCalculator::ProcessCPU(CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::ProcessCPU(CalculatorContext* cc) { if (cc->Inputs().HasTag(kImageFrameTag)) { if (cc->Inputs().Tag(kImageFrameTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // CPU ImageFrame to TfLiteTensor conversion. @@ -361,7 +360,7 @@ mediapipe::Status TfLiteConverterCalculator::ProcessCPU(CalculatorContext* cc) { MP_RETURN_IF_ERROR(NormalizeImage(image_frame, flip_vertically_, tensor_buffer)); } else { - return mediapipe::InternalError( + return absl::InternalError( "Only byte-based (8 bit) and float (32 bit) images supported."); } } @@ -373,7 +372,7 @@ mediapipe::Status TfLiteConverterCalculator::ProcessCPU(CalculatorContext* cc) { .Add(output_tensors.release(), cc->InputTimestamp()); } else if (cc->Inputs().HasTag(kMatrixTag)) { if (cc->Inputs().Tag(kMatrixTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // CPU Matrix to TfLiteTensor conversion. const auto& matrix = cc->Inputs().Tag(kMatrixTag).Get(); @@ -405,16 +404,16 @@ mediapipe::Status TfLiteConverterCalculator::ProcessCPU(CalculatorContext* cc) { .Add(output_tensors.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteConverterCalculator::ProcessGPU(CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::ProcessGPU(CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GL_INFERENCE // GpuBuffer to tflite::gpu::GlBuffer conversion. const auto& input = cc->Inputs().Tag(kGpuBufferTag).Get(); MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, &input]() -> mediapipe::Status { + gpu_helper_.RunInGlContext([this, &input]() -> absl::Status { // Convert GL texture into TfLite GlBuffer (SSBO). auto src = gpu_helper_.CreateSourceTexture(input); glActiveTexture(GL_TEXTURE0 + 0); @@ -427,13 +426,13 @@ mediapipe::Status TfLiteConverterCalculator::ProcessGPU(CalculatorContext* cc) { glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); glBindTexture(GL_TEXTURE_2D, 0); src.Release(); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); // Copy into outputs. auto output_tensors = absl::make_unique>(); - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &output_tensors]() -> mediapipe::Status { + MP_RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, &output_tensors]() -> absl::Status { output_tensors->resize(1); { GpuTensor& tensor = output_tensors->at(0); @@ -441,7 +440,7 @@ mediapipe::Status TfLiteConverterCalculator::ProcessGPU(CalculatorContext* cc) { gpu_data_out_->elements, &tensor)); MP_RETURN_IF_ERROR(CopyBuffer(gpu_data_out_->buffer, tensor)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); })); cc->Outputs() .Tag(kTensorsGpuTag) @@ -487,10 +486,10 @@ mediapipe::Status TfLiteConverterCalculator::ProcessGPU(CalculatorContext* cc) { RET_CHECK_FAIL() << "GPU processing is not enabled."; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GPU_SUPPORTED // Get input image sizes. const auto& input = @@ -511,7 +510,7 @@ mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, &include_alpha, &input, &single_channel]() -> mediapipe::Status { + [this, &include_alpha, &input, &single_channel]() -> absl::Status { // Device memory. MP_RETURN_IF_ERROR( ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer( @@ -557,7 +556,7 @@ mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { GL_COMPUTE_SHADER, shader_source, &gpu_data_out_->shader)); MP_RETURN_IF_ERROR(GlProgram::CreateWithShader( gpu_data_out_->shader, &gpu_data_out_->program)); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE @@ -624,11 +623,10 @@ mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { << [[error localizedDescription] UTF8String]; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteConverterCalculator::LoadOptions( - CalculatorContext* cc) { +absl::Status TfLiteConverterCalculator::LoadOptions(CalculatorContext* cc) { // Get calculator options specified in the graph. const auto& options = cc->Options<::mediapipe::TfLiteConverterCalculatorOptions>(); @@ -676,11 +674,11 @@ mediapipe::Status TfLiteConverterCalculator::LoadOptions( // Get tensor type, float or quantized. use_quantized_tensors_ = options.use_quantized_tensors(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } template -mediapipe::Status TfLiteConverterCalculator::NormalizeImage( +absl::Status TfLiteConverterCalculator::NormalizeImage( const ImageFrame& image_frame, bool flip_vertically, float* tensor_ptr) { const int height = image_frame.Height(); const int width = image_frame.Width(); @@ -724,11 +722,11 @@ mediapipe::Status TfLiteConverterCalculator::NormalizeImage( } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteConverterCalculator::CopyMatrixToTensor( - const Matrix& matrix, float* tensor_ptr) { +absl::Status TfLiteConverterCalculator::CopyMatrixToTensor(const Matrix& matrix, + float* tensor_ptr) { if (row_major_matrix_) { auto matrix_map = Eigen::Map(tensor_ptr, matrix.rows(), matrix.cols()); @@ -739,7 +737,7 @@ mediapipe::Status TfLiteConverterCalculator::CopyMatrixToTensor( matrix_map = matrix; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc index 7b61f3c8f..11e27dff1 100644 --- a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc @@ -39,14 +39,14 @@ namespace mediapipe { // } class TfLiteCustomOpResolverCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->OutputSidePackets() .Index(0) .Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); const TfLiteCustomOpResolverCalculatorOptions& options = @@ -60,11 +60,11 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase { } cc->OutputSidePackets().Index(0).Set(Adopt(op_resolver.release())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; REGISTER_CALCULATOR(TfLiteCustomOpResolverCalculator); diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.cc b/mediapipe/calculators/tflite/tflite_inference_calculator.cc index 2c956d63a..a2fc7ec3a 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.cc @@ -215,40 +215,32 @@ class TfLiteInferenceCalculator : public CalculatorBase { using TfLiteDelegatePtr = std::unique_ptr>; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status ReadKernelsFromFile(); - mediapipe::Status WriteKernelsToFile(); - mediapipe::Status LoadModel(CalculatorContext* cc); - mediapipe::StatusOr GetModelAsPacket(const CalculatorContext& cc); - mediapipe::Status LoadDelegate(CalculatorContext* cc); - mediapipe::Status InitTFLiteGPURunner(CalculatorContext* cc); - mediapipe::Status ProcessInputsCpu( - CalculatorContext* cc, std::vector* output_tensors_cpu); - mediapipe::Status ProcessOutputsCpu( + absl::Status ReadKernelsFromFile(); + absl::Status WriteKernelsToFile(); + absl::Status LoadModel(CalculatorContext* cc); + absl::StatusOr GetModelAsPacket(const CalculatorContext& cc); + absl::Status LoadDelegate(CalculatorContext* cc); + absl::Status InitTFLiteGPURunner(CalculatorContext* cc); + absl::Status ProcessInputsCpu(CalculatorContext* cc, + std::vector* output_tensors_cpu); + absl::Status ProcessOutputsCpu( CalculatorContext* cc, std::unique_ptr> output_tensors_cpu); - mediapipe::Status ProcessInputsGpu( - CalculatorContext* cc, -#if MEDIAPIPE_TFLITE_METAL_INFERENCE - id compute_encoder, -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE - std::vector* output_tensors_gpu); - mediapipe::Status ProcessOutputsGpu( + absl::Status ProcessInputsGpu(CalculatorContext* cc, + std::vector* output_tensors_gpu); + absl::Status ProcessOutputsGpu( CalculatorContext* cc, std::unique_ptr> output_tensors_cpu, -#if MEDIAPIPE_TFLITE_METAL_INFERENCE - id compute_encoder, -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE std::unique_ptr> output_tensors_gpu); - mediapipe::Status RunInContextIfNeeded( - std::function<::mediapipe::Status(void)> f) { + absl::Status RunInContextIfNeeded(std::function f) { if (gpu_inference_) { #if MEDIAPIPE_TFLITE_GL_INFERENCE return gpu_helper_.RunInGlContext(std::move(f)); @@ -312,8 +304,7 @@ bool ShouldUseGpu(CC* cc) { } } // namespace -mediapipe::Status TfLiteInferenceCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TfLiteInferenceCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kTensorsTag) ^ cc->Inputs().HasTag(kTensorsGpuTag)); RET_CHECK(cc->Outputs().HasTag(kTensorsTag) ^ @@ -355,10 +346,10 @@ mediapipe::Status TfLiteInferenceCalculator::GetContract( // Assign this calculator's default InputStreamHandler. cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { +absl::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -375,13 +366,14 @@ mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); tflite_gpu_runner_api_ = options.delegate().gpu().api(); - use_kernel_caching_ = - use_advanced_gpu_api_ && options.delegate().gpu().use_kernel_caching(); + use_kernel_caching_ = use_advanced_gpu_api_ && + options.delegate().gpu().has_cached_kernel_path(); if (use_kernel_caching_) { #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) - cached_kernel_filename_ = - "/sdcard/" + mediapipe::File::Basename(options.model_path()) + ".ker"; + cached_kernel_filename_ = options.delegate().gpu().cached_kernel_path() + + mediapipe::File::Basename(options.model_path()) + + ".ker"; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID } @@ -397,11 +389,10 @@ mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { if (gpu_inference_) { #if MEDIAPIPE_TFLITE_GL_INFERENCE MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { - return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) - : LoadDelegate(cc); - })); + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, + &cc]() -> absl::Status { + return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) : LoadDelegate(cc); + })); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; RET_CHECK(gpu_helper_); @@ -410,33 +401,18 @@ mediapipe::Status TfLiteInferenceCalculator::Open(CalculatorContext* cc) { } else { MP_RETURN_IF_ERROR(LoadDelegate(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) { - return RunInContextIfNeeded([this, cc]() -> ::mediapipe::Status { +absl::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) { + return RunInContextIfNeeded([this, cc]() -> absl::Status { // 0. Declare outputs auto output_tensors_gpu = absl::make_unique>(); auto output_tensors_cpu = absl::make_unique>(); -#if MEDIAPIPE_TFLITE_METAL_INFERENCE - id command_buffer; - id compute_encoder; - if (gpu_inference_) { - command_buffer = [gpu_helper_ commandBuffer]; - command_buffer.label = @"TfLiteInferenceCalculator"; - compute_encoder = [command_buffer computeCommandEncoder]; - } -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE - // 1. Receive pre-processed tensor inputs. if (gpu_input_) { - MP_RETURN_IF_ERROR(ProcessInputsGpu(cc, -#if MEDIAPIPE_TFLITE_METAL_INFERENCE - compute_encoder, -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE - - output_tensors_gpu.get())); + MP_RETURN_IF_ERROR(ProcessInputsGpu(cc, output_tensors_gpu.get())); } else { MP_RETURN_IF_ERROR(ProcessInputsCpu(cc, output_tensors_cpu.get())); } @@ -448,40 +424,36 @@ mediapipe::Status TfLiteInferenceCalculator::Process(CalculatorContext* cc) { } else { RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); } -#else -#if MEDIAPIPE_TFLITE_METAL_INFERENCE - // Metal delegate supports external command encoder only if all input and +#elif MEDIAPIPE_TFLITE_METAL_INFERENCE + // Metal delegate supports external command buffer only if all input and // output buffers are on GPU. if (gpu_inference_ && gpu_input_ && gpu_output_) { + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"TfLiteInferenceCalculator"; RET_CHECK( - TFLGpuDelegateSetCommandEncoder(delegate_.get(), compute_encoder)); + TFLGpuDelegateSetCommandBuffer(delegate_.get(), command_buffer)); + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); + [command_buffer commit]; + } else { + RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); } -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE +#else // MEDIAPIPE_TFLITE_GL_INFERENCE RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); #endif // MEDIAPIPE_TFLITE_GL_INFERENCE // 3. Output processed tensors. if (gpu_output_ || use_advanced_gpu_api_) { MP_RETURN_IF_ERROR(ProcessOutputsGpu(cc, std::move(output_tensors_cpu), -#if MEDIAPIPE_TFLITE_METAL_INFERENCE - compute_encoder, -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE std::move(output_tensors_gpu))); } else { MP_RETURN_IF_ERROR(ProcessOutputsCpu(cc, std::move(output_tensors_cpu))); } -#if MEDIAPIPE_TFLITE_METAL_INFERENCE - if (gpu_inference_) { - [compute_encoder endEncoding]; - [command_buffer commit]; - } -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE - return mediapipe::OkStatus(); + return absl::OkStatus(); }); } -mediapipe::Status TfLiteInferenceCalculator::WriteKernelsToFile() { +absl::Status TfLiteInferenceCalculator::WriteKernelsToFile() { #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) if (use_kernel_caching_) { // Save kernel file. @@ -492,13 +464,13 @@ mediapipe::Status TfLiteInferenceCalculator::WriteKernelsToFile() { mediapipe::file::SetContents(cached_kernel_filename_, cache_str)); } #endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { +absl::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { MP_RETURN_IF_ERROR(WriteKernelsToFile()); - return RunInContextIfNeeded([this]() -> ::mediapipe::Status { + return RunInContextIfNeeded([this]() -> absl::Status { if (delegate_) { interpreter_ = nullptr; delegate_ = nullptr; @@ -516,16 +488,16 @@ mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { #if defined(MEDIAPIPE_EDGE_TPU) edgetpu_context_.reset(); #endif - return mediapipe::OkStatus(); + return absl::OkStatus(); }); } // Calculator Auxiliary Section -mediapipe::Status TfLiteInferenceCalculator::ProcessInputsCpu( +absl::Status TfLiteInferenceCalculator::ProcessInputsCpu( CalculatorContext* cc, std::vector* output_tensors_cpu) { if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Read CPU input into tensors. const auto& input_tensors = @@ -547,17 +519,13 @@ mediapipe::Status TfLiteInferenceCalculator::ProcessInputsCpu( } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteInferenceCalculator::ProcessInputsGpu( - CalculatorContext* cc, -#if MEDIAPIPE_TFLITE_METAL_INFERENCE - id compute_encoder, -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE - std::vector* output_tensors_gpu) { +absl::Status TfLiteInferenceCalculator::ProcessInputsGpu( + CalculatorContext* cc, std::vector* output_tensors_gpu) { if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (use_advanced_gpu_api_) { #if MEDIAPIPE_TFLITE_GL_INFERENCE @@ -603,6 +571,10 @@ mediapipe::Status TfLiteInferenceCalculator::ProcessInputsGpu( RET_CHECK_GT(input_tensors.size(), 0); // Explicit copy input with conversion float 32 bits to 16 bits. gpu_data_in_.resize(input_tensors.size()); + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"TfLiteInferenceCalculatorConvert"; + id compute_encoder = + [command_buffer computeCommandEncoder]; [compute_encoder setComputePipelineState:fp32_to_fp16_program_]; for (int i = 0; i < input_tensors.size(); ++i) { [compute_encoder setBuffer:input_tensors[i] offset:0 atIndex:0]; @@ -614,13 +586,15 @@ mediapipe::Status TfLiteInferenceCalculator::ProcessInputsGpu( [compute_encoder dispatchThreadgroups:MTLSizeMake(threadgroups, 1, 1) threadsPerThreadgroup:threads_per_group]; } + [compute_encoder endEncoding]; + [command_buffer commit]; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsCpu( +absl::Status TfLiteInferenceCalculator::ProcessOutputsCpu( CalculatorContext* cc, std::unique_ptr> output_tensors_cpu) { // Output result tensors (CPU). @@ -633,15 +607,12 @@ mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsCpu( .Tag(kTensorsTag) .Add(output_tensors_cpu.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsGpu( +absl::Status TfLiteInferenceCalculator::ProcessOutputsGpu( CalculatorContext* cc, std::unique_ptr> output_tensors_cpu, -#if MEDIAPIPE_TFLITE_METAL_INFERENCE - id compute_encoder, -#endif // MEDIAPIPE_TFLITE_METAL_INFERENCE std::unique_ptr> output_tensors_gpu) { if (use_advanced_gpu_api_) { #if MEDIAPIPE_TFLITE_GL_INFERENCE @@ -684,27 +655,33 @@ mediapipe::Status TfLiteInferenceCalculator::ProcessOutputsGpu( // Output result tensors (GPU). output_tensors_gpu->resize(gpu_data_out_.size()); id device = gpu_helper_.mtlDevice; + id command_buffer = [gpu_helper_ commandBuffer]; + command_buffer.label = @"TfLiteInferenceBPHWC4Convert"; + id convert_command = + [command_buffer computeCommandEncoder]; for (int i = 0; i < gpu_data_out_.size(); ++i) { // Allocate output tensor. output_tensors_gpu->at(i) = [device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float) options:MTLResourceStorageModeShared]; // Reshape tensor. - [converter_from_BPHWC4_ convertWithEncoder:compute_encoder + [converter_from_BPHWC4_ convertWithEncoder:convert_command shape:gpu_data_out_[i]->shape sourceBuffer:gpu_data_out_[i]->buffer convertedBuffer:output_tensors_gpu->at(i)]; } + [convert_command endEncoding]; + [command_buffer commit]; cc->Outputs() .Tag(kTensorsGpuTag) .Add(output_tensors_gpu.release(), cc->InputTimestamp()); #endif // MEDIAPIPE_TFLITE_GL_INFERENCE } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteInferenceCalculator::ReadKernelsFromFile() { +absl::Status TfLiteInferenceCalculator::ReadKernelsFromFile() { #if MEDIAPIPE_TFLITE_GL_INFERENCE && defined(MEDIAPIPE_ANDROID) if (use_kernel_caching_) { // Load pre-compiled kernel file. @@ -717,10 +694,10 @@ mediapipe::Status TfLiteInferenceCalculator::ReadKernelsFromFile() { } } #endif // MEDIAPIPE_TFLITE_GL_INFERENCE && MEDIAPIPE_ANDROID - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( +absl::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GL_INFERENCE ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); @@ -798,13 +775,13 @@ mediapipe::Status TfLiteInferenceCalculator::InitTFLiteGPURunner( MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build()); #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteInferenceCalculator::LoadModel(CalculatorContext* cc) { +absl::Status TfLiteInferenceCalculator::LoadModel(CalculatorContext* cc) { if (use_advanced_gpu_api_) { // Use InitTFLiteGPURunner for everything. - return mediapipe::OkStatus(); + return absl::OkStatus(); } ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(*cc)); @@ -843,10 +820,10 @@ mediapipe::Status TfLiteInferenceCalculator::LoadModel(CalculatorContext* cc) { if (use_quantized_tensors_) gpu_inference_ = false; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::StatusOr TfLiteInferenceCalculator::GetModelAsPacket( +absl::StatusOr TfLiteInferenceCalculator::GetModelAsPacket( const CalculatorContext& cc) { const auto& options = cc.Options(); @@ -856,19 +833,17 @@ mediapipe::StatusOr TfLiteInferenceCalculator::GetModelAsPacket( if (cc.InputSidePackets().HasTag("MODEL")) { return cc.InputSidePackets().Tag("MODEL"); } - return mediapipe::Status( - mediapipe::StatusCode::kNotFound, - "Must specify TFLite model as path or loaded model."); + return absl::Status(absl::StatusCode::kNotFound, + "Must specify TFLite model as path or loaded model."); } -mediapipe::Status TfLiteInferenceCalculator::LoadDelegate( - CalculatorContext* cc) { +absl::Status TfLiteInferenceCalculator::LoadDelegate(CalculatorContext* cc) { const auto& calculator_opts = cc->Options(); if (calculator_opts.has_delegate() && calculator_opts.delegate().has_tflite()) { // Default tflite inference requeqsted - no need to modify graph. - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (!gpu_inference_) { @@ -887,31 +862,32 @@ mediapipe::Status TfLiteInferenceCalculator::LoadDelegate( }); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); - return mediapipe::OkStatus(); + return absl::OkStatus(); } #endif // MEDIAPIPE_ANDROID #if defined(__EMSCRIPTEN__) - const bool xnnpack_requested = true; + const bool use_xnnpack = true; #else - const bool xnnpack_requested = calculator_opts.has_delegate() && - calculator_opts.delegate().has_xnnpack(); -#endif // __EMSCRIPTEN__ + const bool use_xnnpack = calculator_opts.has_delegate() && + calculator_opts.delegate().has_xnnpack(); +#endif // defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_EDGE_TPU) - if (xnnpack_requested) { + if (use_xnnpack) { TfLiteXNNPackDelegateOptions xnnpack_opts{}; xnnpack_opts.num_threads = GetXnnpackNumThreads(calculator_opts); delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), &TfLiteXNNPackDelegateDelete); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), kTfLiteOk); - return mediapipe::OkStatus(); + return absl::OkStatus(); } #endif // !EDGETPU - // Return, no need for GPU delegate below. - return mediapipe::OkStatus(); + // Return and use default tflite infernece (on CPU). No need for GPU + // delegate below. + return absl::OkStatus(); } #if MEDIAPIPE_TFLITE_GL_INFERENCE @@ -982,7 +958,7 @@ mediapipe::Status TfLiteInferenceCalculator::LoadDelegate( // Configure and create the delegate. TFLGpuDelegateOptions options; options.allow_precision_loss = true; - options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait; + options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeActive; if (!delegate_) delegate_ = TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); @@ -1077,7 +1053,7 @@ mediapipe::Status TfLiteInferenceCalculator::LoadDelegate( gpu_data_out_[i]->shape.c = tensor->dims->data[3]; break; default: - return mediapipe::InternalError("Unsupported tensor shape."); + return absl::InternalError("Unsupported tensor shape."); } } // Create and bind output buffers. @@ -1097,13 +1073,13 @@ mediapipe::Status TfLiteInferenceCalculator::LoadDelegate( isFloat16:true convertToPBHWC4:false]; if (converter_from_BPHWC4_ == nil) { - return mediapipe::InternalError( + return absl::InternalError( "Error initializating output buffer converter"); } } #endif // MEDIAPIPE_TFLITE_METAL_INFERENCE - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator.proto b/mediapipe/calculators/tflite/tflite_inference_calculator.proto index 5b42e9512..862de8b0b 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator.proto +++ b/mediapipe/calculators/tflite/tflite_inference_calculator.proto @@ -65,7 +65,8 @@ message TfLiteInferenceCalculatorOptions { // Load pre-compiled serialized binary cache to accelerate init process. // Only available for OpenCL delegate on Android. - optional bool use_kernel_caching = 2 [default = false]; + // Kernel caching will only be enabled if this path is set. + optional string cached_kernel_path = 2; } // Android only. message Nnapi {} @@ -104,7 +105,10 @@ message TfLiteInferenceCalculatorOptions { optional int32 cpu_num_thread = 4 [default = -1]; // TfLite delegate to run inference. - // NOTE: calculator is free to choose delegate if not specified explicitly. + // If not specified, when any of the input and output is on GPU (i.e, using + // the TENSORS_GPU tag) TFLite GPU delegate is used (as if "gpu {}" is + // specified), or otherwise regular TFLite on CPU is used (as if "tflite {}" + // is specified) except when building with emscripten where xnnpack is used. // NOTE: use_gpu/use_nnapi are ignored if specified. (Delegate takes // precedence over use_* deprecated options.) optional Delegate delegate = 5; diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator_test.cc b/mediapipe/calculators/tflite/tflite_inference_calculator_test.cc index 60ea1a860..ec16d1842 100644 --- a/mediapipe/calculators/tflite/tflite_inference_calculator_test.cc +++ b/mediapipe/calculators/tflite/tflite_inference_calculator_test.cc @@ -12,96 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include - #include "absl/strings/str_replace.h" -#include "absl/strings/string_view.h" -#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" -#include "mediapipe/framework/calculator_framework.h" -#include "mediapipe/framework/calculator_runner.h" -#include "mediapipe/framework/deps/file_path.h" -#include "mediapipe/framework/port/gmock.h" -#include "mediapipe/framework/port/gtest.h" -#include "mediapipe/framework/port/integral_types.h" -#include "mediapipe/framework/port/parse_text_proto.h" -#include "mediapipe/framework/port/status_matchers.h" // NOLINT -#include "mediapipe/framework/tool/validate_type.h" -#include "tensorflow/lite/error_reporter.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/model.h" - -#ifdef __APPLE__ -#include -#endif // defined(__APPLE__) +#include "mediapipe/calculators/tflite/tflite_inference_calculator_test_common.h" namespace mediapipe { -using ::tflite::Interpreter; - -void DoSmokeTest(const std::string& graph_proto) { - const int width = 8; - const int height = 8; - const int channels = 3; - - // Prepare input tensor. - std::unique_ptr interpreter(new Interpreter); - ASSERT_NE(interpreter, nullptr); - - interpreter->AddTensors(1); - interpreter->SetInputs({0}); - interpreter->SetOutputs({0}); - interpreter->SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, - TfLiteQuantization()); - int t = interpreter->inputs()[0]; - TfLiteTensor* input_tensor = interpreter->tensor(t); - interpreter->ResizeInputTensor(t, {width, height, channels}); - interpreter->AllocateTensors(); - - float* input_tensor_buffer = input_tensor->data.f; - ASSERT_NE(input_tensor_buffer, nullptr); - for (int i = 0; i < width * height * channels - 1; i++) { - input_tensor_buffer[i] = 1; - } - - auto input_vec = absl::make_unique>(); - input_vec->emplace_back(*input_tensor); - - // Prepare single calculator graph to and wait for packets. - CalculatorGraphConfig graph_config = - ParseTextProtoOrDie(graph_proto); - std::vector output_packets; - tool::AddVectorSink("tensor_out", &graph_config, &output_packets); - CalculatorGraph graph(graph_config); - MP_ASSERT_OK(graph.StartRun({})); - - // Push the tensor into the graph. - MP_ASSERT_OK(graph.AddPacketToInputStream( - "tensor_in", Adopt(input_vec.release()).At(Timestamp(0)))); - // Wait until the calculator done processing. - MP_ASSERT_OK(graph.WaitUntilIdle()); - ASSERT_EQ(1, output_packets.size()); - - // Get and process results. - const std::vector& result_vec = - output_packets[0].Get>(); - ASSERT_EQ(1, result_vec.size()); - - const TfLiteTensor* result = &result_vec[0]; - float* result_buffer = result->data.f; - ASSERT_NE(result_buffer, nullptr); - for (int i = 0; i < width * height * channels - 1; i++) { - ASSERT_EQ(3, result_buffer[i]); - } - - // Fully close graph at end, otherwise calculator+tensors are destroyed - // after calling WaitUntilDone(). - MP_ASSERT_OK(graph.CloseInputStream("tensor_in")); - MP_ASSERT_OK(graph.WaitUntilDone()); -} - // Tests a simple add model that adds an input tensor to itself. TEST(TfLiteInferenceCalculatorTest, SmokeTest) { std::string graph_proto = R"( @@ -118,13 +33,12 @@ TEST(TfLiteInferenceCalculatorTest, SmokeTest) { } } )"; - DoSmokeTest( - /*graph_proto=*/absl::StrReplaceAll(graph_proto, {{"$delegate", ""}})); - DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( + // Test CPU inference only. + DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( graph_proto, {{"$delegate", "delegate { tflite {} }"}})); - DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( + DoSmokeTest(absl::StrReplaceAll( graph_proto, {{"$delegate", "delegate { xnnpack {} }"}})); - DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( + DoSmokeTest(absl::StrReplaceAll( graph_proto, {{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}})); } @@ -163,11 +77,12 @@ TEST(TfLiteInferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) { options { [mediapipe.TfLiteInferenceCalculatorOptions.ext] { use_gpu: false + delegate { tflite {} } } } } )"; - DoSmokeTest(graph_proto); + DoSmokeTest(graph_proto); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator_test_common.h b/mediapipe/calculators/tflite/tflite_inference_calculator_test_common.h new file mode 100644 index 000000000..cf995f47b --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_inference_calculator_test_common.h @@ -0,0 +1,128 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_CALCULATORS_TFLITE_TFLITE_INFERENCE_CALCULATOR_TEST_H_ +#define MEDIAPIPE_CALCULATORS_TFLITE_TFLITE_INFERENCE_CALCULATOR_TEST_H_ + +#include +#include +#include + +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" // NOLINT +#include "mediapipe/framework/tool/validate_type.h" +#include "tensorflow/lite/error_reporter.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/portable_type_to_tflitetype.h" + +#ifdef __APPLE__ +#include +#endif // defined(__APPLE__) + +namespace mediapipe { + +using ::tflite::Interpreter; + +template +void DoSmokeTest(const std::string& graph_proto) { + const int width = 8; + const int height = 8; + const int channels = 3; + + static_assert(std::is_same_v || std::is_same_v, + "Only float & uint8 currently supported."); + + // Prepare interpreter and input tensor. + std::unique_ptr interpreter(new Interpreter); + ASSERT_NE(interpreter, nullptr); + + interpreter->AddTensors(1); + interpreter->SetInputs({0}); + interpreter->SetOutputs({0}); + TfLiteQuantization quant; + if (std::is_integral_v) { + auto* affine_quant = static_cast( + malloc(sizeof(TfLiteAffineQuantization))); + affine_quant->scale = TfLiteFloatArrayCreate(1); + affine_quant->zero_point = TfLiteIntArrayCreate(1); + affine_quant->scale->data[0] = 1.0; + affine_quant->zero_point->data[0] = 0; + quant.type = kTfLiteAffineQuantization; + quant.params = affine_quant; + } + interpreter->SetTensorParametersReadWrite(0, tflite::typeToTfLiteType(), + "", {3}, quant); + + int t = interpreter->inputs()[0]; + TfLiteTensor* input_tensor = interpreter->tensor(t); + interpreter->ResizeInputTensor(t, {width, height, channels}); + interpreter->AllocateTensors(); + + T* input_tensor_buffer = tflite::GetTensorData(input_tensor); + ASSERT_NE(input_tensor_buffer, nullptr); + for (int i = 0; i < width * height * channels - 1; i++) { + input_tensor_buffer[i] = 1; + } + + auto input_vec = absl::make_unique>(); + input_vec->emplace_back(*input_tensor); + + // Prepare single calculator graph to and wait for packets. + CalculatorGraphConfig graph_config = + ParseTextProtoOrDie(graph_proto); + std::vector output_packets; + tool::AddVectorSink("tensor_out", &graph_config, &output_packets); + CalculatorGraph graph(graph_config); + MP_ASSERT_OK(graph.StartRun({})); + + // Push the tensor into the graph. + MP_ASSERT_OK(graph.AddPacketToInputStream( + "tensor_in", Adopt(input_vec.release()).At(Timestamp(0)))); + // Wait until the calculator done processing. + MP_ASSERT_OK(graph.WaitUntilIdle()); + ASSERT_EQ(1, output_packets.size()); + + // Get and process results. + const std::vector& result_vec = + output_packets[0].Get>(); + ASSERT_EQ(1, result_vec.size()); + + const TfLiteTensor* result = &result_vec[0]; + const T* result_buffer = tflite::GetTensorData(result); + ASSERT_NE(result_buffer, nullptr); + for (int i = 0; i < width * height * channels - 1; i++) { + ASSERT_EQ(3, result_buffer[i]); + } + + // Fully close graph at end, otherwise calculator+tensors are destroyed + // after calling WaitUntilDone(). + MP_ASSERT_OK(graph.CloseInputStream("tensor_in")); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_CALCULATORS_TFLITE_TFLITE_INFERENCE_CALCULATOR_TEST_H_ diff --git a/mediapipe/calculators/tflite/tflite_inference_calculator_tpu_test.cc b/mediapipe/calculators/tflite/tflite_inference_calculator_tpu_test.cc new file mode 100644 index 000000000..eac0d361c --- /dev/null +++ b/mediapipe/calculators/tflite/tflite_inference_calculator_tpu_test.cc @@ -0,0 +1,42 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "absl/strings/str_replace.h" +#include "mediapipe/calculators/tflite/tflite_inference_calculator_test_common.h" + +namespace mediapipe { + +// Tests a simple add model that adds an input tensor to itself. +TEST(TfLiteInferenceCalculatorTpuTest, SmokeTest) { + std::string graph_proto = R"( + input_stream: "tensor_in" + node { + calculator: "TfLiteInferenceCalculator" + input_stream: "TENSORS:tensor_in" + output_stream: "TENSORS:tensor_out" + options { + [mediapipe.TfLiteInferenceCalculatorOptions.ext] { + model_path: "mediapipe/calculators/tflite/testdata/add_quantized.bin" + $delegate + } + } + } + )"; + DoSmokeTest( + /*graph_proto=*/absl::StrReplaceAll(graph_proto, {{"$delegate", ""}})); + DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll( + graph_proto, {{"$delegate", "delegate { tflite {} }"}})); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_model_calculator.cc b/mediapipe/calculators/tflite/tflite_model_calculator.cc index c8e4fc36b..ca28910e5 100644 --- a/mediapipe/calculators/tflite/tflite_model_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_model_calculator.cc @@ -51,13 +51,13 @@ class TfLiteModelCalculator : public CalculatorBase { std::unique_ptr>; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag("MODEL_BLOB").Set(); cc->OutputSidePackets().Tag("MODEL").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const Packet& model_packet = cc->InputSidePackets().Tag("MODEL_BLOB"); const std::string& model_blob = model_packet.Get(); std::unique_ptr model = @@ -74,11 +74,11 @@ class TfLiteModelCalculator : public CalculatorBase { delete model; }))); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; REGISTER_CALCULATOR(TfLiteModelCalculator); diff --git a/mediapipe/calculators/tflite/tflite_model_calculator_test.cc b/mediapipe/calculators/tflite/tflite_model_calculator_test.cc index fed3743a5..a76d322ee 100644 --- a/mediapipe/calculators/tflite/tflite_model_calculator_test.cc +++ b/mediapipe/calculators/tflite/tflite_model_calculator_test.cc @@ -58,7 +58,7 @@ TEST(TfLiteModelCalculatorTest, SmokeTest) { MP_ASSERT_OK(graph.WaitUntilIdle()); auto status_or_packet = graph.GetOutputSidePacket("model"); MP_ASSERT_OK(status_or_packet); - auto model_packet = status_or_packet.ValueOrDie(); + auto model_packet = status_or_packet.value(); const auto& model = model_packet.Get< std::unique_ptr>>(); diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc index 56c9d05f3..4d28b91e9 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_classification_calculator.cc @@ -60,11 +60,11 @@ namespace mediapipe { // } class TfLiteTensorsToClassificationCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: ::mediapipe::TfLiteTensorsToClassificationCalculatorOptions options_; @@ -74,7 +74,7 @@ class TfLiteTensorsToClassificationCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TfLiteTensorsToClassificationCalculator); -mediapipe::Status TfLiteTensorsToClassificationCalculator::GetContract( +absl::Status TfLiteTensorsToClassificationCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -87,10 +87,10 @@ mediapipe::Status TfLiteTensorsToClassificationCalculator::GetContract( cc->Outputs().Tag("CLASSIFICATIONS").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToClassificationCalculator::Open( +absl::Status TfLiteTensorsToClassificationCalculator::Open( CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); @@ -114,10 +114,10 @@ mediapipe::Status TfLiteTensorsToClassificationCalculator::Open( label_map_loaded_ = true; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToClassificationCalculator::Process( +absl::Status TfLiteTensorsToClassificationCalculator::Process( CalculatorContext* cc) { const auto& input_tensors = cc->Inputs().Tag("TENSORS").Get>(); @@ -190,12 +190,12 @@ mediapipe::Status TfLiteTensorsToClassificationCalculator::Process( .Tag("CLASSIFICATIONS") .Add(classification_list.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToClassificationCalculator::Close( +absl::Status TfLiteTensorsToClassificationCalculator::Close( CalculatorContext* cc) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc index 7747c4357..2ed62c46d 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc @@ -143,26 +143,27 @@ void ConvertAnchorsToRawValues(const std::vector& anchors, // } class TfLiteTensorsToDetectionsCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status ProcessCPU(CalculatorContext* cc, - std::vector* output_detections); - mediapipe::Status ProcessGPU(CalculatorContext* cc, - std::vector* output_detections); + absl::Status ProcessCPU(CalculatorContext* cc, + std::vector* output_detections); + absl::Status ProcessGPU(CalculatorContext* cc, + std::vector* output_detections); - mediapipe::Status LoadOptions(CalculatorContext* cc); - mediapipe::Status GpuInit(CalculatorContext* cc); - mediapipe::Status DecodeBoxes(const float* raw_boxes, - const std::vector& anchors, - std::vector* boxes); - mediapipe::Status ConvertToDetections( - const float* detection_boxes, const float* detection_scores, - const int* detection_classes, std::vector* output_detections); + absl::Status LoadOptions(CalculatorContext* cc); + absl::Status GpuInit(CalculatorContext* cc); + absl::Status DecodeBoxes(const float* raw_boxes, + const std::vector& anchors, + std::vector* boxes); + absl::Status ConvertToDetections(const float* detection_boxes, + const float* detection_scores, + const int* detection_classes, + std::vector* output_detections); Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, float box_xmax, float score, int class_id, bool flip_vertically); @@ -189,7 +190,7 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator); -mediapipe::Status TfLiteTensorsToDetectionsCalculator::GetContract( +absl::Status TfLiteTensorsToDetectionsCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -223,11 +224,10 @@ mediapipe::Status TfLiteTensorsToDetectionsCalculator::GetContract( #endif // MEDIAPIPE_TFLITE_GL_INFERENCE } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToDetectionsCalculator::Open( - CalculatorContext* cc) { +absl::Status TfLiteTensorsToDetectionsCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); if (cc->Inputs().HasTag(kTensorsGpuTag)) { @@ -247,14 +247,14 @@ mediapipe::Status TfLiteTensorsToDetectionsCalculator::Open( MP_RETURN_IF_ERROR(GpuInit(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToDetectionsCalculator::Process( +absl::Status TfLiteTensorsToDetectionsCalculator::Process( CalculatorContext* cc) { if ((!gpu_input_ && cc->Inputs().Tag(kTensorsTag).IsEmpty()) || (gpu_input_ && cc->Inputs().Tag(kTensorsGpuTag).IsEmpty())) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } auto output_detections = absl::make_unique>(); @@ -272,10 +272,10 @@ mediapipe::Status TfLiteTensorsToDetectionsCalculator::Process( .Add(output_detections.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU( +absl::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU( CalculatorContext* cc, std::vector* output_detections) { const auto& input_tensors = cc->Inputs().Tag(kTensorsTag).Get>(); @@ -313,7 +313,7 @@ mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU( anchors_ = cc->InputSidePackets().Tag("ANCHORS").Get>(); } else { - return mediapipe::UnavailableError("No anchor data available."); + return absl::UnavailableError("No anchor data available."); } anchors_init_ = true; } @@ -390,9 +390,9 @@ mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessCPU( detection_classes.data(), output_detections)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( +absl::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( CalculatorContext* cc, std::vector* output_detections) { #if MEDIAPIPE_TFLITE_GL_INFERENCE const auto& input_tensors = @@ -401,7 +401,7 @@ mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, &input_tensors, &cc, &output_detections]() - -> mediapipe::Status { + -> absl::Status { // Copy inputs. MP_RETURN_IF_ERROR( CopyBuffer(input_tensors[0], gpu_data_->raw_boxes_buffer)); @@ -458,7 +458,7 @@ mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( ConvertToDetections(boxes.data(), detection_scores.data(), detection_classes.data(), output_detections)); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE @@ -543,21 +543,20 @@ mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( #else LOG(ERROR) << "GPU input on non-Android not supported yet."; #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close( - CalculatorContext* cc) { +absl::Status TfLiteTensorsToDetectionsCalculator::Close(CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GL_INFERENCE gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); }); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE gpu_data_.reset(); #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToDetectionsCalculator::LoadOptions( +absl::Status TfLiteTensorsToDetectionsCalculator::LoadOptions( CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = @@ -579,10 +578,10 @@ mediapipe::Status TfLiteTensorsToDetectionsCalculator::LoadOptions( ignore_classes_.insert(options_.ignore_classes(i)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToDetectionsCalculator::DecodeBoxes( +absl::Status TfLiteTensorsToDetectionsCalculator::DecodeBoxes( const float* raw_boxes, const std::vector& anchors, std::vector* boxes) { for (int i = 0; i < num_boxes_; ++i) { @@ -643,10 +642,10 @@ mediapipe::Status TfLiteTensorsToDetectionsCalculator::DecodeBoxes( } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToDetectionsCalculator::ConvertToDetections( +absl::Status TfLiteTensorsToDetectionsCalculator::ConvertToDetections( const float* detection_boxes, const float* detection_scores, const int* detection_classes, std::vector* output_detections) { for (int i = 0; i < num_boxes_; ++i) { @@ -684,7 +683,7 @@ mediapipe::Status TfLiteTensorsToDetectionsCalculator::ConvertToDetections( } output_detections->emplace_back(detection); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection( @@ -707,10 +706,10 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection( return detection; } -mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit( +absl::Status TfLiteTensorsToDetectionsCalculator::GpuInit( CalculatorContext* cc) { #if MEDIAPIPE_TFLITE_GL_INFERENCE - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> mediapipe::Status { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { gpu_data_ = absl::make_unique(); // A shader to decode detection boxes. @@ -918,7 +917,7 @@ void main() { MP_RETURN_IF_ERROR(CreateReadWriteShaderStorageBuffer( raw_scores_length, &gpu_data_->raw_scores_buffer)); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); #elif MEDIAPIPE_TFLITE_METAL_INFERENCE @@ -1153,7 +1152,7 @@ kernel void scoreKernel( #endif // MEDIAPIPE_TFLITE_GL_INFERENCE - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc index f3a8e3ffe..ef2946c32 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_floats_calculator.cc @@ -38,15 +38,15 @@ namespace mediapipe { // } class TfLiteTensorsToFloatsCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(TfLiteTensorsToFloatsCalculator); -mediapipe::Status TfLiteTensorsToFloatsCalculator::GetContract( +absl::Status TfLiteTensorsToFloatsCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("TENSORS")); RET_CHECK(cc->Outputs().HasTag("FLOATS") || cc->Outputs().HasTag("FLOAT")); @@ -59,17 +59,16 @@ mediapipe::Status TfLiteTensorsToFloatsCalculator::GetContract( cc->Outputs().Tag("FLOAT").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToFloatsCalculator::Open(CalculatorContext* cc) { +absl::Status TfLiteTensorsToFloatsCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToFloatsCalculator::Process( - CalculatorContext* cc) { +absl::Status TfLiteTensorsToFloatsCalculator::Process(CalculatorContext* cc) { RET_CHECK(!cc->Inputs().Tag("TENSORS").IsEmpty()); const auto& input_tensors = @@ -97,6 +96,6 @@ mediapipe::Status TfLiteTensorsToFloatsCalculator::Process( cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc index af3f6684c..1be83bbe1 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_landmarks_calculator.cc @@ -89,13 +89,13 @@ float ApplyActivation( // } class TfLiteTensorsToLandmarksCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: - mediapipe::Status LoadOptions(CalculatorContext* cc); + absl::Status LoadOptions(CalculatorContext* cc); int num_landmarks_ = 0; bool flip_vertically_ = false; bool flip_horizontally_ = false; @@ -104,7 +104,7 @@ class TfLiteTensorsToLandmarksCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator); -mediapipe::Status TfLiteTensorsToLandmarksCalculator::GetContract( +absl::Status TfLiteTensorsToLandmarksCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -137,11 +137,10 @@ mediapipe::Status TfLiteTensorsToLandmarksCalculator::GetContract( cc->Outputs().Tag("NORM_LANDMARKS").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToLandmarksCalculator::Open( - CalculatorContext* cc) { +absl::Status TfLiteTensorsToLandmarksCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); MP_RETURN_IF_ERROR(LoadOptions(cc)); @@ -149,7 +148,7 @@ mediapipe::Status TfLiteTensorsToLandmarksCalculator::Open( if (cc->Outputs().HasTag("NORM_LANDMARKS")) { RET_CHECK(options_.has_input_image_height() && options_.has_input_image_width()) - << "Must provide input with/height for getting normalized landmarks."; + << "Must provide input width/height for getting normalized landmarks."; } if (cc->Outputs().HasTag("LANDMARKS") && (options_.flip_vertically() || options_.flip_horizontally() || @@ -157,7 +156,7 @@ mediapipe::Status TfLiteTensorsToLandmarksCalculator::Open( cc->InputSidePackets().HasTag("FLIP_VERTICALLY"))) { RET_CHECK(options_.has_input_image_height() && options_.has_input_image_width()) - << "Must provide input with/height for using flip_vertically option " + << "Must provide input width/height for using flip_vertically option " "when outputing landmarks in absolute coordinates."; } @@ -171,10 +170,10 @@ mediapipe::Status TfLiteTensorsToLandmarksCalculator::Open( ? cc->InputSidePackets().Tag("FLIP_VERTICALLY").Get() : options_.flip_vertically(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToLandmarksCalculator::Process( +absl::Status TfLiteTensorsToLandmarksCalculator::Process( CalculatorContext* cc) { // Override values if specified so. if (cc->Inputs().HasTag("FLIP_HORIZONTALLY") && @@ -187,7 +186,7 @@ mediapipe::Status TfLiteTensorsToLandmarksCalculator::Process( } if (cc->Inputs().Tag("TENSORS").IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_tensors = @@ -268,16 +267,16 @@ mediapipe::Status TfLiteTensorsToLandmarksCalculator::Process( .At(cc->InputTimestamp())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToLandmarksCalculator::LoadOptions( +absl::Status TfLiteTensorsToLandmarksCalculator::LoadOptions( CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = cc->Options<::mediapipe::TfLiteTensorsToLandmarksCalculatorOptions>(); num_landmarks_ = options_.num_landmarks(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc index 4190a05cb..ec4945201 100644 --- a/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.cc @@ -36,7 +36,7 @@ #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace { constexpr int kWorkgroupSize = 8; // Block size for GPU shader. @@ -69,7 +69,7 @@ using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::GlBuffer; using ::tflite::gpu::gl::GlProgram; using ::tflite::gpu::gl::GlShader; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU // Converts TFLite tensors from a tflite segmentation model to an image mask. // @@ -121,17 +121,17 @@ using ::tflite::gpu::gl::GlShader; // class TfLiteTensorsToSegmentationCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status LoadOptions(CalculatorContext* cc); - mediapipe::Status InitGpu(CalculatorContext* cc); - mediapipe::Status ProcessGpu(CalculatorContext* cc); - mediapipe::Status ProcessCpu(CalculatorContext* cc); + absl::Status LoadOptions(CalculatorContext* cc); + absl::Status InitGpu(CalculatorContext* cc); + absl::Status ProcessGpu(CalculatorContext* cc); + absl::Status ProcessCpu(CalculatorContext* cc); void GlRender(); ::mediapipe::TfLiteTensorsToSegmentationCalculatorOptions options_; @@ -147,12 +147,12 @@ class TfLiteTensorsToSegmentationCalculator : public CalculatorBase { std::unique_ptr mask_program_no_prev_; std::unique_ptr tensor_buffer_; GLuint upsample_program_; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); // static -mediapipe::Status TfLiteTensorsToSegmentationCalculator::GetContract( +absl::Status TfLiteTensorsToSegmentationCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -184,7 +184,7 @@ mediapipe::Status TfLiteTensorsToSegmentationCalculator::GetContract( cc->Inputs().Tag(kSizeImageGpuTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU // Outputs. if (cc->Outputs().HasTag(kMaskTag)) { @@ -195,17 +195,17 @@ mediapipe::Status TfLiteTensorsToSegmentationCalculator::GetContract( cc->Outputs().Tag(kMaskGpuTag).Set(); use_gpu |= true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (use_gpu) { #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToSegmentationCalculator::Open( +absl::Status TfLiteTensorsToSegmentationCalculator::Open( CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); @@ -213,44 +213,42 @@ mediapipe::Status TfLiteTensorsToSegmentationCalculator::Open( use_gpu_ = true; #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } MP_RETURN_IF_ERROR(LoadOptions(cc)); if (use_gpu_) { #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> mediapipe::Status { - MP_RETURN_IF_ERROR(InitGpu(cc)); - return mediapipe::OkStatus(); - })); + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + MP_RETURN_IF_ERROR(InitGpu(cc)); + return absl::OkStatus(); + })); #else RET_CHECK_FAIL() << "GPU processing not enabled."; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToSegmentationCalculator::Process( +absl::Status TfLiteTensorsToSegmentationCalculator::Process( CalculatorContext* cc) { if (use_gpu_) { #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> mediapipe::Status { - MP_RETURN_IF_ERROR(ProcessGpu(cc)); - return mediapipe::OkStatus(); - })); -#endif // !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { + MP_RETURN_IF_ERROR(ProcessGpu(cc)); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU } else { MP_RETURN_IF_ERROR(ProcessCpu(cc)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close( +absl::Status TfLiteTensorsToSegmentationCalculator::Close( CalculatorContext* cc) { #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) gpu_helper_.RunInGlContext([this] { @@ -260,15 +258,15 @@ mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close( mask_program_no_prev_.reset(); tensor_buffer_.reset(); }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToSegmentationCalculator::ProcessCpu( +absl::Status TfLiteTensorsToSegmentationCalculator::ProcessCpu( CalculatorContext* cc) { if (cc->Inputs().Tag(kTensorsTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Get input streams. @@ -366,17 +364,17 @@ mediapipe::Status TfLiteTensorsToSegmentationCalculator::ProcessCpu( large_mask_mat.copyTo(output_mat); cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Steps: // 1. receive tensor and optional previous mask // 2. process segmentation tensor into small mask // 3. upsample small mask into output mask to be same size as input image -mediapipe::Status TfLiteTensorsToSegmentationCalculator::ProcessGpu( +absl::Status TfLiteTensorsToSegmentationCalculator::ProcessGpu( CalculatorContext* cc) { if (cc->Inputs().Tag(kTensorsGpuTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) // Get input streams. @@ -458,9 +456,9 @@ mediapipe::Status TfLiteTensorsToSegmentationCalculator::ProcessGpu( // Cleanup input_mask_texture.Release(); output_texture.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } void TfLiteTensorsToSegmentationCalculator::GlRender() { @@ -512,10 +510,10 @@ void TfLiteTensorsToSegmentationCalculator::GlRender() { glBindVertexArray(0); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } -mediapipe::Status TfLiteTensorsToSegmentationCalculator::LoadOptions( +absl::Status TfLiteTensorsToSegmentationCalculator::LoadOptions( CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = @@ -531,14 +529,13 @@ mediapipe::Status TfLiteTensorsToSegmentationCalculator::LoadOptions( RET_CHECK_EQ(tensor_channels_, 2) << "Only 2 channel segmentation tensor currently supported"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToSegmentationCalculator::InitGpu( +absl::Status TfLiteTensorsToSegmentationCalculator::InitGpu( CalculatorContext* cc) { #if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() - -> ::mediapipe::Status { + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> absl::Status { // A shader to process a segmentation tensor into an output mask, // and use an optional previous mask as input. // Currently uses 4 channels for output, @@ -698,11 +695,11 @@ void main() { glUseProgram(upsample_program_); glUniform1i(glGetUniformLocation(upsample_program_, "input_data"), 1); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index e4bbc9145..df6d5c6d6 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -357,6 +357,21 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "detection_to_landmarks_calculator", + srcs = ["detection_to_landmarks_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:location_data_cc_proto", + "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:status", + ], + alwayslink = 1, +) + cc_library( name = "landmarks_to_detection_calculator", srcs = ["landmarks_to_detection_calculator.cc"], @@ -1135,3 +1150,47 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "to_image_calculator", + srcs = ["to_image_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:vector", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gl_calculator_helper", + ], + }), + alwayslink = 1, +) + +cc_library( + name = "from_image_calculator", + srcs = ["from_image_calculator.cc"], + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:status", + "//mediapipe/framework/port:vector", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gl_calculator_helper", + ], + }), + alwayslink = 1, +) diff --git a/mediapipe/calculators/util/alignment_points_to_rects_calculator.cc b/mediapipe/calculators/util/alignment_points_to_rects_calculator.cc index be79fda20..edfa4196a 100644 --- a/mediapipe/calculators/util/alignment_points_to_rects_calculator.cc +++ b/mediapipe/calculators/util/alignment_points_to_rects_calculator.cc @@ -40,7 +40,7 @@ namespace {} // namespace // } class AlignmentPointsRectsCalculator : public DetectionsToRectsCalculator { public: - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { RET_CHECK_OK(DetectionsToRectsCalculator::Open(cc)); // Make sure that start and end keypoints are provided. @@ -52,18 +52,18 @@ class AlignmentPointsRectsCalculator : public DetectionsToRectsCalculator { RET_CHECK(options_.has_rotation_vector_end_keypoint_index()) << "End keypoint is required to calculate rect size and rotation"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: - mediapipe::Status DetectionToNormalizedRect( + absl::Status DetectionToNormalizedRect( const ::mediapipe::Detection& detection, const DetectionSpec& detection_spec, ::mediapipe::NormalizedRect* rect) override; }; REGISTER_CALCULATOR(AlignmentPointsRectsCalculator); -mediapipe::Status AlignmentPointsRectsCalculator::DetectionToNormalizedRect( +absl::Status AlignmentPointsRectsCalculator::DetectionToNormalizedRect( const Detection& detection, const DetectionSpec& detection_spec, NormalizedRect* rect) { const auto& location_data = detection.location_data(); @@ -96,7 +96,7 @@ mediapipe::Status AlignmentPointsRectsCalculator::DetectionToNormalizedRect( rect->set_width(box_size / image_size->first); rect->set_height(box_size / image_size->second); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/annotation_overlay_calculator.cc b/mediapipe/calculators/util/annotation_overlay_calculator.cc index 89ecd1cee..7c5aadc55 100644 --- a/mediapipe/calculators/util/annotation_overlay_calculator.cc +++ b/mediapipe/calculators/util/annotation_overlay_calculator.cc @@ -31,12 +31,12 @@ #include "mediapipe/util/color.pb.h" #include "mediapipe/util/render_data.pb.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/shader_util.h" -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -124,29 +124,29 @@ class AnnotationOverlayCalculator : public CalculatorBase { AnnotationOverlayCalculator() = default; ~AnnotationOverlayCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // From Calculator. - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status CreateRenderTargetCpu(CalculatorContext* cc, - std::unique_ptr& image_mat, - ImageFormat::Format* target_format); + absl::Status CreateRenderTargetCpu(CalculatorContext* cc, + std::unique_ptr& image_mat, + ImageFormat::Format* target_format); template - mediapipe::Status CreateRenderTargetGpu(CalculatorContext* cc, - std::unique_ptr& image_mat); + absl::Status CreateRenderTargetGpu(CalculatorContext* cc, + std::unique_ptr& image_mat); template - mediapipe::Status RenderToGpu(CalculatorContext* cc, uchar* overlay_image); - mediapipe::Status RenderToCpu(CalculatorContext* cc, - const ImageFormat::Format& target_format, - uchar* data_image); + absl::Status RenderToGpu(CalculatorContext* cc, uchar* overlay_image); + absl::Status RenderToCpu(CalculatorContext* cc, + const ImageFormat::Format& target_format, + uchar* data_image); - mediapipe::Status GlRender(CalculatorContext* cc); + absl::Status GlRender(CalculatorContext* cc); template - mediapipe::Status GlSetup(CalculatorContext* cc); + absl::Status GlSetup(CalculatorContext* cc); // Options for the calculator. AnnotationOverlayCalculatorOptions options_; @@ -159,7 +159,7 @@ class AnnotationOverlayCalculator : public CalculatorBase { bool use_gpu_ = false; bool gpu_initialized_ = false; -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU mediapipe::GlCalculatorHelper gpu_helper_; GLuint program_ = 0; GLuint image_mat_tex_ = 0; // Overlay drawing image for GPU. @@ -167,33 +167,32 @@ class AnnotationOverlayCalculator : public CalculatorBase { int height_ = 0; int width_canvas_ = 0; // Size of overlay drawing texture canvas. int height_canvas_ = 0; -#endif // MEDIAPIPE_DISABLE_GPU +#endif // MEDIAPIPE_DISABLE_GPU }; REGISTER_CALCULATOR(AnnotationOverlayCalculator); -mediapipe::Status AnnotationOverlayCalculator::GetContract( - CalculatorContract* cc) { +absl::Status AnnotationOverlayCalculator::GetContract(CalculatorContract* cc) { CHECK_GE(cc->Inputs().NumEntries(), 1); bool use_gpu = false; if (cc->Inputs().HasTag(kImageFrameTag) && cc->Inputs().HasTag(kGpuBufferTag)) { - return mediapipe::InternalError("Cannot have multiple input images."); + return absl::InternalError("Cannot have multiple input images."); } if (cc->Inputs().HasTag(kGpuBufferTag) != cc->Outputs().HasTag(kGpuBufferTag)) { - return mediapipe::InternalError("GPU output must have GPU input."); + return absl::InternalError("GPU output must have GPU input."); } // Input image to render onto copy of. Should be same type as output. -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kGpuBufferTag)) { cc->Inputs().Tag(kGpuBufferTag).Set(); CHECK(cc->Outputs().HasTag(kGpuBufferTag)); use_gpu = true; } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Inputs().HasTag(kImageFrameTag)) { cc->Inputs().Tag(kImageFrameTag).Set(); CHECK(cc->Outputs().HasTag(kImageFrameTag)); @@ -213,32 +212,32 @@ mediapipe::Status AnnotationOverlayCalculator::GetContract( } // Rendered image. Should be same type as input. -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kGpuBufferTag)) { cc->Outputs().Tag(kGpuBufferTag).Set(); } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kImageFrameTag)) { cc->Outputs().Tag(kImageFrameTag).Set(); } if (use_gpu) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AnnotationOverlayCalculator::Open(CalculatorContext* cc) { +absl::Status AnnotationOverlayCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); if (cc->Inputs().HasTag(kGpuBufferTag) || HasImageTag(cc)) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU use_gpu_ = true; -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } if (cc->Inputs().HasTag(kGpuBufferTag) || @@ -264,23 +263,23 @@ mediapipe::Status AnnotationOverlayCalculator::Open(CalculatorContext* cc) { } if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) { +absl::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) { // Initialize render target, drawn with OpenCV. std::unique_ptr image_mat; ImageFormat::Format target_format; if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (!gpu_initialized_) { MP_RETURN_IF_ERROR( - gpu_helper_.RunInGlContext([this, cc]() -> mediapipe::Status { + gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { return GlSetup(cc); })); gpu_initialized_ = true; @@ -290,7 +289,7 @@ mediapipe::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) { (CreateRenderTargetGpu( cc, image_mat))); } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } else { if (cc->Outputs().HasTag(kImageFrameTag)) { MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format)); @@ -326,44 +325,44 @@ mediapipe::Status AnnotationOverlayCalculator::Process(CalculatorContext* cc) { } if (use_gpu_) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU // Overlay rendered image in OpenGL, onto a copy of input. uchar* image_mat_ptr = image_mat->data; - MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( - [this, cc, image_mat_ptr]() -> mediapipe::Status { + MP_RETURN_IF_ERROR( + gpu_helper_.RunInGlContext([this, cc, image_mat_ptr]() -> absl::Status { return RenderToGpu( cc, image_mat_ptr); })); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU } else { // Copy the rendered image to output. uchar* image_mat_ptr = image_mat->data; MP_RETURN_IF_ERROR(RenderToCpu(cc, target_format, image_mat_ptr)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AnnotationOverlayCalculator::Close(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status AnnotationOverlayCalculator::Close(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU gpu_helper_.RunInGlContext([this] { if (program_) glDeleteProgram(program_); program_ = 0; if (image_mat_tex_) glDeleteTextures(1, &image_mat_tex_); image_mat_tex_ = 0; }); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AnnotationOverlayCalculator::RenderToCpu( +absl::Status AnnotationOverlayCalculator::RenderToCpu( CalculatorContext* cc, const ImageFormat::Format& target_format, uchar* data_image) { auto output_frame = absl::make_unique( target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight()); -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight(), data_image, ImageFrame::kGlDefaultAlignmentBoundary); @@ -371,7 +370,7 @@ mediapipe::Status AnnotationOverlayCalculator::RenderToCpu( output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight(), data_image, ImageFrame::kDefaultAlignmentBoundary); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU if (cc->Outputs().HasTag(kImageFrameTag)) { cc->Outputs() @@ -379,13 +378,13 @@ mediapipe::Status AnnotationOverlayCalculator::RenderToCpu( .Add(output_frame.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } template -mediapipe::Status AnnotationOverlayCalculator::RenderToGpu( - CalculatorContext* cc, uchar* overlay_image) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status AnnotationOverlayCalculator::RenderToGpu(CalculatorContext* cc, + uchar* overlay_image) { +#if !MEDIAPIPE_DISABLE_GPU // Source and destination textures. const auto& input_frame = cc->Inputs().Tag(Tag).Get(); auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); @@ -426,12 +425,12 @@ mediapipe::Status AnnotationOverlayCalculator::RenderToGpu( // Cleanup input_texture.Release(); output_texture.Release(); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetCpu( +absl::Status AnnotationOverlayCalculator::CreateRenderTargetCpu( CalculatorContext* cc, std::unique_ptr& image_mat, ImageFormat::Format* target_format) { if (image_frame_available_) { @@ -453,7 +452,7 @@ mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetCpu( target_mat_type = CV_8UC3; break; default: - return mediapipe::UnknownError("Unexpected image frame format."); + return absl::UnknownError("Unexpected image frame format."); break; } @@ -476,13 +475,13 @@ mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetCpu( *target_format = ImageFormat::SRGB; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } template -mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetGpu( +absl::Status AnnotationOverlayCalculator::CreateRenderTargetGpu( CalculatorContext* cc, std::unique_ptr& image_mat) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU if (image_frame_available_) { const auto& input_frame = cc->Inputs().Tag(Tag).Get(); const mediapipe::ImageFormat::Format format = @@ -500,13 +499,13 @@ mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetGpu( cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(), options_.canvas_color().b())); } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AnnotationOverlayCalculator::GlRender(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status AnnotationOverlayCalculator::GlRender(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU static const GLfloat square_vertices[] = { -1.0f, -1.0f, // bottom left 1.0f, -1.0f, // bottom right @@ -554,14 +553,14 @@ mediapipe::Status AnnotationOverlayCalculator::GlRender(CalculatorContext* cc) { glBindVertexArray(0); glDeleteVertexArrays(1, &vao); glDeleteBuffers(2, vbo); -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } template -mediapipe::Status AnnotationOverlayCalculator::GlSetup(CalculatorContext* cc) { -#if !defined(MEDIAPIPE_DISABLE_GPU) +absl::Status AnnotationOverlayCalculator::GlSetup(CalculatorContext* cc) { +#if !MEDIAPIPE_DISABLE_GPU const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -658,9 +657,9 @@ mediapipe::Status AnnotationOverlayCalculator::GlSetup(CalculatorContext* cc) { glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); glBindTexture(GL_TEXTURE_2D, 0); } -#endif // !MEDIAPIPE_DISABLE_GPU +#endif // !MEDIAPIPE_DISABLE_GPU - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/association_calculator.h b/mediapipe/calculators/util/association_calculator.h index 37a77400a..6e5b480ce 100644 --- a/mediapipe/calculators/util/association_calculator.h +++ b/mediapipe/calculators/util/association_calculator.h @@ -56,7 +56,7 @@ inline float OverlapSimilarity(const Rectangle_f& rect1, template class AssociationCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { // Atmost one input stream can be tagged with "PREV". RET_CHECK_LE(cc->Inputs().NumEntries("PREV"), 1); @@ -71,10 +71,10 @@ class AssociationCalculator : public CalculatorBase { cc->Outputs().Index(0).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); has_prev_input_stream_ = cc->Inputs().HasTag("PREV"); @@ -84,15 +84,15 @@ class AssociationCalculator : public CalculatorBase { options_ = cc->Options<::mediapipe::AssociationCalculatorOptions>(); CHECK_GE(options_.min_similarity_threshold(), 0); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { auto get_non_overlapping_elements = GetNonOverlappingElements(cc); if (!get_non_overlapping_elements.ok()) { return get_non_overlapping_elements.status(); } - std::list result = get_non_overlapping_elements.ValueOrDie(); + std::list result = get_non_overlapping_elements.value(); if (has_prev_input_stream_ && !cc->Inputs().Get(prev_input_stream_id_).IsEmpty()) { @@ -114,7 +114,7 @@ class AssociationCalculator : public CalculatorBase { } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } protected: @@ -123,8 +123,8 @@ class AssociationCalculator : public CalculatorBase { bool has_prev_input_stream_; CollectionItemId prev_input_stream_id_; - virtual mediapipe::StatusOr GetRectangle(const T& input) { - return mediapipe::OkStatus(); + virtual absl::StatusOr GetRectangle(const T& input) { + return absl::OkStatus(); } virtual std::pair GetId(const T& input) { return {false, -1}; } @@ -134,7 +134,7 @@ class AssociationCalculator : public CalculatorBase { private: // Get a list of non-overlapping elements from all input streams, with // increasing order of priority based on input stream index. - mediapipe::StatusOr> GetNonOverlappingElements( + absl::StatusOr> GetNonOverlappingElements( CalculatorContext* cc) { std::list result; @@ -176,7 +176,7 @@ class AssociationCalculator : public CalculatorBase { return result; } - mediapipe::Status AddElementToList(T element, std::list* current) { + absl::Status AddElementToList(T element, std::list* current) { // Compare this element with elements of the input collection. If this // element has high overlap with elements of the collection, remove // those elements from the collection and add this element. @@ -207,20 +207,20 @@ class AssociationCalculator : public CalculatorBase { } current->push_back(element); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Compare elements of the current list with elements in from the collection // of elements from the previous input stream, and propagate IDs from the // previous input stream as appropriate. - mediapipe::Status PropagateIdsFromPreviousToCurrent( + absl::Status PropagateIdsFromPreviousToCurrent( const std::vector& prev_input_vec, std::list* current) { for (auto vit = current->begin(); vit != current->end(); ++vit) { auto get_cur_rectangle = GetRectangle(*vit); if (!get_cur_rectangle.ok()) { return get_cur_rectangle.status(); } - const Rectangle_f& cur_rect = get_cur_rectangle.ValueOrDie(); + const Rectangle_f& cur_rect = get_cur_rectangle.value(); bool change_id = false; int id_for_vi = -1; @@ -230,7 +230,7 @@ class AssociationCalculator : public CalculatorBase { if (!get_prev_rectangle.ok()) { return get_prev_rectangle.status(); } - const Rectangle_f& prev_rect = get_prev_rectangle.ValueOrDie(); + const Rectangle_f& prev_rect = get_prev_rectangle.value(); if (OverlapSimilarity(cur_rect, prev_rect) > options_.min_similarity_threshold()) { @@ -250,7 +250,7 @@ class AssociationCalculator : public CalculatorBase { *vit = element; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/util/association_detection_calculator.cc b/mediapipe/calculators/util/association_detection_calculator.cc index 59d052769..35112aee7 100644 --- a/mediapipe/calculators/util/association_detection_calculator.cc +++ b/mediapipe/calculators/util/association_detection_calculator.cc @@ -37,27 +37,27 @@ namespace mediapipe { class AssociationDetectionCalculator : public AssociationCalculator<::mediapipe::Detection> { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { return AssociationCalculator<::mediapipe::Detection>::GetContract(cc); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::Detection>::Open(cc); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::Detection>::Process(cc); } - mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::Detection>::Close(cc); } protected: - mediapipe::StatusOr GetRectangle( + absl::StatusOr GetRectangle( const ::mediapipe::Detection& input) override { if (!input.has_location_data()) { - return mediapipe::InternalError("Missing location_data in Detection"); + return absl::InternalError("Missing location_data in Detection"); } const Location location(input.location_data()); return location.GetRelativeBBox(); diff --git a/mediapipe/calculators/util/association_norm_rect_calculator.cc b/mediapipe/calculators/util/association_norm_rect_calculator.cc index a77d65f0d..a9194604a 100644 --- a/mediapipe/calculators/util/association_norm_rect_calculator.cc +++ b/mediapipe/calculators/util/association_norm_rect_calculator.cc @@ -36,28 +36,28 @@ namespace mediapipe { class AssociationNormRectCalculator : public AssociationCalculator<::mediapipe::NormalizedRect> { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { return AssociationCalculator<::mediapipe::NormalizedRect>::GetContract(cc); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::NormalizedRect>::Open(cc); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::NormalizedRect>::Process(cc); } - mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { return AssociationCalculator<::mediapipe::NormalizedRect>::Close(cc); } protected: - mediapipe::StatusOr GetRectangle( + absl::StatusOr GetRectangle( const ::mediapipe::NormalizedRect& input) override { if (!input.has_x_center() || !input.has_y_center() || !input.has_width() || !input.has_height()) { - return mediapipe::InternalError("Missing dimensions in NormalizedRect."); + return absl::InternalError("Missing dimensions in NormalizedRect."); } const float xmin = input.x_center() - input.width() / 2.0; const float ymin = input.y_center() - input.height() / 2.0; diff --git a/mediapipe/calculators/util/clock_latency_calculator.cc b/mediapipe/calculators/util/clock_latency_calculator.cc index d852c68c7..5c5711731 100644 --- a/mediapipe/calculators/util/clock_latency_calculator.cc +++ b/mediapipe/calculators/util/clock_latency_calculator.cc @@ -60,17 +60,17 @@ class ClockLatencyCalculator : public CalculatorBase { public: ClockLatencyCalculator() {} - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: int64 num_packet_streams_ = -1; }; REGISTER_CALCULATOR(ClockLatencyCalculator); -mediapipe::Status ClockLatencyCalculator::GetContract(CalculatorContract* cc) { +absl::Status ClockLatencyCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_GT(cc->Inputs().NumEntries(), 1); int64 num_packet_streams = cc->Inputs().NumEntries() - 1; @@ -82,17 +82,17 @@ mediapipe::Status ClockLatencyCalculator::GetContract(CalculatorContract* cc) { } cc->Inputs().Tag(kReferenceTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ClockLatencyCalculator::Open(CalculatorContext* cc) { +absl::Status ClockLatencyCalculator::Open(CalculatorContext* cc) { // Direct passthrough, as far as timestamp and bounds are concerned. cc->SetOffset(TimestampDiff(0)); num_packet_streams_ = cc->Inputs().NumEntries() - 1; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ClockLatencyCalculator::Process(CalculatorContext* cc) { +absl::Status ClockLatencyCalculator::Process(CalculatorContext* cc) { // Get reference time. RET_CHECK(!cc->Inputs().Tag(kReferenceTag).IsEmpty()); const absl::Time& reference_time = @@ -109,7 +109,7 @@ mediapipe::Status ClockLatencyCalculator::Process(CalculatorContext* cc) { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/clock_timestamp_calculator.cc b/mediapipe/calculators/util/clock_timestamp_calculator.cc index 82bdb41b2..4ba56cfd0 100644 --- a/mediapipe/calculators/util/clock_timestamp_calculator.cc +++ b/mediapipe/calculators/util/clock_timestamp_calculator.cc @@ -52,10 +52,10 @@ class ClockTimestampCalculator : public CalculatorBase { public: ClockTimestampCalculator() {} - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Clock object. @@ -63,8 +63,7 @@ class ClockTimestampCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(ClockTimestampCalculator); -mediapipe::Status ClockTimestampCalculator::GetContract( - CalculatorContract* cc) { +absl::Status ClockTimestampCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 1); RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); @@ -78,10 +77,10 @@ mediapipe::Status ClockTimestampCalculator::GetContract( .Set>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ClockTimestampCalculator::Open(CalculatorContext* cc) { +absl::Status ClockTimestampCalculator::Open(CalculatorContext* cc) { // Direct passthrough, as far as timestamp and bounds are concerned. cc->SetOffset(TimestampDiff(0)); @@ -95,14 +94,14 @@ mediapipe::Status ClockTimestampCalculator::Open(CalculatorContext* cc) { ::mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ClockTimestampCalculator::Process(CalculatorContext* cc) { +absl::Status ClockTimestampCalculator::Process(CalculatorContext* cc) { // Push the Time packet to output. auto timestamp_packet = MakePacket(clock_->TimeNow()); cc->Outputs().Index(0).AddPacket(timestamp_packet.At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/collection_has_min_size_calculator.h b/mediapipe/calculators/util/collection_has_min_size_calculator.h index 6cbc63e51..4d4b6a678 100644 --- a/mediapipe/calculators/util/collection_has_min_size_calculator.h +++ b/mediapipe/calculators/util/collection_has_min_size_calculator.h @@ -42,7 +42,7 @@ namespace mediapipe { template class CollectionHasMinSizeCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("ITERABLE")); RET_CHECK_EQ(1, cc->Inputs().NumEntries()); @@ -60,10 +60,10 @@ class CollectionHasMinSizeCalculator : public CalculatorBase { if (cc->InputSidePackets().NumEntries() > 0) { cc->InputSidePackets().Index(0).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); min_size_ = cc->Options<::mediapipe::CollectionHasMinSizeCalculatorOptions>() @@ -73,17 +73,17 @@ class CollectionHasMinSizeCalculator : public CalculatorBase { !cc->InputSidePackets().Index(0).IsEmpty()) { min_size_ = cc->InputSidePackets().Index(0).Get(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const IterableT& input = cc->Inputs().Tag("ITERABLE").Get(); bool has_min_size = input.size() >= min_size_; cc->Outputs().Index(0).AddPacket( MakePacket(has_min_size).At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc index 779e9785b..0de1e53b2 100644 --- a/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc +++ b/mediapipe/calculators/util/detection_label_id_to_text_calculator.cc @@ -47,26 +47,25 @@ namespace mediapipe { // } class DetectionLabelIdToTextCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: absl::node_hash_map label_map_; }; REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); -mediapipe::Status DetectionLabelIdToTextCalculator::GetContract( +absl::Status DetectionLabelIdToTextCalculator::GetContract( CalculatorContract* cc) { cc->Inputs().Index(0).Set>(); cc->Outputs().Index(0).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DetectionLabelIdToTextCalculator::Open( - CalculatorContext* cc) { +absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -90,11 +89,10 @@ mediapipe::Status DetectionLabelIdToTextCalculator::Open( label_map_[i] = options.label(i); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DetectionLabelIdToTextCalculator::Process( - CalculatorContext* cc) { +absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) { std::vector output_detections; for (const auto& input_detection : cc->Inputs().Index(0).Get>()) { @@ -115,7 +113,7 @@ mediapipe::Status DetectionLabelIdToTextCalculator::Process( cc->Outputs().Index(0).AddPacket( MakePacket>(output_detections) .At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc index a23a1d225..8f8025576 100644 --- a/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc +++ b/mediapipe/calculators/util/detection_letterbox_removal_calculator.cc @@ -70,7 +70,7 @@ constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; // } class DetectionLetterboxRemovalCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetectionsTag) && cc->Inputs().HasTag(kLetterboxPaddingTag)) << "Missing one or more input streams."; @@ -80,19 +80,19 @@ class DetectionLetterboxRemovalCalculator : public CalculatorBase { cc->Outputs().Tag(kDetectionsTag).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { // Only process if there's input detections. if (cc->Inputs().Tag(kDetectionsTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_detections = @@ -146,7 +146,7 @@ class DetectionLetterboxRemovalCalculator : public CalculatorBase { cc->Outputs() .Tag("DETECTIONS") .Add(output_detections.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(DetectionLetterboxRemovalCalculator); diff --git a/mediapipe/calculators/util/detection_projection_calculator.cc b/mediapipe/calculators/util/detection_projection_calculator.cc index 9200ebfe3..211fd204c 100644 --- a/mediapipe/calculators/util/detection_projection_calculator.cc +++ b/mediapipe/calculators/util/detection_projection_calculator.cc @@ -51,9 +51,9 @@ namespace mediapipe { // } class DetectionProjectionCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(DetectionProjectionCalculator); @@ -62,7 +62,7 @@ namespace { constexpr char kDetections[] = "DETECTIONS"; constexpr char kProjectionMatrix[] = "PROJECTION_MATRIX"; -mediapipe::Status ProjectDetection( +absl::Status ProjectDetection( const std::function& project_fn, Detection* detection) { auto* location_data = detection->mutable_location_data(); @@ -107,12 +107,12 @@ mediapipe::Status ProjectDetection( box->set_width(right_bottom.x() - left_top.x()); box->set_height(right_bottom.y() - left_top.y()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace -mediapipe::Status DetectionProjectionCalculator::GetContract( +absl::Status DetectionProjectionCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetections) && cc->Inputs().HasTag(kProjectionMatrix)) @@ -133,18 +133,17 @@ mediapipe::Status DetectionProjectionCalculator::GetContract( cc->Outputs().Get(id).Set>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DetectionProjectionCalculator::Open(CalculatorContext* cc) { +absl::Status DetectionProjectionCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DetectionProjectionCalculator::Process( - CalculatorContext* cc) { +absl::Status DetectionProjectionCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().Tag(kProjectionMatrix).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& project_mat = cc->Inputs().Tag(kProjectionMatrix).Get>(); @@ -173,7 +172,7 @@ mediapipe::Status DetectionProjectionCalculator::Process( MakePacket>(std::move(output_detections)) .At(cc->InputTimestamp())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_projection_calculator_test.cc b/mediapipe/calculators/util/detection_projection_calculator_test.cc index bf8d49263..0437e6f96 100644 --- a/mediapipe/calculators/util/detection_projection_calculator_test.cc +++ b/mediapipe/calculators/util/detection_projection_calculator_test.cc @@ -64,7 +64,7 @@ std::vector GetPoints(const Detection& detection) { } // Test helper function to run "DetectionProjectionCalculator". -mediapipe::StatusOr RunProjectionCalculator( +absl::StatusOr RunProjectionCalculator( Detection detection, std::array project_mat) { CalculatorRunner runner(ParseTextProtoOrDie(R"( calculator: "DetectionProjectionCalculator" @@ -132,7 +132,7 @@ TEST(DetectionProjectionCalculatorTest, ProjectionFullRoiNoOp) { auto status_or_result = RunProjectionCalculator(std::move(detection), std::move(projection_matrix)); MP_ASSERT_OK(status_or_result); - const auto& result = status_or_result.ValueOrDie(); + const auto& result = status_or_result.value(); ASSERT_EQ(result.location_data().format(), LocationData::RELATIVE_BOUNDING_BOX); EXPECT_THAT(result.location_data().relative_bounding_box(), @@ -178,7 +178,7 @@ TEST(DetectionProjectionCalculatorTest, ProjectionFullRoi90Rotation) { auto status_or_result = RunProjectionCalculator(std::move(detection), std::move(projection_matrix)); MP_ASSERT_OK(status_or_result); - const auto& result = status_or_result.ValueOrDie(); + const auto& result = status_or_result.value(); ASSERT_EQ(result.location_data().format(), LocationData::RELATIVE_BOUNDING_BOX); EXPECT_THAT(result.location_data().relative_bounding_box(), @@ -224,7 +224,7 @@ TEST(DetectionProjectionCalculatorTest, ProjectionSmallerRoi) { auto status_or_result = RunProjectionCalculator(std::move(detection), std::move(projection_matrix)); MP_ASSERT_OK(status_or_result); - const auto& result = status_or_result.ValueOrDie(); + const auto& result = status_or_result.value(); ASSERT_EQ(result.location_data().format(), LocationData::RELATIVE_BOUNDING_BOX); EXPECT_THAT(result.location_data().relative_bounding_box(), @@ -293,7 +293,7 @@ TEST(DetectionProjectionCalculatorTest, ProjectionSmallerRoi30Rotation) { auto status_or_result = RunProjectionCalculator(std::move(detection), std::move(projection_matrix)); MP_ASSERT_OK(status_or_result); - const auto& result = status_or_result.ValueOrDie(); + const auto& result = status_or_result.value(); ASSERT_EQ(result.location_data().format(), LocationData::RELATIVE_BOUNDING_BOX); EXPECT_THAT(result.location_data().relative_bounding_box(), diff --git a/mediapipe/calculators/util/detection_to_landmarks_calculator.cc b/mediapipe/calculators/util/detection_to_landmarks_calculator.cc new file mode 100644 index 000000000..549298bad --- /dev/null +++ b/mediapipe/calculators/util/detection_to_landmarks_calculator.cc @@ -0,0 +1,100 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/location_data.pb.h" +#include "mediapipe/framework/port/ret_check.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_macros.h" + +namespace mediapipe { + +namespace { + +constexpr char kDetectionTag[] = "DETECTION"; +constexpr char kLandmarksTag[] = "LANDMARKS"; + +absl::Status ConvertDetectionToLandmarks(const Detection& detection, + NormalizedLandmarkList* landmarks) { + const auto& location_data = detection.location_data(); + for (int i = 0; i < location_data.relative_keypoints_size(); ++i) { + const auto& keypoint = location_data.relative_keypoints(i); + + auto* landmark = landmarks->add_landmark(); + landmark->set_x(keypoint.x()); + landmark->set_y(keypoint.y()); + } + + return absl::OkStatus(); +} + +} // namespace + +// Converts a detection into a normalized landmark list by extracting the +// location data relative keypoints as landmarks. +// +// Input: +// DETECTION - `Detection` +// A detection to be converted. +// +// Output: +// LANDMARKS - `NormalizedLandmarkList` +// A converted normalized landmark list. +// +// Example: +// +// node { +// calculator: "DetectionToLandmarksCalculator" +// input_stream: "DETECTION:detection" +// output_stream: "LANDMARKS:landmarks" +// } +// +class DetectionToLandmarksCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + RET_CHECK(cc->Inputs().HasTag(kDetectionTag)); + RET_CHECK(cc->Outputs().HasTag(kLandmarksTag)); + + cc->Inputs().Tag(kDetectionTag).Set(); + cc->Outputs().Tag(kLandmarksTag).Set(); + + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override { + cc->SetOffset(TimestampDiff(0)); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + const auto& detection = cc->Inputs().Tag(kDetectionTag).Get(); + + auto landmarks = absl::make_unique(); + MP_RETURN_IF_ERROR(ConvertDetectionToLandmarks(detection, landmarks.get())); + + cc->Outputs() + .Tag(kLandmarksTag) + .Add(landmarks.release(), cc->InputTimestamp()); + + return absl::OkStatus(); + } +}; + +REGISTER_CALCULATOR(DetectionToLandmarksCalculator); + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/detection_unique_id_calculator.cc b/mediapipe/calculators/util/detection_unique_id_calculator.cc index 9a4d1afa4..ac8889ffb 100644 --- a/mediapipe/calculators/util/detection_unique_id_calculator.cc +++ b/mediapipe/calculators/util/detection_unique_id_calculator.cc @@ -44,7 +44,7 @@ inline int GetNextDetectionId() { return ++detection_id; } // } class DetectionUniqueIdCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetectionListTag) || cc->Inputs().HasTag(kDetectionsTag)) << "None of the input streams are provided."; @@ -60,24 +60,24 @@ class DetectionUniqueIdCalculator : public CalculatorBase { cc->Outputs().Tag(kDetectionsTag).Set>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(mediapipe::TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(DetectionUniqueIdCalculator); -mediapipe::Status DetectionUniqueIdCalculator::Process(CalculatorContext* cc) { +absl::Status DetectionUniqueIdCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kDetectionListTag) && !cc->Inputs().Tag(kDetectionListTag).IsEmpty()) { auto result = cc->Inputs().Tag(kDetectionListTag).Value().Consume(); if (result.ok()) { - auto detection_list = std::move(result).ValueOrDie(); + auto detection_list = std::move(result).value(); for (Detection& detection : *detection_list->mutable_detection()) { detection.set_detection_id(GetNextDetectionId()); } @@ -94,7 +94,7 @@ mediapipe::Status DetectionUniqueIdCalculator::Process(CalculatorContext* cc) { .Value() .Consume>(); if (result.ok()) { - auto detections = std::move(result).ValueOrDie(); + auto detections = std::move(result).value(); for (Detection& detection : *detections) { detection.set_detection_id(GetNextDetectionId()); } @@ -103,7 +103,7 @@ mediapipe::Status DetectionUniqueIdCalculator::Process(CalculatorContext* cc) { .Add(detections.release(), cc->InputTimestamp()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.cc b/mediapipe/calculators/util/detections_to_rects_calculator.cc index 27c0460e2..29836cb59 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator.cc @@ -40,8 +40,8 @@ constexpr char kNormRectsTag[] = "NORM_RECTS"; constexpr float kMinFloat = std::numeric_limits::lowest(); constexpr float kMaxFloat = std::numeric_limits::max(); -mediapipe::Status NormRectFromKeyPoints(const LocationData& location_data, - NormalizedRect* rect) { +absl::Status NormRectFromKeyPoints(const LocationData& location_data, + NormalizedRect* rect) { RET_CHECK_GT(location_data.relative_keypoints_size(), 1) << "2 or more key points required to calculate a rect."; float xmin = kMaxFloat; @@ -59,7 +59,7 @@ mediapipe::Status NormRectFromKeyPoints(const LocationData& location_data, rect->set_y_center((ymin + ymax) / 2); rect->set_width(xmax - xmin); rect->set_height(ymax - ymin); - return mediapipe::OkStatus(); + return absl::OkStatus(); } template @@ -72,7 +72,7 @@ void RectFromBox(B box, R* rect) { } // namespace -mediapipe::Status DetectionsToRectsCalculator::DetectionToRect( +absl::Status DetectionsToRectsCalculator::DetectionToRect( const Detection& detection, const DetectionSpec& detection_spec, Rect* rect) { const LocationData location_data = detection.location_data(); @@ -101,10 +101,10 @@ mediapipe::Status DetectionsToRectsCalculator::DetectionToRect( break; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DetectionsToRectsCalculator::DetectionToNormalizedRect( +absl::Status DetectionsToRectsCalculator::DetectionToNormalizedRect( const Detection& detection, const DetectionSpec& detection_spec, NormalizedRect* rect) { const LocationData location_data = detection.location_data(); @@ -124,11 +124,10 @@ mediapipe::Status DetectionsToRectsCalculator::DetectionToNormalizedRect( break; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DetectionsToRectsCalculator::GetContract( - CalculatorContract* cc) { +absl::Status DetectionsToRectsCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetectionTag) ^ cc->Inputs().HasTag(kDetectionsTag)) << "Exactly one of DETECTION or DETECTIONS input stream should be " @@ -164,10 +163,10 @@ mediapipe::Status DetectionsToRectsCalculator::GetContract( cc->Outputs().Tag(kNormRectsTag).Set>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DetectionsToRectsCalculator::Open(CalculatorContext* cc) { +absl::Status DetectionsToRectsCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); @@ -192,17 +191,17 @@ mediapipe::Status DetectionsToRectsCalculator::Open(CalculatorContext* cc) { output_zero_rect_for_empty_detections_ = options_.output_zero_rect_for_empty_detections(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DetectionsToRectsCalculator::Process(CalculatorContext* cc) { +absl::Status DetectionsToRectsCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kDetectionTag) && cc->Inputs().Tag(kDetectionTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (cc->Inputs().HasTag(kDetectionsTag) && cc->Inputs().Tag(kDetectionsTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } std::vector detections; @@ -230,7 +229,7 @@ mediapipe::Status DetectionsToRectsCalculator::Process(CalculatorContext* cc) { .Add(rect_vector.release(), cc->InputTimestamp()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -297,10 +296,10 @@ mediapipe::Status DetectionsToRectsCalculator::Process(CalculatorContext* cc) { .Add(output_rects.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DetectionsToRectsCalculator::ComputeRotation( +absl::Status DetectionsToRectsCalculator::ComputeRotation( const Detection& detection, const DetectionSpec& detection_spec, float* rotation) { const auto& location_data = detection.location_data(); @@ -318,7 +317,7 @@ mediapipe::Status DetectionsToRectsCalculator::ComputeRotation( *rotation = NormalizeRadians(target_angle_ - std::atan2(-(y1 - y0), x1 - x0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } DetectionSpec DetectionsToRectsCalculator::GetDetectionSpec( diff --git a/mediapipe/calculators/util/detections_to_rects_calculator.h b/mediapipe/calculators/util/detections_to_rects_calculator.h index 333f9bfdd..e91441bc6 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator.h +++ b/mediapipe/calculators/util/detections_to_rects_calculator.h @@ -83,21 +83,21 @@ struct DetectionSpec { // } class DetectionsToRectsCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; protected: - virtual mediapipe::Status DetectionToRect( - const ::mediapipe::Detection& detection, - const DetectionSpec& detection_spec, ::mediapipe::Rect* rect); - virtual mediapipe::Status DetectionToNormalizedRect( + virtual absl::Status DetectionToRect(const ::mediapipe::Detection& detection, + const DetectionSpec& detection_spec, + ::mediapipe::Rect* rect); + virtual absl::Status DetectionToNormalizedRect( const ::mediapipe::Detection& detection, const DetectionSpec& detection_spec, ::mediapipe::NormalizedRect* rect); - virtual mediapipe::Status ComputeRotation( - const ::mediapipe::Detection& detection, - const DetectionSpec& detection_spec, float* rotation); + virtual absl::Status ComputeRotation(const ::mediapipe::Detection& detection, + const DetectionSpec& detection_spec, + float* rotation); virtual DetectionSpec GetDetectionSpec(const CalculatorContext* cc); static inline float NormalizeRadians(float angle) { diff --git a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc index cebe64153..85c2bd72f 100644 --- a/mediapipe/calculators/util/detections_to_rects_calculator_test.cc +++ b/mediapipe/calculators/util/detections_to_rects_calculator_test.cc @@ -105,7 +105,7 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToRect) { EXPECT_THAT(rect, RectEq(250, 400, 300, 400)); } -mediapipe::StatusOr RunDetectionKeyPointsToRectCalculation( +absl::StatusOr RunDetectionKeyPointsToRectCalculation( Detection detection, std::pair image_size) { CalculatorRunner runner(ParseTextProtoOrDie(R"( calculator: "DetectionsToRectsCalculator" @@ -138,25 +138,25 @@ TEST(DetectionsToRectsCalculatorTest, DetectionKeyPointsToRect) { auto status_or_value = RunDetectionKeyPointsToRectCalculation( /*detection=*/DetectionWithKeyPoints({{0.0f, 0.0f}, {1.0f, 1.0f}}), /*image_size=*/{640, 480}); - EXPECT_THAT(status_or_value.ValueOrDie(), RectEq(320, 240, 640, 480)); + EXPECT_THAT(status_or_value.value(), RectEq(320, 240, 640, 480)); status_or_value = RunDetectionKeyPointsToRectCalculation( /*detection=*/DetectionWithKeyPoints({{0.25f, 0.25f}, {0.75f, 0.75f}}), /*image_size=*/{640, 480}); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), RectEq(320, 240, 320, 240)); + EXPECT_THAT(status_or_value.value(), RectEq(320, 240, 320, 240)); status_or_value = RunDetectionKeyPointsToRectCalculation( /*detection=*/DetectionWithKeyPoints({{0.0f, 0.0f}, {0.5f, 0.5f}}), /*image_size=*/{640, 480}); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), RectEq(160, 120, 320, 240)); + EXPECT_THAT(status_or_value.value(), RectEq(160, 120, 320, 240)); status_or_value = RunDetectionKeyPointsToRectCalculation( /*detection=*/DetectionWithKeyPoints({{0.5f, 0.5f}, {1.0f, 1.0f}}), /*image_size=*/{640, 480}); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), RectEq(480, 360, 320, 240)); + EXPECT_THAT(status_or_value.value(), RectEq(480, 360, 320, 240)); } TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) { @@ -181,7 +181,7 @@ TEST(DetectionsToRectsCalculatorTest, DetectionToNormalizedRect) { EXPECT_THAT(rect, NormRectEq(0.25f, 0.4f, 0.3f, 0.4f)); } -mediapipe::StatusOr RunDetectionKeyPointsToNormRectCalculation( +absl::StatusOr RunDetectionKeyPointsToNormRectCalculation( Detection detection) { CalculatorRunner runner(ParseTextProtoOrDie(R"( calculator: "DetectionsToRectsCalculator" @@ -212,22 +212,22 @@ TEST(DetectionsToRectsCalculatorTest, DetectionKeyPointsToNormalizedRect) { /*detection=*/DetectionWithKeyPoints( {{0.0f, 0.0f}, {0.5f, 0.5f}, {1.0f, 1.0f}})); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), RectEq(0.5f, 0.5f, 1.0f, 1.0f)); + EXPECT_THAT(status_or_value.value(), RectEq(0.5f, 0.5f, 1.0f, 1.0f)); status_or_value = RunDetectionKeyPointsToNormRectCalculation( /*detection=*/DetectionWithKeyPoints( {{0.25f, 0.25f}, {0.75f, 0.25f}, {0.75f, 0.75f}})); - EXPECT_THAT(status_or_value.ValueOrDie(), RectEq(0.5f, 0.5f, 0.5f, 0.5f)); + EXPECT_THAT(status_or_value.value(), RectEq(0.5f, 0.5f, 0.5f, 0.5f)); status_or_value = RunDetectionKeyPointsToNormRectCalculation( /*detection=*/DetectionWithKeyPoints({{0.0f, 0.0f}, {0.5f, 0.5f}})); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), RectEq(0.25f, 0.25f, 0.5f, 0.5f)); + EXPECT_THAT(status_or_value.value(), RectEq(0.25f, 0.25f, 0.5f, 0.5f)); status_or_value = RunDetectionKeyPointsToNormRectCalculation( /*detection=*/DetectionWithKeyPoints({{0.5f, 0.5f}, {1.0f, 1.0f}})); MP_ASSERT_OK(status_or_value); - EXPECT_THAT(status_or_value.ValueOrDie(), RectEq(0.75f, 0.75f, 0.5f, 0.5f)); + EXPECT_THAT(status_or_value.value(), RectEq(0.75f, 0.75f, 0.5f, 0.5f)); } TEST(DetectionsToRectsCalculatorTest, DetectionsToRect) { diff --git a/mediapipe/calculators/util/detections_to_render_data_calculator.cc b/mediapipe/calculators/util/detections_to_render_data_calculator.cc index 94099f603..25d74ba68 100644 --- a/mediapipe/calculators/util/detections_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/detections_to_render_data_calculator.cc @@ -82,11 +82,11 @@ class DetectionsToRenderDataCalculator : public CalculatorBase { DetectionsToRenderDataCalculator& operator=( const DetectionsToRenderDataCalculator&) = delete; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // These utility methods are supposed to be used only by this class. No @@ -122,7 +122,7 @@ class DetectionsToRenderDataCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(DetectionsToRenderDataCalculator); -mediapipe::Status DetectionsToRenderDataCalculator::GetContract( +absl::Status DetectionsToRenderDataCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetectionListTag) || cc->Inputs().HasTag(kDetectionsTag) || @@ -139,18 +139,16 @@ mediapipe::Status DetectionsToRenderDataCalculator::GetContract( cc->Inputs().Tag(kDetectionsTag).Set>(); } cc->Outputs().Tag(kRenderDataTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DetectionsToRenderDataCalculator::Open( - CalculatorContext* cc) { +absl::Status DetectionsToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DetectionsToRenderDataCalculator::Process( - CalculatorContext* cc) { +absl::Status DetectionsToRenderDataCalculator::Process(CalculatorContext* cc) { const auto& options = cc->Options(); const bool has_detection_from_list = cc->Inputs().HasTag(kDetectionListTag) && !cc->Inputs() @@ -165,7 +163,7 @@ mediapipe::Status DetectionsToRenderDataCalculator::Process( !cc->Inputs().Tag(kDetectionTag).IsEmpty(); if (!options.produce_empty_packet() && !has_detection_from_list && !has_detection_from_vector && !has_single_detection) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // TODO: Add score threshold to @@ -191,7 +189,7 @@ mediapipe::Status DetectionsToRenderDataCalculator::Process( cc->Outputs() .Tag(kRenderDataTag) .Add(render_data.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void DetectionsToRenderDataCalculator::SetRenderAnnotationColorThickness( diff --git a/mediapipe/calculators/util/detections_to_timed_box_list_calculator.cc b/mediapipe/calculators/util/detections_to_timed_box_list_calculator.cc index 38907d6e7..4b4742b18 100644 --- a/mediapipe/calculators/util/detections_to_timed_box_list_calculator.cc +++ b/mediapipe/calculators/util/detections_to_timed_box_list_calculator.cc @@ -42,7 +42,7 @@ constexpr char kBoxesTag[] = "BOXES"; // } class DetectionsToTimedBoxListCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kDetectionListTag) || cc->Inputs().HasTag(kDetectionsTag)) << "None of the input streams are provided."; @@ -53,14 +53,14 @@ class DetectionsToTimedBoxListCalculator : public CalculatorBase { cc->Inputs().Tag(kDetectionsTag).Set>(); } cc->Outputs().Tag(kBoxesTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: void ConvertDetectionToTimedBox(const Detection& detection, @@ -68,7 +68,7 @@ class DetectionsToTimedBoxListCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(DetectionsToTimedBoxListCalculator); -mediapipe::Status DetectionsToTimedBoxListCalculator::Process( +absl::Status DetectionsToTimedBoxListCalculator::Process( CalculatorContext* cc) { auto output_timed_box_list = absl::make_unique(); @@ -91,7 +91,7 @@ mediapipe::Status DetectionsToTimedBoxListCalculator::Process( cc->Outputs().Tag(kBoxesTag).Add(output_timed_box_list.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void DetectionsToTimedBoxListCalculator::ConvertDetectionToTimedBox( diff --git a/mediapipe/calculators/util/filter_collection_calculator.h b/mediapipe/calculators/util/filter_collection_calculator.h index f3799dd23..60a6255c9 100644 --- a/mediapipe/calculators/util/filter_collection_calculator.h +++ b/mediapipe/calculators/util/filter_collection_calculator.h @@ -42,7 +42,7 @@ namespace mediapipe { template class FilterCollectionCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("ITERABLE")); RET_CHECK(cc->Inputs().HasTag("CONDITION")); RET_CHECK(cc->Outputs().HasTag("ITERABLE")); @@ -52,20 +52,20 @@ class FilterCollectionCalculator : public CalculatorBase { cc->Outputs().Tag("ITERABLE").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->Inputs().Tag("ITERABLE").IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (cc->Inputs().Tag("CONDITION").IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const std::vector& filter_by = @@ -77,11 +77,11 @@ class FilterCollectionCalculator : public CalculatorBase { } template - mediapipe::Status FilterCollection(std::true_type, CalculatorContext* cc, - const std::vector& filter_by) { + absl::Status FilterCollection(std::true_type, CalculatorContext* cc, + const std::vector& filter_by) { const IterableU& input = cc->Inputs().Tag("ITERABLE").Get(); if (input.size() != filter_by.size()) { - return mediapipe::InternalError(absl::StrCat( + return absl::InternalError(absl::StrCat( "Input vector size: ", input.size(), " doesn't mach condition vector size: ", filter_by.size())); } @@ -93,14 +93,13 @@ class FilterCollectionCalculator : public CalculatorBase { } } cc->Outputs().Tag("ITERABLE").Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } template - mediapipe::Status FilterCollection(std::false_type, CalculatorContext* cc, - const std::vector& filter_by) { - return mediapipe::InternalError( - "Cannot copy input collection to filter it."); + absl::Status FilterCollection(std::false_type, CalculatorContext* cc, + const std::vector& filter_by) { + return absl::InternalError("Cannot copy input collection to filter it."); } }; diff --git a/mediapipe/calculators/util/from_image_calculator.cc b/mediapipe/calculators/util/from_image_calculator.cc new file mode 100644 index 000000000..7484d9257 --- /dev/null +++ b/mediapipe/calculators/util/from_image_calculator.cc @@ -0,0 +1,164 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/vector.h" + +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gl_calculator_helper.h" +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +namespace { +constexpr char kImageFrameTag[] = "IMAGE_CPU"; +constexpr char kGpuBufferTag[] = "IMAGE_GPU"; +constexpr char kImageTag[] = "IMAGE"; +} // namespace + +// A calculator for converting the unified image container into +// legacy MediaPipe datatypes. +// +// Inputs: +// IMAGE: An Image containing input image. +// +// Output: +// One of the following two tags: +// IMAGE_CPU: An ImageFrame containing output image. +// IMAGE_GPU: A GpuBuffer containing output image. +// +// Note: +// Data is automatically transferred to/from the CPU or GPU +// depending on output type. +// +class FromImageCalculator : public CalculatorBase { + public: + FromImageCalculator() = default; + ~FromImageCalculator() override = default; + + static absl::Status GetContract(CalculatorContract* cc); + + // From Calculator. + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); + + bool gpu_output_ = false; + bool gpu_initialized_ = false; +#if !MEDIAPIPE_DISABLE_GPU + mediapipe::GlCalculatorHelper gpu_helper_; +#endif // !MEDIAPIPE_DISABLE_GPU +}; +REGISTER_CALCULATOR(FromImageCalculator); + +absl::Status FromImageCalculator::GetContract(CalculatorContract* cc) { + cc->Inputs().Tag(kImageTag).Set(); + + bool gpu_output = false; + + if (cc->Outputs().HasTag(kImageFrameTag) && + cc->Outputs().HasTag(kGpuBufferTag)) { + return absl::InternalError("Cannot have multiple outputs."); + } + + if (cc->Outputs().HasTag(kGpuBufferTag)) { +#if !MEDIAPIPE_DISABLE_GPU + cc->Outputs().Tag(kGpuBufferTag).Set(); + gpu_output = true; +#else + RET_CHECK_FAIL() << "GPU is disabled. Cannot use IMAGE_GPU stream."; +#endif // !MEDIAPIPE_DISABLE_GPU + } + if (cc->Outputs().HasTag(kImageFrameTag)) { + cc->Outputs().Tag(kImageFrameTag).Set(); + } + + if (gpu_output) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + } + + return absl::OkStatus(); +} + +absl::Status FromImageCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + if (cc->Outputs().HasTag(kGpuBufferTag)) { + gpu_output_ = true; + } + + if (gpu_output_) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif + } // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +absl::Status FromImageCalculator::Process(CalculatorContext* cc) { + if (gpu_output_) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([&cc]() -> absl::Status { + auto& input = cc->Inputs().Tag(kImageTag).Get(); + // Unwrap texture pointer; shallow copy. + auto output = + std::make_unique(input.GetGpuBuffer()); + cc->Outputs() + .Tag(kGpuBufferTag) + .Add(output.release(), cc->InputTimestamp()); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU + } else { + // The input Image. + auto& input = cc->Inputs().Tag(kImageTag).Get(); + // Make a copy of the input packet to co-own the input Image. + Packet* packet_copy_ptr = new Packet(cc->Inputs().Tag(kImageTag).Value()); + // Create an output ImageFrame that points to the same pixel data as the + // input Image and also owns the packet copy. As a result, the output + // ImageFrame indirectly co-owns the input Image. This ensures a correct + // life span of the shared pixel data. + std::unique_ptr output = + std::make_unique( + input.image_format(), input.width(), input.height(), input.step(), + const_cast(input.GetImageFrameSharedPtr()->PixelData()), + [packet_copy_ptr](uint8*) { delete packet_copy_ptr; }); + cc->Outputs() + .Tag(kImageFrameTag) + .Add(output.release(), cc->InputTimestamp()); + } + + return absl::OkStatus(); +} + +absl::Status FromImageCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/labels_to_render_data_calculator.cc b/mediapipe/calculators/util/labels_to_render_data_calculator.cc index 2e63d29f7..cf448cff1 100644 --- a/mediapipe/calculators/util/labels_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/labels_to_render_data_calculator.cc @@ -59,9 +59,9 @@ constexpr float kFontHeightScale = 1.25f; // } class LabelsToRenderDataCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: LabelsToRenderDataCalculatorOptions options_; @@ -73,8 +73,7 @@ class LabelsToRenderDataCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(LabelsToRenderDataCalculator); -mediapipe::Status LabelsToRenderDataCalculator::GetContract( - CalculatorContract* cc) { +absl::Status LabelsToRenderDataCalculator::GetContract(CalculatorContract* cc) { if (cc->Inputs().HasTag("CLASSIFICATIONS")) { cc->Inputs().Tag("CLASSIFICATIONS").Set(); } else { @@ -89,25 +88,25 @@ mediapipe::Status LabelsToRenderDataCalculator::GetContract( cc->Inputs().Tag("VIDEO_PRESTREAM").Set(); } cc->Outputs().Tag("RENDER_DATA").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) { +absl::Status LabelsToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); num_colors_ = options_.color_size(); label_height_px_ = std::ceil(options_.font_height_px() * kFontHeightScale); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { +absl::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag("VIDEO_PRESTREAM") && cc->InputTimestamp() == Timestamp::PreStream()) { const VideoHeader& video_header = cc->Inputs().Tag("VIDEO_PRESTREAM").Get(); video_width_ = video_header.width; video_height_ = video_header.height; - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { CHECK_EQ(options_.location(), LabelsToRenderDataCalculatorOptions::TOP_LEFT) << "Only TOP_LEFT is supported without VIDEO_PRESTREAM."; @@ -179,6 +178,6 @@ mediapipe::Status LabelsToRenderDataCalculator::Process(CalculatorContext* cc) { .Tag("RENDER_DATA") .AddPacket(MakePacket(render_data).At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc b/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc index d07f76e1e..d3c7a6453 100644 --- a/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc +++ b/mediapipe/calculators/util/landmark_letterbox_removal_calculator.cc @@ -64,7 +64,7 @@ constexpr char kLetterboxPaddingTag[] = "LETTERBOX_PADDING"; // } class LandmarkLetterboxRemovalCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) && cc->Inputs().HasTag(kLetterboxPaddingTag)) << "Missing one or more input streams."; @@ -84,18 +84,18 @@ class LandmarkLetterboxRemovalCalculator : public CalculatorBase { cc->Outputs().Get(id).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->Inputs().Tag(kLetterboxPaddingTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& letterbox_padding = cc->Inputs().Tag(kLetterboxPaddingTag).Get>(); @@ -134,7 +134,7 @@ class LandmarkLetterboxRemovalCalculator : public CalculatorBase { MakePacket(output_landmarks) .At(cc->InputTimestamp())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(LandmarkLetterboxRemovalCalculator); diff --git a/mediapipe/calculators/util/landmark_projection_calculator.cc b/mediapipe/calculators/util/landmark_projection_calculator.cc index aaf28c02a..59b7c020c 100644 --- a/mediapipe/calculators/util/landmark_projection_calculator.cc +++ b/mediapipe/calculators/util/landmark_projection_calculator.cc @@ -60,7 +60,7 @@ constexpr char kRectTag[] = "NORM_RECT"; // } class LandmarkProjectionCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) && cc->Inputs().HasTag(kRectTag)) << "Missing one or more input streams."; @@ -80,18 +80,18 @@ class LandmarkProjectionCalculator : public CalculatorBase { cc->Outputs().Get(id).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->Inputs().Tag(kRectTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_rect = cc->Inputs().Tag(kRectTag).Get(); @@ -136,7 +136,7 @@ class LandmarkProjectionCalculator : public CalculatorBase { MakePacket(output_landmarks) .At(cc->InputTimestamp())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(LandmarkProjectionCalculator); diff --git a/mediapipe/calculators/util/landmark_visibility_calculator.cc b/mediapipe/calculators/util/landmark_visibility_calculator.cc index e2239a5ee..f22d2ac57 100644 --- a/mediapipe/calculators/util/landmark_visibility_calculator.cc +++ b/mediapipe/calculators/util/landmark_visibility_calculator.cc @@ -44,31 +44,30 @@ constexpr char kVisibilityTag[] = "VISIBILITY"; // class LandmarkVisibilityCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(LandmarkVisibilityCalculator); -mediapipe::Status LandmarkVisibilityCalculator::GetContract( - CalculatorContract* cc) { +absl::Status LandmarkVisibilityCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); cc->Outputs().Tag(kVisibilityTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LandmarkVisibilityCalculator::Open(CalculatorContext* cc) { +absl::Status LandmarkVisibilityCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LandmarkVisibilityCalculator::Process(CalculatorContext* cc) { +absl::Status LandmarkVisibilityCalculator::Process(CalculatorContext* cc) { // Check that landmark is not empty. // Don't emit an empty packet for this timestamp. if (cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& landmarks = @@ -80,7 +79,7 @@ mediapipe::Status LandmarkVisibilityCalculator::Process(CalculatorContext* cc) { .Tag(kVisibilityTag) .AddPacket(MakePacket(visibility).At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc index f08dfe026..4f1d4a608 100644 --- a/mediapipe/calculators/util/landmarks_smoothing_calculator.cc +++ b/mediapipe/calculators/util/landmarks_smoothing_calculator.cc @@ -61,23 +61,23 @@ class LandmarksFilter { public: virtual ~LandmarksFilter() = default; - virtual mediapipe::Status Reset() { return mediapipe::OkStatus(); } + virtual absl::Status Reset() { return absl::OkStatus(); } - virtual mediapipe::Status Apply(const NormalizedLandmarkList& in_landmarks, - const std::pair& image_size, - const absl::Duration& timestamp, - NormalizedLandmarkList* out_landmarks) = 0; + virtual absl::Status Apply(const NormalizedLandmarkList& in_landmarks, + const std::pair& image_size, + const absl::Duration& timestamp, + NormalizedLandmarkList* out_landmarks) = 0; }; // Returns landmarks as is without smoothing. class NoFilter : public LandmarksFilter { public: - mediapipe::Status Apply(const NormalizedLandmarkList& in_landmarks, - const std::pair& image_size, - const absl::Duration& timestamp, - NormalizedLandmarkList* out_landmarks) override { + absl::Status Apply(const NormalizedLandmarkList& in_landmarks, + const std::pair& image_size, + const absl::Duration& timestamp, + NormalizedLandmarkList* out_landmarks) override { *out_landmarks = in_landmarks; - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; @@ -90,17 +90,17 @@ class VelocityFilter : public LandmarksFilter { velocity_scale_(velocity_scale), min_allowed_object_scale_(min_allowed_object_scale) {} - mediapipe::Status Reset() override { + absl::Status Reset() override { x_filters_.clear(); y_filters_.clear(); z_filters_.clear(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Apply(const NormalizedLandmarkList& in_landmarks, - const std::pair& image_size, - const absl::Duration& timestamp, - NormalizedLandmarkList* out_landmarks) override { + absl::Status Apply(const NormalizedLandmarkList& in_landmarks, + const std::pair& image_size, + const absl::Duration& timestamp, + NormalizedLandmarkList* out_landmarks) override { // Get image size. int image_width; int image_height; @@ -113,7 +113,7 @@ class VelocityFilter : public LandmarksFilter { GetObjectScale(in_landmarks, image_width, image_height); if (object_scale < min_allowed_object_scale_) { *out_landmarks = in_landmarks; - return mediapipe::OkStatus(); + return absl::OkStatus(); } const float value_scale = 1.0f / object_scale; @@ -138,18 +138,18 @@ class VelocityFilter : public LandmarksFilter { image_width); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: // Initializes filters for the first time or after Reset. If initialized then // check the size. - mediapipe::Status InitializeFiltersIfEmpty(const int n_landmarks) { + absl::Status InitializeFiltersIfEmpty(const int n_landmarks) { if (!x_filters_.empty()) { RET_CHECK_EQ(x_filters_.size(), n_landmarks); RET_CHECK_EQ(y_filters_.size(), n_landmarks); RET_CHECK_EQ(z_filters_.size(), n_landmarks); - return mediapipe::OkStatus(); + return absl::OkStatus(); } x_filters_.resize(n_landmarks, @@ -159,7 +159,7 @@ class VelocityFilter : public LandmarksFilter { z_filters_.resize(n_landmarks, RelativeVelocityFilter(window_size_, velocity_scale_)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } int window_size_; @@ -202,27 +202,26 @@ class VelocityFilter : public LandmarksFilter { // class LandmarksSmoothingCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: LandmarksFilter* landmarks_filter_; }; REGISTER_CALCULATOR(LandmarksSmoothingCalculator); -mediapipe::Status LandmarksSmoothingCalculator::GetContract( - CalculatorContract* cc) { +absl::Status LandmarksSmoothingCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); cc->Inputs().Tag(kImageSizeTag).Set>(); cc->Outputs() .Tag(kNormalizedFilteredLandmarksTag) .Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) { +absl::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); // Pick landmarks filter. @@ -239,15 +238,15 @@ mediapipe::Status LandmarksSmoothingCalculator::Open(CalculatorContext* cc) { << "Landmarks filter is either not specified or not supported"; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LandmarksSmoothingCalculator::Process(CalculatorContext* cc) { +absl::Status LandmarksSmoothingCalculator::Process(CalculatorContext* cc) { // Check that landmarks are not empty and reset the filter if so. // Don't emit an empty packet for this timestamp. if (cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) { MP_RETURN_IF_ERROR(landmarks_filter_->Reset()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& in_landmarks = @@ -265,7 +264,7 @@ mediapipe::Status LandmarksSmoothingCalculator::Process(CalculatorContext* cc) { .Tag(kNormalizedFilteredLandmarksTag) .Add(out_landmarks.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_to_detection_calculator.cc b/mediapipe/calculators/util/landmarks_to_detection_calculator.cc index 5f9f81061..ffa359877 100644 --- a/mediapipe/calculators/util/landmarks_to_detection_calculator.cc +++ b/mediapipe/calculators/util/landmarks_to_detection_calculator.cc @@ -80,17 +80,17 @@ Detection ConvertLandmarksToDetection(const NormalizedLandmarkList& landmarks) { // } class LandmarksToDetectionCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: ::mediapipe::LandmarksToDetectionCalculatorOptions options_; }; REGISTER_CALCULATOR(LandmarksToDetectionCalculator); -mediapipe::Status LandmarksToDetectionCalculator::GetContract( +absl::Status LandmarksToDetectionCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kNormalizedLandmarksTag)); RET_CHECK(cc->Outputs().HasTag(kDetectionTag)); @@ -98,18 +98,17 @@ mediapipe::Status LandmarksToDetectionCalculator::GetContract( cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); cc->Outputs().Tag(kDetectionTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LandmarksToDetectionCalculator::Open(CalculatorContext* cc) { +absl::Status LandmarksToDetectionCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options<::mediapipe::LandmarksToDetectionCalculatorOptions>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LandmarksToDetectionCalculator::Process( - CalculatorContext* cc) { +absl::Status LandmarksToDetectionCalculator::Process(CalculatorContext* cc) { const auto& landmarks = cc->Inputs().Tag(kNormalizedLandmarksTag).Get(); RET_CHECK_GT(landmarks.landmark_size(), 0) @@ -133,7 +132,7 @@ mediapipe::Status LandmarksToDetectionCalculator::Process( .Tag(kDetectionTag) .Add(detection.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/landmarks_to_floats_calculator.cc b/mediapipe/calculators/util/landmarks_to_floats_calculator.cc index edfbc93f1..fe8dd3ab1 100644 --- a/mediapipe/calculators/util/landmarks_to_floats_calculator.cc +++ b/mediapipe/calculators/util/landmarks_to_floats_calculator.cc @@ -62,7 +62,7 @@ constexpr char kMatrixTag[] = "MATRIX"; // } class LandmarksToFloatsCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kLandmarksTag).Set(); RET_CHECK(cc->Outputs().HasTag(kFloatsTag) || cc->Outputs().HasTag(kMatrixTag)); @@ -73,10 +73,10 @@ class LandmarksToFloatsCalculator : public CalculatorBase { cc->Outputs().Tag(kMatrixTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); const auto& options = cc->Options<::mediapipe::LandmarksToFloatsCalculatorOptions>(); @@ -84,13 +84,13 @@ class LandmarksToFloatsCalculator : public CalculatorBase { // Currently number of dimensions must be within [1, 3]. RET_CHECK_GE(num_dimensions_, 1); RET_CHECK_LE(num_dimensions_, 3); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { // Only process if there's input landmarks. if (cc->Inputs().Tag(kLandmarksTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input_landmarks = @@ -128,7 +128,7 @@ class LandmarksToFloatsCalculator : public CalculatorBase { .Tag(kMatrixTag) .Add(output_matrix.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc index f0fd165fc..7818ad8cd 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.cc @@ -177,7 +177,7 @@ RenderAnnotation* AddPointRenderData(const Color& landmark_color, } // namespace -mediapipe::Status LandmarksToRenderDataCalculator::GetContract( +absl::Status LandmarksToRenderDataCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kLandmarksTag) || cc->Inputs().HasTag(kNormLandmarksTag)) @@ -197,10 +197,10 @@ mediapipe::Status LandmarksToRenderDataCalculator::GetContract( cc->Inputs().Tag(kRenderScaleTag).Set(); } cc->Outputs().Tag(kRenderDataTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LandmarksToRenderDataCalculator::Open(CalculatorContext* cc) { +absl::Status LandmarksToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); @@ -212,20 +212,19 @@ mediapipe::Status LandmarksToRenderDataCalculator::Open(CalculatorContext* cc) { landmark_connections_.push_back(options_.landmark_connections(i)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LandmarksToRenderDataCalculator::Process( - CalculatorContext* cc) { +absl::Status LandmarksToRenderDataCalculator::Process(CalculatorContext* cc) { // Check that landmarks are not empty and skip rendering if so. // Don't emit an empty packet for this timestamp. if (cc->Inputs().HasTag(kLandmarksTag) && cc->Inputs().Tag(kLandmarksTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (cc->Inputs().HasTag(kNormLandmarksTag) && cc->Inputs().Tag(kNormLandmarksTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } auto render_data = absl::make_unique(); @@ -341,7 +340,7 @@ mediapipe::Status LandmarksToRenderDataCalculator::Process( cc->Outputs() .Tag(kRenderDataTag) .Add(render_data.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(LandmarksToRenderDataCalculator); diff --git a/mediapipe/calculators/util/landmarks_to_render_data_calculator.h b/mediapipe/calculators/util/landmarks_to_render_data_calculator.h index ce31ef9c7..0fbe9700c 100644 --- a/mediapipe/calculators/util/landmarks_to_render_data_calculator.h +++ b/mediapipe/calculators/util/landmarks_to_render_data_calculator.h @@ -54,11 +54,11 @@ class LandmarksToRenderDataCalculator : public CalculatorBase { LandmarksToRenderDataCalculator& operator=( const LandmarksToRenderDataCalculator&) = delete; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; protected: ::mediapipe::LandmarksToRenderDataCalculatorOptions options_; diff --git a/mediapipe/calculators/util/local_file_contents_calculator.cc b/mediapipe/calculators/util/local_file_contents_calculator.cc index b9ec9e496..4ad066f69 100644 --- a/mediapipe/calculators/util/local_file_contents_calculator.cc +++ b/mediapipe/calculators/util/local_file_contents_calculator.cc @@ -53,7 +53,7 @@ constexpr char kContentsTag[] = "CONTENTS"; // } class LocalFileContentsCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->InputSidePackets().HasTag(kFilePathTag)) << "Missing PATH input side packet(s)"; RET_CHECK(cc->OutputSidePackets().HasTag(kContentsTag)) @@ -73,10 +73,10 @@ class LocalFileContentsCalculator : public CalculatorBase { cc->OutputSidePackets().Get(id).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { CollectionItemId input_id = cc->InputSidePackets().BeginId(kFilePathTag); CollectionItemId output_id = cc->OutputSidePackets().BeginId(kContentsTag); auto options = cc->Options(); @@ -89,16 +89,16 @@ class LocalFileContentsCalculator : public CalculatorBase { ASSIGN_OR_RETURN(file_path, PathToResourceAsFile(file_path)); std::string contents; - MP_RETURN_IF_ERROR( - GetResourceContents(file_path, &contents, options.read_as_binary())); + MP_RETURN_IF_ERROR(GetResourceContents( + file_path, &contents, /*read_as_binary=*/!options.text_mode())); cc->OutputSidePackets().Get(output_id).Set( MakePacket(std::move(contents))); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; diff --git a/mediapipe/calculators/util/local_file_contents_calculator.proto b/mediapipe/calculators/util/local_file_contents_calculator.proto index ca700fc58..17876c89f 100644 --- a/mediapipe/calculators/util/local_file_contents_calculator.proto +++ b/mediapipe/calculators/util/local_file_contents_calculator.proto @@ -23,6 +23,6 @@ message LocalFileContentsCalculatorOptions { optional LocalFileContentsCalculatorOptions ext = 346849340; } - // If true, set the file open mode to 'rb'. Otherwise, set the mode to 'r'. - optional bool read_as_binary = 1 [default = true]; + // By default, set the file open mode to 'rb'. Otherwise, set the mode to 'r'. + optional bool text_mode = 1; } diff --git a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc index 8bfb49af2..fcba83a49 100644 --- a/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc +++ b/mediapipe/calculators/util/local_file_pattern_contents_calculator.cc @@ -34,22 +34,22 @@ namespace mediapipe { // } class LocalFilePatternContentsCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag("FILE_DIRECTORY").Set(); cc->InputSidePackets().Tag("FILE_SUFFIX").Set(); cc->Outputs().Tag("CONTENTS").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { MP_RETURN_IF_ERROR(mediapipe::file::MatchFileTypeInDirectory( cc->InputSidePackets().Tag("FILE_DIRECTORY").Get(), cc->InputSidePackets().Tag("FILE_SUFFIX").Get(), &filenames_)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (current_output_ < filenames_.size()) { auto contents = absl::make_unique(); LOG(INFO) << filenames_[current_output_]; @@ -62,7 +62,7 @@ class LocalFilePatternContentsCalculator : public CalculatorBase { } else { return tool::StatusStop(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/util/logic_calculator.cc b/mediapipe/calculators/util/logic_calculator.cc index 3b6a3e6c8..d9bb9281a 100644 --- a/mediapipe/calculators/util/logic_calculator.cc +++ b/mediapipe/calculators/util/logic_calculator.cc @@ -45,7 +45,7 @@ using mediapipe::LogicCalculatorOptions; // } class LogicCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int k = 0; k < cc->Inputs().NumEntries(""); ++k) { cc->Inputs().Index(k).Set(); } @@ -58,13 +58,13 @@ class LogicCalculator : public CalculatorBase { 1); RET_CHECK_EQ(cc->Outputs().NumEntries(""), 1); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { options_ = cc->Options(); cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } bool LogicalOp(bool b1, bool b2) { @@ -79,7 +79,7 @@ class LogicCalculator : public CalculatorBase { return false; } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { bool result = options_.op() == LogicCalculatorOptions::AND ? true : false; for (int k = 0; k < options_.input_value_size(); ++k) { result = LogicalOp(result, options_.input_value(k)); @@ -94,7 +94,7 @@ class LogicCalculator : public CalculatorBase { result = !result; } cc->Outputs().Index(0).Add(new bool(result), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/util/non_max_suppression_calculator.cc b/mediapipe/calculators/util/non_max_suppression_calculator.cc index 048ba33fe..535e2a719 100644 --- a/mediapipe/calculators/util/non_max_suppression_calculator.cc +++ b/mediapipe/calculators/util/non_max_suppression_calculator.cc @@ -52,6 +52,7 @@ bool RetainMaxScoringLabelOnly(Detection* detection) { << "Number of scores must be equal to number of detections."; std::vector> indexed_scores; + indexed_scores.reserve(detection->score_size()); for (int k = 0; k < detection->score_size(); ++k) { indexed_scores.push_back(std::make_pair(k, detection->score(k))); } @@ -154,7 +155,7 @@ class NonMaxSuppressionCalculator : public CalculatorBase { NonMaxSuppressionCalculator() = default; ~NonMaxSuppressionCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { const auto& options = cc->Options(); if (cc->Inputs().HasTag(kImageTag)) { cc->Inputs().Tag(kImageTag).Set(); @@ -163,10 +164,10 @@ class NonMaxSuppressionCalculator : public CalculatorBase { cc->Inputs().Index(k).Set(); } cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); @@ -176,10 +177,10 @@ class NonMaxSuppressionCalculator : public CalculatorBase { << "max_num_detections=0 is not a valid value. Please choose a " << "positive number of you want to limit the number of output " << "detections, or set -1 if you do not want any limit."; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { // Add all input detections to the same vector. Detections input_detections; for (int i = 0; i < options_.num_detection_streams(); ++i) { @@ -199,7 +200,7 @@ class NonMaxSuppressionCalculator : public CalculatorBase { if (options_.return_empty_detections()) { cc->Outputs().Index(0).Add(new Detections(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Remove all but the maximum scoring label from each input detection. This @@ -244,7 +245,7 @@ class NonMaxSuppressionCalculator : public CalculatorBase { cc->Outputs().Index(0).Add(retained_detections, cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/util/packet_frequency_calculator.cc b/mediapipe/calculators/util/packet_frequency_calculator.cc index cc407efd8..19ffae70e 100644 --- a/mediapipe/calculators/util/packet_frequency_calculator.cc +++ b/mediapipe/calculators/util/packet_frequency_calculator.cc @@ -70,26 +70,25 @@ class PacketFrequencyCalculator : public CalculatorBase { public: PacketFrequencyCalculator() {} - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Outputs the given framerate on the specified output stream as a // PacketFrequency proto. - mediapipe::Status OutputPacketFrequency(CalculatorContext* cc, int stream_id, - double framerate_hz, - const std::string& label, - const Timestamp& input_timestamp); + absl::Status OutputPacketFrequency(CalculatorContext* cc, int stream_id, + double framerate_hz, + const std::string& label, + const Timestamp& input_timestamp); // Adds the input timestamp in the particular stream's timestamp buffer. - mediapipe::Status AddPacketTimestampForStream(int stream_id, int64 timestamp); + absl::Status AddPacketTimestampForStream(int stream_id, int64 timestamp); // For the specified input stream, clears timestamps from buffer that are // older than the configured time_window_sec. - mediapipe::Status ClearOldpacketTimestamps(int stream_id, - int64 current_timestamp); + absl::Status ClearOldpacketTimestamps(int stream_id, int64 current_timestamp); // Options for the calculator. PacketFrequencyCalculatorOptions options_; @@ -105,17 +104,16 @@ class PacketFrequencyCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(PacketFrequencyCalculator); -mediapipe::Status PacketFrequencyCalculator::GetContract( - CalculatorContract* cc) { +absl::Status PacketFrequencyCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Outputs().NumEntries(), cc->Inputs().NumEntries()); for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketFrequencyCalculator::Open(CalculatorContext* cc) { +absl::Status PacketFrequencyCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); RET_CHECK_EQ(options_.label_size(), cc->Inputs().NumEntries()); RET_CHECK_GT(options_.time_window_sec(), 0); @@ -127,10 +125,10 @@ mediapipe::Status PacketFrequencyCalculator::Open(CalculatorContext* cc) { previous_timestamps_for_stream_id_[i] = {}; first_timestamp_for_stream_id_usec_[i] = -1; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketFrequencyCalculator::Process(CalculatorContext* cc) { +absl::Status PacketFrequencyCalculator::Process(CalculatorContext* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { if (cc->Inputs().Index(i).IsEmpty()) { continue; @@ -164,26 +162,26 @@ mediapipe::Status PacketFrequencyCalculator::Process(CalculatorContext* cc) { options_.label(i), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketFrequencyCalculator::AddPacketTimestampForStream( +absl::Status PacketFrequencyCalculator::AddPacketTimestampForStream( int stream_id, int64 timestamp_usec) { if (previous_timestamps_for_stream_id_.find(stream_id) == previous_timestamps_for_stream_id_.end()) { - return mediapipe::InvalidArgumentError("Input stream id is invalid"); + return absl::InvalidArgumentError("Input stream id is invalid"); } previous_timestamps_for_stream_id_[stream_id].push_back(timestamp_usec); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketFrequencyCalculator::ClearOldpacketTimestamps( +absl::Status PacketFrequencyCalculator::ClearOldpacketTimestamps( int stream_id, int64 current_timestamp_usec) { if (previous_timestamps_for_stream_id_.find(stream_id) == previous_timestamps_for_stream_id_.end()) { - return mediapipe::InvalidArgumentError("Input stream id is invalid"); + return absl::InvalidArgumentError("Input stream id is invalid"); } auto& timestamps_buffer = previous_timestamps_for_stream_id_[stream_id]; @@ -198,10 +196,10 @@ mediapipe::Status PacketFrequencyCalculator::ClearOldpacketTimestamps( }), timestamps_buffer.end()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketFrequencyCalculator::OutputPacketFrequency( +absl::Status PacketFrequencyCalculator::OutputPacketFrequency( CalculatorContext* cc, int stream_id, double framerate_hz, const std::string& label, const Timestamp& input_timestamp) { auto packet_frequency = absl::make_unique(); @@ -211,7 +209,7 @@ mediapipe::Status PacketFrequencyCalculator::OutputPacketFrequency( cc->Outputs().Index(stream_id).Add(packet_frequency.release(), input_timestamp); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/packet_latency_calculator.cc b/mediapipe/calculators/util/packet_latency_calculator.cc index 19ff6fc24..35e415505 100644 --- a/mediapipe/calculators/util/packet_latency_calculator.cc +++ b/mediapipe/calculators/util/packet_latency_calculator.cc @@ -101,10 +101,10 @@ class PacketLatencyCalculator : public CalculatorBase { public: PacketLatencyCalculator() {} - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Resets the histogram and running average variables by initializing them to @@ -139,7 +139,7 @@ class PacketLatencyCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(PacketLatencyCalculator); -mediapipe::Status PacketLatencyCalculator::GetContract(CalculatorContract* cc) { +absl::Status PacketLatencyCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_GT(cc->Inputs().NumEntries(), 1); // Input and output streams. @@ -160,7 +160,7 @@ mediapipe::Status PacketLatencyCalculator::GetContract(CalculatorContract* cc) { .Set>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void PacketLatencyCalculator::ResetStatistics() { @@ -177,7 +177,7 @@ void PacketLatencyCalculator::ResetStatistics() { } } -mediapipe::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { +absl::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); num_packet_streams_ = cc->Inputs().NumEntries() - 1; @@ -224,10 +224,10 @@ mediapipe::Status PacketLatencyCalculator::Open(CalculatorContext* cc) { ::mediapipe::MonotonicClock::CreateSynchronizedMonotonicClock()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { +absl::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { // Record first process timestamp if this is the first call. if (first_process_time_usec_ < 0 && !cc->Inputs().Tag(kReferenceSignalTag).IsEmpty()) { @@ -238,7 +238,7 @@ mediapipe::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { if (first_process_time_usec_ < 0) { LOG(WARNING) << "No reference packet received."; - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (options_.reset_duration_usec() > 0) { @@ -292,7 +292,7 @@ mediapipe::Status PacketLatencyCalculator::Process(CalculatorContext* cc) { } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/rect_projection_calculator.cc b/mediapipe/calculators/util/rect_projection_calculator.cc index 364f26629..dcc6e7391 100644 --- a/mediapipe/calculators/util/rect_projection_calculator.cc +++ b/mediapipe/calculators/util/rect_projection_calculator.cc @@ -47,29 +47,28 @@ constexpr char kNormReferenceRectTag[] = "NORM_REFERENCE_RECT"; // class RectProjectionCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(RectProjectionCalculator); -mediapipe::Status RectProjectionCalculator::GetContract( - CalculatorContract* cc) { +absl::Status RectProjectionCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kNormRectTag).Set(); cc->Inputs().Tag(kNormReferenceRectTag).Set(); cc->Outputs().Tag(kNormRectTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RectProjectionCalculator::Open(CalculatorContext* cc) { +absl::Status RectProjectionCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RectProjectionCalculator::Process(CalculatorContext* cc) { +absl::Status RectProjectionCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().Tag(kNormRectTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& rect = cc->Inputs().Tag(kNormRectTag).Get(); @@ -101,7 +100,7 @@ mediapipe::Status RectProjectionCalculator::Process(CalculatorContext* cc) { cc->Outputs().Tag(kNormRectTag).Add(new_rect.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/rect_to_render_data_calculator.cc b/mediapipe/calculators/util/rect_to_render_data_calculator.cc index 4b85e232a..3b395818f 100644 --- a/mediapipe/calculators/util/rect_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_data_calculator.cc @@ -94,19 +94,18 @@ void SetRect(bool normalized, double xmin, double ymin, double width, // } class RectToRenderDataCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: RectToRenderDataCalculatorOptions options_; }; REGISTER_CALCULATOR(RectToRenderDataCalculator); -mediapipe::Status RectToRenderDataCalculator::GetContract( - CalculatorContract* cc) { +absl::Status RectToRenderDataCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ((cc->Inputs().HasTag(kNormRectTag) ? 1 : 0) + (cc->Inputs().HasTag(kRectTag) ? 1 : 0) + (cc->Inputs().HasTag(kNormRectsTag) ? 1 : 0) + @@ -130,18 +129,18 @@ mediapipe::Status RectToRenderDataCalculator::GetContract( } cc->Outputs().Tag(kRenderDataTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RectToRenderDataCalculator::Open(CalculatorContext* cc) { +absl::Status RectToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RectToRenderDataCalculator::Process(CalculatorContext* cc) { +absl::Status RectToRenderDataCalculator::Process(CalculatorContext* cc) { auto render_data = absl::make_unique(); if (cc->Inputs().HasTag(kNormRectTag) && @@ -185,7 +184,7 @@ mediapipe::Status RectToRenderDataCalculator::Process(CalculatorContext* cc) { .Tag(kRenderDataTag) .Add(render_data.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc index fdc209359..79a740315 100644 --- a/mediapipe/calculators/util/rect_to_render_scale_calculator.cc +++ b/mediapipe/calculators/util/rect_to_render_scale_calculator.cc @@ -51,38 +51,37 @@ constexpr char kRenderScaleTag[] = "RENDER_SCALE"; // } class RectToRenderScaleCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: RectToRenderScaleCalculatorOptions options_; }; REGISTER_CALCULATOR(RectToRenderScaleCalculator); -mediapipe::Status RectToRenderScaleCalculator::GetContract( - CalculatorContract* cc) { +absl::Status RectToRenderScaleCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kNormRectTag).Set(); cc->Inputs().Tag(kImageSizeTag).Set>(); cc->Outputs().Tag(kRenderScaleTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RectToRenderScaleCalculator::Open(CalculatorContext* cc) { +absl::Status RectToRenderScaleCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RectToRenderScaleCalculator::Process(CalculatorContext* cc) { +absl::Status RectToRenderScaleCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().Tag(kNormRectTag).IsEmpty()) { cc->Outputs() .Tag(kRenderScaleTag) .AddPacket( MakePacket(options_.multiplier()).At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Get image size. @@ -104,7 +103,7 @@ mediapipe::Status RectToRenderScaleCalculator::Process(CalculatorContext* cc) { .Tag(kRenderScaleTag) .AddPacket(MakePacket(render_scale).At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/rect_transformation_calculator.cc b/mediapipe/calculators/util/rect_transformation_calculator.cc index 5132ca8b3..7c71dd5a1 100644 --- a/mediapipe/calculators/util/rect_transformation_calculator.cc +++ b/mediapipe/calculators/util/rect_transformation_calculator.cc @@ -57,10 +57,10 @@ inline float NormalizeRadians(float angle) { // } class RectTransformationCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: RectTransformationCalculatorOptions options_; @@ -72,8 +72,7 @@ class RectTransformationCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(RectTransformationCalculator); -mediapipe::Status RectTransformationCalculator::GetContract( - CalculatorContract* cc) { +absl::Status RectTransformationCalculator::GetContract(CalculatorContract* cc) { RET_CHECK_EQ((cc->Inputs().HasTag(kNormRectTag) ? 1 : 0) + (cc->Inputs().HasTag(kNormRectsTag) ? 1 : 0) + (cc->Inputs().HasTag(kRectTag) ? 1 : 0) + @@ -100,20 +99,20 @@ mediapipe::Status RectTransformationCalculator::GetContract( cc->Outputs().Index(0).Set>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RectTransformationCalculator::Open(CalculatorContext* cc) { +absl::Status RectTransformationCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); RET_CHECK(!(options_.has_rotation() && options_.has_rotation_degrees())); RET_CHECK(!(options_.has_square_long() && options_.has_square_short())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RectTransformationCalculator::Process(CalculatorContext* cc) { +absl::Status RectTransformationCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kRectTag) && !cc->Inputs().Tag(kRectTag).IsEmpty()) { auto rect = cc->Inputs().Tag(kRectTag).Get(); TransformRect(&rect); @@ -156,7 +155,7 @@ mediapipe::Status RectTransformationCalculator::Process(CalculatorContext* cc) { cc->Outputs().Index(0).Add(output_rects.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } float RectTransformationCalculator::ComputeNewRotation(float rotation) { diff --git a/mediapipe/calculators/util/set_landmark_visibility_calculator.cc b/mediapipe/calculators/util/set_landmark_visibility_calculator.cc index 90ce06bca..233c3a0cb 100644 --- a/mediapipe/calculators/util/set_landmark_visibility_calculator.cc +++ b/mediapipe/calculators/util/set_landmark_visibility_calculator.cc @@ -50,34 +50,33 @@ constexpr char kVisibilityTag[] = "VISIBILITY"; // class SetLandmarkVisibilityCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(SetLandmarkVisibilityCalculator); -mediapipe::Status SetLandmarkVisibilityCalculator::GetContract( +absl::Status SetLandmarkVisibilityCalculator::GetContract( CalculatorContract* cc) { cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); cc->Inputs().Tag(kVisibilityTag).Set(); cc->Outputs().Tag(kNormalizedLandmarksTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SetLandmarkVisibilityCalculator::Open(CalculatorContext* cc) { +absl::Status SetLandmarkVisibilityCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SetLandmarkVisibilityCalculator::Process( - CalculatorContext* cc) { +absl::Status SetLandmarkVisibilityCalculator::Process(CalculatorContext* cc) { // Check that landmark and visibility are not empty. // Don't emit an empty packet for this timestamp. if (cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty() || cc->Inputs().Tag(kVisibilityTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& in_landmarks = @@ -97,7 +96,7 @@ mediapipe::Status SetLandmarkVisibilityCalculator::Process( .Tag(kNormalizedLandmarksTag) .Add(out_landmarks.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/thresholding_calculator.cc b/mediapipe/calculators/util/thresholding_calculator.cc index 4ee5fc4b6..65876c075 100644 --- a/mediapipe/calculators/util/thresholding_calculator.cc +++ b/mediapipe/calculators/util/thresholding_calculator.cc @@ -50,17 +50,17 @@ namespace mediapipe { // } class ThresholdingCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: double threshold_{}; }; REGISTER_CALCULATOR(ThresholdingCalculator); -mediapipe::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) { +absl::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("FLOAT")); cc->Inputs().Tag("FLOAT").Set(); @@ -83,10 +83,10 @@ mediapipe::Status ThresholdingCalculator::GetContract(CalculatorContract* cc) { "supported."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ThresholdingCalculator::Open(CalculatorContext* cc) { +absl::Status ThresholdingCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -103,10 +103,10 @@ mediapipe::Status ThresholdingCalculator::Open(CalculatorContext* cc) { if (cc->InputSidePackets().HasTag("THRESHOLD")) { threshold_ = cc->InputSidePackets().Tag("THRESHOLD").Get(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ThresholdingCalculator::Process(CalculatorContext* cc) { +absl::Status ThresholdingCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag("THRESHOLD") && !cc->Inputs().Tag("THRESHOLD").IsEmpty()) { threshold_ = cc->Inputs().Tag("THRESHOLD").Get(); @@ -131,6 +131,6 @@ mediapipe::Status ThresholdingCalculator::Process(CalculatorContext* cc) { MakePacket(false).At(cc->InputTimestamp())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc b/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc index 391a83c67..790b426de 100644 --- a/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc +++ b/mediapipe/calculators/util/timed_box_list_id_to_label_calculator.cc @@ -48,25 +48,25 @@ using mediapipe::TimedBoxProtoList; // } class TimedBoxListIdToLabelCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: absl::node_hash_map label_map_; }; REGISTER_CALCULATOR(TimedBoxListIdToLabelCalculator); -mediapipe::Status TimedBoxListIdToLabelCalculator::GetContract( +absl::Status TimedBoxListIdToLabelCalculator::GetContract( CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TimedBoxListIdToLabelCalculator::Open(CalculatorContext* cc) { +absl::Status TimedBoxListIdToLabelCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); const auto& options = @@ -83,11 +83,10 @@ mediapipe::Status TimedBoxListIdToLabelCalculator::Open(CalculatorContext* cc) { while (std::getline(stream, line)) { label_map_[i++] = line; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TimedBoxListIdToLabelCalculator::Process( - CalculatorContext* cc) { +absl::Status TimedBoxListIdToLabelCalculator::Process(CalculatorContext* cc) { const auto& input_list = cc->Inputs().Index(0).Get(); auto output_list = absl::make_unique(); for (const auto& input_box : input_list.box()) { @@ -99,7 +98,7 @@ mediapipe::Status TimedBoxListIdToLabelCalculator::Process( } } cc->Outputs().Index(0).Add(output_list.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/timed_box_list_to_render_data_calculator.cc b/mediapipe/calculators/util/timed_box_list_to_render_data_calculator.cc index a70c62f49..53c2ffa2f 100644 --- a/mediapipe/calculators/util/timed_box_list_to_render_data_calculator.cc +++ b/mediapipe/calculators/util/timed_box_list_to_render_data_calculator.cc @@ -120,35 +120,34 @@ class TimedBoxListToRenderDataCalculator : public CalculatorBase { TimedBoxListToRenderDataCalculator& operator=( const TimedBoxListToRenderDataCalculator&) = delete; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: TimedBoxListToRenderDataCalculatorOptions options_; }; REGISTER_CALCULATOR(TimedBoxListToRenderDataCalculator); -mediapipe::Status TimedBoxListToRenderDataCalculator::GetContract( +absl::Status TimedBoxListToRenderDataCalculator::GetContract( CalculatorContract* cc) { if (cc->Inputs().HasTag(kTimedBoxListTag)) { cc->Inputs().Tag(kTimedBoxListTag).Set(); } cc->Outputs().Tag(kRenderDataTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TimedBoxListToRenderDataCalculator::Open( - CalculatorContext* cc) { +absl::Status TimedBoxListToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TimedBoxListToRenderDataCalculator::Process( +absl::Status TimedBoxListToRenderDataCalculator::Process( CalculatorContext* cc) { auto render_data = absl::make_unique(); @@ -164,7 +163,7 @@ mediapipe::Status TimedBoxListToRenderDataCalculator::Process( cc->Outputs() .Tag(kRenderDataTag) .Add(render_data.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/util/to_image_calculator.cc b/mediapipe/calculators/util/to_image_calculator.cc new file mode 100644 index 000000000..5e119fca7 --- /dev/null +++ b/mediapipe/calculators/util/to_image_calculator.cc @@ -0,0 +1,160 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_options.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/logging.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/vector.h" + +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gl_calculator_helper.h" +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +namespace { +constexpr char kImageFrameTag[] = "IMAGE_CPU"; +constexpr char kGpuBufferTag[] = "IMAGE_GPU"; +constexpr char kImageTag[] = "IMAGE"; +} // namespace + +// A calculator for converting from legacy MediaPipe datatypes into a +// unified image container. +// +// Inputs: +// One of the following two tags: +// IMAGE_CPU: An ImageFrame containing input image. +// IMAGE_GPU: A GpuBuffer containing input image. +// +// Output: +// IMAGE: An Image containing output image. +// +// Note: +// No CPU/GPU conversion is done. +// +class ToImageCalculator : public CalculatorBase { + public: + ToImageCalculator() = default; + ~ToImageCalculator() override = default; + + static absl::Status GetContract(CalculatorContract* cc); + + // From Calculator. + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; + + private: + absl::Status RenderGpu(CalculatorContext* cc); + absl::Status RenderCpu(CalculatorContext* cc); + + bool gpu_input_ = false; + bool gpu_initialized_ = false; +#if !MEDIAPIPE_DISABLE_GPU + mediapipe::GlCalculatorHelper gpu_helper_; +#endif // !MEDIAPIPE_DISABLE_GPU +}; +REGISTER_CALCULATOR(ToImageCalculator); + +absl::Status ToImageCalculator::GetContract(CalculatorContract* cc) { + cc->Outputs().Tag(kImageTag).Set(); + + bool gpu_input = false; + + if (cc->Inputs().HasTag(kImageFrameTag) && + cc->Inputs().HasTag(kGpuBufferTag)) { + return absl::InternalError("Cannot have multiple inputs."); + } + + if (cc->Inputs().HasTag(kGpuBufferTag)) { +#if !MEDIAPIPE_DISABLE_GPU + cc->Inputs().Tag(kGpuBufferTag).Set(); + gpu_input = true; +#else + RET_CHECK_FAIL() << "GPU is disabled. Cannot use IMAGE_GPU stream."; +#endif // !MEDIAPIPE_DISABLE_GPU + } + if (cc->Inputs().HasTag(kImageFrameTag)) { + cc->Inputs().Tag(kImageFrameTag).Set(); + } + + if (gpu_input) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); +#endif // !MEDIAPIPE_DISABLE_GPU + } + + return absl::OkStatus(); +} + +absl::Status ToImageCalculator::Open(CalculatorContext* cc) { + cc->SetOffset(TimestampDiff(0)); + + if (cc->Inputs().HasTag(kGpuBufferTag)) { + gpu_input_ = true; + } + + if (gpu_input_) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); +#endif + } // !MEDIAPIPE_DISABLE_GPU + + return absl::OkStatus(); +} + +absl::Status ToImageCalculator::Process(CalculatorContext* cc) { + if (gpu_input_) { +#if !MEDIAPIPE_DISABLE_GPU + MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([&cc]() -> absl::Status { + auto& input = cc->Inputs().Tag(kGpuBufferTag).Get(); + // Wrap texture pointer; shallow copy. + auto output = std::make_unique(input); + cc->Outputs().Tag(kImageTag).Add(output.release(), cc->InputTimestamp()); + return absl::OkStatus(); + })); +#endif // !MEDIAPIPE_DISABLE_GPU + } else { + // The input ImageFrame. + auto& input = cc->Inputs().Tag(kImageFrameTag).Get(); + // Make a copy of the input packet to co-own the input ImageFrame. + Packet* packet_copy_ptr = + new Packet(cc->Inputs().Tag(kImageFrameTag).Value()); + // Create an output Image that (co-)owns a new ImageFrame that points to + // the same pixel data as the input ImageFrame and also owns the packet + // copy. As a result, the output Image indirectly co-owns the input + // ImageFrame. This ensures a correct life span of the shared pixel data. + std::unique_ptr output = + std::make_unique( + std::make_shared( + input.Format(), input.Width(), input.Height(), + input.WidthStep(), const_cast(input.PixelData()), + [packet_copy_ptr](uint8*) { delete packet_copy_ptr; })); + cc->Outputs().Tag(kImageTag).Add(output.release(), cc->InputTimestamp()); + } + + return absl::OkStatus(); +} + +absl::Status ToImageCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/top_k_scores_calculator.cc b/mediapipe/calculators/util/top_k_scores_calculator.cc index d2b0d98f7..37d1b2ab2 100644 --- a/mediapipe/calculators/util/top_k_scores_calculator.cc +++ b/mediapipe/calculators/util/top_k_scores_calculator.cc @@ -62,14 +62,14 @@ namespace mediapipe { // } class TopKScoresCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: - mediapipe::Status LoadLabelmap(std::string label_map_path); + absl::Status LoadLabelmap(std::string label_map_path); int top_k_ = -1; float threshold_ = 0.0; @@ -78,7 +78,7 @@ class TopKScoresCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TopKScoresCalculator); -mediapipe::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) { +absl::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("SCORES")); cc->Inputs().Tag("SCORES").Set>(); if (cc->Outputs().HasTag("TOP_K_INDEXES")) { @@ -96,10 +96,10 @@ mediapipe::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) { if (cc->Outputs().HasTag("SUMMARY")) { cc->Outputs().Tag("SUMMARY").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TopKScoresCalculator::Open(CalculatorContext* cc) { +absl::Status TopKScoresCalculator::Open(CalculatorContext* cc) { const auto& options = cc->Options<::mediapipe::TopKScoresCalculatorOptions>(); RET_CHECK(options.has_top_k() || options.has_threshold()) << "Must specify at least one of the top_k and threshold fields in " @@ -117,10 +117,10 @@ mediapipe::Status TopKScoresCalculator::Open(CalculatorContext* cc) { if (cc->Outputs().HasTag("TOP_K_LABELS")) { RET_CHECK(!label_map_.empty()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TopKScoresCalculator::Process(CalculatorContext* cc) { +absl::Status TopKScoresCalculator::Process(CalculatorContext* cc) { const std::vector& input_vector = cc->Inputs().Tag("SCORES").Get>(); std::vector top_k_indexes; @@ -213,11 +213,10 @@ mediapipe::Status TopKScoresCalculator::Process(CalculatorContext* cc) { } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TopKScoresCalculator::LoadLabelmap( - std::string label_map_path) { +absl::Status TopKScoresCalculator::LoadLabelmap(std::string label_map_path) { std::string string_path; ASSIGN_OR_RETURN(string_path, PathToResourceAsFile(label_map_path)); std::string label_map_string; @@ -230,7 +229,7 @@ mediapipe::Status TopKScoresCalculator::LoadLabelmap( label_map_[i++] = line; } label_map_loaded_ = true; - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/video/box_detector_calculator.cc b/mediapipe/calculators/video/box_detector_calculator.cc index db179e125..b7b91d253 100644 --- a/mediapipe/calculators/video/box_detector_calculator.cc +++ b/mediapipe/calculators/video/box_detector_calculator.cc @@ -92,11 +92,11 @@ class BoxDetectorCalculator : public CalculatorBase { public: ~BoxDetectorCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: BoxDetectorCalculatorOptions options_; @@ -109,7 +109,7 @@ class BoxDetectorCalculator : public CalculatorBase { REGISTER_CALCULATOR(BoxDetectorCalculator); -mediapipe::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) { +absl::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) { if (cc->Inputs().HasTag("TRACKING")) { cc->Inputs().Tag("TRACKING").Set(); } @@ -172,10 +172,10 @@ mediapipe::Status BoxDetectorCalculator::GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BoxDetectorCalculator::Open(CalculatorContext* cc) { +absl::Status BoxDetectorCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); box_detector_ = BoxDetectorInterface::Create(options_.detector_options()); @@ -210,10 +210,10 @@ mediapipe::Status BoxDetectorCalculator::Open(CalculatorContext* cc) { frame_alignment_ = cc->InputSidePackets().Tag("FRAME_ALIGNMENT").Get(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { +absl::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { const Timestamp timestamp = cc->InputTimestamp(); const int64 timestamp_msec = timestamp.Value() / 1000; @@ -246,7 +246,7 @@ mediapipe::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { } if (!detector_switch_) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } InputStream* track_stream = cc->Inputs().HasTag("TRACKING") @@ -274,7 +274,7 @@ mediapipe::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { if (track_stream != nullptr) { // Detect from tracking data if (track_stream->IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const TrackingData& tracking_data = track_stream->Get(); @@ -289,7 +289,7 @@ mediapipe::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { } else if (video_stream != nullptr) { // Detect from input frame if (video_stream->IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } TimedBoxProtoList tracked_boxes; @@ -305,7 +305,7 @@ mediapipe::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { detected_boxes.get()); } else { if (feature_stream->IsEmpty() || descriptor_stream->IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& image_size = @@ -377,17 +377,17 @@ mediapipe::Status BoxDetectorCalculator::Process(CalculatorContext* cc) { cc->Outputs().Tag("BOXES").Add(detected_boxes.release(), timestamp); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BoxDetectorCalculator::Close(CalculatorContext* cc) { +absl::Status BoxDetectorCalculator::Close(CalculatorContext* cc) { if (write_index_) { BoxDetectorIndex index = box_detector_->ObtainBoxDetectorIndex(); MEDIAPIPE_CHECK_OK(mediapipe::file::SetContents( cc->InputSidePackets().Tag("OUTPUT_INDEX_FILENAME").Get(), index.SerializeAsString())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/video/box_tracker_calculator.cc b/mediapipe/calculators/video/box_tracker_calculator.cc index 30ac2b26d..7d04d9765 100644 --- a/mediapipe/calculators/video/box_tracker_calculator.cc +++ b/mediapipe/calculators/video/box_tracker_calculator.cc @@ -123,10 +123,10 @@ class BoxTrackerCalculator : public CalculatorBase { public: ~BoxTrackerCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; protected: void RenderStates(const std::vector& states, cv::Mat* mat); @@ -373,7 +373,7 @@ void AddStateToPath(const MotionBoxState& state, int64 time_msec, } // namespace. -mediapipe::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) { +absl::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) { if (cc->Inputs().HasTag("TRACKING")) { cc->Inputs().Tag("TRACKING").Set(); } @@ -452,10 +452,10 @@ mediapipe::Status BoxTrackerCalculator::GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag(kOptionsTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { +absl::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { options_ = tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), kOptionsTag); @@ -515,10 +515,10 @@ mediapipe::Status BoxTrackerCalculator::Open(CalculatorContext* cc) { << "Streaming mode not compatible with cache dir."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { +absl::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { // Batch mode, issue tracking requests. if (box_tracker_ && !tracking_issued_) { for (const auto& pos : initial_pos_.box()) { @@ -530,7 +530,7 @@ mediapipe::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { const Timestamp& timestamp = cc->InputTimestamp(); if (timestamp == Timestamp::PreStream()) { // Indicator packet. - return mediapipe::OkStatus(); + return absl::OkStatus(); } InputStream* track_stream = cc->Inputs().HasTag("TRACKING") @@ -892,7 +892,7 @@ mediapipe::Status BoxTrackerCalculator::Process(CalculatorContext* cc) { cc->Outputs().Tag("VIZ").Add(viz_frame.release(), timestamp); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void BoxTrackerCalculator::AddSmoothTransitionToOutputBox( diff --git a/mediapipe/calculators/video/flow_packager_calculator.cc b/mediapipe/calculators/video/flow_packager_calculator.cc index ee4181723..a57105928 100644 --- a/mediapipe/calculators/video/flow_packager_calculator.cc +++ b/mediapipe/calculators/video/flow_packager_calculator.cc @@ -59,11 +59,11 @@ class FlowPackagerCalculator : public CalculatorBase { public: ~FlowPackagerCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; // Writes passed chunk to disk. void WriteChunk(const TrackingDataChunk& chunk) const; @@ -90,7 +90,7 @@ class FlowPackagerCalculator : public CalculatorBase { REGISTER_CALCULATOR(FlowPackagerCalculator); -mediapipe::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) { +absl::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) { if (!cc->Inputs().HasTag("FLOW")) { return tool::StatusFail("No input flow was specified."); } @@ -114,10 +114,10 @@ mediapipe::Status FlowPackagerCalculator::GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag("CACHE_DIR").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FlowPackagerCalculator::Open(CalculatorContext* cc) { +absl::Status FlowPackagerCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); flow_packager_.reset(new FlowPackager(options_.flow_packager_options())); @@ -128,10 +128,10 @@ mediapipe::Status FlowPackagerCalculator::Open(CalculatorContext* cc) { cache_dir_ = cc->InputSidePackets().Tag("CACHE_DIR").Get(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { +absl::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { InputStream* flow_stream = &(cc->Inputs().Tag("FLOW")); const RegionFlowFeatureList& flow = flow_stream->Get(); @@ -193,10 +193,10 @@ mediapipe::Status FlowPackagerCalculator::Process(CalculatorContext* cc) { prev_timestamp_ = timestamp; ++frame_idx_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FlowPackagerCalculator::Close(CalculatorContext* cc) { +absl::Status FlowPackagerCalculator::Close(CalculatorContext* cc) { if (frame_idx_ > 0) { tracking_chunk_.set_last_chunk(true); if (cc->Outputs().HasTag("TRACKING_CHUNK")) { @@ -215,7 +215,7 @@ mediapipe::Status FlowPackagerCalculator::Close(CalculatorContext* cc) { cc->Outputs().Tag("COMPLETE").Add(new bool(true), Timestamp::PreStream()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void FlowPackagerCalculator::WriteChunk(const TrackingDataChunk& chunk) const { diff --git a/mediapipe/calculators/video/flow_to_image_calculator.cc b/mediapipe/calculators/video/flow_to_image_calculator.cc index b63163092..6a078ee72 100644 --- a/mediapipe/calculators/video/flow_to_image_calculator.cc +++ b/mediapipe/calculators/video/flow_to_image_calculator.cc @@ -56,27 +56,27 @@ class FlowToImageCalculator : public CalculatorBase { public: FlowToImageCalculator() {} ~FlowToImageCalculator() override {} - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: FlowQuantizerModel model_; }; -mediapipe::Status FlowToImageCalculator::GetContract(CalculatorContract* cc) { +absl::Status FlowToImageCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); // Model sanity check const auto& options = cc->Options(); if (options.min_value() >= options.max_value()) { - return mediapipe::InvalidArgumentError("Invalid quantizer model."); + return absl::InvalidArgumentError("Invalid quantizer model."); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FlowToImageCalculator::Open(CalculatorContext* cc) { +absl::Status FlowToImageCalculator::Open(CalculatorContext* cc) { const auto& options = cc->Options(); // Fill the the model_data, ideally we want to train the model, but we omit // the step for now, and takes the (min, max) range from protobuf. @@ -86,10 +86,10 @@ mediapipe::Status FlowToImageCalculator::Open(CalculatorContext* cc) { options.min_value(), options.min_value(), options.max_value(), options.max_value())); model_.LoadFromProto(model_data); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FlowToImageCalculator::Process(CalculatorContext* cc) { +absl::Status FlowToImageCalculator::Process(CalculatorContext* cc) { const auto& input = cc->Inputs().Index(0).Get(); // Input flow is 2-channel with x-dim flow and y-dim flow. // Convert it to a ImageFrame in SRGB space, the 3rd channel is not used (0). @@ -106,7 +106,7 @@ mediapipe::Status FlowToImageCalculator::Process(CalculatorContext* cc) { } } cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(FlowToImageCalculator); diff --git a/mediapipe/calculators/video/motion_analysis_calculator.cc b/mediapipe/calculators/video/motion_analysis_calculator.cc index bce6dbbe0..59673108c 100644 --- a/mediapipe/calculators/video/motion_analysis_calculator.cc +++ b/mediapipe/calculators/video/motion_analysis_calculator.cc @@ -95,11 +95,11 @@ class MotionAnalysisCalculator : public CalculatorBase { public: ~MotionAnalysisCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: // Outputs results to Outputs() if MotionAnalysis buffered sufficient results. @@ -107,8 +107,8 @@ class MotionAnalysisCalculator : public CalculatorBase { void OutputMotionAnalyzedFrames(bool flush, CalculatorContext* cc); // Lazy init function to be called on Process. - mediapipe::Status InitOnProcess(InputStream* video_stream, - InputStream* selection_stream); + absl::Status InitOnProcess(InputStream* video_stream, + InputStream* selection_stream); // Parses CSV file contents to homographies. bool ParseModelCSV(const std::string& contents, @@ -189,8 +189,7 @@ class MotionAnalysisCalculator : public CalculatorBase { REGISTER_CALCULATOR(MotionAnalysisCalculator); -mediapipe::Status MotionAnalysisCalculator::GetContract( - CalculatorContract* cc) { +absl::Status MotionAnalysisCalculator::GetContract(CalculatorContract* cc) { if (cc->Inputs().HasTag("VIDEO")) { cc->Inputs().Tag("VIDEO").Set(); } @@ -246,10 +245,10 @@ mediapipe::Status MotionAnalysisCalculator::GetContract( cc->InputSidePackets().Tag(kOptionsTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { +absl::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { options_ = tool::RetrieveOptions(cc->Options(), cc->InputSidePackets(), kOptionsTag); @@ -364,7 +363,7 @@ mediapipe::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { // If no video header is provided, just return and initialize on the first // Process() call. if (video_header == nullptr) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } ////////////// EARLY RETURN; ONLY HEADER OUTPUT SHOULD GO HERE /////////////// @@ -397,12 +396,12 @@ mediapipe::Status MotionAnalysisCalculator::Open(CalculatorContext* cc) { .SetHeader(Adopt(new VideoHeader(*video_header))); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { +absl::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { if (options_.bypass_mode()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } InputStream* video_stream = @@ -441,7 +440,7 @@ mediapipe::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { } ++frame_idx_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (motion_analysis_ == nullptr) { @@ -491,7 +490,7 @@ mediapipe::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { cc->Outputs().Tag("VIDEO_OUT").AddPacket(video_stream->Value()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (use_frame) { @@ -574,10 +573,10 @@ mediapipe::Status MotionAnalysisCalculator::Process(CalculatorContext* cc) { OutputMotionAnalyzedFrames(false, cc); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MotionAnalysisCalculator::Close(CalculatorContext* cc) { +absl::Status MotionAnalysisCalculator::Close(CalculatorContext* cc) { // Guard against empty videos. if (motion_analysis_) { OutputMotionAnalyzedFrames(true, cc); @@ -588,7 +587,7 @@ mediapipe::Status MotionAnalysisCalculator::Close(CalculatorContext* cc) { << meta_motions_.size(); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void MotionAnalysisCalculator::OutputMotionAnalyzedFrames( @@ -688,7 +687,7 @@ void MotionAnalysisCalculator::OutputMotionAnalyzedFrames( } } -mediapipe::Status MotionAnalysisCalculator::InitOnProcess( +absl::Status MotionAnalysisCalculator::InitOnProcess( InputStream* video_stream, InputStream* selection_stream) { if (video_stream) { frame_width_ = video_stream->Get().Width(); @@ -761,7 +760,7 @@ mediapipe::Status MotionAnalysisCalculator::InitOnProcess( motion_options->set_filter_initialized_irls_weights(true); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } bool MotionAnalysisCalculator::ParseModelCSV( diff --git a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc index a70785cb0..bf7ed3e8a 100644 --- a/mediapipe/calculators/video/opencv_video_decoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_decoder_calculator.cc @@ -86,7 +86,7 @@ ImageFormat::Format GetImageFormat(int num_channels) { // class OpenCvVideoDecoderCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set(); cc->Outputs().Tag("VIDEO").Set(); if (cc->Outputs().HasTag("VIDEO_PRESTREAM")) { @@ -95,10 +95,10 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { if (cc->OutputSidePackets().HasTag("SAVED_AUDIO_PATH")) { cc->OutputSidePackets().Tag("SAVED_AUDIO_PATH").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const std::string& input_file_path = cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get(); cap_ = absl::make_unique(input_file_path); @@ -177,10 +177,10 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { "config."; #endif } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { auto image_frame = absl::make_unique(format_, width_, height_, /*alignment_boundary=*/1); // Use microsecond as the unit of time. @@ -213,10 +213,10 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { decoded_frames_++; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { if (cap_ && cap_->isOpened()) { cap_->release(); } @@ -225,7 +225,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase { << frame_count_ << " vs decoded frames: " << decoded_frames_ << ")."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc index ffb546dbd..9a74fb710 100644 --- a/mediapipe/calculators/video/opencv_video_encoder_calculator.cc +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator.cc @@ -76,21 +76,20 @@ namespace mediapipe { // class OpenCvVideoEncoderCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status SetUpVideoWriter(float frame_rate, int width, int height); + absl::Status SetUpVideoWriter(float frame_rate, int width, int height); std::string output_file_path_; int four_cc_; std::unique_ptr writer_; }; -mediapipe::Status OpenCvVideoEncoderCalculator::GetContract( - CalculatorContract* cc) { +absl::Status OpenCvVideoEncoderCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag("VIDEO")); cc->Inputs().Tag("VIDEO").Set(); if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { @@ -101,10 +100,10 @@ mediapipe::Status OpenCvVideoEncoderCalculator::GetContract( if (cc->InputSidePackets().HasTag("AUDIO_FILE_PATH")) { cc->InputSidePackets().Tag("AUDIO_FILE_PATH").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { +absl::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { OpenCvVideoEncoderCalculatorOptions options = cc->Options(); RET_CHECK(options.has_codec() && options.codec().length() == 4) @@ -128,12 +127,12 @@ mediapipe::Status OpenCvVideoEncoderCalculator::Open(CalculatorContext* cc) { // from the video header directly. The calculator will receive the video // header packet at timestamp prestream. if (cc->Inputs().HasTag("VIDEO_PRESTREAM")) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } return SetUpVideoWriter(options.fps(), options.width(), options.height()); } -mediapipe::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) { +absl::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) { if (cc->InputTimestamp() == Timestamp::PreStream()) { const VideoHeader& video_header = cc->Inputs().Tag("VIDEO_PRESTREAM").Get(); @@ -171,10 +170,10 @@ mediapipe::Status OpenCvVideoEncoderCalculator::Process(CalculatorContext* cc) { } } writer_->write(frame); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) { +absl::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) { if (writer_ && writer_->isOpened()) { writer_->release(); } @@ -205,11 +204,12 @@ mediapipe::Status OpenCvVideoEncoderCalculator::Close(CalculatorContext* cc) { "config."; #endif } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OpenCvVideoEncoderCalculator::SetUpVideoWriter( - float frame_rate, int width, int height) { +absl::Status OpenCvVideoEncoderCalculator::SetUpVideoWriter(float frame_rate, + int width, + int height) { RET_CHECK(frame_rate > 0 && width > 0 && height > 0) << "Invalid video metadata: frame_rate=" << frame_rate << ", width=" << width << ", height=" << height; @@ -219,7 +219,7 @@ mediapipe::Status OpenCvVideoEncoderCalculator::SetUpVideoWriter( return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Fail to open file at " << output_file_path_; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(OpenCvVideoEncoderCalculator); diff --git a/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc b/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc index faf693c60..1a1530331 100644 --- a/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc +++ b/mediapipe/calculators/video/opencv_video_encoder_calculator_test.cc @@ -70,7 +70,7 @@ TEST(OpenCvVideoEncoderCalculatorTest, DISABLED_TestMp4Avc720pVideo) { StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("video_prestream"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); Packet packet; @@ -129,7 +129,7 @@ TEST(OpenCvVideoEncoderCalculatorTest, TestFlvH264Video) { StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("video_prestream"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); Packet packet; @@ -190,7 +190,7 @@ TEST(OpenCvVideoEncoderCalculatorTest, TestMkvVp8Video) { StatusOrPoller status_or_poller = graph.AddOutputStreamPoller("video_prestream"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK(graph.StartRun({})); Packet packet; diff --git a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc index ccab6755d..c416fa9b0 100644 --- a/mediapipe/calculators/video/tracked_detection_manager_calculator.cc +++ b/mediapipe/calculators/video/tracked_detection_manager_calculator.cc @@ -106,7 +106,15 @@ Detection GetAxisAlignedDetectionFromTrackedDetection( } else { detection.set_detection_id(tracked_detection.unique_id()); } + + // Sort the labels by descending scores. + std::vector> labels_and_scores; for (const auto& label_and_score : tracked_detection.label_to_score_map()) { + labels_and_scores.push_back(label_and_score); + } + std::sort(labels_and_scores.begin(), labels_and_scores.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + for (const auto& label_and_score : labels_and_scores) { detection.add_label(label_and_score.first); detection.add_score(label_and_score.second); } @@ -139,10 +147,10 @@ Detection GetAxisAlignedDetectionFromTrackedDetection( // } class TrackedDetectionManagerCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Adds new list of detections to |waiting_for_update_detections_|. @@ -161,7 +169,7 @@ class TrackedDetectionManagerCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TrackedDetectionManagerCalculator); -mediapipe::Status TrackedDetectionManagerCalculator::GetContract( +absl::Status TrackedDetectionManagerCalculator::GetContract( CalculatorContract* cc) { if (cc->Inputs().HasTag(kDetectionsTag)) { cc->Inputs().Tag(kDetectionsTag).Set>(); @@ -183,20 +191,18 @@ mediapipe::Status TrackedDetectionManagerCalculator::GetContract( cc->Outputs().Tag(kDetectionBoxesTag).Set>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TrackedDetectionManagerCalculator::Open( - CalculatorContext* cc) { +absl::Status TrackedDetectionManagerCalculator::Open(CalculatorContext* cc) { mediapipe::TrackedDetectionManagerCalculatorOptions options = cc->Options(); tracked_detection_manager_.SetConfig( options.tracked_detection_manager_options()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TrackedDetectionManagerCalculator::Process( - CalculatorContext* cc) { +absl::Status TrackedDetectionManagerCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kTrackingBoxesTag) && !cc->Inputs().Tag(kTrackingBoxesTag).IsEmpty()) { const TimedBoxProtoList& tracked_boxes = @@ -296,7 +302,7 @@ mediapipe::Status TrackedDetectionManagerCalculator::Process( AddDetectionList(detection_list, cc); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void TrackedDetectionManagerCalculator::AddDetectionList( diff --git a/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc index 56aa86412..cf00da1f7 100644 --- a/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc +++ b/mediapipe/calculators/video/tvl1_optical_flow_calculator.cc @@ -74,14 +74,14 @@ cv::Mat ConvertToGrayscale(const cv::Mat& image) { // num_threads: 10 class Tvl1OpticalFlowCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: - mediapipe::Status CalculateOpticalFlow(const ImageFrame& current_frame, - const ImageFrame& next_frame, - OpticalFlowField* flow); + absl::Status CalculateOpticalFlow(const ImageFrame& current_frame, + const ImageFrame& next_frame, + OpticalFlowField* flow); bool forward_requested_ = false; bool backward_requested_ = false; // Stores the idle DenseOpticalFlow objects. @@ -93,11 +93,10 @@ class Tvl1OpticalFlowCalculator : public CalculatorBase { absl::Mutex mutex_; }; -mediapipe::Status Tvl1OpticalFlowCalculator::GetContract( - CalculatorContract* cc) { +absl::Status Tvl1OpticalFlowCalculator::GetContract(CalculatorContract* cc) { if (!cc->Inputs().HasTag("FIRST_FRAME") || !cc->Inputs().HasTag("SECOND_FRAME")) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Missing required input streams. Both FIRST_FRAME and SECOND_FRAME " "must be specified."); } @@ -109,10 +108,10 @@ mediapipe::Status Tvl1OpticalFlowCalculator::GetContract( if (cc->Outputs().HasTag("BACKWARD_FLOW")) { cc->Outputs().Tag("BACKWARD_FLOW").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) { +absl::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) { { absl::MutexLock lock(&mutex_); tvl1_computers_.emplace_back(cv::createOptFlow_DualTVL1()); @@ -124,10 +123,10 @@ mediapipe::Status Tvl1OpticalFlowCalculator::Open(CalculatorContext* cc) { backward_requested_ = true; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { +absl::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { const ImageFrame& first_frame = cc->Inputs().Tag("FIRST_FRAME").Value().Get(); const ImageFrame& second_frame = @@ -148,10 +147,10 @@ mediapipe::Status Tvl1OpticalFlowCalculator::Process(CalculatorContext* cc) { .Tag("BACKWARD_FLOW") .Add(backward_optical_flow_field.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Tvl1OpticalFlowCalculator::CalculateOpticalFlow( +absl::Status Tvl1OpticalFlowCalculator::CalculateOpticalFlow( const ImageFrame& current_frame, const ImageFrame& next_frame, OpticalFlowField* flow) { CHECK(flow); @@ -184,7 +183,7 @@ mediapipe::Status Tvl1OpticalFlowCalculator::CalculateOpticalFlow( absl::MutexLock lock(&mutex_); tvl1_computers_.push_back(tvl1_computer); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(Tvl1OpticalFlowCalculator); diff --git a/mediapipe/calculators/video/tvl1_optical_flow_calculator_test.cc b/mediapipe/calculators/video/tvl1_optical_flow_calculator_test.cc index b226dfd87..c9d30b73d 100644 --- a/mediapipe/calculators/video/tvl1_optical_flow_calculator_test.cc +++ b/mediapipe/calculators/video/tvl1_optical_flow_calculator_test.cc @@ -78,11 +78,11 @@ void RunTest(int num_input_packets, int max_in_flight) { StatusOrPoller status_or_poller1 = graph.AddOutputStreamPoller("forward_flow"); ASSERT_TRUE(status_or_poller1.ok()); - OutputStreamPoller poller1 = std::move(status_or_poller1.ValueOrDie()); + OutputStreamPoller poller1 = std::move(status_or_poller1.value()); StatusOrPoller status_or_poller2 = graph.AddOutputStreamPoller("backward_flow"); ASSERT_TRUE(status_or_poller2.ok()); - OutputStreamPoller poller2 = std::move(status_or_poller2.ValueOrDie()); + OutputStreamPoller poller2 = std::move(status_or_poller2.value()); MP_ASSERT_OK(graph.StartRun({})); AddInputPackets(num_input_packets, &graph); diff --git a/mediapipe/calculators/video/video_pre_stream_calculator.cc b/mediapipe/calculators/video/video_pre_stream_calculator.cc index 36547830e..ab9cd22a4 100644 --- a/mediapipe/calculators/video/video_pre_stream_calculator.cc +++ b/mediapipe/calculators/video/video_pre_stream_calculator.cc @@ -45,13 +45,13 @@ namespace mediapipe { // } class VideoPreStreamCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: - mediapipe::Status ProcessWithFrameRateInPreStream(CalculatorContext* cc); - mediapipe::Status ProcessWithFrameRateInOptions(CalculatorContext* cc); + absl::Status ProcessWithFrameRateInPreStream(CalculatorContext* cc); + absl::Status ProcessWithFrameRateInOptions(CalculatorContext* cc); std::unique_ptr header_; bool frame_rate_in_prestream_ = false; @@ -60,8 +60,7 @@ class VideoPreStreamCalculator : public CalculatorBase { REGISTER_CALCULATOR(VideoPreStreamCalculator); -mediapipe::Status VideoPreStreamCalculator::GetContract( - CalculatorContract* cc) { +absl::Status VideoPreStreamCalculator::GetContract(CalculatorContract* cc) { if (!cc->Inputs().UsesTags()) { cc->Inputs().Index(0).Set(); } else { @@ -69,17 +68,17 @@ mediapipe::Status VideoPreStreamCalculator::GetContract( cc->Inputs().Tag("VIDEO_PRESTREAM").Set(); } cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) { +absl::Status VideoPreStreamCalculator::Open(CalculatorContext* cc) { frame_rate_in_prestream_ = cc->Inputs().UsesTags() && cc->Inputs().HasTag("FRAME") && cc->Inputs().HasTag("VIDEO_PRESTREAM"); header_ = absl::make_unique(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status VideoPreStreamCalculator::ProcessWithFrameRateInPreStream( +absl::Status VideoPreStreamCalculator::ProcessWithFrameRateInPreStream( CalculatorContext* cc) { cc->GetCounter("ProcessWithFrameRateInPreStream")->Increment(); if (cc->InputTimestamp() == Timestamp::PreStream()) { @@ -99,13 +98,13 @@ mediapipe::Status VideoPreStreamCalculator::ProcessWithFrameRateInPreStream( cc->Outputs().Index(0).Add(header_.release(), Timestamp::PreStream()); emitted_ = true; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status VideoPreStreamCalculator::Process(CalculatorContext* cc) { +absl::Status VideoPreStreamCalculator::Process(CalculatorContext* cc) { cc->GetCounter("Process")->Increment(); if (emitted_) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (frame_rate_in_prestream_) { return ProcessWithFrameRateInPreStream(cc); @@ -114,7 +113,7 @@ mediapipe::Status VideoPreStreamCalculator::Process(CalculatorContext* cc) { } } -mediapipe::Status VideoPreStreamCalculator::ProcessWithFrameRateInOptions( +absl::Status VideoPreStreamCalculator::ProcessWithFrameRateInOptions( CalculatorContext* cc) { cc->GetCounter("ProcessWithFrameRateInOptions")->Increment(); RET_CHECK_NE(cc->InputTimestamp(), Timestamp::PreStream()); @@ -136,7 +135,7 @@ mediapipe::Status VideoPreStreamCalculator::ProcessWithFrameRateInOptions( RET_CHECK_NE(header_->frame_rate, 0.0) << "frame rate should be non-zero"; cc->Outputs().Index(0).Add(header_.release(), Timestamp::PreStream()); emitted_ = true; - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/calculators/video/video_pre_stream_calculator_test.cc b/mediapipe/calculators/video/video_pre_stream_calculator_test.cc index c5d71ff90..38f132e9e 100644 --- a/mediapipe/calculators/video/video_pre_stream_calculator_test.cc +++ b/mediapipe/calculators/video/video_pre_stream_calculator_test.cc @@ -39,7 +39,7 @@ TEST(VideoPreStreamCalculatorTest, ProcessesWithFrameRateInOptions) { MP_ASSERT_OK(graph.Initialize(config)); auto poller_status = graph.AddOutputStreamPoller("output"); MP_ASSERT_OK(poller_status.status()); - OutputStreamPoller& poller = poller_status.ValueOrDie(); + OutputStreamPoller& poller = poller_status.value(); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.AddPacketToInputStream( "input", @@ -79,7 +79,7 @@ TEST(VideoPreStreamCalculatorTest, ProcessesWithFrameRateInPreStream) { MP_ASSERT_OK(graph.Initialize(config)); auto poller_status = graph.AddOutputStreamPoller("output_header"); MP_ASSERT_OK(poller_status.status()); - OutputStreamPoller& poller = poller_status.ValueOrDie(); + OutputStreamPoller& poller = poller_status.value(); MP_ASSERT_OK(graph.StartRun({})); auto input_header = absl::make_unique(); input_header->frame_rate = 3.0; @@ -118,7 +118,7 @@ TEST(VideoPreStreamCalculatorTest, FailsWithoutFrameRateInOptions) { "frame", Adopt(new ImageFrame(ImageFormat::SRGB, 1, 2)).At(Timestamp(0)))); MP_ASSERT_OK(graph.CloseInputStream("frame")); - mediapipe::Status status = graph.WaitUntilDone(); + absl::Status status = graph.WaitUntilDone(); EXPECT_FALSE(status.ok()); EXPECT_THAT(status.ToString(), testing::HasSubstr("frame rate should be non-zero")); @@ -144,7 +144,7 @@ TEST(VideoPreStreamCalculatorTest, FailsWithoutFrameRateInPreStream1) { Adopt(new ImageFrame(ImageFormat::SRGB, 1, 2)).At(Timestamp(0)))); MP_ASSERT_OK(graph.CloseInputStream("frame")); MP_ASSERT_OK(graph.CloseInputStream("input_header")); - mediapipe::Status status = graph.WaitUntilDone(); + absl::Status status = graph.WaitUntilDone(); EXPECT_FALSE(status.ok()); EXPECT_THAT(status.ToString(), testing::HasSubstr("frame rate should be non-zero")); @@ -177,7 +177,7 @@ TEST(VideoPreStreamCalculatorTest, FailsWithoutFrameRateInPreStream2) { "frame", Adopt(new ImageFrame(ImageFormat::SRGB, 1, 2)).At(Timestamp(0)))); MP_ASSERT_OK(graph.CloseInputStream("frame")); - mediapipe::Status status = graph.WaitUntilDone(); + absl::Status status = graph.WaitUntilDone(); EXPECT_FALSE(status.ok()); } } diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/BUILD index d9b2554dc..8bf6c0a54 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/BUILD @@ -36,12 +36,15 @@ android_binary( name = "faceeffect", srcs = glob(["*.java"]), assets = [ + "//mediapipe/graphs/face_effect/data:axis.binarypb", + "//mediapipe/graphs/face_effect/data:axis.pngblob", "//mediapipe/graphs/face_effect/data:facepaint.pngblob", "//mediapipe/graphs/face_effect/data:glasses.binarypb", "//mediapipe/graphs/face_effect/data:glasses.pngblob", "//mediapipe/graphs/face_effect:face_effect_gpu.binarypb", "//mediapipe/modules/face_detection:face_detection_front.tflite", - "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata.binarypb", + "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata_detection.binarypb", + "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata_landmarks.binarypb", "//mediapipe/modules/face_landmark:face_landmark.tflite", ], assets_dir = "", diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/MainActivity.java b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/MainActivity.java index e9c4cb80e..78c220aae 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/MainActivity.java +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/faceeffect/MainActivity.java @@ -29,23 +29,31 @@ import com.google.mediapipe.framework.Packet; import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.modules.facegeometry.FaceGeometryProto.FaceGeometry; import com.google.mediapipe.formats.proto.MatrixDataProto.MatrixData; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** Main activity of MediaPipe face mesh app. */ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { private static final String TAG = "MainActivity"; - // Stream names. - private static final String IS_FACEPAINT_EFFECT_SELECTED_INPUT_STREAM_NAME = - "is_facepaint_effect_selected"; + // Side packet / stream names. + private static final String USE_FACE_DETECTION_INPUT_SOURCE_INPUT_SIDE_PACKET_NAME = + "use_face_detection_input_source"; + private static final String SELECTED_EFFECT_ID_INPUT_STREAM_NAME = "selected_effect_id"; private static final String OUTPUT_FACE_GEOMETRY_STREAM_NAME = "multi_face_geometry"; private static final String EFFECT_SWITCHING_HINT_TEXT = "Tap to switch between effects!"; + private static final boolean USE_FACE_DETECTION_INPUT_SOURCE = false; private static final int MATRIX_TRANSLATION_Z_INDEX = 14; - private final Object isFacepaintEffectSelectedLock = new Object(); - private boolean isFacepaintEffectSelected; + private static final int SELECTED_EFFECT_ID_AXIS = 0; + private static final int SELECTED_EFFECT_ID_FACEPAINT = 1; + private static final int SELECTED_EFFECT_ID_GLASSES = 2; + + private final Object effectSelectionLock = new Object(); + private int selectedEffectId; private View effectSwitchingHintView; private GestureDetector tapGestureDetector; @@ -60,8 +68,20 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { ViewGroup viewGroup = findViewById(R.id.preview_display_layout); viewGroup.addView(effectSwitchingHintView); - // By default, render the glasses effect. - isFacepaintEffectSelected = false; + // By default, render the axis effect for the face detection input source and the glasses effect + // for the face landmark input source. + if (USE_FACE_DETECTION_INPUT_SOURCE) { + selectedEffectId = SELECTED_EFFECT_ID_AXIS; + } else { + selectedEffectId = SELECTED_EFFECT_ID_GLASSES; + } + + // Pass the USE_FACE_DETECTION_INPUT_SOURCE flag value as an input side packet into the graph. + Map inputSidePackets = new HashMap<>(); + inputSidePackets.put( + USE_FACE_DETECTION_INPUT_SOURCE_INPUT_SIDE_PACKET_NAME, + processor.getPacketCreator().createBool(USE_FACE_DETECTION_INPUT_SOURCE)); + processor.setInputSidePackets(inputSidePackets); // This callback demonstrates how the output face geometry packet can be obtained and used // in an Android app. As an example, the Z-translation component of the face pose transform @@ -71,12 +91,9 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { OUTPUT_FACE_GEOMETRY_STREAM_NAME, (packet) -> { effectSwitchingHintView.post( - new Runnable() { - @Override - public void run() { - effectSwitchingHintView.setVisibility(View.VISIBLE); - } - }); + () -> + effectSwitchingHintView.setVisibility( + USE_FACE_DETECTION_INPUT_SOURCE ? View.INVISIBLE : View.VISIBLE)); Log.d(TAG, "Received a multi face geometry packet."); List multiFaceGeometry = @@ -103,30 +120,26 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { + "]"); }); - // Alongside the input camera frame, we also send the `is_facepaint_effect_selected` boolean - // packet to indicate which effect should be rendered on this frame. + // Alongside the input camera frame, we also send the `selected_effect_id` int32 packet to + // indicate which effect should be rendered on this frame. processor.setOnWillAddFrameListener( (timestamp) -> { - Packet isFacepaintEffectSelectedPacket = null; + Packet selectedEffectIdPacket = null; try { - synchronized (isFacepaintEffectSelectedLock) { - isFacepaintEffectSelectedPacket = - processor.getPacketCreator().createBool(isFacepaintEffectSelected); + synchronized (effectSelectionLock) { + selectedEffectIdPacket = processor.getPacketCreator().createInt32(selectedEffectId); } processor .getGraph() .addPacketToInputStream( - IS_FACEPAINT_EFFECT_SELECTED_INPUT_STREAM_NAME, - isFacepaintEffectSelectedPacket, - timestamp); + SELECTED_EFFECT_ID_INPUT_STREAM_NAME, selectedEffectIdPacket, timestamp); } catch (RuntimeException e) { Log.e( - TAG, - "Exception while adding packet to input stream while switching effects: " + e); + TAG, "Exception while adding packet to input stream while switching effects: " + e); } finally { - if (isFacepaintEffectSelectedPacket != null) { - isFacepaintEffectSelectedPacket.release(); + if (selectedEffectIdPacket != null) { + selectedEffectIdPacket.release(); } } }); @@ -149,8 +162,35 @@ public class MainActivity extends com.google.mediapipe.apps.basic.MainActivity { } private void switchEffect() { - synchronized (isFacepaintEffectSelectedLock) { - isFacepaintEffectSelected = !isFacepaintEffectSelected; + // Avoid switching the Axis effect for the face detection input source. + if (USE_FACE_DETECTION_INPUT_SOURCE) { + return; + } + + // Looped effect order: glasses -> facepaint -> axis -> glasses -> ... + synchronized (effectSelectionLock) { + switch (selectedEffectId) { + case SELECTED_EFFECT_ID_AXIS: + { + selectedEffectId = SELECTED_EFFECT_ID_GLASSES; + break; + } + + case SELECTED_EFFECT_ID_FACEPAINT: + { + selectedEffectId = SELECTED_EFFECT_ID_AXIS; + break; + } + + case SELECTED_EFFECT_ID_GLASSES: + { + selectedEffectId = SELECTED_EFFECT_ID_FACEPAINT; + break; + } + + default: + break; + } } } }); diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/BUILD index 5bf497f42..f629951df 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/iristrackinggpu/BUILD @@ -60,5 +60,6 @@ android_binary( "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", "//mediapipe/framework/formats:landmark_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "@com_google_protobuf//:protobuf_javalite", ], ) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD index a8114b3f8..783ae200e 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/objectdetection3d/BUILD @@ -90,7 +90,7 @@ genrule( cmd = "cp $< $@", ) -MODELS_DIR = "//mediapipe/models" +MODELS_DIR = "//mediapipe/modules/objectron" genrule( name = "model", @@ -165,7 +165,7 @@ android_binary( ":mesh", ":texture", MODELS_DIR + ":object_detection_ssd_mobilenetv2_oidv4_fp16.tflite", - MODELS_DIR + ":object_detection_oidv4_labelmap.pbtxt", + MODELS_DIR + ":object_detection_oidv4_labelmap.txt", ASSETS_DIR + ":box.obj.uuu", ASSETS_DIR + ":classic_colors.png", ], diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD index 4ed51a556..5eff6a833 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/posetrackinggpu/BUILD @@ -59,5 +59,6 @@ android_binary( "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", "//mediapipe/framework/formats:landmark_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "@com_google_protobuf//:protobuf_javalite", ], ) diff --git a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD index e4e41741f..50f9d643a 100644 --- a/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD +++ b/mediapipe/examples/android/src/java/com/google/mediapipe/apps/upperbodyposetrackinggpu/BUILD @@ -59,5 +59,6 @@ android_binary( "//mediapipe/examples/android/src/java/com/google/mediapipe/apps/basic:basic_lib", "//mediapipe/framework/formats:landmark_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "@com_google_protobuf//:protobuf_javalite", ], ) diff --git a/mediapipe/examples/coral/demo_run_graph_main.cc b/mediapipe/examples/coral/demo_run_graph_main.cc index db066043a..698955472 100644 --- a/mediapipe/examples/coral/demo_run_graph_main.cc +++ b/mediapipe/examples/coral/demo_run_graph_main.cc @@ -40,10 +40,11 @@ DEFINE_string(output_video_path, "", "Full path of where to save result (.mp4 only). " "If not provided, show result in a window."); -mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = @@ -56,22 +57,22 @@ mediapipe::Status RunMPPGraph() { LOG(INFO) << "Initialize the camera or load the video."; cv::VideoCapture capture; - const bool load_video = !FLAGS_input_video_path.empty(); + const bool load_video = !absl::GetFlag(FLAGS_input_video_path).empty(); if (load_video) { - capture.open(FLAGS_input_video_path); + capture.open(absl::GetFlag(FLAGS_input_video_path)); } else { capture.open(0); } RET_CHECK(capture.isOpened()); cv::VideoWriter writer; - const bool save_video = !FLAGS_output_video_path.empty(); + const bool save_video = !absl::GetFlag(FLAGS_output_video_path).empty(); if (save_video) { LOG(INFO) << "Prepare video writer."; cv::Mat test_frame; capture.read(test_frame); // Consume first frame. capture.set(cv::CAP_PROP_POS_AVI_RATIO, 0); // Rewind to beginning. - writer.open(FLAGS_output_video_path, + writer.open(absl::GetFlag(FLAGS_output_video_path), mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4 capture.get(cv::CAP_PROP_FPS), test_frame.size()); RET_CHECK(writer.isOpened()); @@ -143,7 +144,7 @@ mediapipe::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/autoflip/calculators/BUILD b/mediapipe/examples/desktop/autoflip/calculators/BUILD index 688084062..99b9d6fff 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/BUILD +++ b/mediapipe/examples/desktop/autoflip/calculators/BUILD @@ -368,6 +368,7 @@ cc_test( "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgcodecs", diff --git a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc index 440620fc9..caaa368a7 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/border_detection_calculator.cc @@ -68,9 +68,9 @@ class BorderDetectionCalculator : public CalculatorBase { BorderDetectionCalculator& operator=(const BorderDetectionCalculator&) = delete; - static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; private: // Given a color and image direction, check to see if a border of that color @@ -83,7 +83,7 @@ class BorderDetectionCalculator : public CalculatorBase { double ColorCount(const Color& mask_color, const cv::Mat& image) const; // Set member vars (image size) and confirm no changes frame-to-frame. - mediapipe::Status SetAndCheckInputs(const cv::Mat& frame); + absl::Status SetAndCheckInputs(const cv::Mat& frame); // Find the dominant color for a input image. double FindDominantColor(const cv::Mat& image, Color* dominant_color); @@ -97,15 +97,14 @@ class BorderDetectionCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(BorderDetectionCalculator); -mediapipe::Status BorderDetectionCalculator::Open( - mediapipe::CalculatorContext* cc) { +absl::Status BorderDetectionCalculator::Open(mediapipe::CalculatorContext* cc) { options_ = cc->Options(); RET_CHECK_LT(options_.vertical_search_distance(), 0.5) << "Search distance must be less than half the full image."; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BorderDetectionCalculator::SetAndCheckInputs( +absl::Status BorderDetectionCalculator::SetAndCheckInputs( const cv::Mat& frame) { if (frame_width_ < 0) { frame_width_ = frame.cols; @@ -118,10 +117,10 @@ mediapipe::Status BorderDetectionCalculator::SetAndCheckInputs( RET_CHECK_EQ(frame.rows, frame_height_) << "Input frame dimensions must remain constant throughout the video."; RET_CHECK_EQ(frame.channels(), 3) << "Input video type must be 3-channel"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BorderDetectionCalculator::Process( +absl::Status BorderDetectionCalculator::Process( mediapipe::CalculatorContext* cc) { if (!cc->Inputs().HasTag(kVideoInputTag) || cc->Inputs().Tag(kVideoInputTag).Value().IsEmpty()) { @@ -173,7 +172,7 @@ mediapipe::Status BorderDetectionCalculator::Process( .Tag(kDetectedBorders) .AddPacket(Adopt(features.release()).At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Find the dominant color within an image. @@ -291,11 +290,11 @@ void BorderDetectionCalculator::DetectBorder( } } -mediapipe::Status BorderDetectionCalculator::GetContract( +absl::Status BorderDetectionCalculator::GetContract( mediapipe::CalculatorContract* cc) { cc->Inputs().Tag(kVideoInputTag).Set(); cc->Outputs().Tag(kDetectedBorders).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc index dad46f924..c2ee6b0ff 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.cc @@ -55,23 +55,25 @@ class ContentZoomingCalculator : public CalculatorBase { ContentZoomingCalculator(const ContentZoomingCalculator&) = delete; ContentZoomingCalculator& operator=(const ContentZoomingCalculator&) = delete; - static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; private: // Converts bounds to tilt offset, pan offset and height. - mediapipe::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin, - float ymax, int* tilt_offset, - int* pan_offset, int* height); + absl::Status ConvertToPanTiltZoom(float xmin, float xmax, float ymin, + float ymax, int* tilt_offset, + int* pan_offset, int* height); + // Sets max_frame_value_ and target_aspect_ + absl::Status UpdateAspectAndMax(); ContentZoomingCalculatorOptions options_; // Detection frame width/height. int frame_height_; int frame_width_; // Path solver used to smooth top/bottom border crop values. - std::unique_ptr path_solver_height_; - std::unique_ptr path_solver_width_; - std::unique_ptr path_solver_offset_; + std::unique_ptr path_solver_zoom_; + std::unique_ptr path_solver_pan_; + std::unique_ptr path_solver_tilt_; // Are parameters initialized. bool initialized_; // Stores the time of the last "only_required" input. @@ -89,7 +91,7 @@ class ContentZoomingCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(ContentZoomingCalculator); -mediapipe::Status ContentZoomingCalculator::GetContract( +absl::Status ContentZoomingCalculator::GetContract( mediapipe::CalculatorContract* cc) { RET_CHECK( !(cc->Inputs().HasTag(kVideoFrame) && cc->Inputs().HasTag(kVideoSize))) @@ -114,11 +116,10 @@ mediapipe::Status ContentZoomingCalculator::GetContract( if (cc->Outputs().HasTag(kCropRect)) { cc->Outputs().Tag(kCropRect).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ContentZoomingCalculator::Open( - mediapipe::CalculatorContext* cc) { +absl::Status ContentZoomingCalculator::Open(mediapipe::CalculatorContext* cc) { options_ = cc->Options(); if (options_.has_kinematic_options()) { return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) @@ -131,10 +132,10 @@ mediapipe::Status ContentZoomingCalculator::Open( "in kinematic_options_zoom and kinematic_options_tilt " "directly."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ContentZoomingCalculator::ConvertToPanTiltZoom( +absl::Status ContentZoomingCalculator::ConvertToPanTiltZoom( float xmin, float xmax, float ymin, float ymax, int* tilt_offset, int* pan_offset, int* height) { // Find center of the y-axis offset (for tilt control). @@ -142,10 +143,11 @@ mediapipe::Status ContentZoomingCalculator::ConvertToPanTiltZoom( // Find center of the x-axis offset (for pan control). float x_center = xmin + (xmax - xmin) / 2; // Find size and apply scale factor to y-axis. - float fit_size = fmax((ymax - ymin) / options_.scale_factor(), xmax - xmin); + float fit_size_raw = + fmax((ymax - ymin) / options_.scale_factor(), xmax - xmin); // Apply max frame for cases where the target size is different than input // frame size. - fit_size = fmin(max_frame_value_, fit_size); + float fit_size = fmin(max_frame_value_, fit_size_raw); // Prevent box from extending beyond the image. if (y_center - fit_size / 2 < 0) { y_center = fit_size / 2; @@ -160,8 +162,8 @@ mediapipe::Status ContentZoomingCalculator::ConvertToPanTiltZoom( // Scale to pixel coordinates. *tilt_offset = frame_height_ * y_center; *pan_offset = frame_width_ * x_center; - *height = frame_height_ * fit_size; - return mediapipe::OkStatus(); + *height = frame_height_ * fit_size_raw; + return absl::OkStatus(); } namespace { @@ -185,10 +187,10 @@ mediapipe::autoflip::RectF ShiftDetection( relative_bounding_box.width() * x_offset_percent); return shifted_bb; } -mediapipe::Status UpdateRanges(const SalientRegion& region, - const float shift_vertical, - const float shift_horizontal, float* xmin, - float* xmax, float* ymin, float* ymax) { +absl::Status UpdateRanges(const SalientRegion& region, + const float shift_vertical, + const float shift_horizontal, float* xmin, + float* xmax, float* ymin, float* ymax) { if (!region.has_location_normalized()) { return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "SalientRegion did not have location normalized set."; @@ -200,12 +202,12 @@ mediapipe::Status UpdateRanges(const SalientRegion& region, *ymin = fmin(*ymin, location.y()); *ymax = fmax(*ymax, location.y() + location.height()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status UpdateRanges(const mediapipe::Detection& detection, - const float shift_vertical, - const float shift_horizontal, float* xmin, - float* xmax, float* ymin, float* ymax) { +absl::Status UpdateRanges(const mediapipe::Detection& detection, + const float shift_vertical, + const float shift_horizontal, float* xmin, + float* xmax, float* ymin, float* ymax) { RET_CHECK(detection.location_data().format() == mediapipe::LocationData::RELATIVE_BOUNDING_BOX) << "Face detection input is lacking required relative_bounding_box()"; @@ -217,7 +219,7 @@ mediapipe::Status UpdateRanges(const mediapipe::Detection& detection, *ymin = fmin(*ymin, location.ymin()); *ymax = fmax(*ymax, location.ymin() + location.height()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void MakeStaticFeatures(const int top_border, const int bottom_border, const int frame_width, const int frame_height, @@ -236,57 +238,97 @@ void MakeStaticFeatures(const int top_border, const int bottom_border, border_bottom->mutable_border_position()->set_width(frame_width); border_bottom->mutable_border_position()->set_height(bottom_border); } -} // namespace - -mediapipe::Status ContentZoomingCalculator::Process( - mediapipe::CalculatorContext* cc) { +absl::Status GetVideoResolution(mediapipe::CalculatorContext* cc, + int* frame_width, int* frame_height) { if (cc->Inputs().HasTag(kVideoFrame)) { - frame_width_ = cc->Inputs().Tag(kVideoFrame).Get().Width(); - frame_height_ = cc->Inputs().Tag(kVideoFrame).Get().Height(); + *frame_width = cc->Inputs().Tag(kVideoFrame).Get().Width(); + *frame_height = cc->Inputs().Tag(kVideoFrame).Get().Height(); } else if (cc->Inputs().HasTag(kVideoSize)) { - if (cc->Inputs().Tag(kVideoSize).IsEmpty()) { - return mediapipe::OkStatus(); - } - frame_width_ = + *frame_width = cc->Inputs().Tag(kVideoSize).Get>().first; - frame_height_ = + *frame_height = cc->Inputs().Tag(kVideoSize).Get>().second; } else { return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Input VIDEO or VIDEO_SIZE must be provided."; } + return absl::OkStatus(); +} +} // namespace +absl::Status ContentZoomingCalculator::UpdateAspectAndMax() { + max_frame_value_ = 1.0; + target_aspect_ = frame_width_ / static_cast(frame_height_); + // If target size is set and wider than input aspect, make sure to always + // crop the min required amount. + if (options_.has_target_size()) { + RET_CHECK_GT(options_.target_size().width(), 0) + << "Provided target width not valid."; + RET_CHECK_GT(options_.target_size().height(), 0) + << "Provided target height not valid."; + float input_aspect = frame_width_ / static_cast(frame_height_); + target_aspect_ = options_.target_size().width() / + static_cast(options_.target_size().height()); + max_frame_value_ = + std::min(input_aspect / target_aspect_, target_aspect_ / input_aspect); + } + return absl::OkStatus(); +} + +absl::Status ContentZoomingCalculator::Process( + mediapipe::CalculatorContext* cc) { + // For async subgraph support, return on empty video size packets. + if (cc->Inputs().HasTag(kVideoSize) && + cc->Inputs().Tag(kVideoSize).IsEmpty()) { + return absl::OkStatus(); + } + int frame_width, frame_height; + MP_RETURN_IF_ERROR(GetVideoResolution(cc, &frame_width, &frame_height)); + + // Init on first call. if (!initialized_) { - path_solver_height_ = std::make_unique( - options_.kinematic_options_zoom(), 0, frame_height_, - static_cast(frame_height_) / kFieldOfView); - path_solver_width_ = std::make_unique( + frame_width_ = frame_width; + frame_height_ = frame_height; + path_solver_pan_ = std::make_unique( options_.kinematic_options_pan(), 0, frame_width_, static_cast(frame_width_) / kFieldOfView); - path_solver_offset_ = std::make_unique( + path_solver_tilt_ = std::make_unique( options_.kinematic_options_tilt(), 0, frame_height_, static_cast(frame_height_) / kFieldOfView); - max_frame_value_ = 1.0; - target_aspect_ = frame_width_ / static_cast(frame_height_); - // If target size is set and wider than input aspect, make sure to always - // crop the min required amount. - if (options_.has_target_size()) { - RET_CHECK_GT(options_.target_size().width(), 0) - << "Provided target width not valid."; - RET_CHECK_GT(options_.target_size().height(), 0) - << "Provided target height not valid."; - float input_aspect = frame_width_ / static_cast(frame_height_); - target_aspect_ = options_.target_size().width() / - static_cast(options_.target_size().height()); - max_frame_value_ = std::min(input_aspect / target_aspect_, - target_aspect_ / input_aspect); - } + MP_RETURN_IF_ERROR(UpdateAspectAndMax()); + int min_zoom_size = frame_height_ * (options_.max_zoom_value_deg() / + static_cast(kFieldOfView)); + path_solver_zoom_ = std::make_unique( + options_.kinematic_options_zoom(), min_zoom_size, + max_frame_value_ * frame_height_, + static_cast(frame_height_) / kFieldOfView); last_measured_height_ = max_frame_value_ * frame_height_; last_measured_x_offset_ = target_aspect_ * frame_width_; last_measured_y_offset_ = frame_width_ / 2; initialized_ = true; } + // Update state for change in input resolution. + if (frame_width_ != frame_width || frame_height_ != frame_height) { + double width_scale = frame_width / static_cast(frame_width_); + double height_scale = frame_height / static_cast(frame_height_); + last_measured_height_ = last_measured_height_ * height_scale; + last_measured_y_offset_ = last_measured_y_offset_ * height_scale; + last_measured_x_offset_ = last_measured_x_offset_ * width_scale; + frame_width_ = frame_width; + frame_height_ = frame_height; + MP_RETURN_IF_ERROR(UpdateAspectAndMax()); + MP_RETURN_IF_ERROR(path_solver_pan_->UpdateMinMaxLocation(0, frame_width_)); + MP_RETURN_IF_ERROR( + path_solver_tilt_->UpdateMinMaxLocation(0, frame_height_)); + int min_zoom_size = frame_height_ * (options_.max_zoom_value_deg() / + static_cast(kFieldOfView)); + MP_RETURN_IF_ERROR(path_solver_zoom_->UpdateMinMaxLocation( + min_zoom_size, max_frame_value_ * frame_height_)); + MP_RETURN_IF_ERROR(path_solver_zoom_->UpdatePixelsPerDegree( + static_cast(frame_height_) / kFieldOfView)); + } + bool only_required_found = false; // Compute the box that contains all "is_required" detections. @@ -307,11 +349,13 @@ mediapipe::Status ContentZoomingCalculator::Process( if (cc->Inputs().HasTag(kDetections)) { if (cc->Inputs().Tag(kDetections).IsEmpty()) { auto default_rect = absl::make_unique(); + default_rect->set_x_center(frame_width_ / 2); + default_rect->set_y_center(frame_height_ / 2); default_rect->set_width(frame_width_); default_rect->set_height(frame_height_); cc->Outputs().Tag(kCropRect).Add(default_rect.release(), Timestamp(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } auto raw_detections = cc->Inputs().Tag(kDetections).Get>(); @@ -350,31 +394,46 @@ mediapipe::Status ContentZoomingCalculator::Process( offset_y = last_measured_y_offset_; } + // Check if the camera is changing in pan, tilt or zoom. If the camera is in + // motion disable temporal filtering. + bool pan_state, tilt_state, zoom_state; + MP_RETURN_IF_ERROR(path_solver_pan_->PredictMotionState( + offset_x, cc->InputTimestamp().Microseconds(), &pan_state)); + MP_RETURN_IF_ERROR(path_solver_tilt_->PredictMotionState( + offset_y, cc->InputTimestamp().Microseconds(), &tilt_state)); + MP_RETURN_IF_ERROR(path_solver_zoom_->PredictMotionState( + height, cc->InputTimestamp().Microseconds(), &zoom_state)); + if (pan_state || tilt_state || zoom_state) { + path_solver_pan_->ClearHistory(); + path_solver_tilt_->ClearHistory(); + path_solver_zoom_->ClearHistory(); + } + // Compute smoothed zoom camera path. - MP_RETURN_IF_ERROR(path_solver_height_->AddObservation( + MP_RETURN_IF_ERROR(path_solver_zoom_->AddObservation( height, cc->InputTimestamp().Microseconds())); int path_height; - MP_RETURN_IF_ERROR(path_solver_height_->GetState(&path_height)); + MP_RETURN_IF_ERROR(path_solver_zoom_->GetState(&path_height)); int path_width = path_height * target_aspect_; // Update pixel-per-degree value for pan/tilt. int target_height; - MP_RETURN_IF_ERROR(path_solver_height_->GetTargetPosition(&target_height)); + MP_RETURN_IF_ERROR(path_solver_zoom_->GetTargetPosition(&target_height)); int target_width = target_height * target_aspect_; - MP_RETURN_IF_ERROR(path_solver_width_->UpdatePixelsPerDegree( + MP_RETURN_IF_ERROR(path_solver_pan_->UpdatePixelsPerDegree( static_cast(target_width) / kFieldOfView)); - MP_RETURN_IF_ERROR(path_solver_offset_->UpdatePixelsPerDegree( + MP_RETURN_IF_ERROR(path_solver_tilt_->UpdatePixelsPerDegree( static_cast(target_height) / kFieldOfView)); // Compute smoothed pan/tilt paths. - MP_RETURN_IF_ERROR(path_solver_width_->AddObservation( + MP_RETURN_IF_ERROR(path_solver_pan_->AddObservation( offset_x, cc->InputTimestamp().Microseconds())); - MP_RETURN_IF_ERROR(path_solver_offset_->AddObservation( + MP_RETURN_IF_ERROR(path_solver_tilt_->AddObservation( offset_y, cc->InputTimestamp().Microseconds())); int path_offset_x; - MP_RETURN_IF_ERROR(path_solver_width_->GetState(&path_offset_x)); + MP_RETURN_IF_ERROR(path_solver_pan_->GetState(&path_offset_x)); int path_offset_y; - MP_RETURN_IF_ERROR(path_solver_offset_->GetState(&path_offset_y)); + MP_RETURN_IF_ERROR(path_solver_tilt_->GetState(&path_offset_y)); // Prevent box from extending beyond the image after camera smoothing. if (path_offset_y - ceil(path_height / 2.0) < 0) { @@ -415,7 +474,7 @@ mediapipe::Status ContentZoomingCalculator::Process( Timestamp(cc->InputTimestamp())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto index 2634a4afe..c0d4dd78b 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator.proto @@ -19,7 +19,7 @@ package mediapipe.autoflip; import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto"; import "mediapipe/framework/calculator.proto"; -// NextTag: 13 +// NextTag: 14 message ContentZoomingCalculatorOptions { extend mediapipe.CalculatorOptions { optional ContentZoomingCalculatorOptions ext = 313091992; @@ -52,6 +52,9 @@ message ContentZoomingCalculatorOptions { optional float detection_shift_vertical = 11 [default = 0.0]; optional float detection_shift_horizontal = 12 [default = 0.0]; + // Defines the smallest value in degrees the camera is permitted to zoom. + optional float max_zoom_value_deg = 13 [default = 35]; + // Deprecated parameters optional KinematicOptions kinematic_options = 2 [deprecated = true]; optional int64 min_motion_to_reframe = 4 [deprecated = true]; diff --git a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc index 0d2f77993..0db252fec 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/content_zooming_calculator_test.cc @@ -42,6 +42,20 @@ const char kConfigA[] = R"( input_stream: "VIDEO:camera_frames" input_stream: "SALIENT_REGIONS:detection_set" output_stream: "BORDERS:borders" + options: { + [mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: { + max_zoom_value_deg: 0 + kinematic_options_zoom { + min_motion_to_reframe: 1.2 + } + kinematic_options_tilt { + min_motion_to_reframe: 1.2 + } + kinematic_options_pan { + min_motion_to_reframe: 1.2 + } + } + } )"; const char kConfigB[] = R"( @@ -55,6 +69,16 @@ const char kConfigB[] = R"( width: 1000 height: 500 } + max_zoom_value_deg: 0 + kinematic_options_zoom { + min_motion_to_reframe: 1.2 + } + kinematic_options_tilt { + min_motion_to_reframe: 1.2 + } + kinematic_options_pan { + min_motion_to_reframe: 1.2 + } } } )"; @@ -64,6 +88,20 @@ const char kConfigC[] = R"( input_stream: "VIDEO_SIZE:size" input_stream: "SALIENT_REGIONS:detection_set" output_stream: "BORDERS:borders" + options: { + [mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: { + max_zoom_value_deg: 0 + kinematic_options_zoom { + min_motion_to_reframe: 1.2 + } + kinematic_options_tilt { + min_motion_to_reframe: 1.2 + } + kinematic_options_pan { + min_motion_to_reframe: 1.2 + } + } + } )"; const char kConfigD[] = R"( @@ -71,6 +109,20 @@ const char kConfigD[] = R"( input_stream: "VIDEO_SIZE:size" input_stream: "DETECTIONS:detections" output_stream: "CROP_RECT:rect" + options: { + [mediapipe.autoflip.ContentZoomingCalculatorOptions.ext]: { + max_zoom_value_deg: 0 + kinematic_options_zoom { + min_motion_to_reframe: 1.2 + } + kinematic_options_tilt { + min_motion_to_reframe: 1.2 + } + kinematic_options_pan { + min_motion_to_reframe: 1.2 + } + } + } )"; void CheckBorder(const StaticFeatures& static_features, int width, int height, @@ -91,8 +143,9 @@ void CheckBorder(const StaticFeatures& static_features, int width, int height, EXPECT_EQ(Border::BOTTOM, part.relative_position()); } -void AddDetection(const cv::Rect_& position, const int64 time, - CalculatorRunner* runner) { +void AddDetectionFrameSize(const cv::Rect_& position, const int64 time, + const int width, const int height, + CalculatorRunner* runner) { auto detections = std::make_unique>(); mediapipe::Detection detection; detection.mutable_location_data()->set_format( @@ -111,12 +164,17 @@ void AddDetection(const cv::Rect_& position, const int64 time, ->Tag("DETECTIONS") .packets.push_back(Adopt(detections.release()).At(Timestamp(time))); - auto input_size = ::absl::make_unique>(1000, 1000); + auto input_size = ::absl::make_unique>(width, height); runner->MutableInputs() ->Tag("VIDEO_SIZE") .packets.push_back(Adopt(input_size.release()).At(Timestamp(time))); } +void AddDetection(const cv::Rect_& position, const int64 time, + CalculatorRunner* runner) { + AddDetectionFrameSize(position, time, 1000, 1000, runner); +} + void CheckCropRect(const int x_center, const int y_center, const int width, const int height, const int frame_number, const std::vector& output_packets) { @@ -433,7 +491,53 @@ TEST(ContentZoomingCalculatorTest, EmptyDetections) { ->Tag("VIDEO_SIZE") .packets.push_back(Adopt(input_size.release()).At(Timestamp(0))); MP_ASSERT_OK(runner->Run()); - CheckCropRect(0, 0, 1000, 1000, 0, + CheckCropRect(500, 500, 1000, 1000, 0, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, ResolutionChangeStationary) { + auto config = ParseTextProtoOrDie(kConfigD); + auto runner = ::absl::make_unique(config); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 0, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1, 500, 500, + runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 222, 222, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500 * 0.5, 500 * 0.5, 222 * 0.5, 222 * 0.5, 1, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, ResolutionChangeZooming) { + auto config = ParseTextProtoOrDie(kConfigD); + auto runner = ::absl::make_unique(config); + AddDetectionFrameSize(cv::Rect_(.1, .1, .8, .8), 0, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 1000000, 1000, 1000, + runner.get()); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 2000000, 500, 500, + runner.get()); + MP_ASSERT_OK(runner->Run()); + CheckCropRect(500, 500, 888, 888, 0, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500, 500, 588, 588, 1, + runner->Outputs().Tag("CROP_RECT").packets); + CheckCropRect(500 * 0.5, 500 * 0.5, 288 * 0.5, 288 * 0.5, 2, + runner->Outputs().Tag("CROP_RECT").packets); +} + +TEST(ContentZoomingCalculatorTest, MaxZoomValue) { + auto config = ParseTextProtoOrDie(kConfigD); + auto* options = config.mutable_options()->MutableExtension( + ContentZoomingCalculatorOptions::ext); + options->set_max_zoom_value_deg(55); + auto runner = ::absl::make_unique(config); + AddDetectionFrameSize(cv::Rect_(.4, .4, .2, .2), 0, 1000, 1000, + runner.get()); + MP_ASSERT_OK(runner->Run()); + // 55/60 * 1000 = 916 + CheckCropRect(500, 500, 916, 916, 0, runner->Outputs().Tag("CROP_RECT").packets); } diff --git a/mediapipe/examples/desktop/autoflip/calculators/face_box_adjuster_calculator.proto b/mediapipe/examples/desktop/autoflip/calculators/face_box_adjuster_calculator.proto new file mode 100644 index 000000000..b00755d36 --- /dev/null +++ b/mediapipe/examples/desktop/autoflip/calculators/face_box_adjuster_calculator.proto @@ -0,0 +1,50 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe.autoflip; + +import "mediapipe/framework/calculator.proto"; + +message FaceBoxAdjusterCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional FaceBoxAdjusterCalculatorOptions ext = 347462240; + } + + // When faces are detected in a given frame, we check these number of frames + // in the past. We include only those faces in auto framing that have been + // seen in this past history. This helps reduce False Positives and also + // handles some of the edge cases. Setting the value to 0 disables the + // feature. + optional int32 num_frame_history = 1 [default = 0]; + + // IOU threshold for matching detected faces with the faces in the frame + // history buffer. + optional float iou_threshold = 2 [default = 0.2]; + + // If true, the face boxes are adjusted based on their face pose. This is done + // to correct for extreme poses that can cause the detected face boxes to be + // either too big or too small. + optional bool adjust_for_pose = 3 [default = true]; + + // There are DEPRECATED fields. Do not use. + optional float box_area_change_per_up_tilt_degree = 4 [deprecated = true]; + optional float box_area_change_per_down_tilt_degree = 5 [deprecated = true]; + + // The ratios of the face-pose corrected IPD to the face bounding box's width + // and height respectively. + optional float ipd_face_box_width_ratio = 6 [default = 0.5566]; + optional float ipd_face_box_height_ratio = 7 [default = 0.3131]; +} diff --git a/mediapipe/examples/desktop/autoflip/calculators/face_to_region_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/face_to_region_calculator.cc index e9904a299..3c9aeb4c8 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/face_to_region_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/face_to_region_calculator.cc @@ -55,9 +55,9 @@ class FaceToRegionCalculator : public CalculatorBase { FaceToRegionCalculator(const FaceToRegionCalculator&) = delete; FaceToRegionCalculator& operator=(const FaceToRegionCalculator&) = delete; - static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; private: double NormalizeX(const int pixel); @@ -78,18 +78,17 @@ REGISTER_CALCULATOR(FaceToRegionCalculator); FaceToRegionCalculator::FaceToRegionCalculator() {} -mediapipe::Status FaceToRegionCalculator::GetContract( +absl::Status FaceToRegionCalculator::GetContract( mediapipe::CalculatorContract* cc) { if (cc->Inputs().HasTag("VIDEO")) { cc->Inputs().Tag("VIDEO").Set(); } cc->Inputs().Tag("FACES").Set>(); cc->Outputs().Tag("REGIONS").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FaceToRegionCalculator::Open( - mediapipe::CalculatorContext* cc) { +absl::Status FaceToRegionCalculator::Open(mediapipe::CalculatorContext* cc) { options_ = cc->Options(); if (!cc->Inputs().HasTag("VIDEO")) { RET_CHECK(!options_.use_visual_scorer()) @@ -105,7 +104,7 @@ mediapipe::Status FaceToRegionCalculator::Open( scorer_ = absl::make_unique(options_.scorer_options()); frame_width_ = -1; frame_height_ = -1; - return mediapipe::OkStatus(); + return absl::OkStatus(); } inline double FaceToRegionCalculator::NormalizeX(const int pixel) { @@ -146,8 +145,7 @@ void FaceToRegionCalculator::ExtendSalientRegionWithPoint( } } -mediapipe::Status FaceToRegionCalculator::Process( - mediapipe::CalculatorContext* cc) { +absl::Status FaceToRegionCalculator::Process(mediapipe::CalculatorContext* cc) { if (cc->Inputs().HasTag("VIDEO") && cc->Inputs().Tag("VIDEO").Value().IsEmpty()) { return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) @@ -280,7 +278,7 @@ mediapipe::Status FaceToRegionCalculator::Process( } cc->Outputs().Tag("REGIONS").Add(region_set.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc index 106be49b9..80f0f4552 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/localization_to_region_calculator.cc @@ -38,9 +38,9 @@ class LocalizationToRegionCalculator : public mediapipe::CalculatorBase { LocalizationToRegionCalculator& operator=( const LocalizationToRegionCalculator&) = delete; - static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; private: // Calculator options. @@ -84,21 +84,21 @@ void FillSalientRegion(const mediapipe::Detection& detection, } // namespace -mediapipe::Status LocalizationToRegionCalculator::GetContract( +absl::Status LocalizationToRegionCalculator::GetContract( mediapipe::CalculatorContract* cc) { cc->Inputs().Tag("DETECTIONS").Set>(); cc->Outputs().Tag("REGIONS").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LocalizationToRegionCalculator::Open( +absl::Status LocalizationToRegionCalculator::Open( mediapipe::CalculatorContext* cc) { options_ = cc->Options(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LocalizationToRegionCalculator::Process( +absl::Status LocalizationToRegionCalculator::Process( mediapipe::CalculatorContext* cc) { const auto& annotations = cc->Inputs().Tag("DETECTIONS").Get>(); @@ -119,7 +119,7 @@ mediapipe::Status LocalizationToRegionCalculator::Process( } cc->Outputs().Tag("REGIONS").Add(regions.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc index 8cd6c42aa..885753d63 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.cc @@ -68,7 +68,7 @@ constexpr char kOutputSummary[] = "CROPPING_SUMMARY"; constexpr char kExternalRenderingPerFrame[] = "EXTERNAL_RENDERING_PER_FRAME"; constexpr char kExternalRenderingFullVid[] = "EXTERNAL_RENDERING_FULL_VID"; -mediapipe::Status SceneCroppingCalculator::GetContract( +absl::Status SceneCroppingCalculator::GetContract( mediapipe::CalculatorContract* cc) { if (cc->InputSidePackets().HasTag(kInputExternalSettings)) { cc->InputSidePackets().Tag(kInputExternalSettings).Set(); @@ -136,10 +136,10 @@ mediapipe::Status SceneCroppingCalculator::GetContract( cc->Outputs().HasTag(kExternalRenderingFullVid) || cc->Outputs().HasTag(kOutputCroppedFrames)) << "At leaset one output stream must be specified"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SceneCroppingCalculator::Open(CalculatorContext* cc) { +absl::Status SceneCroppingCalculator::Open(CalculatorContext* cc) { options_ = cc->Options(); RET_CHECK_GT(options_.max_scene_size(), 0) << "Maximum scene size is non-positive."; @@ -175,17 +175,17 @@ mediapipe::Status SceneCroppingCalculator::Open(CalculatorContext* cc) { should_perform_frame_cropping_ = cc->Outputs().HasTag(kOutputCroppedFrames); scene_camera_motion_analyzer_ = absl::make_unique( options_.scene_camera_motion_analyzer_options()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } namespace { -mediapipe::Status ParseAspectRatioString(const std::string& aspect_ratio_string, - double* aspect_ratio) { +absl::Status ParseAspectRatioString(const std::string& aspect_ratio_string, + double* aspect_ratio) { std::string error_msg = "Aspect ratio std::string must be in the format of 'width:height', e.g. " "'1:1' or '5:4', your input was " + aspect_ratio_string; - auto pos = aspect_ratio_string.find(":"); + auto pos = aspect_ratio_string.find(':'); RET_CHECK(pos != std::string::npos) << error_msg; double width_ratio; RET_CHECK(absl::SimpleAtod(aspect_ratio_string.substr(0, pos), &width_ratio)) @@ -196,7 +196,7 @@ mediapipe::Status ParseAspectRatioString(const std::string& aspect_ratio_string, &height_ratio)) << error_msg; *aspect_ratio = width_ratio / height_ratio; - return mediapipe::OkStatus(); + return absl::OkStatus(); } void ConstructExternalRenderMessage( const cv::Rect& crop_from_location, const cv::Rect& render_to_location, @@ -235,7 +235,7 @@ int RoundToEven(float value) { } // namespace -mediapipe::Status SceneCroppingCalculator::InitializeSceneCroppingCalculator( +absl::Status SceneCroppingCalculator::InitializeSceneCroppingCalculator( mediapipe::CalculatorContext* cc) { if (cc->Inputs().HasTag(kInputVideoFrames)) { const auto& frame = cc->Inputs().Tag(kInputVideoFrames).Get(); @@ -302,8 +302,7 @@ mediapipe::Status SceneCroppingCalculator::InitializeSceneCroppingCalculator( target_height_ = frame_height_; break; case SceneCroppingCalculatorOptions::UNKNOWN: - return mediapipe::InvalidArgumentError( - "target_size_type not set properly."); + return absl::InvalidArgumentError("target_size_type not set properly."); } target_aspect_ratio_ = GetRatio(target_width_, target_height_); @@ -337,7 +336,7 @@ mediapipe::Status SceneCroppingCalculator::InitializeSceneCroppingCalculator( scene_cropper_ = absl::make_unique( options_.camera_motion_options(), frame_width_, frame_height_); - return mediapipe::OkStatus(); + return absl::OkStatus(); } bool HasFrameSignal(mediapipe::CalculatorContext* cc) { @@ -347,7 +346,7 @@ bool HasFrameSignal(mediapipe::CalculatorContext* cc) { return !cc->Inputs().Tag(kInputVideoSize).Value().IsEmpty(); } -mediapipe::Status SceneCroppingCalculator::Process( +absl::Status SceneCroppingCalculator::Process( mediapipe::CalculatorContext* cc) { // Sets frame dimension and initializes scenecroppingcalculator on first video // frame. @@ -417,11 +416,10 @@ mediapipe::Status SceneCroppingCalculator::Process( continue_last_scene_ = true; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SceneCroppingCalculator::Close( - mediapipe::CalculatorContext* cc) { +absl::Status SceneCroppingCalculator::Close(mediapipe::CalculatorContext* cc) { if (!scene_frame_timestamps_.empty()) { MP_RETURN_IF_ERROR(ProcessScene(/* is_end_of_scene = */ true, cc)); } @@ -435,12 +433,12 @@ mediapipe::Status SceneCroppingCalculator::Close( .Tag(kExternalRenderingFullVid) .Add(external_render_list_.release(), Timestamp::PostStream()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // TODO: split this function into two, one for calculating the border // sizes, the other for the actual removal of borders from the frames. -mediapipe::Status SceneCroppingCalculator::RemoveStaticBorders( +absl::Status SceneCroppingCalculator::RemoveStaticBorders( CalculatorContext* cc, int* top_border_size, int* bottom_border_size) { *top_border_size = 0; *bottom_border_size = 0; @@ -492,10 +490,10 @@ mediapipe::Status SceneCroppingCalculator::RemoveStaticBorders( *key_frame_infos_[i].mutable_detections() = adjusted_detections; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SceneCroppingCalculator::InitializeFrameCropRegionComputer() { +absl::Status SceneCroppingCalculator::InitializeFrameCropRegionComputer() { key_frame_crop_options_ = options_.key_frame_crop_options(); MP_RETURN_IF_ERROR( SetKeyFrameCropTarget(frame_width_, effective_frame_height_, @@ -504,7 +502,7 @@ mediapipe::Status SceneCroppingCalculator::InitializeFrameCropRegionComputer() { VLOG(1) << "Target height " << key_frame_crop_options_.target_height(); frame_crop_region_computer_ = absl::make_unique(key_frame_crop_options_); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void SceneCroppingCalculator::FilterKeyFrameInfo() { @@ -530,8 +528,8 @@ void SceneCroppingCalculator::FilterKeyFrameInfo() { } } -mediapipe::Status SceneCroppingCalculator::ProcessScene( - const bool is_end_of_scene, CalculatorContext* cc) { +absl::Status SceneCroppingCalculator::ProcessScene(const bool is_end_of_scene, + CalculatorContext* cc) { // Removes detections under special circumstances. FilterKeyFrameInfo(); @@ -653,10 +651,10 @@ mediapipe::Status SceneCroppingCalculator::ProcessScene( is_key_frames_.clear(); static_features_.clear(); static_features_timestamps_.clear(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SceneCroppingCalculator::FormatAndOutputCroppedFrames( +absl::Status SceneCroppingCalculator::FormatAndOutputCroppedFrames( const int crop_width, const int crop_height, const int num_frames, std::vector* render_to_locations, bool* apply_padding, std::vector* padding_colors, float* vertical_fill_percent, @@ -729,7 +727,7 @@ mediapipe::Status SceneCroppingCalculator::FormatAndOutputCroppedFrames( padding_colors->push_back(padding_color_to_add); } if (!cropped_frames_ptr) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Resizes cropped frames, pads frames, and output frames. @@ -772,10 +770,10 @@ mediapipe::Status SceneCroppingCalculator::FormatAndOutputCroppedFrames( .Add(scaled_frame.release(), timestamp); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SceneCroppingCalculator::OutputVizFrames( +absl::Status SceneCroppingCalculator::OutputVizFrames( const std::vector& key_frame_crop_results, const std::vector& focus_point_frames, const std::vector& crop_from_locations, @@ -815,7 +813,7 @@ mediapipe::Status SceneCroppingCalculator::OutputVizFrames( .Add(viz_frames[i].release(), Timestamp(scene_frame_timestamps_[i])); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(SceneCroppingCalculator); diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.h b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.h index 4ffacafca..61b7b53d6 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.h +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator.h @@ -125,35 +125,34 @@ namespace autoflip { // fields are optional with default settings. class SceneCroppingCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); // Validates calculator options and initializes SceneCameraMotionAnalyzer and // SceneCropper. - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; // Buffers each scene frame and its timestamp. Packs and stores KeyFrameInfo // for key frames (a.k.a. frames with detection features). When a shot // boundary is encountered or when the buffer is full, calls ProcessScene() // to process the scene at once, and clears buffers. - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; // Calls ProcessScene() on remaining buffered frames. Optionally outputs a // VideoCroppingSummary if the output stream CROPPING_SUMMARY is present. - mediapipe::Status Close(mediapipe::CalculatorContext* cc) override; + absl::Status Close(mediapipe::CalculatorContext* cc) override; private: // Removes any static borders from the scene frames before cropping. The // arguments |top_border_size| and |bottom_border_size| report the size of the // removed borders. - mediapipe::Status RemoveStaticBorders(CalculatorContext* cc, - int* top_border_size, - int* bottom_border_size); + absl::Status RemoveStaticBorders(CalculatorContext* cc, int* top_border_size, + int* bottom_border_size); // Sets up autoflip after first frame is received and input size is known. - mediapipe::Status InitializeSceneCroppingCalculator( + absl::Status InitializeSceneCroppingCalculator( mediapipe::CalculatorContext* cc); // Initializes a FrameCropRegionComputer given input and target frame sizes. - mediapipe::Status InitializeFrameCropRegionComputer(); + absl::Status InitializeFrameCropRegionComputer(); // Processes a scene using buffered scene frames and KeyFrameInfos: // 1. Computes key frame crop regions using a FrameCropRegionComputer. @@ -165,8 +164,7 @@ class SceneCroppingCalculator : public CalculatorBase { // to force flush). // 6. Optionally outputs visualization frames. // 7. Optionally updates cropping summary. - mediapipe::Status ProcessScene(const bool is_end_of_scene, - CalculatorContext* cc); + absl::Status ProcessScene(const bool is_end_of_scene, CalculatorContext* cc); // Formats and outputs the cropped frames passed in through // |cropped_frames_ptr|. Scales them to be at least as big as the target @@ -177,14 +175,14 @@ class SceneCroppingCalculator : public CalculatorBase { // cropped frames. This is useful when the calculator is only used for // computing the cropping metadata rather than doing the actual cropping // operation. - mediapipe::Status FormatAndOutputCroppedFrames( + absl::Status FormatAndOutputCroppedFrames( const int crop_width, const int crop_height, const int num_frames, std::vector* render_to_locations, bool* apply_padding, std::vector* padding_colors, float* vertical_fill_percent, const std::vector* cropped_frames_ptr, CalculatorContext* cc); // Draws and outputs visualization frames if those streams are present. - mediapipe::Status OutputVizFrames( + absl::Status OutputVizFrames( const std::vector& key_frame_crop_results, const std::vector& focus_point_frames, const std::vector& crop_from_locations, diff --git a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc index 6cc9217e3..27867d31b 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/scene_cropping_calculator_test.cc @@ -803,6 +803,7 @@ TEST(SceneCroppingCalculatorTest, OutputsCropMessageKinematicPath) { SceneCroppingCalculatorOptions::ext); auto* kinematic_options = options->mutable_camera_motion_options()->mutable_kinematic_options(); + kinematic_options->set_min_motion_to_reframe(1.2); kinematic_options->set_max_velocity(200); auto runner = absl::make_unique(config); @@ -875,6 +876,7 @@ TEST(SceneCroppingCalculatorTest, OutputsCropMessageKinematicPathNoVideo) { SceneCroppingCalculatorOptions::ext); auto* kinematic_options = options->mutable_camera_motion_options()->mutable_kinematic_options(); + kinematic_options->set_min_motion_to_reframe(1.2); kinematic_options->set_max_velocity(2.0); auto runner = absl::make_unique(config); diff --git a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc index 9a091523d..299f60b10 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator.cc @@ -60,9 +60,9 @@ class ShotBoundaryCalculator : public mediapipe::CalculatorBase { ShotBoundaryCalculator(const ShotBoundaryCalculator&) = delete; ShotBoundaryCalculator& operator=(const ShotBoundaryCalculator&) = delete; - static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; private: // Computes the histogram of an image. @@ -98,12 +98,11 @@ void ShotBoundaryCalculator::ComputeHistogram(const cv::Mat& image, kHistogramBinNum, kHistogramRange, true, false); } -mediapipe::Status ShotBoundaryCalculator::Open( - mediapipe::CalculatorContext* cc) { +absl::Status ShotBoundaryCalculator::Open(mediapipe::CalculatorContext* cc) { options_ = cc->Options(); last_shot_timestamp_ = Timestamp(0); init_ = false; - return mediapipe::OkStatus(); + return absl::OkStatus(); } void ShotBoundaryCalculator::Transmit(mediapipe::CalculatorContext* cc, @@ -127,8 +126,7 @@ void ShotBoundaryCalculator::Transmit(mediapipe::CalculatorContext* cc, } } -mediapipe::Status ShotBoundaryCalculator::Process( - mediapipe::CalculatorContext* cc) { +absl::Status ShotBoundaryCalculator::Process(mediapipe::CalculatorContext* cc) { // Connect to input frame and make a mutable copy. cv::Mat frame_org = mediapipe::formats::MatView( &cc->Inputs().Tag(kVideoInputTag).Get()); @@ -142,7 +140,7 @@ mediapipe::Status ShotBoundaryCalculator::Process( last_histogram_ = current_histogram; init_ = true; Transmit(cc, false); - return mediapipe::OkStatus(); + return absl::OkStatus(); } double current_motion_estimate = @@ -152,7 +150,7 @@ mediapipe::Status ShotBoundaryCalculator::Process( if (motion_history_.size() != options_.window_size()) { Transmit(cc, false); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Shot detection algorithm is a mixture of adaptive (controlled with @@ -176,14 +174,14 @@ mediapipe::Status ShotBoundaryCalculator::Process( // Store histogram for next frame. last_histogram_ = current_histogram; motion_history_.pop_back(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ShotBoundaryCalculator::GetContract( +absl::Status ShotBoundaryCalculator::GetContract( mediapipe::CalculatorContract* cc) { cc->Inputs().Tag(kVideoInputTag).Set(); cc->Outputs().Tag(kShotChangeTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc index 06e5e768b..e2b4f659d 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/shot_boundary_calculator_test.cc @@ -19,6 +19,7 @@ #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/opencv_core_inc.h" diff --git a/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc index a85c8bb2e..37643b5d1 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/signal_fusing_calculator.cc @@ -105,13 +105,13 @@ class SignalFusingCalculator : public mediapipe::CalculatorBase { SignalFusingCalculator(const SignalFusingCalculator&) = delete; SignalFusingCalculator& operator=(const SignalFusingCalculator&) = delete; - static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc); - mediapipe::Status Open(mediapipe::CalculatorContext* cc) override; - mediapipe::Status Process(mediapipe::CalculatorContext* cc) override; - mediapipe::Status Close(mediapipe::CalculatorContext* cc) override; + static absl::Status GetContract(mediapipe::CalculatorContract* cc); + absl::Status Open(mediapipe::CalculatorContext* cc) override; + absl::Status Process(mediapipe::CalculatorContext* cc) override; + absl::Status Close(mediapipe::CalculatorContext* cc) override; private: - mediapipe::Status ProcessScene(mediapipe::CalculatorContext* cc); + absl::Status ProcessScene(mediapipe::CalculatorContext* cc); std::vector GetSignalPackets(mediapipe::CalculatorContext* cc); SignalFusingCalculatorOptions options_; std::map settings_by_type_; @@ -154,8 +154,7 @@ void SetupOrderedInput(mediapipe::CalculatorContract* cc) { } } // namespace -mediapipe::Status SignalFusingCalculator::Open( - mediapipe::CalculatorContext* cc) { +absl::Status SignalFusingCalculator::Open(mediapipe::CalculatorContext* cc) { options_ = cc->Options(); for (const auto& setting : options_.signal_settings()) { settings_by_type_[CreateSettingsKey(setting.type())] = setting; @@ -166,19 +165,18 @@ mediapipe::Status SignalFusingCalculator::Open( process_by_scene_ = false; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SignalFusingCalculator::Close( - mediapipe::CalculatorContext* cc) { +absl::Status SignalFusingCalculator::Close(mediapipe::CalculatorContext* cc) { if (!scene_frames_.empty()) { MP_RETURN_IF_ERROR(ProcessScene(cc)); scene_frames_.clear(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SignalFusingCalculator::ProcessScene( +absl::Status SignalFusingCalculator::ProcessScene( mediapipe::CalculatorContext* cc) { std::map detection_count; std::map multiframe_score; @@ -240,7 +238,7 @@ mediapipe::Status SignalFusingCalculator::ProcessScene( } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } std::vector SignalFusingCalculator::GetSignalPackets( @@ -260,8 +258,7 @@ std::vector SignalFusingCalculator::GetSignalPackets( return signal_packets; } -mediapipe::Status SignalFusingCalculator::Process( - mediapipe::CalculatorContext* cc) { +absl::Status SignalFusingCalculator::Process(mediapipe::CalculatorContext* cc) { bool is_boundary = false; if (process_by_scene_) { const auto& shot_tag = (tag_input_interface_) @@ -302,17 +299,17 @@ mediapipe::Status SignalFusingCalculator::Process( scene_frames_.clear(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SignalFusingCalculator::GetContract( +absl::Status SignalFusingCalculator::GetContract( mediapipe::CalculatorContract* cc) { if (cc->Inputs().NumEntries(kSignalInputsTag) > 0) { SetupTagInput(cc); } else { SetupOrderedInput(cc); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator.cc b/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator.cc index 0a7c34d9d..8d67eb8f0 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator.cc @@ -57,20 +57,19 @@ class VideoFilteringCalculator : public CalculatorBase { VideoFilteringCalculator() = default; ~VideoFilteringCalculator() override = default; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(VideoFilteringCalculator); -mediapipe::Status VideoFilteringCalculator::GetContract( - CalculatorContract* cc) { +absl::Status VideoFilteringCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kInputFrameTag).Set(); cc->Outputs().Tag(kOutputFrameTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status VideoFilteringCalculator::Process(CalculatorContext* cc) { +absl::Status VideoFilteringCalculator::Process(CalculatorContext* cc) { const auto& options = cc->Options(); const Packet& input_packet = cc->Inputs().Tag(kInputFrameTag).Value(); @@ -84,7 +83,7 @@ mediapipe::Status VideoFilteringCalculator::Process(CalculatorContext* cc) { if (filter_type == VideoFilteringCalculatorOptions::AspectRatioFilter::NO_FILTERING) { cc->Outputs().Tag(kOutputFrameTag).AddPacket(input_packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); } const int target_width = options.aspect_ratio_filter().target_width(); const int target_height = options.aspect_ratio_filter().target_height(); @@ -106,7 +105,7 @@ mediapipe::Status VideoFilteringCalculator::Process(CalculatorContext* cc) { } if (should_pass) { cc->Outputs().Tag(kOutputFrameTag).AddPacket(input_packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (options.fail_if_any()) { return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << absl::Substitute( @@ -115,7 +114,7 @@ mediapipe::Status VideoFilteringCalculator::Process(CalculatorContext* cc) { target_ratio, frame.Width(), frame.Height()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip } // namespace mediapipe diff --git a/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator_test.cc b/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator_test.cc index 9927d8077..758193832 100644 --- a/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator_test.cc +++ b/mediapipe/examples/desktop/autoflip/calculators/video_filtering_calculator_test.cc @@ -166,8 +166,8 @@ TEST(VerticalFrameRemovalCalculatorTest, OutputError) { runner->MutableInputs() ->Tag("INPUT_FRAMES") .packets.push_back(Adopt(input_frame.release()).At(Timestamp(1000))); - mediapipe::Status status = runner->Run(); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kUnknown); + absl::Status status = runner->Run(); + EXPECT_EQ(status.code(), absl::StatusCode::kUnknown); EXPECT_THAT(status.ToString(), ::testing::HasSubstr("Failing due to aspect ratio")); } diff --git a/mediapipe/examples/desktop/autoflip/quality/BUILD b/mediapipe/examples/desktop/autoflip/quality/BUILD index a6e79c3a3..4a5ac3b7a 100644 --- a/mediapipe/examples/desktop/autoflip/quality/BUILD +++ b/mediapipe/examples/desktop/autoflip/quality/BUILD @@ -249,6 +249,7 @@ cc_test( ":scene_camera_motion_analyzer", "//mediapipe/examples/desktop/autoflip:autoflip_messages_cc_proto", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:status", diff --git a/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc b/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc index 8626ae715..5916d1829 100644 --- a/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc +++ b/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.cc @@ -22,7 +22,7 @@ namespace mediapipe { namespace autoflip { -mediapipe::Status FrameCropRegionComputer::ExpandSegmentUnderConstraint( +absl::Status FrameCropRegionComputer::ExpandSegmentUnderConstraint( const Segment& segment_to_add, const Segment& base_segment, const int max_length, Segment* combined_segment, CoverType* cover_type) const { @@ -75,10 +75,10 @@ mediapipe::Status FrameCropRegionComputer::ExpandSegmentUnderConstraint( *combined_segment = std::make_pair(combined_segment_left, combined_segment_right); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FrameCropRegionComputer::ExpandRectUnderConstraints( +absl::Status FrameCropRegionComputer::ExpandRectUnderConstraints( const Rect& rect_to_add, const int max_width, const int max_height, Rect* base_rect, CoverType* cover_type) const { RET_CHECK(base_rect != nullptr) << "Base rect is null."; @@ -129,7 +129,7 @@ mediapipe::Status FrameCropRegionComputer::ExpandRectUnderConstraints( } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void FrameCropRegionComputer::UpdateCropRegionScore( @@ -167,7 +167,7 @@ void FrameCropRegionComputer::UpdateCropRegionScore( } } -mediapipe::Status FrameCropRegionComputer::ComputeFrameCropRegion( +absl::Status FrameCropRegionComputer::ComputeFrameCropRegion( const KeyFrameInfo& frame_info, KeyFrameCropResult* crop_result) const { RET_CHECK(crop_result != nullptr) << "KeyFrameCropResult is null."; @@ -254,7 +254,7 @@ mediapipe::Status FrameCropRegionComputer::ComputeFrameCropRegion( crop_result->set_region_is_empty(crop_region_is_empty); crop_result->set_region_score(crop_region_score); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.h b/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.h index 1d5107bb3..b2be9e28c 100644 --- a/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.h +++ b/mediapipe/examples/desktop/autoflip/quality/frame_crop_region_computer.h @@ -43,8 +43,8 @@ class FrameCropRegionComputer { // consider static features, and simply tries to fit the detected features // within the target frame size. The score of the crop region is aggregated // from individual feature scores given the score aggregation type. - mediapipe::Status ComputeFrameCropRegion( - const KeyFrameInfo& frame_info, KeyFrameCropResult* crop_result) const; + absl::Status ComputeFrameCropRegion(const KeyFrameInfo& frame_info, + KeyFrameCropResult* crop_result) const; protected: // A segment is a 1-d object defined by its left and right point. @@ -75,11 +75,11 @@ class FrameCropRegionComputer { // fraction of the new segment exceeds the maximum length. // In this case the combined segment is the base segment, and cover // type is NOT_COVERED. - mediapipe::Status ExpandSegmentUnderConstraint(const Segment& segment_to_add, - const Segment& base_segment, - const int max_length, - Segment* combined_segment, - CoverType* cover_type) const; + absl::Status ExpandSegmentUnderConstraint(const Segment& segment_to_add, + const Segment& base_segment, + const int max_length, + Segment* combined_segment, + CoverType* cover_type) const; // Expands a base rectangle to cover a new rectangle to be added under width // and height constraints. The operation is best-effort. It considers @@ -88,11 +88,10 @@ class FrameCropRegionComputer { // FULLY_COVERED if the new rectangle is fully covered in both directions, // PARTIALLY_COVERED if it is at least partially covered in both directions, // and NOT_COVERED if it is not covered in either direction. - mediapipe::Status ExpandRectUnderConstraints(const Rect& rect_to_add, - const int max_width, - const int max_height, - Rect* base_rect, - CoverType* cover_type) const; + absl::Status ExpandRectUnderConstraints(const Rect& rect_to_add, + const int max_width, + const int max_height, Rect* base_rect, + CoverType* cover_type) const; // Updates crop region score given current feature score, whether the feature // is required, and the score aggregation type. Ignores negative scores. diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc index 27b42a34b..899724921 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.cc @@ -14,10 +14,66 @@ int Median(const std::deque>& positions_raw) { return positions[n]; } } // namespace -mediapipe::Status KinematicPathSolver::AddObservation(int position, - const uint64 time_us) { +bool KinematicPathSolver::IsMotionTooSmall(double delta_degs) { + if (options_.has_min_motion_to_reframe()) { + return abs(delta_degs) < options_.min_motion_to_reframe(); + } else if (delta_degs > 0) { + return delta_degs < options_.min_motion_to_reframe_upper(); + } else { + return abs(delta_degs) < options_.min_motion_to_reframe_lower(); + } +} +void KinematicPathSolver::ClearHistory() { raw_positions_at_time_.clear(); } +absl::Status KinematicPathSolver::PredictMotionState(int position, + const uint64 time_us, + bool* state) { if (!initialized_) { - current_position_px_ = position; + *state = false; + return absl::OkStatus(); + } + + auto raw_positions_at_time_copy = raw_positions_at_time_; + + raw_positions_at_time_copy.push_front( + std::pair(time_us, position)); + while (raw_positions_at_time_copy.size() > 1) { + if (static_cast(raw_positions_at_time_copy.back().first) < + static_cast(time_us) - options_.filtering_time_window_us()) { + raw_positions_at_time_copy.pop_back(); + } else { + break; + } + } + + int filtered_position = Median(raw_positions_at_time_copy); + double delta_degs = + (filtered_position - current_position_px_) / pixels_per_degree_; + + // If the motion is smaller than the min_motion_to_reframe and camera is + // stationary, don't use the update. + if (IsMotionTooSmall(delta_degs) && !motion_state_) { + *state = false; + } else if (abs(delta_degs) < options_.reframe_window() && motion_state_) { + // If the motion is smaller than the reframe_window and camera is moving, + // don't use the update. + *state = false; + } else { + // Apply new position, plus the reframe window size. + *state = true; + } + + return absl::OkStatus(); +} +absl::Status KinematicPathSolver::AddObservation(int position, + const uint64 time_us) { + if (!initialized_) { + if (position < min_location_) { + current_position_px_ = min_location_; + } else if (position > max_location_) { + current_position_px_ = max_location_; + } else { + current_position_px_ = position; + } target_position_px_ = position; motion_state_ = false; mean_delta_t_ = -1; @@ -30,13 +86,27 @@ mediapipe::Status KinematicPathSolver::AddObservation(int position, << "pixels_per_degree must be larger than 0."; RET_CHECK_GE(options_.update_rate_seconds(), 0) << "update_rate_seconds must be greater than 0."; - RET_CHECK_GE(options_.min_motion_to_reframe(), options_.reframe_window()) - << "Reframe window cannot exceed min_motion_to_reframe."; RET_CHECK_GE(options_.filtering_time_window_us(), 0) << "update_rate_seconds must be greater than 0."; RET_CHECK_GE(options_.mean_period_update_rate(), 0) << "mean_period_update_rate must be greater than 0."; - return mediapipe::OkStatus(); + RET_CHECK(options_.has_min_motion_to_reframe() ^ + (options_.has_min_motion_to_reframe_upper() && + options_.has_min_motion_to_reframe_lower())) + << "Must set min_motion_to_reframe or min_motion_to_reframe_upper and " + "min_motion_to_reframe_lower."; + if (options_.has_min_motion_to_reframe()) { + RET_CHECK_GE(options_.min_motion_to_reframe(), options_.reframe_window()) + << "Reframe window cannot exceed min_motion_to_reframe."; + } else { + RET_CHECK_GE(options_.min_motion_to_reframe_upper(), + options_.reframe_window()) + << "Reframe window cannot exceed min_motion_to_reframe."; + RET_CHECK_GE(options_.min_motion_to_reframe_lower(), + options_.reframe_window()) + << "Reframe window cannot exceed min_motion_to_reframe."; + } + return absl::OkStatus(); } RET_CHECK(current_time_ < time_us) @@ -58,7 +128,7 @@ mediapipe::Status KinematicPathSolver::AddObservation(int position, // If the motion is smaller than the min_motion_to_reframe and camera is // stationary, don't use the update. - if (abs(delta_degs) < options_.min_motion_to_reframe() && !motion_state_) { + if (IsMotionTooSmall(delta_degs) && !motion_state_) { delta_degs = 0; motion_state_ = false; } else if (abs(delta_degs) < options_.reframe_window() && motion_state_) { @@ -100,7 +170,7 @@ mediapipe::Status KinematicPathSolver::AddObservation(int position, return UpdatePrediction(time_us); } -mediapipe::Status KinematicPathSolver::UpdatePrediction(const int64 time_us) { +absl::Status KinematicPathSolver::UpdatePrediction(const int64 time_us) { RET_CHECK(current_time_ < time_us) << "Prediction time added before a prior observation or prediction."; @@ -122,36 +192,58 @@ mediapipe::Status KinematicPathSolver::UpdatePrediction(const int64 time_us) { if (update_position_px < min_location_) { current_position_px_ = min_location_; current_velocity_deg_per_s_ = 0; + motion_state_ = false; } else if (update_position_px > max_location_) { current_position_px_ = max_location_; current_velocity_deg_per_s_ = 0; + motion_state_ = false; } else { current_position_px_ = update_position_px; } current_time_ = time_us; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status KinematicPathSolver::GetState(int* position) { +absl::Status KinematicPathSolver::GetState(int* position) { RET_CHECK(initialized_) << "GetState called before first observation added."; *position = round(current_position_px_); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status KinematicPathSolver::GetTargetPosition(int* target_position) { +absl::Status KinematicPathSolver::GetTargetPosition(int* target_position) { RET_CHECK(initialized_) << "GetTargetPosition called before first observation added."; *target_position = round(target_position_px_); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status KinematicPathSolver::UpdatePixelsPerDegree( +absl::Status KinematicPathSolver::UpdatePixelsPerDegree( const float pixels_per_degree) { RET_CHECK_GT(pixels_per_degree_, 0) << "pixels_per_degree must be larger than 0."; pixels_per_degree_ = pixels_per_degree; - return mediapipe::OkStatus(); + return absl::OkStatus(); +} + +absl::Status KinematicPathSolver::UpdateMinMaxLocation(const int min_location, + const int max_location) { + RET_CHECK(initialized_) + << "UpdateMinMaxLocation called before first observation added."; + double prior_distance = max_location_ - min_location_; + double updated_distance = max_location - min_location; + double scale_change = updated_distance / prior_distance; + current_position_px_ = current_position_px_ * scale_change; + target_position_px_ = target_position_px_ * scale_change; + max_location_ = max_location; + min_location_ = min_location; + auto original_positions_at_time = raw_positions_at_time_; + raw_positions_at_time_.clear(); + for (auto position_at_time : original_positions_at_time) { + position_at_time.second = position_at_time.second * scale_change; + raw_positions_at_time_.push_front(position_at_time); + } + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h index 60ac4dc35..4f4b896e2 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.h @@ -43,22 +43,34 @@ class KinematicPathSolver { initialized_(false), pixels_per_degree_(pixels_per_degree) {} // Add an observation (detection) at a position and time. - mediapipe::Status AddObservation(int position, const uint64 time_us); + absl::Status AddObservation(int position, const uint64 time_us); // Get the predicted position at a time. - mediapipe::Status UpdatePrediction(const int64 time_us); + absl::Status UpdatePrediction(const int64 time_us); // Get the state at a time. - mediapipe::Status GetState(int* position); + absl::Status GetState(int* position); // Update PixelPerDegree value. - mediapipe::Status UpdatePixelsPerDegree(const float pixels_per_degree); + absl::Status UpdatePixelsPerDegree(const float pixels_per_degree); // Provide the current target position of the reframe action. - mediapipe::Status GetTargetPosition(int* target_position); + absl::Status GetTargetPosition(int* target_position); + // Change min/max location and update state based on new scaling. + absl::Status UpdateMinMaxLocation(const int min_location, + const int max_location); + // Check if motion is within the reframe window, return false if not. + bool IsMotionTooSmall(double delta_degs); + // Check if a position measurement will cause the camera to be in motion + // without updating the internal state. + absl::Status PredictMotionState(int position, const uint64 time_us, + bool* state); + // Clear any history buffer of positions that are used when + // filtering_time_window_us is set to a non-zero value. + void ClearHistory(); private: // Tuning options. KinematicOptions options_; // Min and max value the state can be. - const int min_location_; - const int max_location_; + int min_location_; + int max_location_; bool initialized_; float pixels_per_degree_; // Current state values. diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto index 406418733..9f481db6d 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto @@ -8,8 +8,14 @@ message KinematicOptions { optional double update_rate = 1 [default = 0.5, deprecated = true]; // Max velocity (degrees per second) that the camera can move. optional double max_velocity = 2 [default = 18]; - // Min motion (in degrees) to react in pixels. - optional float min_motion_to_reframe = 3 [default = 1.8]; + // Min motion (in degrees) to react for both upper and lower directions. Must + // not be set if using min_motion_to_reframe_lower and + // min_motion_to_reframe_upper. + optional float min_motion_to_reframe = 3; + // Min motion (in degrees) for upper and lower direction to react. Both must + // be set and min_motion_to_reframe cannot be set if these are specified. + optional float min_motion_to_reframe_lower = 9; + optional float min_motion_to_reframe_upper = 10; // When motion exceeds min_motion_to_reframe, move within this distance of the // camera from the starting direction. Setting this value non-zero reduces // total reframe distance on average. Value cannot exceed diff --git a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc index 7ca8045e5..d6f14cce4 100644 --- a/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver_test.cc @@ -190,6 +190,50 @@ TEST(KinematicPathSolverTest, PassReframeWindow) { EXPECT_EQ(state, 508); } +TEST(KinematicPathSolverTest, PassReframeWindowLowerUpper) { + KinematicOptions options; + // Set min motion to 1deg + options.set_min_motion_to_reframe_upper(1.3); + options.set_min_motion_to_reframe_lower(1.0); + options.set_update_rate_seconds(.0000001); + options.set_max_update_rate(1.0); + options.set_max_velocity(1000); + // Set reframe window size to .75 for test. + options.set_reframe_window(0.75); + // Set degrees / pixel to 16.6 + KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); + int state; + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0)); + // Move target by 20px / 16.6 = 1.2deg + MP_ASSERT_OK(solver.AddObservation(520, kMicroSecInSec * 1)); + MP_ASSERT_OK(solver.GetState(&state)); + // Expect cam to not move + EXPECT_EQ(state, 500); + MP_ASSERT_OK(solver.AddObservation(480, kMicroSecInSec * 2)); + MP_ASSERT_OK(solver.GetState(&state)); + // Expect cam to move + EXPECT_EQ(state, 493); +} + +TEST(KinematicPathSolverTest, PassCheckState) { + KinematicOptions options; + // Set min motion to 1deg + options.set_min_motion_to_reframe(1.0); + options.set_update_rate_seconds(.0000001); + options.set_max_update_rate(1.0); + options.set_max_velocity(1000); + // Set reframe window size to .75 for test. + options.set_reframe_window(0.75); + // Set degrees / pixel to 16.6 + KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0)); + // Move target by 20px / 16.6 = 1.2deg + bool motion_state; + MP_ASSERT_OK( + solver.PredictMotionState(520, kMicroSecInSec * 1, &motion_state)); + EXPECT_TRUE(motion_state); +} + TEST(KinematicPathSolverTest, PassUpdateRate30FPS) { KinematicOptions options; options.set_min_motion_to_reframe(1.0); @@ -238,6 +282,26 @@ TEST(KinematicPathSolverTest, PassUpdateRate) { EXPECT_EQ(state, 505); } +TEST(KinematicPathSolverTest, PassUpdateRateResolutionChange) { + KinematicOptions options; + options.set_min_motion_to_reframe(1.0); + options.set_update_rate_seconds(4); + options.set_max_update_rate(1.0); + options.set_max_velocity(18); + KinematicPathSolver solver(options, 0, 1000, 1000.0 / kWidthFieldOfView); + int state, target_position; + MP_ASSERT_OK(solver.AddObservation(500, kMicroSecInSec * 0)); + MP_ASSERT_OK(solver.GetTargetPosition(&target_position)); + EXPECT_EQ(target_position, 500); + MP_ASSERT_OK(solver.UpdateMinMaxLocation(0, 500)); + MP_ASSERT_OK(solver.UpdatePixelsPerDegree(500.0 / kWidthFieldOfView)); + MP_ASSERT_OK(solver.AddObservation(520 * 0.5, kMicroSecInSec * 1)); + MP_ASSERT_OK(solver.GetTargetPosition(&target_position)); + EXPECT_EQ(target_position, 520 * 0.5); + MP_ASSERT_OK(solver.GetState(&state)); + EXPECT_EQ(state, 253); +} + TEST(KinematicPathSolverTest, PassMaxVelocity) { KinematicOptions options; options.set_min_motion_to_reframe(1.0); diff --git a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.cc b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.cc index bdbfe2d42..3d489c395 100644 --- a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.cc +++ b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.cc @@ -45,7 +45,7 @@ PaddingEffectGenerator::PaddingEffectGenerator(const int input_width, } } -mediapipe::Status PaddingEffectGenerator::Process( +absl::Status PaddingEffectGenerator::Process( const ImageFrame& input_frame, const float background_contrast, const int blur_cv_size, const float overlay_opacity, ImageFrame* output_frame, const cv::Scalar* background_color_in_rgb) { @@ -170,7 +170,7 @@ mediapipe::Status PaddingEffectGenerator::Process( output_frame->CopyPixelData(input_frame.Format(), canvas.cols, canvas.rows, canvas.data, ImageFrame::kDefaultAlignmentBoundary); - return mediapipe::OkStatus(); + return absl::OkStatus(); } cv::Rect PaddingEffectGenerator::ComputeOutputLocation() { diff --git a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h index 20e34ecd7..2d33593b6 100644 --- a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h +++ b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator.h @@ -49,11 +49,10 @@ class PaddingEffectGenerator { // the opacity of the black layer. // - background_color_in_rgb: If not null, uses this solid color as background // instead of blurring the image, and does not adjust contrast or opacity. - mediapipe::Status Process( - const ImageFrame& input_frame, const float background_contrast, - const int blur_cv_size, const float overlay_opacity, - ImageFrame* output_frame, - const cv::Scalar* background_color_in_rgb = nullptr); + absl::Status Process(const ImageFrame& input_frame, + const float background_contrast, const int blur_cv_size, + const float overlay_opacity, ImageFrame* output_frame, + const cv::Scalar* background_color_in_rgb = nullptr); // Compute the "render location" on the output frame where the "crop from" // location is to be placed. For use with external rendering soutions. diff --git a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc index f3c3097d7..fcdcf4b09 100644 --- a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc @@ -48,12 +48,14 @@ const cv::Scalar kRed = cv::Scalar(255, 0, 0); void TestWithAspectRatio(const double aspect_ratio, const cv::Scalar* background_color_in_rgb = nullptr) { std::string test_image; - const bool process_arbitrary_image = !FLAGS_input_image.empty(); + const bool process_arbitrary_image = + !absl::GetFlag(FLAGS_input_image).empty(); if (!process_arbitrary_image) { std::string test_image_path = mediapipe::file::JoinPath("./", kTestImage); MP_ASSERT_OK(mediapipe::file::GetContents(test_image_path, &test_image)); } else { - MP_ASSERT_OK(mediapipe::file::GetContents(FLAGS_input_image, &test_image)); + MP_ASSERT_OK(mediapipe::file::GetContents(absl::GetFlag(FLAGS_input_image), + &test_image)); } const std::vector contents_vector(test_image.begin(), test_image.end()); @@ -138,7 +140,7 @@ void TestWithAspectRatio(const double aspect_ratio, EXPECT_EQ(result_image, output_string); } else { std::string output_string_path = mediapipe::file::JoinPath( - FLAGS_output_folder, + absl::GetFlag(FLAGS_output_folder), absl::StrCat("result_", aspect_ratio, background_color_in_rgb ? "_solid_background" : "", ".jpg")); diff --git a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc index 2db06151c..dd30566c2 100644 --- a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc +++ b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.cc @@ -91,7 +91,7 @@ void PolynomialRegressionPathSolver::AddCostFunctionToProblem( problem->AddResidualBlock(cost_function, new CauchyLoss(0.5), a, b, c, d, k); } -mediapipe::Status PolynomialRegressionPathSolver::ComputeCameraPath( +absl::Status PolynomialRegressionPathSolver::ComputeCameraPath( const std::vector& focus_point_frames, const std::vector& prior_focus_point_frames, const int original_width, const int original_height, const int output_width, @@ -163,7 +163,7 @@ mediapipe::Status PolynomialRegressionPathSolver::ComputeCameraPath( } all_transforms->push_back(transform); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.h b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.h index 99d6a2f2c..a510169db 100644 --- a/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.h +++ b/mediapipe/examples/desktop/autoflip/quality/polynomial_regression_path_solver.h @@ -42,7 +42,7 @@ class PolynomialRegressionPathSolver { // y-axis, such that focus points can be preserved as much as possible. The // returned |all_transforms| hold the camera location at each timestamp // corresponding to each input frame. - mediapipe::Status ComputeCameraPath( + absl::Status ComputeCameraPath( const std::vector& focus_point_frames, const std::vector& prior_focus_point_frames, const int original_width, const int original_height, diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc index 25a6a2c6a..0bfe72548 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.cc @@ -30,8 +30,7 @@ namespace mediapipe { namespace autoflip { -mediapipe::Status -SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( +absl::Status SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( const KeyFrameCropOptions& key_frame_crop_options, const std::vector& key_frame_crop_results, const int scene_frame_width, const int scene_frame_height, @@ -67,7 +66,7 @@ SceneCameraMotionAnalyzer::AnalyzeSceneAndPopulateFocusPointFrames( scene_frame_timestamps, focus_point_frames); } -mediapipe::Status SceneCameraMotionAnalyzer::ToUseSteadyMotion( +absl::Status SceneCameraMotionAnalyzer::ToUseSteadyMotion( const float look_at_center_x, const float look_at_center_y, const int crop_window_width, const int crop_window_height, SceneKeyFrameCropSummary* scene_summary, @@ -77,10 +76,10 @@ mediapipe::Status SceneCameraMotionAnalyzer::ToUseSteadyMotion( auto* steady_motion = scene_camera_motion->mutable_steady_motion(); steady_motion->set_steady_look_at_center_x(look_at_center_x); steady_motion->set_steady_look_at_center_y(look_at_center_y); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SceneCameraMotionAnalyzer::ToUseSweepingMotion( +absl::Status SceneCameraMotionAnalyzer::ToUseSweepingMotion( const float start_x, const float start_y, const float end_x, const float end_y, const int crop_window_width, const int crop_window_height, const double time_duration_in_sec, @@ -99,10 +98,10 @@ mediapipe::Status SceneCameraMotionAnalyzer::ToUseSweepingMotion( scene_summary->frame_success_rate(), start_x, start_y, end_x, end_y, time_duration_in_sec); VLOG(1) << sweeping_log; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SceneCameraMotionAnalyzer::DecideCameraMotionType( +absl::Status SceneCameraMotionAnalyzer::DecideCameraMotionType( const KeyFrameCropOptions& key_frame_crop_options, const double scene_span_sec, const int64 end_time_us, SceneKeyFrameCropSummary* scene_summary, @@ -131,7 +130,7 @@ mediapipe::Status SceneCameraMotionAnalyzer::DecideCameraMotionType( no_salient_position_x, no_salient_position_y, scene_summary->crop_window_width(), scene_summary->crop_window_height(), scene_summary, scene_camera_motion)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Sweep across the scene when 1) success rate is too low, AND 2) the current @@ -164,7 +163,7 @@ mediapipe::Status SceneCameraMotionAnalyzer::DecideCameraMotionType( start_x, start_y, end_x, end_y, key_frame_crop_options.target_width(), key_frame_crop_options.target_height(), scene_span_sec, scene_summary, scene_camera_motion)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // If scene motion is small, then look at a steady point in the scene. @@ -179,14 +178,14 @@ mediapipe::Status SceneCameraMotionAnalyzer::DecideCameraMotionType( // Otherwise, tracks the focus regions. scene_camera_motion->mutable_tracking_motion(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // If there is no required focus region, looks at the middle of the center // range, and snaps to the scene center if close. Otherwise, look at the center // of the union of the required focus regions, and ensures the crop region // covers this union. -mediapipe::Status SceneCameraMotionAnalyzer::DecideSteadyLookAtRegion( +absl::Status SceneCameraMotionAnalyzer::DecideSteadyLookAtRegion( const KeyFrameCropOptions& key_frame_crop_options, SceneKeyFrameCropSummary* scene_summary, SceneCameraMotion* scene_camera_motion) const { @@ -252,11 +251,10 @@ mediapipe::Status SceneCameraMotionAnalyzer::DecideSteadyLookAtRegion( MP_RETURN_IF_ERROR(ToUseSteadyMotion(center_x, center_y, crop_width, crop_height, scene_summary, scene_camera_motion)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status -SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( +absl::Status SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( const float center_x, const float center_y, const int frame_width, const int frame_height, const FocusPointFrameType type, const float weight, const float bound, FocusPointFrame* focus_point_frame) const { @@ -294,10 +292,10 @@ SceneCameraMotionAnalyzer::AddFocusPointsFromCenterTypeAndWeight( } else { RET_CHECK_FAIL() << absl::StrCat("Invalid FocusPointFrameType ", type); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( +absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( const SceneKeyFrameCropSummary& scene_summary, const SceneCameraMotion& scene_camera_motion, const std::vector& scene_frame_timestamps, @@ -340,7 +338,7 @@ mediapipe::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( options_.salient_point_bound(), &focus_point_frame)); focus_point_frames->push_back(focus_point_frame); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } else if (scene_camera_motion.has_sweeping_motion()) { // Camera sweeps across the frame. const auto& sweeping_motion = scene_camera_motion.sweeping_motion(); @@ -361,7 +359,7 @@ mediapipe::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( options_.salient_point_bound(), &focus_point_frame)); focus_point_frames->push_back(focus_point_frame); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } else if (scene_camera_motion.has_tracking_motion()) { // Camera tracks crop regions. RET_CHECK_GT(scene_summary.num_key_frames(), 0) << "No key frames."; @@ -369,8 +367,7 @@ mediapipe::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( scene_summary, focus_point_frame_type, scene_frame_timestamps, focus_point_frames); } else { - return mediapipe::Status(StatusCode::kInvalidArgument, - "Unknown motion type."); + return absl::Status(StatusCode::kInvalidArgument, "Unknown motion type."); } } @@ -380,8 +377,7 @@ mediapipe::Status SceneCameraMotionAnalyzer::PopulateFocusPointFrames( // The weight for the focus point is proportional to the interpolated score // and scaled so that the maximum weight is equal to // maximum_focus_point_weight in the SceneCameraMotionAnalyzerOptions. -mediapipe::Status -SceneCameraMotionAnalyzer::PopulateFocusPointFramesForTracking( +absl::Status SceneCameraMotionAnalyzer::PopulateFocusPointFramesForTracking( const SceneKeyFrameCropSummary& scene_summary, const FocusPointFrameType focus_point_frame_type, const std::vector& scene_frame_timestamps, @@ -440,7 +436,7 @@ SceneCameraMotionAnalyzer::PopulateFocusPointFramesForTracking( focus_point->set_weight(scale * focus_point->weight()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h index 63295ffd1..8688e16ed 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer.h @@ -62,7 +62,7 @@ class SceneCameraMotionAnalyzer { // Aggregates information from KeyFrameInfos and KeyFrameCropResults into // SceneKeyFrameCropSummary, and populates FocusPointFrames given scene // frame timestamps. Optionally returns SceneCameraMotion. - mediapipe::Status AnalyzeSceneAndPopulateFocusPointFrames( + absl::Status AnalyzeSceneAndPopulateFocusPointFrames( const KeyFrameCropOptions& key_frame_crop_options, const std::vector& key_frame_crop_results, const int scene_frame_width, const int scene_frame_height, @@ -75,7 +75,7 @@ class SceneCameraMotionAnalyzer { protected: // Decides SceneCameraMotion based on SceneKeyFrameCropSummary. Updates the // crop window in SceneKeyFrameCropSummary in the case of steady motion. - mediapipe::Status DecideCameraMotionType( + absl::Status DecideCameraMotionType( const KeyFrameCropOptions& key_frame_crop_options, const double scene_span_sec, const int64 end_time_us, SceneKeyFrameCropSummary* scene_summary, @@ -83,7 +83,7 @@ class SceneCameraMotionAnalyzer { // Populates the FocusPointFrames for each scene frame based on // SceneKeyFrameCropSummary, SceneCameraMotion, and scene frame timestamps. - mediapipe::Status PopulateFocusPointFrames( + absl::Status PopulateFocusPointFrames( const SceneKeyFrameCropSummary& scene_summary, const SceneCameraMotion& scene_camera_motion, const std::vector& scene_frame_timestamps, @@ -91,7 +91,7 @@ class SceneCameraMotionAnalyzer { private: // Decides the look-at region when camera is steady. - mediapipe::Status DecideSteadyLookAtRegion( + absl::Status DecideSteadyLookAtRegion( const KeyFrameCropOptions& key_frame_crop_options, SceneKeyFrameCropSummary* scene_summary, SceneCameraMotion* scene_camera_motion) const; @@ -105,7 +105,7 @@ class SceneCameraMotionAnalyzer { // Adds FocusPoint(s) to given FocusPointFrame given center location, // frame size, FocusPointFrameType, weight, and bound. - mediapipe::Status AddFocusPointsFromCenterTypeAndWeight( + absl::Status AddFocusPointsFromCenterTypeAndWeight( const float center_x, const float center_y, const int frame_width, const int frame_height, const FocusPointFrameType type, const float weight, const float bound, @@ -114,21 +114,22 @@ class SceneCameraMotionAnalyzer { // Populates the FocusPointFrames for each scene frame based on // SceneKeyFrameCropSummary and scene frame timestamps in the case where // camera is tracking the crop regions. - mediapipe::Status PopulateFocusPointFramesForTracking( + absl::Status PopulateFocusPointFramesForTracking( const SceneKeyFrameCropSummary& scene_summary, const FocusPointFrameType focus_point_frame_type, const std::vector& scene_frame_timestamps, std::vector* focus_point_frames) const; // Decide to use steady motion. - mediapipe::Status ToUseSteadyMotion( - const float look_at_center_x, const float look_at_center_y, - const int crop_window_width, const int crop_window_height, - SceneKeyFrameCropSummary* scene_summary, - SceneCameraMotion* scene_camera_motion) const; + absl::Status ToUseSteadyMotion(const float look_at_center_x, + const float look_at_center_y, + const int crop_window_width, + const int crop_window_height, + SceneKeyFrameCropSummary* scene_summary, + SceneCameraMotion* scene_camera_motion) const; // Decide to use sweeping motion. - mediapipe::Status ToUseSweepingMotion( + absl::Status ToUseSweepingMotion( const float start_x, const float start_y, const float end_x, const float end_y, const int crop_window_width, const int crop_window_height, const double time_duration_in_sec, diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc index f24a2f22d..1e8805b09 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_camera_motion_analyzer_test.cc @@ -24,6 +24,7 @@ #include "mediapipe/examples/desktop/autoflip/quality/focus_point.pb.h" #include "mediapipe/examples/desktop/autoflip/quality/piecewise_linear_function.h" #include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_cropper.cc b/mediapipe/examples/desktop/autoflip/quality/scene_cropper.cc index 70c2f92b5..a3c6f17c6 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_cropper.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_cropper.cc @@ -29,7 +29,7 @@ constexpr float kWidthFieldOfView = 60; namespace mediapipe { namespace autoflip { -mediapipe::Status SceneCropper::ProcessKinematicPathSolver( +absl::Status SceneCropper::ProcessKinematicPathSolver( const SceneKeyFrameCropSummary& scene_summary, const std::vector& scene_timestamps, const std::vector& is_key_frames, @@ -77,10 +77,10 @@ mediapipe::Status SceneCropper::ProcessKinematicPathSolver( -(x_path - scene_summary.crop_window_width() / 2); all_xforms->push_back(transform); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SceneCropper::CropFrames( +absl::Status SceneCropper::CropFrames( const SceneKeyFrameCropSummary& scene_summary, const std::vector& scene_timestamps, const std::vector& is_key_frames, @@ -151,7 +151,7 @@ mediapipe::Status SceneCropper::CropFrames( // If no cropped_frames is passed in, return directly. if (!cropped_frames) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } RET_CHECK(!scene_frames_or_empty.empty()) << "If |cropped_frames| != nullptr, scene_frames_or_empty must not be " diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h b/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h index b77cfc60f..0e5c332db 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h +++ b/mediapipe/examples/desktop/autoflip/quality/scene_cropper.h @@ -60,7 +60,7 @@ class SceneCropper { // on the transform matrix if |cropped_frames| is not nullptr and // |scene_frames_or_empty| isn't empty. // TODO: split this function into two separate functions. - mediapipe::Status CropFrames( + absl::Status CropFrames( const SceneKeyFrameCropSummary& scene_summary, const std::vector& scene_timestamps, const std::vector& is_key_frames, @@ -71,7 +71,7 @@ class SceneCropper { const bool continue_last_scene, std::vector* crop_from_location, std::vector* cropped_frames); - mediapipe::Status ProcessKinematicPathSolver( + absl::Status ProcessKinematicPathSolver( const SceneKeyFrameCropSummary& scene_summary, const std::vector& scene_timestamps, const std::vector& is_key_frames, diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc b/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc index 76fffb33e..d99292fa3 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc +++ b/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.cc @@ -46,7 +46,7 @@ const cv::Scalar kOrange = cv::Scalar(255.0, 165.0, 0.0); // ica object detector const cv::Scalar kWhite = cv::Scalar(255.0, 255.0, 255.0); // others -mediapipe::Status DrawDetectionsAndCropRegions( +absl::Status DrawDetectionsAndCropRegions( const std::vector& scene_frames, const std::vector& is_key_frames, const std::vector& key_frame_infos, @@ -130,7 +130,7 @@ mediapipe::Status DrawDetectionsAndCropRegions( } viz_frames->push_back(std::move(viz_frame)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } namespace { @@ -147,7 +147,7 @@ cv::Rect LimitBounds(const cv::Rect& rect, const int max_width, } } // namespace -mediapipe::Status DrawDetectionAndFramingWindow( +absl::Status DrawDetectionAndFramingWindow( const std::vector& org_scene_frames, const std::vector& crop_from_locations, const ImageFormat::Format image_format, const float overlay_opacity, @@ -166,10 +166,10 @@ mediapipe::Status DrawDetectionAndFramingWindow( scene_frame(crop_from_bounded).copyTo(darkened(crop_from_bounded)); viz_frames->push_back(std::move(viz_frame)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status DrawFocusPointAndCropWindow( +absl::Status DrawFocusPointAndCropWindow( const std::vector& scene_frames, const std::vector& focus_point_frames, const float overlay_opacity, const int crop_window_width, @@ -215,7 +215,7 @@ mediapipe::Status DrawFocusPointAndCropWindow( } viz_frames->push_back(std::move(viz_frame)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.h b/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.h index c2309f77a..01f8c5de5 100644 --- a/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.h +++ b/mediapipe/examples/desktop/autoflip/quality/scene_cropping_viz.h @@ -36,7 +36,7 @@ namespace autoflip { // magenta, logos are red, ocrs are yellow (foreground) and light yellow // (background), brain objects are cyan, ica objects are orange, and the rest // are white. -mediapipe::Status DrawDetectionsAndCropRegions( +absl::Status DrawDetectionsAndCropRegions( const std::vector& scene_frames, const std::vector& is_key_frames, const std::vector& key_frame_infos, @@ -47,7 +47,7 @@ mediapipe::Status DrawDetectionsAndCropRegions( // Draws the focus point from the given FocusPointFrame and the crop window // centered around it on the scene frame in red. This helps visualize the input // to the retargeter. -mediapipe::Status DrawFocusPointAndCropWindow( +absl::Status DrawFocusPointAndCropWindow( const std::vector& scene_frames, const std::vector& focus_point_frames, const float overlay_opacity, const int crop_window_width, @@ -57,7 +57,7 @@ mediapipe::Status DrawFocusPointAndCropWindow( // Draws the final smoothed path of the camera retargeter by darkening the // removed areas. -mediapipe::Status DrawDetectionAndFramingWindow( +absl::Status DrawDetectionAndFramingWindow( const std::vector& org_scene_frames, const std::vector& crop_from_locations, const ImageFormat::Format image_format, const float overlay_opacity, diff --git a/mediapipe/examples/desktop/autoflip/quality/utils.cc b/mediapipe/examples/desktop/autoflip/quality/utils.cc index d1cd7723e..7b25930fc 100644 --- a/mediapipe/examples/desktop/autoflip/quality/utils.cc +++ b/mediapipe/examples/desktop/autoflip/quality/utils.cc @@ -53,12 +53,12 @@ void NormalizedRectToRect(const RectF& normalized_location, const int width, ScaleRect(normalized_location, width, height, location); } -mediapipe::Status ClampRect(const int width, const int height, Rect* location) { +absl::Status ClampRect(const int width, const int height, Rect* location) { return ClampRect(0, 0, width, height, location); } -mediapipe::Status ClampRect(const int x0, const int y0, const int x1, - const int y1, Rect* location) { +absl::Status ClampRect(const int x0, const int y0, const int x1, const int y1, + Rect* location) { RET_CHECK(!(location->x() >= x1 || location->x() + location->width() <= x0 || location->y() >= y1 || location->y() + location->height() <= y0)); @@ -73,7 +73,7 @@ mediapipe::Status ClampRect(const int x0, const int y0, const int x1, location->set_y(clamped_top); location->set_width(std::max(0, clamped_right - clamped_left)); location->set_height(std::max(0, clamped_bottom - clamped_top)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void RectUnion(const Rect& rect_to_add, Rect* rect) { @@ -89,13 +89,13 @@ void RectUnion(const Rect& rect_to_add, Rect* rect) { rect->set_height(y2 - y1); } -mediapipe::Status PackKeyFrameInfo(const int64 frame_timestamp_ms, - const DetectionSet& detections, - const int original_frame_width, - const int original_frame_height, - const int feature_frame_width, - const int feature_frame_height, - KeyFrameInfo* key_frame_info) { +absl::Status PackKeyFrameInfo(const int64 frame_timestamp_ms, + const DetectionSet& detections, + const int original_frame_width, + const int original_frame_height, + const int feature_frame_width, + const int feature_frame_height, + KeyFrameInfo* key_frame_info) { RET_CHECK(key_frame_info != nullptr) << "KeyFrameInfo is null"; RET_CHECK(original_frame_width > 0 && original_frame_height > 0 && feature_frame_width > 0 && feature_frame_height > 0) @@ -135,13 +135,12 @@ mediapipe::Status PackKeyFrameInfo(const int64 frame_timestamp_ms, } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SortDetections( - const DetectionSet& detections, - std::vector* required_regions, - std::vector* non_required_regions) { +absl::Status SortDetections(const DetectionSet& detections, + std::vector* required_regions, + std::vector* non_required_regions) { required_regions->clear(); non_required_regions->clear(); @@ -174,13 +173,13 @@ mediapipe::Status SortDetections( non_required_regions->push_back(detections.detections(original_idx)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SetKeyFrameCropTarget(const int frame_width, - const int frame_height, - const double target_aspect_ratio, - KeyFrameCropOptions* crop_options) { +absl::Status SetKeyFrameCropTarget(const int frame_width, + const int frame_height, + const double target_aspect_ratio, + KeyFrameCropOptions* crop_options) { RET_CHECK_NE(crop_options, nullptr) << "KeyFrameCropOptions is null."; RET_CHECK_GT(frame_width, 0) << "Frame width is non-positive."; RET_CHECK_GT(frame_height, 0) << "Frame height is non-positive."; @@ -198,10 +197,10 @@ mediapipe::Status SetKeyFrameCropTarget(const int frame_width, : std::round(frame_width / target_aspect_ratio); crop_options->set_target_width(crop_target_width); crop_options->set_target_height(crop_target_height); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AggregateKeyFrameResults( +absl::Status AggregateKeyFrameResults( const KeyFrameCropOptions& key_frame_crop_options, const std::vector& key_frame_crop_results, const int scene_frame_width, const int scene_frame_height, @@ -231,7 +230,7 @@ mediapipe::Status AggregateKeyFrameResults( // Handles the corner case of no key frames. if (num_key_frames == 0) { scene_summary->set_has_salient_region(false); - return mediapipe::OkStatus(); + return absl::OkStatus(); } scene_summary->set_num_key_frames(num_key_frames); @@ -327,10 +326,10 @@ mediapipe::Status AggregateKeyFrameResults( scene_summary->key_frame_center_min_y()) / scene_frame_height; scene_summary->set_vertical_motion_amount(motion_y); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ComputeSceneStaticBordersSize( +absl::Status ComputeSceneStaticBordersSize( const std::vector& static_features, int* top_border_size, int* bottom_border_size) { RET_CHECK(top_border_size) << "Output top border size is null."; @@ -374,10 +373,10 @@ mediapipe::Status ComputeSceneStaticBordersSize( *top_border_size = std::max(0, *top_border_size); *bottom_border_size = std::max(0, *bottom_border_size); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FindSolidBackgroundColor( +absl::Status FindSolidBackgroundColor( const std::vector& static_features, const std::vector& static_features_timestamps, const double min_fraction_solid_background_color, @@ -422,13 +421,13 @@ mediapipe::Status FindSolidBackgroundColor( min_fraction_solid_background_color) { *has_solid_background = true; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AffineRetarget(const cv::Size& output_size, - const std::vector& frames, - const std::vector& affine_projection, - std::vector* cropped_frames) { +absl::Status AffineRetarget(const cv::Size& output_size, + const std::vector& frames, + const std::vector& affine_projection, + std::vector* cropped_frames) { RET_CHECK(frames.size() == affine_projection.size()) << "number of frames and retarget offsets must be the same."; RET_CHECK(cropped_frames->size() == frames.size()) @@ -442,7 +441,7 @@ mediapipe::Status AffineRetarget(const cv::Size& output_size, RET_CHECK(affine.rows == 2) << "Affine matrix must be 2x3"; cv::warpAffine(frames[i], (*cropped_frames)[i], affine, output_size); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip } // namespace mediapipe diff --git a/mediapipe/examples/desktop/autoflip/quality/utils.h b/mediapipe/examples/desktop/autoflip/quality/utils.h index d9f05c4ea..7285e8265 100644 --- a/mediapipe/examples/desktop/autoflip/quality/utils.h +++ b/mediapipe/examples/desktop/autoflip/quality/utils.h @@ -29,31 +29,30 @@ namespace autoflip { // Packs detected features and timestamp (ms) into a KeyFrameInfo object. Scales // features back to the original frame size if features have been detected on a // different frame size. -mediapipe::Status PackKeyFrameInfo(const int64 frame_timestamp_ms, - const DetectionSet& detections, - const int original_frame_width, - const int original_frame_height, - const int feature_frame_width, - const int feature_frame_height, - KeyFrameInfo* key_frame_info); +absl::Status PackKeyFrameInfo(const int64 frame_timestamp_ms, + const DetectionSet& detections, + const int original_frame_width, + const int original_frame_height, + const int feature_frame_width, + const int feature_frame_height, + KeyFrameInfo* key_frame_info); // Sorts required and non-required salient regions given a detection set. -mediapipe::Status SortDetections( - const DetectionSet& detections, - std::vector* required_regions, - std::vector* non_required_regions); +absl::Status SortDetections(const DetectionSet& detections, + std::vector* required_regions, + std::vector* non_required_regions); // Sets the target crop size in KeyFrameCropOptions based on frame size and // target aspect ratio so that the target crop size covers the biggest area // possible in the frame. -mediapipe::Status SetKeyFrameCropTarget(const int frame_width, - const int frame_height, - const double target_aspect_ratio, - KeyFrameCropOptions* crop_options); +absl::Status SetKeyFrameCropTarget(const int frame_width, + const int frame_height, + const double target_aspect_ratio, + KeyFrameCropOptions* crop_options); // Aggregates information from KeyFrameInfos and KeyFrameCropResults into // SceneKeyFrameCropSummary. -mediapipe::Status AggregateKeyFrameResults( +absl::Status AggregateKeyFrameResults( const KeyFrameCropOptions& key_frame_crop_options, const std::vector& key_frame_crop_results, const int scene_frame_width, const int scene_frame_height, @@ -61,7 +60,7 @@ mediapipe::Status AggregateKeyFrameResults( // Computes the static top and border size across a scene given a vector of // StaticFeatures over frames. -mediapipe::Status ComputeSceneStaticBordersSize( +absl::Status ComputeSceneStaticBordersSize( const std::vector& static_features, int* top_border_size, int* bottom_border_size); @@ -70,7 +69,7 @@ mediapipe::Status ComputeSceneStaticBordersSize( // background color exceeds given threshold, i.e., // min_fraction_solid_background_color. Builds the background color // interpolation functions in Lab space using input timestamps. -mediapipe::Status FindSolidBackgroundColor( +absl::Status FindSolidBackgroundColor( const std::vector& static_features, const std::vector& static_features_timestamps, const double min_fraction_solid_background_color, @@ -93,12 +92,12 @@ void NormalizedRectToRect(const RectF& normalized_location, const int width, // Clamps a rectangle to lie within [x0, y0] and [x1, y1]. Returns true if the // rectangle has any overlapping with the target window. -mediapipe::Status ClampRect(const int x0, const int y0, const int x1, - const int y1, Rect* location); +absl::Status ClampRect(const int x0, const int y0, const int x1, const int y1, + Rect* location); // Convenience function to clamp a rectangle to lie within [0, 0] and // [width, height]. -mediapipe::Status ClampRect(const int width, const int height, Rect* location); +absl::Status ClampRect(const int width, const int height, Rect* location); // Enlarges a given rectangle to cover a new rectangle to be added. void RectUnion(const Rect& rect_to_add, Rect* rect); @@ -106,10 +105,10 @@ void RectUnion(const Rect& rect_to_add, Rect* rect); // Performs an affine retarget on a list of input images. Output vector // cropped_frames must be filled with Mats of the same size as output_size and // type. -mediapipe::Status AffineRetarget(const cv::Size& output_size, - const std::vector& frames, - const std::vector& affine_projection, - std::vector* cropped_frames); +absl::Status AffineRetarget(const cv::Size& output_size, + const std::vector& frames, + const std::vector& affine_projection, + std::vector* cropped_frames); } // namespace autoflip } // namespace mediapipe diff --git a/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc b/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc index cf01adf2a..9ae612004 100644 --- a/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc +++ b/mediapipe/examples/desktop/autoflip/quality/visual_scorer.cc @@ -48,9 +48,9 @@ void CropRectToMat(const cv::Mat& image, cv::Rect* rect) { VisualScorer::VisualScorer(const VisualScorerOptions& options) : options_(options) {} -mediapipe::Status VisualScorer::CalculateScore(const cv::Mat& image, - const SalientRegion& region, - float* score) const { +absl::Status VisualScorer::CalculateScore(const cv::Mat& image, + const SalientRegion& region, + float* score) const { const float weight_sum = options_.area_weight() + options_.sharpness_weight() + options_.colorfulness_weight(); @@ -74,7 +74,7 @@ mediapipe::Status VisualScorer::CalculateScore(const cv::Mat& image, CropRectToMat(image, ®ion_rect); if (region_rect.area() == 0) { *score = 0; - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Compute a score based on area covered by this region. @@ -108,11 +108,11 @@ mediapipe::Status VisualScorer::CalculateScore(const cv::Mat& image, if (*score > 1.0f || *score < 0.0f) { LOG(WARNING) << "Score of region outside expected range: " << *score; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status VisualScorer::CalculateColorfulness( - const cv::Mat& image, float* colorfulness) const { +absl::Status VisualScorer::CalculateColorfulness(const cv::Mat& image, + float* colorfulness) const { // Convert the image to HSV. cv::Mat image_hsv; cv::cvtColor(image, image_hsv, CV_RGB2HSV); @@ -134,7 +134,7 @@ mediapipe::Status VisualScorer::CalculateColorfulness( // If the mask is empty, return. if (empty_mask) { *colorfulness = 0; - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Generate a 2D histogram (hue/saturation). @@ -162,7 +162,7 @@ mediapipe::Status VisualScorer::CalculateColorfulness( } if (hue_sum == 0.0f) { *colorfulness = 0; - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Compute the histogram entropy. @@ -175,7 +175,7 @@ mediapipe::Status VisualScorer::CalculateColorfulness( } *colorfulness /= std::log(2.0f); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace autoflip diff --git a/mediapipe/examples/desktop/autoflip/quality/visual_scorer.h b/mediapipe/examples/desktop/autoflip/quality/visual_scorer.h index aa2d2d20b..b2c6d3af7 100644 --- a/mediapipe/examples/desktop/autoflip/quality/visual_scorer.h +++ b/mediapipe/examples/desktop/autoflip/quality/visual_scorer.h @@ -30,13 +30,12 @@ class VisualScorer { explicit VisualScorer(const VisualScorerOptions& options); // Computes a score on a salientregion and returns a value [0...1]. - mediapipe::Status CalculateScore(const cv::Mat& image, - const SalientRegion& region, - float* score) const; + absl::Status CalculateScore(const cv::Mat& image, const SalientRegion& region, + float* score) const; private: - mediapipe::Status CalculateColorfulness(const cv::Mat& image, - float* colorfulness) const; + absl::Status CalculateColorfulness(const cv::Mat& image, + float* colorfulness) const; VisualScorerOptions options_; }; diff --git a/mediapipe/examples/desktop/demo_run_graph_main.cc b/mediapipe/examples/desktop/demo_run_graph_main.cc index e09d2abf5..343460eac 100644 --- a/mediapipe/examples/desktop/demo_run_graph_main.cc +++ b/mediapipe/examples/desktop/demo_run_graph_main.cc @@ -40,10 +40,11 @@ DEFINE_string(output_video_path, "", "Full path of where to save result (.mp4 only). " "If not provided, show result in a window."); -mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = @@ -56,16 +57,16 @@ mediapipe::Status RunMPPGraph() { LOG(INFO) << "Initialize the camera or load the video."; cv::VideoCapture capture; - const bool load_video = !FLAGS_input_video_path.empty(); + const bool load_video = !absl::GetFlag(FLAGS_input_video_path).empty(); if (load_video) { - capture.open(FLAGS_input_video_path); + capture.open(absl::GetFlag(FLAGS_input_video_path)); } else { capture.open(0); } RET_CHECK(capture.isOpened()); cv::VideoWriter writer; - const bool save_video = !FLAGS_output_video_path.empty(); + const bool save_video = !absl::GetFlag(FLAGS_output_video_path).empty(); if (!save_video) { cv::namedWindow(kWindowName, /*flags=WINDOW_AUTOSIZE*/ 1); #if (CV_MAJOR_VERSION >= 3) && (CV_MINOR_VERSION >= 2) @@ -125,7 +126,7 @@ mediapipe::Status RunMPPGraph() { if (save_video) { if (!writer.isOpened()) { LOG(INFO) << "Prepare video writer."; - writer.open(FLAGS_output_video_path, + writer.open(absl::GetFlag(FLAGS_output_video_path), mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4 capture.get(cv::CAP_PROP_FPS), output_frame_mat.size()); RET_CHECK(writer.isOpened()); @@ -148,7 +149,7 @@ mediapipe::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc b/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc index 4bc7f92e4..6942971f7 100644 --- a/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc +++ b/mediapipe/examples/desktop/demo_run_graph_main_gpu.cc @@ -44,10 +44,11 @@ DEFINE_string(output_video_path, "", "Full path of where to save result (.mp4 only). " "If not provided, show result in a window."); -mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = @@ -66,16 +67,16 @@ mediapipe::Status RunMPPGraph() { LOG(INFO) << "Initialize the camera or load the video."; cv::VideoCapture capture; - const bool load_video = !FLAGS_input_video_path.empty(); + const bool load_video = !absl::GetFlag(FLAGS_input_video_path).empty(); if (load_video) { - capture.open(FLAGS_input_video_path); + capture.open(absl::GetFlag(FLAGS_input_video_path)); } else { capture.open(0); } RET_CHECK(capture.isOpened()); cv::VideoWriter writer; - const bool save_video = !FLAGS_output_video_path.empty(); + const bool save_video = !absl::GetFlag(FLAGS_output_video_path).empty(); if (!save_video) { cv::namedWindow(kWindowName, /*flags=WINDOW_AUTOSIZE*/ 1); #if (CV_MAJOR_VERSION >= 3) && (CV_MINOR_VERSION >= 2) @@ -122,7 +123,7 @@ mediapipe::Status RunMPPGraph() { (double)cv::getTickCount() / (double)cv::getTickFrequency() * 1e6; MP_RETURN_IF_ERROR( gpu_helper.RunInGlContext([&input_frame, &frame_timestamp_us, &graph, - &gpu_helper]() -> ::mediapipe::Status { + &gpu_helper]() -> absl::Status { // Convert ImageFrame to GpuBuffer. auto texture = gpu_helper.CreateSourceTexture(*input_frame.get()); auto gpu_frame = texture.GetFrame(); @@ -132,7 +133,7 @@ mediapipe::Status RunMPPGraph() { MP_RETURN_IF_ERROR(graph.AddPacketToInputStream( kInputStream, mediapipe::Adopt(gpu_frame.release()) .At(mediapipe::Timestamp(frame_timestamp_us)))); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); // Get the graph result packet, or stop if that fails. @@ -142,7 +143,7 @@ mediapipe::Status RunMPPGraph() { // Convert GpuBuffer to ImageFrame. MP_RETURN_IF_ERROR(gpu_helper.RunInGlContext( - [&packet, &output_frame, &gpu_helper]() -> ::mediapipe::Status { + [&packet, &output_frame, &gpu_helper]() -> absl::Status { auto& gpu_frame = packet.Get(); auto texture = gpu_helper.CreateSourceTexture(gpu_frame); output_frame = absl::make_unique( @@ -150,13 +151,13 @@ mediapipe::Status RunMPPGraph() { gpu_frame.width(), gpu_frame.height(), mediapipe::ImageFrame::kGlDefaultAlignmentBoundary); gpu_helper.BindFramebuffer(texture); - const auto info = - mediapipe::GlTextureInfoForGpuBufferFormat(gpu_frame.format(), 0); + const auto info = mediapipe::GlTextureInfoForGpuBufferFormat( + gpu_frame.format(), 0, gpu_helper.GetGlVersion()); glReadPixels(0, 0, texture.width(), texture.height(), info.gl_format, info.gl_type, output_frame->MutablePixelData()); glFlush(); texture.Release(); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); // Convert back to opencv for display or saving. @@ -168,7 +169,7 @@ mediapipe::Status RunMPPGraph() { if (save_video) { if (!writer.isOpened()) { LOG(INFO) << "Prepare video writer."; - writer.open(FLAGS_output_video_path, + writer.open(absl::GetFlag(FLAGS_output_video_path), mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4 capture.get(cv::CAP_PROP_FPS), output_frame_mat.size()); RET_CHECK(writer.isOpened()); @@ -191,7 +192,7 @@ mediapipe::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/hello_world/hello_world.cc b/mediapipe/examples/desktop/hello_world/hello_world.cc index a5d0e790c..95c34146d 100644 --- a/mediapipe/examples/desktop/hello_world/hello_world.cc +++ b/mediapipe/examples/desktop/hello_world/hello_world.cc @@ -21,7 +21,7 @@ namespace mediapipe { -mediapipe::Status PrintHelloWorld() { +absl::Status PrintHelloWorld() { // Configures a simple graph, which concatenates 2 PassThroughCalculators. CalculatorGraphConfig config = ParseTextProtoOrDie(R"( input_stream: "in" diff --git a/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc b/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc index d08e95a1e..515ee37b0 100644 --- a/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc +++ b/mediapipe/examples/desktop/iris_tracking/iris_depth_from_image_desktop.cc @@ -47,18 +47,16 @@ DEFINE_string(output_image_path, "", namespace { -mediapipe::StatusOr ReadFileToString( - const std::string& file_path) { +absl::StatusOr ReadFileToString(const std::string& file_path) { std::string contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents(file_path, &contents)); return contents; } -mediapipe::Status ProcessImage( - std::unique_ptr graph) { +absl::Status ProcessImage(std::unique_ptr graph) { LOG(INFO) << "Load the image."; ASSIGN_OR_RETURN(const std::string raw_image, - ReadFileToString(FLAGS_input_image_path)); + ReadFileToString(absl::GetFlag(FLAGS_input_image_path))); LOG(INFO) << "Start running the calculator graph."; ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller output_image_poller, @@ -80,7 +78,7 @@ mediapipe::Status ProcessImage( // Get the graph result packets, or stop if that fails. mediapipe::Packet left_iris_depth_packet; if (!left_iris_depth_poller.Next(&left_iris_depth_packet)) { - return mediapipe::UnknownError( + return absl::UnknownError( "Failed to get packet from output stream 'left_iris_depth_mm'."); } const auto& left_iris_depth_mm = left_iris_depth_packet.Get(); @@ -89,7 +87,7 @@ mediapipe::Status ProcessImage( mediapipe::Packet right_iris_depth_packet; if (!right_iris_depth_poller.Next(&right_iris_depth_packet)) { - return mediapipe::UnknownError( + return absl::UnknownError( "Failed to get packet from output stream 'right_iris_depth_mm'."); } const auto& right_iris_depth_mm = right_iris_depth_packet.Get(); @@ -99,7 +97,7 @@ mediapipe::Status ProcessImage( mediapipe::Packet output_image_packet; if (!output_image_poller.Next(&output_image_packet)) { - return mediapipe::UnknownError( + return absl::UnknownError( "Failed to get packet from output stream 'output_image'."); } auto& output_frame = output_image_packet.Get(); @@ -107,10 +105,10 @@ mediapipe::Status ProcessImage( // Convert back to opencv for display or saving. cv::Mat output_frame_mat = mediapipe::formats::MatView(&output_frame); cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGB2BGR); - const bool save_image = !FLAGS_output_image_path.empty(); + const bool save_image = !absl::GetFlag(FLAGS_output_image_path).empty(); if (save_image) { LOG(INFO) << "Saving image to file..."; - cv::imwrite(FLAGS_output_image_path, output_frame_mat); + cv::imwrite(absl::GetFlag(FLAGS_output_image_path), output_frame_mat); } else { cv::namedWindow(kWindowName, /*flags=WINDOW_AUTOSIZE*/ 1); cv::imshow(kWindowName, output_frame_mat); @@ -123,7 +121,7 @@ mediapipe::Status ProcessImage( return graph->WaitUntilDone(); } -mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( kCalculatorGraphConfigFile, &calculator_graph_config_contents)); @@ -138,11 +136,11 @@ mediapipe::Status RunMPPGraph() { absl::make_unique(); MP_RETURN_IF_ERROR(graph->Initialize(config)); - const bool load_image = !FLAGS_input_image_path.empty(); + const bool load_image = !absl::GetFlag(FLAGS_input_image_path).empty(); if (load_image) { return ProcessImage(std::move(graph)); } else { - return mediapipe::InvalidArgumentError("Missing image file."); + return absl::InvalidArgumentError("Missing image file."); } } @@ -151,7 +149,7 @@ mediapipe::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc index 56be7505d..a15f599d1 100644 --- a/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc +++ b/mediapipe/examples/desktop/media_sequence/run_graph_file_io_main.cc @@ -38,10 +38,11 @@ DEFINE_string(output_side_packets, "", "side packets and paths to write to disk for the " "CalculatorGraph."); -mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = @@ -49,7 +50,7 @@ mediapipe::Status RunMPPGraph() { calculator_graph_config_contents); std::map input_side_packets; std::vector kv_pairs = - absl::StrSplit(FLAGS_input_side_packets, ','); + absl::StrSplit(absl::GetFlag(FLAGS_input_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); RET_CHECK(name_and_value.size() == 2); @@ -66,26 +67,26 @@ mediapipe::Status RunMPPGraph() { LOG(INFO) << "Start running the calculator graph."; MP_RETURN_IF_ERROR(graph.Run()); LOG(INFO) << "Gathering output side packets."; - kv_pairs = absl::StrSplit(FLAGS_output_side_packets, ','); + kv_pairs = absl::StrSplit(absl::GetFlag(FLAGS_output_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); RET_CHECK(name_and_value.size() == 2); - mediapipe::StatusOr output_packet = + absl::StatusOr output_packet = graph.GetOutputSidePacket(name_and_value[0]); RET_CHECK(output_packet.ok()) << "Packet " << name_and_value[0] << " was not available."; const std::string& serialized_string = - output_packet.ValueOrDie().Get(); + output_packet.value().Get(); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(name_and_value[1], serialized_string)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/object_detection_3d/BUILD b/mediapipe/examples/desktop/object_detection_3d/BUILD index 0e72c9e51..86e29a728 100644 --- a/mediapipe/examples/desktop/object_detection_3d/BUILD +++ b/mediapipe/examples/desktop/object_detection_3d/BUILD @@ -22,9 +22,9 @@ package(default_visibility = ["//mediapipe/examples:__subpackages__"]) # --calculator_graph_config_file=mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt \ # --input_side_packets="input_video_path=,box_landmark_model_path=mediapipe/models/object_detection_3d_sneakers.tflite,output_video_path=,allowed_labels=Footwear" # To detect objects from other categories, change box_landmark_model_path and allowed_labels accordingly. -# Chair: box_landmark_model_path=mediapipe/models/object_detection_3d_chair.tflite,allowed_labels=Chair -# Camera: box_landmark_model_path=mediapipe/models/object_detection_3d_camera.tflite,allowed_labels=Camera -# Cup: box_landmark_model_path=mediapipe/models/object_detection_3d_cup.tflite,allowed_labels=Mug +# Chair: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_chair.tflite,allowed_labels=Chair +# Camera: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_camera.tflite,allowed_labels=Camera +# Cup: box_landmark_model_path=mediapipe/modules/objectron/object_detection_3d_cup.tflite,allowed_labels=Mug cc_binary( name = "objectron_cpu", deps = [ diff --git a/mediapipe/examples/desktop/simple_run_graph_main.cc b/mediapipe/examples/desktop/simple_run_graph_main.cc index 0f45a3810..5d33af66c 100644 --- a/mediapipe/examples/desktop/simple_run_graph_main.cc +++ b/mediapipe/examples/desktop/simple_run_graph_main.cc @@ -58,31 +58,29 @@ DEFINE_string(output_side_packets_file, "", "The name of the local file to output all side packets specified " "with --output_side_packets. "); -mediapipe::Status OutputStreamToLocalFile( - mediapipe::OutputStreamPoller& poller) { +absl::Status OutputStreamToLocalFile(mediapipe::OutputStreamPoller& poller) { std::ofstream file; - file.open(FLAGS_output_stream_file); + file.open(absl::GetFlag(FLAGS_output_stream_file)); mediapipe::Packet packet; while (poller.Next(&packet)) { std::string output_data; - if (!FLAGS_strip_timestamps) { + if (!absl::GetFlag(FLAGS_strip_timestamps)) { absl::StrAppend(&output_data, packet.Timestamp().Value(), ","); } absl::StrAppend(&output_data, packet.Get(), "\n"); file << output_data; } file.close(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OutputSidePacketsToLocalFile( - mediapipe::CalculatorGraph& graph) { - if (!FLAGS_output_side_packets.empty() && - !FLAGS_output_side_packets_file.empty()) { +absl::Status OutputSidePacketsToLocalFile(mediapipe::CalculatorGraph& graph) { + if (!absl::GetFlag(FLAGS_output_side_packets).empty() && + !absl::GetFlag(FLAGS_output_side_packets_file).empty()) { std::ofstream file; - file.open(FLAGS_output_side_packets_file); + file.open(absl::GetFlag(FLAGS_output_side_packets_file)); std::vector side_packet_names = - absl::StrSplit(FLAGS_output_side_packets, ','); + absl::StrSplit(absl::GetFlag(FLAGS_output_side_packets), ','); for (const std::string& side_packet_name : side_packet_names) { ASSIGN_OR_RETURN(auto status_or_packet, graph.GetOutputSidePacket(side_packet_name)); @@ -91,27 +89,28 @@ mediapipe::Status OutputSidePacketsToLocalFile( } file.close(); } else { - RET_CHECK(FLAGS_output_side_packets.empty() && - FLAGS_output_side_packets_file.empty()) + RET_CHECK(absl::GetFlag(FLAGS_output_side_packets).empty() && + absl::GetFlag(FLAGS_output_side_packets_file).empty()) << "--output_side_packets and --output_side_packets_file should be " "specified in pair."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie( calculator_graph_config_contents); std::map input_side_packets; - if (!FLAGS_input_side_packets.empty()) { + if (!absl::GetFlag(FLAGS_input_side_packets).empty()) { std::vector kv_pairs = - absl::StrSplit(FLAGS_input_side_packets, ','); + absl::StrSplit(absl::GetFlag(FLAGS_input_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); RET_CHECK(name_and_value.size() == 2); @@ -123,14 +122,16 @@ mediapipe::Status RunMPPGraph() { LOG(INFO) << "Initialize the calculator graph."; mediapipe::CalculatorGraph graph; MP_RETURN_IF_ERROR(graph.Initialize(config, input_side_packets)); - if (!FLAGS_output_stream.empty() && !FLAGS_output_stream_file.empty()) { - ASSIGN_OR_RETURN(auto poller, - graph.AddOutputStreamPoller(FLAGS_output_stream)); + if (!absl::GetFlag(FLAGS_output_stream).empty() && + !absl::GetFlag(FLAGS_output_stream_file).empty()) { + ASSIGN_OR_RETURN(auto poller, graph.AddOutputStreamPoller( + absl::GetFlag(FLAGS_output_stream))); LOG(INFO) << "Start running the calculator graph."; MP_RETURN_IF_ERROR(graph.StartRun({})); MP_RETURN_IF_ERROR(OutputStreamToLocalFile(poller)); } else { - RET_CHECK(FLAGS_output_stream.empty() && FLAGS_output_stream_file.empty()) + RET_CHECK(absl::GetFlag(FLAGS_output_stream).empty() && + absl::GetFlag(FLAGS_output_stream_file).empty()) << "--output_stream and --output_stream_file should be specified in " "pair."; LOG(INFO) << "Start running the calculator graph."; @@ -143,7 +144,7 @@ mediapipe::Status RunMPPGraph() { int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/desktop/youtube8m/BUILD b/mediapipe/examples/desktop/youtube8m/BUILD index af85e3113..e6347b243 100644 --- a/mediapipe/examples/desktop/youtube8m/BUILD +++ b/mediapipe/examples/desktop/youtube8m/BUILD @@ -21,7 +21,6 @@ cc_binary( "@com_google_absl//absl/strings", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:matrix", - "//mediapipe/framework/formats:matrix_data_cc_proto", "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:map_util", diff --git a/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc b/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc index c1ad40a90..a303077cc 100644 --- a/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc +++ b/mediapipe/examples/desktop/youtube8m/extract_yt8m_features.cc @@ -39,10 +39,11 @@ DEFINE_string(output_side_packets, "", "side packets and paths to write to disk for the " "CalculatorGraph."); -mediapipe::Status RunMPPGraph() { +absl::Status RunMPPGraph() { std::string calculator_graph_config_contents; MP_RETURN_IF_ERROR(mediapipe::file::GetContents( - FLAGS_calculator_graph_config_file, &calculator_graph_config_contents)); + absl::GetFlag(FLAGS_calculator_graph_config_file), + &calculator_graph_config_contents)); LOG(INFO) << "Get calculator graph config contents: " << calculator_graph_config_contents; mediapipe::CalculatorGraphConfig config = @@ -50,7 +51,7 @@ mediapipe::Status RunMPPGraph() { calculator_graph_config_contents); std::map input_side_packets; std::vector kv_pairs = - absl::StrSplit(FLAGS_input_side_packets, ','); + absl::StrSplit(absl::GetFlag(FLAGS_input_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); RET_CHECK(name_and_value.size() == 2); @@ -107,26 +108,26 @@ mediapipe::Status RunMPPGraph() { LOG(INFO) << "Start running the calculator graph."; MP_RETURN_IF_ERROR(graph.Run()); LOG(INFO) << "Gathering output side packets."; - kv_pairs = absl::StrSplit(FLAGS_output_side_packets, ','); + kv_pairs = absl::StrSplit(absl::GetFlag(FLAGS_output_side_packets), ','); for (const std::string& kv_pair : kv_pairs) { std::vector name_and_value = absl::StrSplit(kv_pair, '='); RET_CHECK(name_and_value.size() == 2); - mediapipe::StatusOr output_packet = + absl::StatusOr output_packet = graph.GetOutputSidePacket(name_and_value[0]); RET_CHECK(output_packet.ok()) << "Packet " << name_and_value[0] << " was not available."; const std::string& serialized_string = - output_packet.ValueOrDie().Get(); + output_packet.value().Get(); MP_RETURN_IF_ERROR( mediapipe::file::SetContents(name_and_value[1], serialized_string)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } int main(int argc, char** argv) { google::InitGoogleLogging(argv[0]); gflags::ParseCommandLineFlags(&argc, &argv, true); - mediapipe::Status run_status = RunMPPGraph(); + absl::Status run_status = RunMPPGraph(); if (!run_status.ok()) { LOG(ERROR) << "Failed to run the graph: " << run_status.message(); return EXIT_FAILURE; diff --git a/mediapipe/examples/ios/common/CommonViewController.mm b/mediapipe/examples/ios/common/CommonViewController.mm index e819e8170..f6c47eacf 100644 --- a/mediapipe/examples/ios/common/CommonViewController.mm +++ b/mediapipe/examples/ios/common/CommonViewController.mm @@ -137,10 +137,10 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; [self.cameraSource requestCameraAccessWithCompletionHandler:^void(BOOL granted) { if (granted) { - [self startGraphAndCamera]; dispatch_async(dispatch_get_main_queue(), ^{ self.noCameraLabel.hidden = YES; }); + [self startGraphAndCamera]; } }]; @@ -155,6 +155,9 @@ static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; if (![self.mediapipeGraph startWithError:&error]) { NSLog(@"Failed to start graph: %@", error); } + else if (![self.mediapipeGraph waitUntilIdleWithError:&error]) { + NSLog(@"Failed to complete graph initial run: %@", error); + } // Start fetching frames from the camera. dispatch_async(self.videoQueue, ^{ diff --git a/mediapipe/examples/ios/faceeffect/BUILD b/mediapipe/examples/ios/faceeffect/BUILD index 7b681cd85..9e074ef2f 100644 --- a/mediapipe/examples/ios/faceeffect/BUILD +++ b/mediapipe/examples/ios/faceeffect/BUILD @@ -57,6 +57,21 @@ objc_library( "FaceEffectViewController.h", ], copts = ["-std=c++17"], + data = [ + "Base.lproj/LaunchScreen.storyboard", + "Base.lproj/Main.storyboard", + "//mediapipe/graphs/face_effect:face_effect_gpu.binarypb", + "//mediapipe/graphs/face_effect/data:axis.binarypb", + "//mediapipe/graphs/face_effect/data:axis.pngblob", + "//mediapipe/graphs/face_effect/data:facepaint.pngblob", + "//mediapipe/graphs/face_effect/data:glasses.binarypb", + "//mediapipe/graphs/face_effect/data:glasses.pngblob", + "//mediapipe/modules/face_detection:face_detection_front.tflite", + "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata.binarypb", + "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata_detection.binarypb", + "//mediapipe/modules/face_geometry/data:geometry_pipeline_metadata_landmarks.binarypb", + "//mediapipe/modules/face_landmark:face_landmark.tflite", + ], sdk_frameworks = [ "AVFoundation", "CoreGraphics", diff --git a/mediapipe/examples/ios/faceeffect/FaceEffectViewController.mm b/mediapipe/examples/ios/faceeffect/FaceEffectViewController.mm index 2fd8f5f6f..56a895c69 100644 --- a/mediapipe/examples/ios/faceeffect/FaceEffectViewController.mm +++ b/mediapipe/examples/ios/faceeffect/FaceEffectViewController.mm @@ -18,6 +18,8 @@ #import "mediapipe/objc/MPPGraph.h" #import "mediapipe/objc/MPPLayerRenderer.h" +#include +#include #include #include "mediapipe/framework/formats/matrix_data.pb.h" @@ -27,13 +29,19 @@ static NSString* const kGraphName = @"face_effect_gpu"; static const char* kInputStream = "input_video"; -static const char* kIsFacepaintEffectSelectedInputStream = "is_facepaint_effect_selected"; static const char* kOutputStream = "output_video"; static const char* kMultiFaceGeometryStream = "multi_face_geometry"; static const char* kVideoQueueLabel = "com.google.mediapipe.example.videoQueue"; +static const char* kSelectedEffectIdInputStream = "selected_effect_id"; +static const char* kUseFaceDetectionInputSourceInputSidePacket = "use_face_detection_input_source"; +static const BOOL kUseFaceDetectionInputSource = NO; static const int kMatrixTranslationZIndex = 14; +static const int kSelectedEffectIdAxis = 0; +static const int kSelectedEffectIdFacepaint = 1; +static const int kSelectedEffectIdGlasses = 2; + @interface FaceEffectViewController () // The MediaPipe graph currently in use. Initialized in viewDidLoad, started in viewWillAppear: and @@ -45,7 +53,7 @@ static const int kMatrixTranslationZIndex = 14; @implementation FaceEffectViewController { /// Handle tap gestures. UITapGestureRecognizer* _tapGestureRecognizer; - BOOL _isFacepaintEffectSelected; + int _selectedEffectId; /// Handles camera access via AVCaptureSession library. MPPCameraInputSource* _cameraSource; @@ -93,8 +101,14 @@ static const int kMatrixTranslationZIndex = 14; mediapipe::CalculatorGraphConfig config; config.ParseFromArray(data.bytes, data.length); + // Pass the kUseFaceDetectionInputSource flag value as an input side packet into the graph. + std::map side_packets; + side_packets[kUseFaceDetectionInputSourceInputSidePacket] = + mediapipe::MakePacket(kUseFaceDetectionInputSource); + // Create MediaPipe graph with mediapipe::CalculatorGraphConfig proto object. MPPGraph* newGraph = [[MPPGraph alloc] initWithGraphConfig:config]; + [newGraph addSidePackets:side_packets]; [newGraph addFrameOutputStream:kOutputStream outputPacketType:MPPPacketTypePixelBuffer]; [newGraph addFrameOutputStream:kMultiFaceGeometryStream outputPacketType:MPPPacketTypeRaw]; return newGraph; @@ -110,8 +124,13 @@ static const int kMatrixTranslationZIndex = 14; action:@selector(handleTap)]; [self.view addGestureRecognizer:_tapGestureRecognizer]; - // By default, render the glasses effect. - _isFacepaintEffectSelected = NO; + // By default, render the axis effect for the face detection input source and the glasses effect + // for the face landmark input source. + if (kUseFaceDetectionInputSource) { + _selectedEffectId = kSelectedEffectIdAxis; + } else { + _selectedEffectId = kSelectedEffectIdGlasses; + } _renderer = [[MPPLayerRenderer alloc] init]; _renderer.layer.frame = _liveView.layer.bounds; @@ -175,7 +194,28 @@ static const int kMatrixTranslationZIndex = 14; // multiple pre-bundled face effects without a need to recompile the app. - (void)handleTap { dispatch_async(_videoQueue, ^{ - _isFacepaintEffectSelected = !_isFacepaintEffectSelected; + // Avoid switching the Axis effect for the face detection input source. + if (kUseFaceDetectionInputSource) { + return; + } + + // Looped effect order: glasses -> facepaint -> axis -> glasses -> ... + switch (_selectedEffectId) { + case kSelectedEffectIdAxis: { + _selectedEffectId = kSelectedEffectIdGlasses; + break; + } + + case kSelectedEffectIdFacepaint: { + _selectedEffectId = kSelectedEffectIdAxis; + break; + } + + case kSelectedEffectIdGlasses: { + _selectedEffectId = kSelectedEffectIdFacepaint; + break; + } + } }); } @@ -189,7 +229,7 @@ static const int kMatrixTranslationZIndex = 14; // Display the captured image on the screen. CVPixelBufferRetain(pixelBuffer); dispatch_async(dispatch_get_main_queue(), ^{ - _effectSwitchingHintLabel.hidden = NO; + _effectSwitchingHintLabel.hidden = kUseFaceDetectionInputSource; [_renderer renderPixelBuffer:pixelBuffer]; CVPixelBufferRelease(pixelBuffer); }); @@ -236,18 +276,18 @@ static const int kMatrixTranslationZIndex = 14; mediapipe::Timestamp graphTimestamp(static_cast( mediapipe::Timestamp::kTimestampUnitsPerSecond * CMTimeGetSeconds(timestamp))); - mediapipe::Packet isFacepaintEffectSelectedPacket = - mediapipe::MakePacket(_isFacepaintEffectSelected).At(graphTimestamp); + mediapipe::Packet selectedEffectIdPacket = + mediapipe::MakePacket(_selectedEffectId).At(graphTimestamp); [self.graph sendPixelBuffer:imageBuffer intoStream:kInputStream packetType:MPPPacketTypePixelBuffer timestamp:graphTimestamp]; - // Alongside the input camera frame, we also send the `is_facepaint_effect_selected` boolean - // packet to indicate which effect should be rendered on this frame. - [self.graph movePacket:std::move(isFacepaintEffectSelectedPacket) - intoStream:kIsFacepaintEffectSelectedInputStream + // Alongside the input camera frame, we also send the `selected_effect_id` int packet to indicate + // which effect should be rendered on this frame. + [self.graph movePacket:std::move(selectedEffectIdPacket) + intoStream:kSelectedEffectIdInputStream error:nil]; } diff --git a/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD b/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD new file mode 100644 index 000000000..37e0b85e9 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiontrackinggpu/BUILD @@ -0,0 +1,70 @@ +# Copyright 2019 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load( + "@build_bazel_rules_apple//apple:ios.bzl", + "ios_application", +) +load( + "//mediapipe/examples/ios:bundle_id.bzl", + "BUNDLE_ID_PREFIX", + "example_provisioning", +) + +licenses(["notice"]) + +MIN_IOS_VERSION = "10.0" + +alias( + name = "objectdetectiontrackinggpu", + actual = "ObjectDetectionTrackingGpuApp", +) + +ios_application( + name = "ObjectDetectionTrackingGpuApp", + app_icons = ["//mediapipe/examples/ios/common:AppIcon"], + bundle_id = BUNDLE_ID_PREFIX + ".ObjectDetectionTrackingGpu", + families = [ + "iphone", + "ipad", + ], + infoplists = [ + "//mediapipe/examples/ios/common:Info.plist", + "Info.plist", + ], + minimum_os_version = MIN_IOS_VERSION, + provisioning_profile = example_provisioning(), + deps = [ + ":ObjectDetectionTrackingGpuAppLibrary", + "@ios_opencv//:OpencvFramework", + ], +) + +objc_library( + name = "ObjectDetectionTrackingGpuAppLibrary", + data = [ + "//mediapipe/graphs/tracking:mobile_gpu_binary_graph", + "//mediapipe/models:ssdlite_object_detection.tflite", + "//mediapipe/models:ssdlite_object_detection_labelmap.txt", + ], + deps = [ + "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", + ] + select({ + "//mediapipe:ios_i386": [], + "//mediapipe:ios_x86_64": [], + "//conditions:default": [ + "//mediapipe/graphs/tracking:mobile_calculators", + ], + }), +) diff --git a/mediapipe/examples/ios/objectdetectiontrackinggpu/Info.plist b/mediapipe/examples/ios/objectdetectiontrackinggpu/Info.plist new file mode 100644 index 000000000..7e792c9b4 --- /dev/null +++ b/mediapipe/examples/ios/objectdetectiontrackinggpu/Info.plist @@ -0,0 +1,14 @@ + + + + + CameraPosition + back + GraphName + mobile_gpu + GraphOutputStream + output_video + GraphInputStream + input_video + + diff --git a/mediapipe/framework/BUILD b/mediapipe/framework/BUILD index 7f66bdd93..d2ed6cf1e 100644 --- a/mediapipe/framework/BUILD +++ b/mediapipe/framework/BUILD @@ -269,12 +269,6 @@ cc_library( "calculator_graph.h", "scheduler.h", ], - defines = select({ - "//conditions:default": [], - "//mediapipe/gpu:disable_gpu": [ - "MEDIAPIPE_DISABLE_GPU", - ], - }), visibility = [ ":mediapipe_internal", ], @@ -460,6 +454,7 @@ cc_library( ":type_map", "//mediapipe/framework/port:logging", "//mediapipe/framework/tool:tag_map", + "//mediapipe/framework/tool:tag_map_helper", "//mediapipe/framework/tool:validate_name", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", @@ -917,7 +912,7 @@ cc_library( "//conditions:default": [], }) + select({ "//conditions:default": [], - "//mediapipe/gpu:disable_gpu": ["MEDIAPIPE_DISABLE_GPU"], + "//mediapipe/gpu:disable_gpu": ["MEDIAPIPE_DISABLE_GPU=1"], }) + select({ "//conditions:default": [], "//mediapipe/framework:disable_rtti_and_exceptions": [ @@ -928,6 +923,7 @@ cc_library( "//mediapipe/calculators:__subpackages__", "//mediapipe/framework:__subpackages__", "//mediapipe/framework/port:__pkg__", + "//mediapipe/gpu:__pkg__", "//mediapipe/util:__subpackages__", ], ) diff --git a/mediapipe/framework/api2/BUILD b/mediapipe/framework/api2/BUILD index a700fa61a..7c9a45e36 100644 --- a/mediapipe/framework/api2/BUILD +++ b/mediapipe/framework/api2/BUILD @@ -3,15 +3,10 @@ package( features = ["-use_header_modules"], ) -# API2 is in preview mode. Internal clients are welcome and encouraged to try -# it out, but be aware that there may be more changes before release. Please -# add your package to this list and reach out to the MediaPipe team (use -# camillol@ as the CL reviewer). package_group( name = "preview_users", packages = [ "//mediapipe/...", - "//video/content_analysis/...", ], ) @@ -134,6 +129,7 @@ cc_library( ":tuple", "//mediapipe/framework:packet", "//mediapipe/framework/port:logging", + "@com_google_absl//absl/meta:type_traits", ], ) diff --git a/mediapipe/framework/api2/builder.h b/mediapipe/framework/api2/builder.h index 10ad555a3..ae32c628a 100644 --- a/mediapipe/framework/api2/builder.h +++ b/mediapipe/framework/api2/builder.h @@ -481,8 +481,7 @@ class Graph { std::string TaggedName(const TagIndexLocation& loc, const std::string& name) { if (loc.tag.empty()) { // ParseTagIndexName does not allow using explicit indices without tags, - // while ParseTagIndex does. There is no explanation for this discrepancy - // in the CLs that introduced them (cl/143209019, cl/156499931). + // while ParseTagIndex does. // TODO: decide whether we should just allow it. return name; } else { @@ -494,8 +493,8 @@ class Graph { } } - mediapipe::Status UpdateNodeConfig(const NodeBase& node, - CalculatorGraphConfig::Node* config) { + absl::Status UpdateNodeConfig(const NodeBase& node, + CalculatorGraphConfig::Node* config) { config->set_calculator(node.type_); node.in_streams_.Visit( [&](const TagIndexLocation& loc, const DestinationBase& endpoint) { @@ -521,8 +520,8 @@ class Graph { return {}; } - mediapipe::Status UpdateNodeConfig(const PacketGenerator& node, - PacketGeneratorConfig* config) { + absl::Status UpdateNodeConfig(const PacketGenerator& node, + PacketGeneratorConfig* config) { config->set_packet_generator(node.type_); node.in_sides_.Visit([&](const TagIndexLocation& loc, const DestinationBase& endpoint) { @@ -540,7 +539,7 @@ class Graph { } // For special boundary node. - mediapipe::Status UpdateBoundaryConfig(CalculatorGraphConfig* config) { + absl::Status UpdateBoundaryConfig(CalculatorGraphConfig* config) { graph_boundary_.in_streams_.Visit( [&](const TagIndexLocation& loc, const DestinationBase& endpoint) { CHECK(endpoint.source != nullptr); diff --git a/mediapipe/framework/api2/contract.h b/mediapipe/framework/api2/contract.h index 4ef6096d1..90e4c38cd 100644 --- a/mediapipe/framework/api2/contract.h +++ b/mediapipe/framework/api2/contract.h @@ -26,7 +26,7 @@ class StreamHandler { const const_str& name() { return name_; } - mediapipe::Status AddToContract(CalculatorContract* cc) const { + absl::Status AddToContract(CalculatorContract* cc) const { cc->SetInputStreamHandler(name_.data()); return {}; } @@ -47,7 +47,7 @@ class TimestampChange { return TimestampChange(kUnset); } - mediapipe::Status AddToContract(CalculatorContract* cc) const { + absl::Status AddToContract(CalculatorContract* cc) const { if (offset_ != kUnset) cc->SetTimestampOffset(offset_); return {}; } @@ -71,10 +71,9 @@ struct HasProcessMethod : std::false_type {}; template struct HasProcessMethod< - T, std::void_t>().Process( - std::declval())))>> - : std::true_type {}; + T, + std::void_t>().Process( + std::declval())))>> : std::true_type {}; template struct HasNestedItems : std::false_type {}; @@ -142,9 +141,9 @@ class Contract { constexpr Contract(T&&... args) : Contract(std::tuple{std::move(args)...}) {} - mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) const { - std::vector statuses; - auto store_status = [&statuses](mediapipe::Status status) { + absl::Status GetContract(mediapipe::CalculatorContract* cc) const { + std::vector statuses; + auto store_status = [&statuses](absl::Status status) { if (!status.ok()) statuses.push_back(std::move(status)); }; internal::tuple_for_each( @@ -209,7 +208,7 @@ class TaggedContract { public: constexpr TaggedContract() = default; - static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) { + static absl::Status GetContract(mediapipe::CalculatorContract* cc) { return c2.GetContract(cc); } @@ -272,34 +271,32 @@ class OutputSender { OutputSender(std::tuple&& args) : outputs_(args) {} template = 0> - mediapipe::Status operator()(CalculatorContext* cc, - mediapipe::StatusOr&& result) { + absl::Status operator()(CalculatorContext* cc, absl::StatusOr&& result) { if (result.ok()) { - return this(cc, result.ValueOrDie()); + return this(cc, result.value()); } else { return result.status(); } } template = 0> - mediapipe::Status operator()(CalculatorContext* cc, R&& result) { + absl::Status operator()(CalculatorContext* cc, R&& result) { std::get<0>(outputs_)(cc).Send(std::forward(result)); return {}; } template - mediapipe::Status operator()(CalculatorContext* cc, - mediapipe::StatusOr>&& result) { + absl::Status operator()(CalculatorContext* cc, + absl::StatusOr>&& result) { if (result.ok()) { - return this(cc, result.ValueOrDie()); + return this(cc, result.value()); } else { return result.status(); } } template - mediapipe::Status operator()(CalculatorContext* cc, - std::tuple&& result) { + absl::Status operator()(CalculatorContext* cc, std::tuple&& result) { static_assert(sizeof...(P) == sizeof...(R), ""); internal::tuple_for_each( [cc, &result](const auto& port, auto i_const) { @@ -345,9 +342,9 @@ class FunCaller { auto inputs() const { return internal::filter_tuple(args_); } auto outputs() const { return internal::filter_tuple(args_); } - mediapipe::Status AddToContract(CalculatorContract* cc) const { return {}; } + absl::Status AddToContract(CalculatorContract* cc) const { return {}; } - mediapipe::Status Process(CalculatorContext* cc) const { return (*this)(cc); } + absl::Status Process(CalculatorContext* cc) const { return (*this)(cc); } constexpr std::tuple nested_items() const { return args_; } @@ -359,16 +356,14 @@ class FunCaller { // TODO: implement multiple callers for syncsets. template -mediapipe::Status ProcessFnCallers(CalculatorContext* cc, - std::tuple callers); +absl::Status ProcessFnCallers(CalculatorContext* cc, std::tuple callers); -inline mediapipe::Status ProcessFnCallers(CalculatorContext* cc, std::tuple<>) { - return mediapipe::InternalError("Process unimplemented"); +inline absl::Status ProcessFnCallers(CalculatorContext* cc, std::tuple<>) { + return absl::InternalError("Process unimplemented"); } template -mediapipe::Status ProcessFnCallers(CalculatorContext* cc, - std::tuple callers) { +absl::Status ProcessFnCallers(CalculatorContext* cc, std::tuple callers) { return std::get<0>(callers).Process(cc); } diff --git a/mediapipe/framework/api2/contract_test.cc b/mediapipe/framework/api2/contract_test.cc index 3c2d164e2..08419187a 100644 --- a/mediapipe/framework/api2/contract_test.cc +++ b/mediapipe/framework/api2/contract_test.cc @@ -10,7 +10,7 @@ namespace api2 { namespace { struct ProcessItem { - mediapipe::Status Process(CalculatorContext* cc) { return {}; } + absl::Status Process(CalculatorContext* cc) { return {}; } }; struct ItemWithNested { diff --git a/mediapipe/framework/api2/node.h b/mediapipe/framework/api2/node.h index e54adf328..b5f7586e7 100644 --- a/mediapipe/framework/api2/node.h +++ b/mediapipe/framework/api2/node.h @@ -34,7 +34,7 @@ class CalculatorBaseFactoryFor< typename std::enable_if{}>::type> : public CalculatorBaseFactory { public: - mediapipe::Status GetContract(CalculatorContract* cc) final { + absl::Status GetContract(CalculatorContract* cc) final { auto status = T::Contract::GetContract(cc); if (status.ok()) { status = UpdateContract(cc); @@ -54,7 +54,7 @@ class CalculatorBaseFactoryFor< return U::UpdateContract(cc); } template - mediapipe::Status UpdateContract(...) { + absl::Status UpdateContract(...) { return {}; } }; @@ -142,7 +142,7 @@ class RegisteredNode : public Node {}; template struct FunctionNode : public RegisteredNode { - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { return internal::ProcessFnCallers(cc, Impl::kContract.process_items()); } }; diff --git a/mediapipe/framework/api2/node_test.cc b/mediapipe/framework/api2/node_test.cc index 22dd8cdea..952a8b010 100644 --- a/mediapipe/framework/api2/node_test.cc +++ b/mediapipe/framework/api2/node_test.cc @@ -12,6 +12,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_matchers.h" namespace mediapipe { @@ -32,7 +33,7 @@ std::vector PacketValues(const std::vector& packets) { class FooImpl : public NodeImpl { public: - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { float bias = kBias(cc).GetOr(0.0); float scale = kScale(cc).GetOr(1.0); kOut(cc).Send(*kBase(cc) * scale + bias); @@ -80,7 +81,7 @@ class Foo5 : public FunctionNode { class Foo2Impl : public NodeImpl { public: - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { float bias = SideIn(MPP_TAG("BIAS"), cc).GetOr(0.0); float scale = In(MPP_TAG("SCALE"), cc).GetOr(1.0); Out(MPP_TAG("OUT"), cc).Send(*In(MPP_TAG("BASE"), cc) * scale + bias); @@ -90,7 +91,7 @@ class Foo2Impl : public NodeImpl { class BarImpl : public NodeImpl { public: - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { Packet p = kIn(cc); kOut(cc).Send(p); return {}; @@ -99,9 +100,9 @@ class BarImpl : public NodeImpl { class BazImpl : public NodeImpl { public: - static mediapipe::Status UpdateContract(CalculatorContract* cc) { return {}; } + static absl::Status UpdateContract(CalculatorContract* cc) { return {}; } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { for (int i = 0; i < kData(cc).Count(); ++i) { kDataOut(cc)[i].Send(kData(cc)[i]); } @@ -112,7 +113,7 @@ MEDIAPIPE_NODE_IMPLEMENTATION(BazImpl); class IntForwarderImpl : public NodeImpl { public: - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { kOut(cc).Send(*kIn(cc)); return {}; } @@ -120,7 +121,7 @@ class IntForwarderImpl : public NodeImpl { class ToFloatImpl : public NodeImpl { public: - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { kIn(cc).Visit([cc](auto x) { kOut(cc).Send(x); }); return {}; } @@ -315,7 +316,7 @@ struct SideFallback : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kFactor, kOut); - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { kOut(cc).Send(kIn(cc).Get() * kFactor(cc).Get()); return {}; } @@ -341,7 +342,7 @@ TEST(NodeTest, SideFallbackWithStream) { MP_EXPECT_OK( graph.ObserveOutputStream("out", [&outputs](const mediapipe::Packet& p) { outputs.push_back(p.Get()); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_EXPECT_OK(graph.StartRun({})); MP_EXPECT_OK(graph.AddPacketToInputStream( @@ -372,7 +373,7 @@ TEST(NodeTest, SideFallbackWithSide) { MP_EXPECT_OK( graph.ObserveOutputStream("out", [&outputs](const mediapipe::Packet& p) { outputs.push_back(p.Get()); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_EXPECT_OK(graph.StartRun({{"factor", mediapipe::MakePacket(2)}})); MP_EXPECT_OK(graph.AddPacketToInputStream( @@ -451,7 +452,7 @@ struct DropEvenTimestamps : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kOut); - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->InputTimestamp().Value() % 2) { kOut(cc).Send(kIn(cc)); } @@ -466,7 +467,7 @@ struct ListIntPackets : public Node { MEDIAPIPE_NODE_CONTRACT(kIn, kOut); - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { std::string result = absl::StrCat(cc->InputTimestamp().DebugString(), ":"); for (int i = 0; i < kIn(cc).Count(); ++i) { if (kIn(cc)[i].IsEmpty()) { @@ -522,6 +523,48 @@ TEST(NodeTest, DefaultTimestampChange0) { MP_EXPECT_OK(graph.WaitUntilDone()); } +struct ConsumerNode : public Node { + static constexpr Input kInt{"INT"}; + static constexpr Input kGeneric{"ANY"}; + static constexpr Input> kOneOf{"NUM"}; + + MEDIAPIPE_NODE_CONTRACT(kInt, kGeneric, kOneOf); + + absl::Status Process(CalculatorContext* cc) override { + ASSIGN_OR_RETURN(auto maybe_int, kInt(cc).Consume()); + ASSIGN_OR_RETURN(auto maybe_float, kGeneric(cc).Consume()); + ASSIGN_OR_RETURN(auto maybe_int2, kOneOf(cc).Consume()); + return {}; + } +}; +MEDIAPIPE_REGISTER_NODE(ConsumerNode); + +TEST(NodeTest, ConsumeInputs) { + CalculatorGraphConfig config = + mediapipe::ParseTextProtoOrDie(R"( + input_stream: "int" + input_stream: "any" + input_stream: "num" + node { + calculator: "ConsumerNode" + input_stream: "INT:int" + input_stream: "ANY:any" + input_stream: "NUM:num" + } + )"); + mediapipe::CalculatorGraph graph; + MP_EXPECT_OK(graph.Initialize(config, {})); + MP_EXPECT_OK(graph.StartRun({})); + MP_EXPECT_OK(graph.AddPacketToInputStream( + "int", mediapipe::MakePacket(10).At(Timestamp(0)))); + MP_EXPECT_OK(graph.AddPacketToInputStream( + "any", mediapipe::MakePacket(10).At(Timestamp(0)))); + MP_EXPECT_OK(graph.AddPacketToInputStream( + "num", mediapipe::MakePacket(10).At(Timestamp(0)))); + MP_EXPECT_OK(graph.CloseAllPacketSources()); + MP_EXPECT_OK(graph.WaitUntilDone()); +} + } // namespace test } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/api2/packet.cc b/mediapipe/framework/api2/packet.cc index f3ac6c884..1a2f3774a 100644 --- a/mediapipe/framework/api2/packet.cc +++ b/mediapipe/framework/api2/packet.cc @@ -7,9 +7,19 @@ PacketBase FromOldPacket(const mediapipe::Packet& op) { return PacketBase(packet_internal::GetHolderShared(op)).At(op.Timestamp()); } +PacketBase FromOldPacket(mediapipe::Packet&& op) { + Timestamp t = op.Timestamp(); + return PacketBase(packet_internal::GetHolderShared(std::move(op))).At(t); +} + mediapipe::Packet ToOldPacket(const PacketBase& p) { return mediapipe::packet_internal::Create(p.payload_, p.timestamp_); } +mediapipe::Packet ToOldPacket(PacketBase&& p) { + return mediapipe::packet_internal::Create(std::move(p.payload_), + p.timestamp_); +} + } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/api2/packet.h b/mediapipe/framework/api2/packet.h index 097020372..426d4701d 100644 --- a/mediapipe/framework/api2/packet.h +++ b/mediapipe/framework/api2/packet.h @@ -13,6 +13,7 @@ #include #include +#include "absl/meta/type_traits.h" #include "mediapipe/framework/api2/tuple.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/logging.h" @@ -58,7 +59,22 @@ class PacketBase { const T& Get() const; // Conversion to old Packet type. - operator mediapipe::Packet() const { return ToOldPacket(*this); } + operator mediapipe::Packet() const& { return ToOldPacket(*this); } + operator mediapipe::Packet() && { return ToOldPacket(std::move(*this)); } + + // Note: Consume is included for compatibility with the old Packet; however, + // it relies on shared_ptr.unique(), which is deprecated and is not guaranteed + // to give exact results. + template + absl::StatusOr> Consume() { + // Using the implementation in the old Packet for now. + mediapipe::Packet old = + packet_internal::Create(std::move(payload_), timestamp_); + auto result = old.Consume(); + if (!result.ok()) + payload_ = packet_internal::GetHolderShared(std::move(old)); + return result; + } protected: explicit PacketBase(std::shared_ptr payload) @@ -70,11 +86,15 @@ class PacketBase { template friend PacketBase PacketBaseAdopting(const T* ptr); friend PacketBase FromOldPacket(const mediapipe::Packet& op); + friend PacketBase FromOldPacket(mediapipe::Packet&& op); friend mediapipe::Packet ToOldPacket(const PacketBase& p); + friend mediapipe::Packet ToOldPacket(PacketBase&& p); }; PacketBase FromOldPacket(const mediapipe::Packet& op); +PacketBase FromOldPacket(mediapipe::Packet&& op); mediapipe::Packet ToOldPacket(const PacketBase& p); +mediapipe::Packet ToOldPacket(PacketBase&& p); template inline const T& PacketBase::Get() const { @@ -132,6 +152,16 @@ struct Generic { Generic() = delete; }; +template +struct IsCompatibleType : std::false_type {}; +template +struct IsCompatibleType : std::true_type {}; +template +struct IsCompatibleType : std::true_type {}; +template +struct IsCompatibleType> + : std::integral_constant || ...)> {}; + }; // namespace internal template @@ -191,6 +221,13 @@ class Packet : public Packet { return IsEmpty() ? static_cast(absl::forward(v)) : **this; } + // Note: Consume is included for compatibility with the old Packet; however, + // it relies on shared_ptr.unique(), which is deprecated and is not guaranteed + // to give exact results. + absl::StatusOr> Consume() { + return PacketBase::Consume(); + } + private: explicit Packet(std::shared_ptr payload) : Packet(std::move(payload)) {} @@ -216,6 +253,44 @@ template struct First { using type = T; }; + +template +struct AddStatus { + using type = StatusOr; +}; +template +struct AddStatus> { + using type = StatusOr; +}; +template <> +struct AddStatus { + using type = Status; +}; +template <> +struct AddStatus { + using type = Status; +}; + +template +struct CallAndAddStatusImpl { + typename AddStatus::type operator()(const F& f, A&&... a) { + return f(std::forward(a)...); + } +}; +template +struct CallAndAddStatusImpl { + Status operator()(const F& f, A&&... a) { + f(std::forward(a)...); + return {}; + } +}; + +template +auto CallAndAddStatus(const F& f, A&&... a) { + return CallAndAddStatusImpl, F, A...>()( + f, std::forward(a)...); +} + } // namespace internal template @@ -276,6 +351,30 @@ class Packet> : public PacketBase { return Invoke(f); } + // Note: Consume is included for compatibility with the old Packet; however, + // it relies on shared_ptr.unique(), which is deprecated and is not guaranteed + // to give exact results. + template > + absl::StatusOr> Consume() { + return PacketBase::Consume(); + } + + template + auto ConsumeAndVisit(const F&... args) { + CHECK(payload_); + auto f = internal::Overload{args...}; + using FirstT = typename internal::First::type; + using VisitorResultType = + absl::result_of_t)>; + static_assert( + (std::is_same_v)>> && + ...), + "All visitor overloads must have the same return type"); + using ResultType = typename internal::AddStatus::type; + return InvokeConsuming(f); + } + protected: explicit Packet(std::shared_ptr payload) : PacketBase(std::move(payload)) {} @@ -292,6 +391,21 @@ class Packet> : public PacketBase { auto Invoke(const F& f) const { return Has() ? f(Get()) : Invoke(f); } + + template + auto InvokeConsuming(const F& f) -> R { + auto maybe_value = Consume(); + if (maybe_value.ok()) + return internal::CallAndAddStatus(f, std::move(maybe_value).value()); + else + return maybe_value.status(); + } + + template + auto InvokeConsuming(const F& f) -> R { + return Has() ? InvokeConsuming(f) + : InvokeConsuming(f); + } }; template <> diff --git a/mediapipe/framework/api2/packet_test.cc b/mediapipe/framework/api2/packet_test.cc index 833f67251..6d8fd0015 100644 --- a/mediapipe/framework/api2/packet_test.cc +++ b/mediapipe/framework/api2/packet_test.cc @@ -1,7 +1,9 @@ #include "mediapipe/framework/api2/packet.h" #include "absl/strings/str_cat.h" +#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" namespace mediapipe { namespace api2 { @@ -168,12 +170,26 @@ TEST(PacketTest, FromOldPacket) { mediapipe::Packet op = mediapipe::MakePacket(7); Packet p = FromOldPacket(op).As(); EXPECT_EQ(p.Get(), 7); + EXPECT_EQ(op.Get(), 7); +} + +TEST(PacketTest, FromOldPacketConsume) { + mediapipe::Packet op = mediapipe::MakePacket(7); + Packet p = FromOldPacket(std::move(op)).As(); + MP_EXPECT_OK(p.Consume()); } TEST(PacketTest, ToOldPacket) { auto p = MakePacket(7); mediapipe::Packet op = ToOldPacket(p); EXPECT_EQ(op.Get(), 7); + EXPECT_EQ(p.Get(), 7); +} + +TEST(PacketTest, ToOldPacketConsume) { + auto p = MakePacket(7); + mediapipe::Packet op = ToOldPacket(std::move(p)); + MP_EXPECT_OK(op.Consume()); } TEST(PacketTest, OldRefCounting) { @@ -190,6 +206,42 @@ TEST(PacketTest, OldRefCounting) { EXPECT_FALSE(alive); } +TEST(PacketTest, Consume) { + auto p = MakePacket(7); + auto maybe_int = p.Consume(); + EXPECT_TRUE(p.IsEmpty()); + ASSERT_TRUE(maybe_int.ok()); + EXPECT_EQ(*maybe_int.value(), 7); + + p = MakePacket(3); + auto p2 = p; + maybe_int = p.Consume(); + EXPECT_FALSE(maybe_int.ok()); + EXPECT_FALSE(p.IsEmpty()); + EXPECT_FALSE(p2.IsEmpty()); +} + +TEST(PacketTest, OneOfConsume) { + Packet> p = MakePacket("hi"); + EXPECT_TRUE(p.Has()); + EXPECT_FALSE(p.Has()); + EXPECT_EQ(p.Get(), "hi"); + absl::StatusOr out = p.ConsumeAndVisit( + [](std::unique_ptr s) { + return absl::StrCat("string: ", *s); + }, + [](std::unique_ptr i) { return absl::StrCat("int: ", *i); }); + MP_EXPECT_OK(out); + EXPECT_EQ(out.value(), "string: hi"); + EXPECT_TRUE(p.IsEmpty()); + + p = MakePacket(3); + absl::Status out2 = p.ConsumeAndVisit([](std::unique_ptr s) {}, + [](std::unique_ptr i) {}); + MP_EXPECT_OK(out2); + EXPECT_TRUE(p.IsEmpty()); +} + } // namespace } // namespace api2 } // namespace mediapipe diff --git a/mediapipe/framework/api2/port.h b/mediapipe/framework/api2/port.h index ee41a76c4..cab39abdb 100644 --- a/mediapipe/framework/api2/port.h +++ b/mediapipe/framework/api2/port.h @@ -179,7 +179,7 @@ inline void SetType(CalculatorContract* cc, PacketType& pt) { template InputShardAccess SinglePortAccess(mediapipe::CalculatorContext* cc, - const InputStreamShard* stream) { + InputStreamShard* stream) { return InputShardAccess(*cc, stream); } @@ -203,7 +203,7 @@ OutputSidePacketAccess SinglePortAccess( template InputShardOrSideAccess SinglePortAccess( - mediapipe::CalculatorContext* cc, const InputStreamShard* stream, + mediapipe::CalculatorContext* cc, InputStreamShard* stream, const mediapipe::Packet* packet) { return InputShardOrSideAccess(*cc, stream, packet); } @@ -226,19 +226,50 @@ auto AccessPort(std::false_type, const PortT& port, CC* cc) { template class MultiplePortAccess { public: + using AccessT = decltype(SinglePortAccess(std::declval(), + std::declval())); + MultiplePortAccess(CC* cc, X* first, int count) : cc_(cc), first_(first), count_(count) {} // TODO: maybe this should be size(), like in a standard C++ // container? int Count() { return count_; } - auto operator[](int pos) { + AccessT operator[](int pos) { CHECK_GE(pos, 0); CHECK_LT(pos, count_); return SinglePortAccess(cc_, &first_[pos]); } - // TODO: add begin/end. + class Iterator { + public: + using iterator_category = std::input_iterator_tag; + using value_type = AccessT; + using difference_type = std::ptrdiff_t; + using pointer = AccessT*; + using reference = AccessT; // allowed; see e.g. std::istreambuf_iterator + + Iterator(CC* cc, X* p) : cc_(cc), p_(p) {} + Iterator& operator++() { + ++p_; + return *this; + } + Iterator operator++(int) { + Iterator res = *this; + ++(*this); + return res; + } + bool operator==(const Iterator& other) const { return p_ == other.p_; } + bool operator!=(const Iterator& other) const { return !(*this == other); } + AccessT operator*() const { return SinglePortAccess(cc_, p_); } + + private: + CC* cc_; + X* p_; + }; + + Iterator begin() { return Iterator(cc_, first_); } + Iterator end() { return Iterator(cc_, first_ + count_); } private: CC* cc_; @@ -307,7 +338,7 @@ class PortCommon : public Base { } private: - mediapipe::Status AddToContract(CalculatorContract* cc) const { + absl::Status AddToContract(CalculatorContract* cc) const { if (kMultiple) { AddMultiple(cc); } else { @@ -385,17 +416,17 @@ class SideFallbackT : public Base { side_port(tag) {} protected: - mediapipe::Status AddToContract(CalculatorContract* cc) const { + absl::Status AddToContract(CalculatorContract* cc) const { stream_port.AddToContract(cc); side_port.AddToContract(cc); int connected_count = stream_port(cc).IsConnected() + side_port(cc).IsConnected(); if (connected_count > 1) - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( Tag(), " can be connected as a stream or as a side packet, but not both")); if (!IsOptionalV && connected_count == 0) - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat(Tag(), " must be connected")); return {}; } @@ -452,6 +483,14 @@ class OutputShardAccess : public OutputShardAccessBase { void Send(const T& payload) { Send(payload, context_.InputTimestamp()); } + void Send(T&& payload, Timestamp time) { + Send(api2::MakePacket(std::move(payload)).At(time)); + } + + void Send(T&& payload) { + Send(std::move(payload), context_.InputTimestamp()); + } + void Send(std::unique_ptr payload, Timestamp time) { Send(api2::PacketAdopting(std::move(payload)).At(time)); } @@ -501,6 +540,7 @@ class OutputSidePacketAccess { } void Set(const T& payload) { Set(MakePacket(payload)); } + void Set(T&& payload) { Set(MakePacket(std::move(payload))); } private: OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {} @@ -523,15 +563,54 @@ class InputShardAccess : public Packet { PacketBase Header() const { return FromOldPacket(stream_->Header()); } + // "Consume" requires exclusive ownership of the packet's payload. In the + // current interim implementation, InputShardAccess creates a new reference to + // the payload (as a Packet instead of a type-erased Packet), which means + // the conditions for Consume would never be satisfied. This helper class + // defines wrappers for the Consume methods in Packet which temporarily erase + // the reference held by the underlying InputStreamShard. + // Note that we cannot simply take over the reference when InputShardAccess is + // created, because it is currently created as a temporary and we might create + // more than one instance for the same stream. + template {}, + decltype(&Packet::Consume)>> + absl::StatusOr> Consume() { + return WrapConsumeCall(&Packet::Consume); + } + + template {}, int> = 0> + absl::StatusOr> Consume() { + return WrapConsumeCall(&Packet::template Consume); + } + + template + auto ConsumeAndVisit(F&&... args) { + auto f = &Packet::template ConsumeAndVisit; + return WrapConsumeCall(f, std::forward(args)...); + } + private: - InputShardAccess(const CalculatorContext&, const InputStreamShard* stream) + InputShardAccess(const CalculatorContext&, InputStreamShard* stream) : Packet(stream ? FromOldPacket(stream->Value()).template As() : Packet()), stream_(stream) {} - const InputStreamShard* stream_; + + template + auto WrapConsumeCall(F f, A&&... args) { + stream_->Value() = {}; + auto result = (this->*f)(std::forward(args)...); + if (!result.ok()) { + stream_->Value() = ToOldPacket(*this); + } + return result; + } + + InputStreamShard* stream_; friend InputShardAccess internal::SinglePortAccess( - mediapipe::CalculatorContext*, const InputStreamShard*); + mediapipe::CalculatorContext*, InputStreamShard*); }; template @@ -566,19 +645,18 @@ class InputShardOrSideAccess : public Packet { PacketBase Header() const { return FromOldPacket(stream_->Header()); } private: - InputShardOrSideAccess(const CalculatorContext&, - const InputStreamShard* stream, + InputShardOrSideAccess(const CalculatorContext&, InputStreamShard* stream, const mediapipe::Packet* packet) : Packet(stream ? FromOldPacket(stream->Value()).template As() : packet ? FromOldPacket(*packet).template As() : Packet()), stream_(stream), connected_(stream_ != nullptr || packet != nullptr) {} - const InputStreamShard* stream_; + InputStreamShard* stream_; bool connected_; friend InputShardOrSideAccess internal::SinglePortAccess( - mediapipe::CalculatorContext*, const InputStreamShard*, + mediapipe::CalculatorContext*, InputStreamShard*, const mediapipe::Packet*); }; diff --git a/mediapipe/framework/api2/subgraph_test.cc b/mediapipe/framework/api2/subgraph_test.cc index 7e663fce0..a56ba9fe4 100644 --- a/mediapipe/framework/api2/subgraph_test.cc +++ b/mediapipe/framework/api2/subgraph_test.cc @@ -17,7 +17,7 @@ namespace test { class FooBarImpl1 : public SubgraphImpl { public: - mediapipe::StatusOr GetConfig( + absl::StatusOr GetConfig( const SubgraphOptions& /*options*/) { builder::Graph graph; auto& foo = graph.AddNode("Foo"); @@ -31,7 +31,7 @@ class FooBarImpl1 : public SubgraphImpl { class FooBarImpl2 : public SubgraphImpl { public: - mediapipe::StatusOr GetConfig( + absl::StatusOr GetConfig( const SubgraphOptions& /*options*/) { builder::Graph graph; auto& foo = graph.AddNode(); @@ -44,7 +44,7 @@ class FooBarImpl2 : public SubgraphImpl { }; TEST(SubgraphTest, SubgraphConfig) { - CalculatorGraphConfig subgraph = FooBarImpl1().GetConfig({}).ValueOrDie(); + CalculatorGraphConfig subgraph = FooBarImpl1().GetConfig({}).value(); const CalculatorGraphConfig expected_graph = mediapipe::ParseTextProtoOrDie(R"( input_stream: "IN:__stream_0" @@ -64,7 +64,7 @@ TEST(SubgraphTest, SubgraphConfig) { } TEST(SubgraphTest, TypedSubgraphConfig) { - CalculatorGraphConfig subgraph = FooBarImpl2().GetConfig({}).ValueOrDie(); + CalculatorGraphConfig subgraph = FooBarImpl2().GetConfig({}).value(); const CalculatorGraphConfig expected_graph = mediapipe::ParseTextProtoOrDie(R"( input_stream: "IN:__stream_0" diff --git a/mediapipe/framework/basic_types_registration.cc b/mediapipe/framework/basic_types_registration.cc index a16daac72..9f7966ef8 100644 --- a/mediapipe/framework/basic_types_registration.cc +++ b/mediapipe/framework/basic_types_registration.cc @@ -9,6 +9,11 @@ mediapipe::type_map_internal::ReflectType::Type, #type, \ nullptr, nullptr) +#define MEDIAPIPE_REGISTER_GENERIC_TYPE_WITH_NAME(type, name) \ + MEDIAPIPE_REGISTER_TYPE( \ + mediapipe::type_map_internal::ReflectType::Type, name, \ + nullptr, nullptr) + // Note: we cannot define a type which type hash id is already in the map. // E.g. if tool::GetTypeHash() == tool::GetTypeHash(), then only one // can be registered. @@ -26,3 +31,4 @@ MEDIAPIPE_REGISTER_GENERIC_TYPE(::std::vector); MEDIAPIPE_REGISTER_GENERIC_TYPE(::std::vector); MEDIAPIPE_REGISTER_GENERIC_TYPE(::std::vector); MEDIAPIPE_REGISTER_GENERIC_TYPE(::std::vector<::std::vector>); +MEDIAPIPE_REGISTER_GENERIC_TYPE_WITH_NAME(::std::string, "string"); diff --git a/mediapipe/framework/calculator.proto b/mediapipe/framework/calculator.proto index 503c7d559..3c5ac66e5 100644 --- a/mediapipe/framework/calculator.proto +++ b/mediapipe/framework/calculator.proto @@ -351,7 +351,7 @@ message CalculatorGraphConfig { int32 num_threads = 8; // Configs for StatusHandlers that will be called after each call to // Run() on the graph. StatusHandlers take zero or more input side - // packets and the ::util::Status returned by a graph run. For example, + // packets and the absl::Status returned by a graph run. For example, // a StatusHandler could store information about graph failures and // their causes for later monitoring. Note that graph failures during // initialization may cause required input side packets (created by a diff --git a/mediapipe/framework/calculator_base.h b/mediapipe/framework/calculator_base.h index 1ac0a0184..f9f0d7a8a 100644 --- a/mediapipe/framework/calculator_base.h +++ b/mediapipe/framework/calculator_base.h @@ -82,7 +82,7 @@ class CalculatorBase { // this function is static the registration macro provides access to // each subclass' GetContract function. // - // static mediapipe::Status GetContract(CalculatorContract* cc); + // static absl::Status GetContract(CalculatorContract* cc); // // GetContract fills in the calculator's contract with the framework, such // as its expectations of what packets it will receive. When this function @@ -116,23 +116,21 @@ class CalculatorBase { // Open is called before any Process() calls, on a freshly constructed // calculator. Subclasses may override this method to perform necessary // setup, and possibly output Packets and/or set output streams' headers. - // Must return mediapipe::OkStatus() to indicate success. On failure any + // Must return absl::OkStatus() to indicate success. On failure any // other status code can be returned. If failure is returned then the // framework will call neither Process() nor Close() on the calculator (so any // necessary cleanup should be done before returning failure or in the // destructor). - virtual mediapipe::Status Open(CalculatorContext* cc) { - return mediapipe::OkStatus(); - } + virtual absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); } // Processes the incoming inputs. May call the methods on cc to access // inputs and produce outputs. // // Process() called on a non-source node must return - // mediapipe::OkStatus() to indicate that all went well, or any other + // absl::OkStatus() to indicate that all went well, or any other // status code to signal an error. // For example: - // mediapipe::UnknownError("Failure Message"); + // absl::UnknownError("Failure Message"); // Notice the convenience functions in util/task/canonical_errors.h . // If a non-source Calculator returns tool::StatusStop(), then this // signals the graph is being cancelled early. In this case, all @@ -140,23 +138,21 @@ class CalculatorBase { // remaining Packets will propagate through the graph). // // A source node will continue to have Process() called on it as long - // as it returns mediapipe::OkStatus(). To indicate that there is + // as it returns absl::OkStatus(). To indicate that there is // no more data to be generated return tool::StatusStop(). Any other // status indicates an error has occurred. - virtual mediapipe::Status Process(CalculatorContext* cc) = 0; + virtual absl::Status Process(CalculatorContext* cc) = 0; // Is called if Open() was called and succeeded. Is called either // immediately after processing is complete or after a graph run has ended - // (if an error occurred in the graph). Must return mediapipe::OkStatus() + // (if an error occurred in the graph). Must return absl::OkStatus() // to indicate success. On failure any other status code can be returned. // Packets may be output during a call to Close(). However, output packets // are silently discarded if Close() is called after a graph run has ended. // // NOTE: If Close() needs to perform an action only when processing is // complete, Close() must check if cc->GraphStatus() is OK. - virtual mediapipe::Status Close(CalculatorContext* cc) { - return mediapipe::OkStatus(); - } + virtual absl::Status Close(CalculatorContext* cc) { return absl::OkStatus(); } // Returns a value according to which the framework selects // the next source calculator to Process(); smaller value means @@ -180,7 +176,7 @@ namespace internal { class CalculatorBaseFactory { public: virtual ~CalculatorBaseFactory() {} - virtual mediapipe::Status GetContract(CalculatorContract* cc) = 0; + virtual absl::Status GetContract(CalculatorContract* cc) = 0; virtual std::unique_ptr CreateCalculator( CalculatorContext* calculator_context) = 0; virtual std::string ContractMethodName() { return "GetContract"; } @@ -189,7 +185,7 @@ class CalculatorBaseFactory { // Functions for checking that the calculator has the required GetContract. template constexpr bool CalculatorHasGetContract(decltype(&T::GetContract) /*unused*/) { - typedef mediapipe::Status (*GetContractType)(CalculatorContract * cc); + typedef absl::Status (*GetContractType)(CalculatorContract * cc); return std::is_same::value; } template @@ -219,7 +215,7 @@ class CalculatorBaseFactoryFor< // Provides access to the static function GetContract within a specific // subclass of CalculatorBase. - mediapipe::Status GetContract(CalculatorContract* cc) final { + absl::Status GetContract(CalculatorContract* cc) final { // CalculatorBaseSubclass must implement this function, since it is not // implemented in the parent class. return T::GetContract(cc); diff --git a/mediapipe/framework/calculator_base_test.cc b/mediapipe/framework/calculator_base_test.cc index b9f5b9f5d..4d3891818 100644 --- a/mediapipe/framework/calculator_base_test.cc +++ b/mediapipe/framework/calculator_base_test.cc @@ -41,7 +41,7 @@ namespace test_ns { // streams and input side packets. class DeadEndCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetAny(); } @@ -51,16 +51,14 @@ class DeadEndCalculator : public CalculatorBase { for (int i = 0; i < cc->InputSidePackets().NumEntries(); ++i) { cc->InputSidePackets().Index(i).SetAny(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->Inputs().NumEntries() > 0) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { // This is a source calculator, but we don't produce any outputs. return tool::StatusStop(); @@ -73,14 +71,12 @@ namespace whitelisted_ns { class DeadCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { - return mediapipe::OkStatus(); + static absl::Status GetContract(CalculatorContract* cc) { + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { - return mediapipe::OkStatus(); - } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); } + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; @@ -89,14 +85,12 @@ class DeadCalculator : public CalculatorBase { class EndCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { - return mediapipe::OkStatus(); + static absl::Status GetContract(CalculatorContract* cc) { + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { - return mediapipe::OkStatus(); - } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); } + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; REGISTER_CALCULATOR(::mediapipe::EndCalculator); @@ -105,7 +99,7 @@ namespace { TEST(CalculatorTest, SourceProcessOrder) { internal::Collection output_stream_managers( - tool::CreateTagMap(2).ValueOrDie()); + tool::CreateTagMap(2).value()); PacketType output0_type; PacketType output1_type; @@ -117,7 +111,7 @@ TEST(CalculatorTest, SourceProcessOrder) { MP_ASSERT_OK( output_stream_managers.Index(1).Initialize("output1", &output1_type)); - PacketSet input_side_packets(tool::CreateTagMap({}).ValueOrDie()); + PacketSet input_side_packets(tool::CreateTagMap({}).value()); CalculatorState calculator_state("Node", /*node_id=*/0, "Calculator", CalculatorGraphConfig::Node(), nullptr); @@ -126,7 +120,7 @@ TEST(CalculatorTest, SourceProcessOrder) { CalculatorContextManager calculator_context_manager; CalculatorContext calculator_context(&calculator_state, - tool::CreateTagMap({}).ValueOrDie(), + tool::CreateTagMap({}).value(), output_stream_managers.TagMap()); OutputStreamShardSet& output_set = calculator_context.Outputs(); output_set.Index(0).SetSpec(output_stream_managers.Index(0).Spec()); @@ -167,13 +161,13 @@ TEST(CalculatorTest, CreateByName) { "mediapipe", "DeadEndCalculator") .status() .code(), - mediapipe::StatusCode::kNotFound); + absl::StatusCode::kNotFound); EXPECT_EQ(CalculatorBaseRegistry::CreateByName( // "DeadEndCalculator") .status() .code(), - mediapipe::StatusCode::kNotFound); + absl::StatusCode::kNotFound); } // Tests registration of a calculator within a whitelisted namespace. diff --git a/mediapipe/framework/calculator_context.cc b/mediapipe/framework/calculator_context.cc index 31e4d2da8..0d1c05b1e 100644 --- a/mediapipe/framework/calculator_context.cc +++ b/mediapipe/framework/calculator_context.cc @@ -41,6 +41,11 @@ Counter* CalculatorContext::GetCounter(const std::string& name) { return calculator_state_->GetCounter(name); } +CounterSet* CalculatorContext::GetCounterSet() { + CHECK(calculator_state_); + return calculator_state_->GetCounterSet(); +} + const PacketSet& CalculatorContext::InputSidePackets() const { return calculator_state_->InputSidePackets(); } diff --git a/mediapipe/framework/calculator_context.h b/mediapipe/framework/calculator_context.h index 14d35550a..e73dd66ce 100644 --- a/mediapipe/framework/calculator_context.h +++ b/mediapipe/framework/calculator_context.h @@ -74,6 +74,10 @@ class CalculatorContext { // the calculator's type (if not). Counter* GetCounter(const std::string& name); + // Returns the counter set, which can be used to create new counters. + // No prefix is added to counters created in this way. + CounterSet* GetCounterSet(); + // Returns the current input timestamp, or Timestamp::Unset if there are // no input packets. Timestamp InputTimestamp() const { @@ -103,7 +107,7 @@ class CalculatorContext { // Returns the status of the graph run. // // NOTE: This method should only be called during CalculatorBase::Close(). - mediapipe::Status GraphStatus() const { return graph_status_; } + absl::Status GraphStatus() const { return graph_status_; } ProfilingContext* GetProfilingContext() const { return calculator_state_->GetSharedProfilingContext().get(); @@ -148,9 +152,7 @@ class CalculatorContext { input_timestamps_.pop(); } - void SetGraphStatus(const mediapipe::Status& status) { - graph_status_ = status; - } + void SetGraphStatus(const absl::Status& status) { graph_status_ = status; } // Interface for the friend class Calculator. const InputStreamSet& InputStreams() const; @@ -171,7 +173,7 @@ class CalculatorContext { std::queue input_timestamps_; // The status of the graph run. Only used when Close() is called. - mediapipe::Status graph_status_; + absl::Status graph_status_; // Accesses CalculatorContext for setting input timestamp. friend class CalculatorContextManager; diff --git a/mediapipe/framework/calculator_context_manager.cc b/mediapipe/framework/calculator_context_manager.cc index 271976628..acd70dd94 100644 --- a/mediapipe/framework/calculator_context_manager.cc +++ b/mediapipe/framework/calculator_context_manager.cc @@ -34,9 +34,8 @@ void CalculatorContextManager::Initialize( calculator_run_in_parallel_ = calculator_run_in_parallel; } -mediapipe::Status CalculatorContextManager::PrepareForRun( - std::function - setup_shards_callback) { +absl::Status CalculatorContextManager::PrepareForRun( + std::function setup_shards_callback) { setup_shards_callback_ = std::move(setup_shards_callback); default_context_ = absl::make_unique( calculator_state_, input_tag_map_, output_tag_map_); diff --git a/mediapipe/framework/calculator_context_manager.h b/mediapipe/framework/calculator_context_manager.h index 14e49e9bf..6b988b03d 100644 --- a/mediapipe/framework/calculator_context_manager.h +++ b/mediapipe/framework/calculator_context_manager.h @@ -45,9 +45,8 @@ class CalculatorContextManager { // Sets the callback that can setup the input and output stream shards in a // newly constructed calculator context. Then, initializes the default // calculator context. - mediapipe::Status PrepareForRun( - std::function - setup_shards_callback); + absl::Status PrepareForRun( + std::function setup_shards_callback); // Invoked by CalculatorNode::CleanupAfterRun(). void CleanupAfterRun() ABSL_LOCKS_EXCLUDED(contexts_mutex_); @@ -108,7 +107,7 @@ class CalculatorContextManager { } void SetGraphStatusInContext(CalculatorContext* calculator_context, - const mediapipe::Status& status) { + const absl::Status& status) { CHECK(calculator_context); calculator_context->SetGraphStatus(status); } @@ -124,7 +123,7 @@ class CalculatorContextManager { // NOTE: This callback invokes input/output stream handler methods. // The callback is used to break the circular dependency between // calculator context manager and input/output stream handlers. - std::function setup_shards_callback_; + std::function setup_shards_callback_; // The default calculator context that is always reused for sequential // execution. It is also used by Open() and Close() method of a parallel diff --git a/mediapipe/framework/calculator_context_test.cc b/mediapipe/framework/calculator_context_test.cc index 044e10310..e7612501a 100644 --- a/mediapipe/framework/calculator_context_test.cc +++ b/mediapipe/framework/calculator_context_test.cc @@ -99,9 +99,9 @@ std::unique_ptr MakeCalculatorState( std::unique_ptr MakeCalculatorContext( CalculatorState* calculator_state) { - return absl::make_unique( - calculator_state, tool::CreateTagMap({}).ValueOrDie(), - tool::CreateTagMap({}).ValueOrDie()); + return absl::make_unique(calculator_state, + tool::CreateTagMap({}).value(), + tool::CreateTagMap({}).value()); } TEST(CalculatorTest, NodeId) { diff --git a/mediapipe/framework/calculator_contract.cc b/mediapipe/framework/calculator_contract.cc index 503d47106..cff9fcd84 100644 --- a/mediapipe/framework/calculator_contract.cc +++ b/mediapipe/framework/calculator_contract.cc @@ -24,9 +24,9 @@ namespace mediapipe { -mediapipe::Status CalculatorContract::Initialize( +absl::Status CalculatorContract::Initialize( const CalculatorGraphConfig::Node& node) { - std::vector statuses; + std::vector statuses; auto input_stream_statusor = tool::TagMap::Create(node.input_stream()); if (!input_stream_statusor.ok()) { @@ -64,19 +64,18 @@ mediapipe::Status CalculatorContract::Initialize( options_.Initialize(*node_config_); // Create the PacketTypeSets. inputs_ = absl::make_unique( - std::move(input_stream_statusor).ValueOrDie()); + std::move(input_stream_statusor).value()); outputs_ = absl::make_unique( - std::move(output_stream_statusor).ValueOrDie()); + std::move(output_stream_statusor).value()); input_side_packets_ = absl::make_unique( - std::move(input_side_packet_statusor).ValueOrDie()); + std::move(input_side_packet_statusor).value()); output_side_packets_ = absl::make_unique( - std::move(output_side_packet_statusor).ValueOrDie()); - return mediapipe::OkStatus(); + std::move(output_side_packet_statusor).value()); + return absl::OkStatus(); } -mediapipe::Status CalculatorContract::Initialize( - const PacketGeneratorConfig& node) { - std::vector statuses; +absl::Status CalculatorContract::Initialize(const PacketGeneratorConfig& node) { + std::vector statuses; auto input_side_packet_statusor = tool::TagMap::Create(node.input_side_packet()); @@ -103,15 +102,14 @@ mediapipe::Status CalculatorContract::Initialize( } input_side_packets_ = absl::make_unique( - std::move(input_side_packet_statusor).ValueOrDie()); + std::move(input_side_packet_statusor).value()); output_side_packets_ = absl::make_unique( - std::move(output_side_packet_statusor).ValueOrDie()); - return mediapipe::OkStatus(); + std::move(output_side_packet_statusor).value()); + return absl::OkStatus(); } -mediapipe::Status CalculatorContract::Initialize( - const StatusHandlerConfig& node) { - std::vector statuses; +absl::Status CalculatorContract::Initialize(const StatusHandlerConfig& node) { + std::vector statuses; auto input_side_packet_statusor = tool::TagMap::Create(node.input_side_packet()); @@ -133,8 +131,8 @@ mediapipe::Status CalculatorContract::Initialize( } input_side_packets_ = absl::make_unique( - std::move(input_side_packet_statusor).ValueOrDie()); - return mediapipe::OkStatus(); + std::move(input_side_packet_statusor).value()); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/framework/calculator_contract.h b/mediapipe/framework/calculator_contract.h index dfe4a897b..9ff189ffb 100644 --- a/mediapipe/framework/calculator_contract.h +++ b/mediapipe/framework/calculator_contract.h @@ -47,9 +47,9 @@ namespace mediapipe { // class CalculatorContract { public: - mediapipe::Status Initialize(const CalculatorGraphConfig::Node& node); - mediapipe::Status Initialize(const PacketGeneratorConfig& node); - mediapipe::Status Initialize(const StatusHandlerConfig& node); + absl::Status Initialize(const CalculatorGraphConfig::Node& node); + absl::Status Initialize(const PacketGeneratorConfig& node); + absl::Status Initialize(const StatusHandlerConfig& node); void SetNodeName(const std::string& node_name) { node_name_ = node_name; } // Returns the options given to this node. diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index 2111569e8..961861abe 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -62,9 +62,9 @@ #include "mediapipe/framework/validated_graph_config.h" #include "mediapipe/gpu/graph_support.h" #include "mediapipe/util/cpu_util.h" -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gpu_shared_data_internal.h" -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { @@ -129,13 +129,13 @@ CalculatorGraph::CalculatorGraph(const CalculatorGraphConfig& config) // instantiated. CalculatorGraph::~CalculatorGraph() { // Stop periodic profiler output to ublock Executor destructors. - mediapipe::Status status = profiler()->Stop(); + absl::Status status = profiler()->Stop(); if (!status.ok()) { LOG(ERROR) << "During graph destruction: " << status; } } -mediapipe::Status CalculatorGraph::InitializePacketGeneratorGraph( +absl::Status CalculatorGraph::InitializePacketGeneratorGraph( const std::map& side_packets) { // Create and initialize the output side packets. if (!validated_graph_->OutputSidePacketInfos().empty()) { @@ -164,7 +164,7 @@ mediapipe::Status CalculatorGraph::InitializePacketGeneratorGraph( default_executor, side_packets); } -mediapipe::Status CalculatorGraph::InitializeStreams() { +absl::Status CalculatorGraph::InitializeStreams() { any_packet_type_.SetAny(); // Create and initialize the input streams. @@ -221,16 +221,16 @@ mediapipe::Status CalculatorGraph::InitializeStreams() { graph_input_stream_add_mode_ = GraphInputStreamAddMode::WAIT_TILL_NOT_FULL; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorGraph::InitializeCalculatorNodes() { +absl::Status CalculatorGraph::InitializeCalculatorNodes() { // Check if the user has specified a maximum queue size for an input stream. max_queue_size_ = validated_graph_->Config().max_queue_size(); max_queue_size_ = max_queue_size_ ? max_queue_size_ : 100; // Use a local variable to avoid needing to lock errors_. - std::vector errors; + std::vector errors; // Create and initialize all the nodes in the graph. nodes_ = absl::make_unique>( @@ -240,7 +240,7 @@ mediapipe::Status CalculatorGraph::InitializeCalculatorNodes() { // buffer_size_hint will be positive if one was specified in // the graph proto. int buffer_size_hint = 0; - const mediapipe::Status result = (*nodes_)[node_id].Initialize( + const absl::Status result = (*nodes_)[node_id].Initialize( validated_graph_.get(), node_id, input_stream_managers_.get(), output_stream_managers_.get(), output_side_packets_.get(), &buffer_size_hint, profiler_); @@ -259,15 +259,15 @@ mediapipe::Status CalculatorGraph::InitializeCalculatorNodes() { VLOG(2) << "Maximum input stream queue size based on graph config: " << max_queue_size_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorGraph::InitializeProfiler() { +absl::Status CalculatorGraph::InitializeProfiler() { profiler_->Initialize(*validated_graph_); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorGraph::InitializeExecutors() { +absl::Status CalculatorGraph::InitializeExecutors() { // If the ExecutorConfig for the default executor leaves the executor type // unspecified, default_executor_options points to the // ThreadPoolExecutorOptions in that ExecutorConfig. Otherwise, @@ -324,10 +324,10 @@ mediapipe::Status CalculatorGraph::InitializeExecutors() { use_application_thread)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorGraph::InitializeDefaultExecutor( +absl::Status CalculatorGraph::InitializeDefaultExecutor( const ThreadPoolExecutorOptions* default_executor_options, bool use_application_thread) { #ifdef __EMSCRIPTEN__ @@ -340,7 +340,7 @@ mediapipe::Status CalculatorGraph::InitializeDefaultExecutor( "", std::make_shared( std::bind(&internal::Scheduler::AddApplicationThreadTask, &scheduler_, std::placeholders::_1)))); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Check the number of threads specified in the proto. @@ -359,10 +359,10 @@ mediapipe::Status CalculatorGraph::InitializeDefaultExecutor( } MP_RETURN_IF_ERROR( CreateDefaultThreadPool(default_executor_options, num_threads)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorGraph::Initialize( +absl::Status CalculatorGraph::Initialize( std::unique_ptr validated_graph, const std::map& side_packets) { RET_CHECK(!initialized_).SetNoLogging() @@ -380,15 +380,15 @@ mediapipe::Status CalculatorGraph::Initialize( #endif initialized_ = true; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorGraph::Initialize( +absl::Status CalculatorGraph::Initialize( const CalculatorGraphConfig& input_config) { return Initialize(input_config, {}); } -mediapipe::Status CalculatorGraph::Initialize( +absl::Status CalculatorGraph::Initialize( const CalculatorGraphConfig& input_config, const std::map& side_packets) { auto validated_graph = absl::make_unique(); @@ -396,7 +396,7 @@ mediapipe::Status CalculatorGraph::Initialize( return Initialize(std::move(validated_graph), side_packets); } -mediapipe::Status CalculatorGraph::Initialize( +absl::Status CalculatorGraph::Initialize( const std::vector& input_configs, const std::vector& input_templates, const std::map& side_packets, @@ -407,9 +407,9 @@ mediapipe::Status CalculatorGraph::Initialize( return Initialize(std::move(validated_graph), side_packets); } -mediapipe::Status CalculatorGraph::ObserveOutputStream( +absl::Status CalculatorGraph::ObserveOutputStream( const std::string& stream_name, - std::function packet_callback) { + std::function packet_callback) { RET_CHECK(initialized_).SetNoLogging() << "CalculatorGraph is not initialized."; // TODO Allow output observers to be attached by graph level @@ -425,10 +425,10 @@ mediapipe::Status CalculatorGraph::ObserveOutputStream( stream_name, &any_packet_type_, std::move(packet_callback), &output_stream_managers_[output_stream_index])); graph_output_streams_.push_back(std::move(observer)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::StatusOr CalculatorGraph::AddOutputStreamPoller( +absl::StatusOr CalculatorGraph::AddOutputStreamPoller( const std::string& stream_name) { RET_CHECK(initialized_).SetNoLogging() << "CalculatorGraph is not initialized."; @@ -449,7 +449,7 @@ mediapipe::StatusOr CalculatorGraph::AddOutputStreamPoller( return std::move(poller); } -mediapipe::StatusOr CalculatorGraph::GetOutputSidePacket( +absl::StatusOr CalculatorGraph::GetOutputSidePacket( const std::string& packet_name) { int side_packet_index = validated_graph_->OutputSidePacketIndex(packet_name); if (side_packet_index < 0) { @@ -486,7 +486,7 @@ mediapipe::StatusOr CalculatorGraph::GetOutputSidePacket( return output_packet; } -mediapipe::Status CalculatorGraph::Run( +absl::Status CalculatorGraph::Run( const std::map& extra_side_packets) { RET_CHECK(graph_input_streams_.empty()).SetNoLogging() << "When using graph input streams, call StartRun() instead of Run() so " @@ -495,7 +495,7 @@ mediapipe::Status CalculatorGraph::Run( return WaitUntilDone(); } -mediapipe::Status CalculatorGraph::StartRun( +absl::Status CalculatorGraph::StartRun( const std::map& extra_side_packets, const std::map& stream_headers) { RET_CHECK(initialized_).SetNoLogging() @@ -503,18 +503,18 @@ mediapipe::Status CalculatorGraph::StartRun( MP_RETURN_IF_ERROR(PrepareForRun(extra_side_packets, stream_headers)); MP_RETURN_IF_ERROR(profiler_->Start(executors_[""].get())); scheduler_.Start(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -#ifndef MEDIAPIPE_DISABLE_GPU -mediapipe::Status CalculatorGraph::SetGpuResources( +#if !MEDIAPIPE_DISABLE_GPU +absl::Status CalculatorGraph::SetGpuResources( std::shared_ptr<::mediapipe::GpuResources> resources) { RET_CHECK(!ContainsKey(service_packets_, kGpuService.key)) << "The GPU resources have already been configured."; service_packets_[kGpuService.key] = MakePacket>( std::move(resources)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources() @@ -524,7 +524,7 @@ std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources() return service_iter->second.Get>(); } -mediapipe::StatusOr> CalculatorGraph::PrepareGpu( +absl::StatusOr> CalculatorGraph::PrepareGpu( const std::map& side_packets) { std::map additional_side_packets; bool update_sp = false; @@ -588,9 +588,9 @@ mediapipe::StatusOr> CalculatorGraph::PrepareGpu( } return additional_side_packets; } -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU -mediapipe::Status CalculatorGraph::PrepareForRun( +absl::Status CalculatorGraph::PrepareForRun( const std::map& extra_side_packets, const std::map& stream_headers) { if (VLOG_IS_ON(1)) { @@ -607,9 +607,9 @@ mediapipe::Status CalculatorGraph::PrepareForRun( num_closed_graph_input_streams_ = 0; std::map additional_side_packets; -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU ASSIGN_OR_RETURN(additional_side_packets, PrepareGpu(extra_side_packets)); -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU const std::map* input_side_packets; if (!additional_side_packets.empty()) { @@ -621,7 +621,7 @@ mediapipe::Status CalculatorGraph::PrepareForRun( } current_run_side_packets_.clear(); - mediapipe::Status generator_status = packet_generator_graph_.RunGraphSetup( + absl::Status generator_status = packet_generator_graph_.RunGraphSetup( *input_side_packets, ¤t_run_side_packets_); CallStatusHandlers(GraphRunState::PRE_RUN, generator_status); @@ -632,7 +632,7 @@ mediapipe::Status CalculatorGraph::PrepareForRun( // If there was an error on the CallStatusHandlers (PRE_RUN), it was stored // in the error list. We return immediately notifying this to the caller. - mediapipe::Status error_status; + absl::Status error_status; if (has_error_) { GetCombinedErrors(&error_status); LOG(ERROR) << error_status; @@ -682,7 +682,7 @@ mediapipe::Status CalculatorGraph::PrepareForRun( std::placeholders::_1, std::placeholders::_2); node.SetQueueSizeCallbacks(queue_size_callback, queue_size_callback); scheduler_.AssignNodeToSchedulerQueue(&node); - const mediapipe::Status result = node.PrepareForRun( + const absl::Status result = node.PrepareForRun( current_run_side_packets_, service_packets_, std::bind(&internal::Scheduler::ScheduleNodeForOpen, &scheduler_, &node), @@ -700,13 +700,13 @@ mediapipe::Status CalculatorGraph::PrepareForRun( for (auto& graph_output_stream : graph_output_streams_) { graph_output_stream->PrepareForRun( [&graph_output_stream, this] { - mediapipe::Status status = graph_output_stream->Notify(); + absl::Status status = graph_output_stream->Notify(); if (!status.ok()) { RecordError(status); } scheduler_.EmittedObservedOutput(); }, - [this](mediapipe::Status status) { RecordError(status); }); + [this](absl::Status status) { RecordError(status); }); } if (GetCombinedErrors(&error_status)) { @@ -759,20 +759,20 @@ mediapipe::Status CalculatorGraph::PrepareForRun( } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorGraph::WaitUntilIdle() { +absl::Status CalculatorGraph::WaitUntilIdle() { MP_RETURN_IF_ERROR(scheduler_.WaitUntilIdle()); VLOG(2) << "Scheduler idle."; - mediapipe::Status status = mediapipe::OkStatus(); + absl::Status status = absl::OkStatus(); if (GetCombinedErrors(&status)) { LOG(ERROR) << status; } return status; } -mediapipe::Status CalculatorGraph::WaitUntilDone() { +absl::Status CalculatorGraph::WaitUntilDone() { VLOG(2) << "Waiting for scheduler to terminate..."; MP_RETURN_IF_ERROR(scheduler_.WaitUntilDone()); VLOG(2) << "Scheduler terminated."; @@ -780,16 +780,16 @@ mediapipe::Status CalculatorGraph::WaitUntilDone() { return FinishRun(); } -mediapipe::Status CalculatorGraph::WaitForObservedOutput() { +absl::Status CalculatorGraph::WaitForObservedOutput() { return scheduler_.WaitForObservedOutput(); } -mediapipe::Status CalculatorGraph::AddPacketToInputStream( +absl::Status CalculatorGraph::AddPacketToInputStream( const std::string& stream_name, const Packet& packet) { return AddPacketToInputStreamInternal(stream_name, packet); } -mediapipe::Status CalculatorGraph::AddPacketToInputStream( +absl::Status CalculatorGraph::AddPacketToInputStream( const std::string& stream_name, Packet&& packet) { return AddPacketToInputStreamInternal(stream_name, std::move(packet)); } @@ -799,7 +799,7 @@ mediapipe::Status CalculatorGraph::AddPacketToInputStream( // internal-only templated version. T&& is a forwarding reference here, so // std::forward will deduce the correct type as we pass along packet. template -mediapipe::Status CalculatorGraph::AddPacketToInputStreamInternal( +absl::Status CalculatorGraph::AddPacketToInputStreamInternal( const std::string& stream_name, T&& packet) { std::unique_ptr* stream = mediapipe::FindOrNull(graph_input_streams_, stream_name); @@ -814,7 +814,7 @@ mediapipe::Status CalculatorGraph::AddPacketToInputStreamInternal( if (graph_input_stream_add_mode_ == GraphInputStreamAddMode::ADD_IF_NOT_FULL) { if (has_error_) { - mediapipe::Status error_status; + absl::Status error_status; GetCombinedErrors("Graph has errors: ", &error_status); return error_status; } @@ -835,7 +835,7 @@ mediapipe::Status CalculatorGraph::AddPacketToInputStreamInternal( &full_input_streams_mutex_); } if (has_error_) { - mediapipe::Status error_status; + absl::Status error_status; GetCombinedErrors("Graph has errors: ", &error_status); return error_status; } @@ -857,7 +857,7 @@ mediapipe::Status CalculatorGraph::AddPacketToInputStreamInternal( // because we don't have the lock over the input stream. (*stream)->AddPacket(std::forward(packet)); if (has_error_) { - mediapipe::Status error_status; + absl::Status error_status; GetCombinedErrors("Graph has errors: ", &error_status); return error_status; } @@ -869,23 +869,22 @@ mediapipe::Status CalculatorGraph::AddPacketToInputStreamInternal( // again if the graph is still idle. Unthrottling basically only lets in one // packet at a time. TODO: add test. scheduler_.AddedPacketToGraphInputStream(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorGraph::SetInputStreamMaxQueueSize( +absl::Status CalculatorGraph::SetInputStreamMaxQueueSize( const std::string& stream_name, int max_queue_size) { // graph_input_streams_ has not been filled in yet, so we'll check this when // it is applied when the graph is started. graph_input_stream_max_queue_size_[stream_name] = max_queue_size; - return mediapipe::OkStatus(); + return absl::OkStatus(); } bool CalculatorGraph::HasInputStream(const std::string& stream_name) { return mediapipe::FindOrNull(graph_input_streams_, stream_name) != nullptr; } -mediapipe::Status CalculatorGraph::CloseInputStream( - const std::string& stream_name) { +absl::Status CalculatorGraph::CloseInputStream(const std::string& stream_name) { std::unique_ptr* stream = mediapipe::FindOrNull(graph_input_streams_, stream_name); RET_CHECK(stream).SetNoLogging() << absl::Substitute( @@ -896,7 +895,7 @@ mediapipe::Status CalculatorGraph::CloseInputStream( // threads cannot call CloseInputStream() on the same stream_name at the same // time. if ((*stream)->IsClosed()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } (*stream)->Close(); @@ -905,10 +904,10 @@ mediapipe::Status CalculatorGraph::CloseInputStream( scheduler_.ClosedAllGraphInputStreams(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorGraph::CloseAllInputStreams() { +absl::Status CalculatorGraph::CloseAllInputStreams() { for (auto& item : graph_input_streams_) { item.second->Close(); } @@ -916,10 +915,10 @@ mediapipe::Status CalculatorGraph::CloseAllInputStreams() { num_closed_graph_input_streams_ = graph_input_streams_.size(); scheduler_.ClosedAllGraphInputStreams(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorGraph::CloseAllPacketSources() { +absl::Status CalculatorGraph::CloseAllPacketSources() { for (auto& item : graph_input_streams_) { item.second->Close(); } @@ -928,10 +927,10 @@ mediapipe::Status CalculatorGraph::CloseAllPacketSources() { scheduler_.ClosedAllGraphInputStreams(); scheduler_.CloseAllSourceNodes(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -void CalculatorGraph::RecordError(const mediapipe::Status& error) { +void CalculatorGraph::RecordError(const absl::Status& error) { VLOG(2) << "RecordError called with " << error; { absl::MutexLock lock(&error_mutex_); @@ -942,7 +941,7 @@ void CalculatorGraph::RecordError(const mediapipe::Status& error) { stream->NotifyError(); } if (errors_.size() > kMaxNumAccumulatedErrors) { - for (const mediapipe::Status& error : errors_) { + for (const absl::Status& error : errors_) { LOG(ERROR) << error; } LOG(FATAL) << "Forcefully aborting to prevent the framework running out " @@ -951,13 +950,13 @@ void CalculatorGraph::RecordError(const mediapipe::Status& error) { } } -bool CalculatorGraph::GetCombinedErrors(mediapipe::Status* error_status) { +bool CalculatorGraph::GetCombinedErrors(absl::Status* error_status) { return GetCombinedErrors("CalculatorGraph::Run() failed in Run: ", error_status); } bool CalculatorGraph::GetCombinedErrors(const std::string& error_prefix, - mediapipe::Status* error_status) { + absl::Status* error_status) { absl::MutexLock lock(&error_mutex_); if (!errors_.empty()) { *error_status = tool::CombinedStatus(error_prefix, errors_); @@ -967,7 +966,7 @@ bool CalculatorGraph::GetCombinedErrors(const std::string& error_prefix, } void CalculatorGraph::CallStatusHandlers(GraphRunState graph_run_state, - const mediapipe::Status& status) { + const absl::Status& status) { for (int status_handler_index = 0; status_handler_index < validated_graph_->Config().status_handler_size(); ++status_handler_index) { @@ -979,7 +978,7 @@ void CalculatorGraph::CallStatusHandlers(GraphRunState graph_run_state, validated_graph_->StatusHandlerInfos()[status_handler_index]; const PacketTypeSet& packet_type_set = status_handler_info.InputSidePacketTypes(); - mediapipe::StatusOr> packet_set_statusor = + absl::StatusOr> packet_set_statusor = tool::FillPacketSet(packet_type_set, current_run_side_packets_, nullptr); if (!packet_set_statusor.ok()) { @@ -989,18 +988,18 @@ void CalculatorGraph::CallStatusHandlers(GraphRunState graph_run_state, << "Skipping run of " << handler_type << ": "); continue; } - mediapipe::StatusOr> + absl::StatusOr> static_access_statusor = internal::StaticAccessToStatusHandlerRegistry:: CreateByNameInNamespace(validated_graph_->Package(), handler_type); CHECK(static_access_statusor.ok()) << handler_type << " is not registered."; - auto static_access = std::move(static_access_statusor).ValueOrDie(); - mediapipe::Status handler_result; + auto static_access = std::move(static_access_statusor).value(); + absl::Status handler_result; if (graph_run_state == GraphRunState::PRE_RUN) { handler_result = static_access->HandlePreRunStatus( - handler_config.options(), *packet_set_statusor.ValueOrDie(), status); + handler_config.options(), *packet_set_statusor.value(), status); } else { // POST_RUN handler_result = static_access->HandleStatus( - handler_config.options(), *packet_set_statusor.ValueOrDie(), status); + handler_config.options(), *packet_set_statusor.value(), status); } if (!handler_result.ok()) { mediapipe::StatusBuilder builder(std::move(handler_result), @@ -1134,7 +1133,7 @@ bool CalculatorGraph::UnthrottleSources() { } for (InputStreamManager* stream : full_streams) { if (Config().report_deadlock()) { - RecordError(mediapipe::UnavailableError(absl::StrCat( + RecordError(absl::UnavailableError(absl::StrCat( "Detected a deadlock due to input throttling for: \"", stream->Name(), "\". All calculators are idle while packet sources remain active " "and throttled. Consider adjusting \"max_queue_size\" or " @@ -1163,7 +1162,7 @@ void CalculatorGraph::SetGraphInputStreamAddMode(GraphInputStreamAddMode mode) { } void CalculatorGraph::Cancel() { - // TODO This function should return mediapipe::Status. + // TODO This function should return absl::Status. scheduler_.Cancel(); } @@ -1171,11 +1170,11 @@ void CalculatorGraph::Pause() { scheduler_.Pause(); } void CalculatorGraph::Resume() { scheduler_.Resume(); } -mediapipe::Status CalculatorGraph::SetServicePacket( - const GraphServiceBase& service, Packet p) { +absl::Status CalculatorGraph::SetServicePacket(const GraphServiceBase& service, + Packet p) { // TODO: check that the graph has not been started! service_packets_[service.key] = std::move(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); } Packet CalculatorGraph::GetServicePacket(const GraphServiceBase& service) { @@ -1186,7 +1185,7 @@ Packet CalculatorGraph::GetServicePacket(const GraphServiceBase& service) { return it->second; } -mediapipe::Status CalculatorGraph::SetExecutorInternal( +absl::Status CalculatorGraph::SetExecutorInternal( const std::string& name, std::shared_ptr executor) { if (!executors_.emplace(name, executor).second) { return mediapipe::AlreadyExistsErrorBuilder(MEDIAPIPE_LOC) @@ -1198,11 +1197,11 @@ mediapipe::Status CalculatorGraph::SetExecutorInternal( } else { MP_RETURN_IF_ERROR(scheduler_.SetNonDefaultExecutor(name, executor.get())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorGraph::SetExecutor( - const std::string& name, std::shared_ptr executor) { +absl::Status CalculatorGraph::SetExecutor(const std::string& name, + std::shared_ptr executor) { RET_CHECK(!initialized_) << "SetExecutor can only be called before Initialize()"; if (IsReservedExecutorName(name)) { @@ -1212,7 +1211,7 @@ mediapipe::Status CalculatorGraph::SetExecutor( return SetExecutorInternal(name, std::move(executor)); } -mediapipe::Status CalculatorGraph::CreateDefaultThreadPool( +absl::Status CalculatorGraph::CreateDefaultThreadPool( const ThreadPoolExecutorOptions* default_executor_options, int num_threads) { MediaPipeOptions extendable_options; @@ -1234,16 +1233,16 @@ bool CalculatorGraph::IsReservedExecutorName(const std::string& name) { return ValidatedGraphConfig::IsReservedExecutorName(name); } -mediapipe::Status CalculatorGraph::FinishRun() { +absl::Status CalculatorGraph::FinishRun() { // Check for any errors that may have occurred. - mediapipe::Status status = mediapipe::OkStatus(); + absl::Status status = absl::OkStatus(); MP_RETURN_IF_ERROR(profiler_->Stop()); GetCombinedErrors(&status); CleanupAfterRun(&status); return status; } -void CalculatorGraph::CleanupAfterRun(mediapipe::Status* status) { +void CalculatorGraph::CleanupAfterRun(absl::Status* status) { for (auto& item : graph_input_streams_) { item.second->Close(); } @@ -1310,7 +1309,7 @@ bool MetricElementComparator(const std::pair& e1, } } // namespace -mediapipe::Status CalculatorGraph::GetCalculatorProfiles( +absl::Status CalculatorGraph::GetCalculatorProfiles( std::vector* profiles) const { return profiler_->GetCalculatorProfiles(profiles); } diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index 56d46b0ae..a70da438b 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -53,16 +53,16 @@ #include "mediapipe/framework/scheduler.h" #include "mediapipe/framework/thread_pool_executor.pb.h" -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU namespace mediapipe { class GpuResources; struct GpuSharedData; } // namespace mediapipe -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { -typedef mediapipe::StatusOr StatusOrPoller; +typedef absl::StatusOr StatusOrPoller; // The class representing a DAG of calculator nodes. // @@ -126,12 +126,11 @@ class CalculatorGraph { // Initializes the graph from a its proto description. // side_packets that are provided at this stage are common across all Run() // invocations and could be used to execute PacketGenerators immediately. - mediapipe::Status Initialize( - const CalculatorGraphConfig& config, - const std::map& side_packets); + absl::Status Initialize(const CalculatorGraphConfig& config, + const std::map& side_packets); // Convenience version which does not take side packets. - mediapipe::Status Initialize(const CalculatorGraphConfig& config); + absl::Status Initialize(const CalculatorGraphConfig& config); // Initializes the CalculatorGraph from the specified graph and subgraph // configs. Template graph and subgraph configs can be specified through @@ -139,7 +138,7 @@ class CalculatorGraph { // CalclatorGraphConfig.type. A subgraph can be instantiated directly by // specifying its type in |graph_type|. A template graph can be instantiated // directly by specifying its template arguments in |options|. - mediapipe::Status Initialize( + absl::Status Initialize( const std::vector& configs, const std::vector& templates, const std::map& side_packets = {}, @@ -155,9 +154,9 @@ class CalculatorGraph { // packet emitted by the output stream. Can only be called before Run() or // StartRun(). // TODO: Rename to AddOutputStreamCallback. - mediapipe::Status ObserveOutputStream( + absl::Status ObserveOutputStream( const std::string& stream_name, - std::function packet_callback); + std::function packet_callback); // Adds an OutputStreamPoller for a stream. This provides a synchronous, // polling API for accessing a stream's output. Should only be called before @@ -169,17 +168,16 @@ class CalculatorGraph { // packets (generated by PacketGenerators) can be retrieved before // graph is done. Returns error if the graph is still running (for non-base // packets) or the output side packet is not found or empty. - mediapipe::StatusOr GetOutputSidePacket( - const std::string& packet_name); + absl::StatusOr GetOutputSidePacket(const std::string& packet_name); // Runs the graph after adding the given extra input side packets. All // arguments are forgotten after Run() returns. // Run() is a blocking call and will return when all calculators are done. - virtual mediapipe::Status Run( + virtual absl::Status Run( const std::map& extra_side_packets); // Run the graph without adding any input side packets. - mediapipe::Status Run() { return Run({}); } + absl::Status Run() { return Run({}); } // Start a run of the graph. StartRun, WaitUntilDone, HasError, // AddPacketToInputStream, and CloseInputStream allow more control over @@ -199,7 +197,7 @@ class CalculatorGraph { // MP_RETURN_IF_ERROR(graph.CloseInputStream(stream)); // } // MP_RETURN_IF_ERROR(graph.WaitUntilDone()); - mediapipe::Status StartRun( + absl::Status StartRun( const std::map& extra_side_packets) { return StartRun(extra_side_packets, {}); } @@ -208,28 +206,27 @@ class CalculatorGraph { // stream header before running. // Note: We highly discourage the use of stream headers, this is added for the // compatibility of existing calculators that use headers during Open(). - mediapipe::Status StartRun( - const std::map& extra_side_packets, - const std::map& stream_headers); + absl::Status StartRun(const std::map& extra_side_packets, + const std::map& stream_headers); // Wait for the current run to finish (block the current thread // until all source calculators have returned StatusStop(), all // graph_input_streams_ have been closed, and no more calculators can // be run). This function can be called only after StartRun(). - mediapipe::Status WaitUntilDone(); + absl::Status WaitUntilDone(); // Wait until the running graph is in the idle mode, which is when nothing can // be scheduled and nothing is running in the worker threads. This function // can be called only after StartRun(). // NOTE: The graph must not have any source nodes because source nodes prevent // the running graph from becoming idle until the source nodes are done. - mediapipe::Status WaitUntilIdle(); + absl::Status WaitUntilIdle(); // Wait until a packet is emitted on one of the observed output streams. // Returns immediately if a packet has already been emitted since the last // call to this function. // Returns OutOfRangeError if the graph terminated while waiting. - mediapipe::Status WaitForObservedOutput(); + absl::Status WaitForObservedOutput(); // Quick non-locking means of checking if the graph has encountered an error. bool HasError() const { return has_error_; } @@ -243,8 +240,8 @@ class CalculatorGraph { // sizes of the queues in the graph. The input stream must have been specified // in the configuration as a graph level input_stream. On error, nothing is // added. - mediapipe::Status AddPacketToInputStream(const std::string& stream_name, - const Packet& packet); + absl::Status AddPacketToInputStream(const std::string& stream_name, + const Packet& packet); // Same as the l-value version of this function by the same name, but moves // the r-value referenced packet into the stream instead of copying it over. @@ -253,12 +250,12 @@ class CalculatorGraph { // packet may remain valid. In particular, when using the ADD_IF_NOT_FULL // mode with a full queue, this will return StatusUnavailable and the caller // may try adding the packet again later. - mediapipe::Status AddPacketToInputStream(const std::string& stream_name, - Packet&& packet); + absl::Status AddPacketToInputStream(const std::string& stream_name, + Packet&& packet); // Sets the queue size of a graph input stream, overriding the graph default. - mediapipe::Status SetInputStreamMaxQueueSize(const std::string& stream_name, - int max_queue_size); + absl::Status SetInputStreamMaxQueueSize(const std::string& stream_name, + int max_queue_size); // Check if an input stream exists in the graph bool HasInputStream(const std::string& name); @@ -268,14 +265,14 @@ class CalculatorGraph { // been closed (and all packets propagate through the graph). // Note that multiple threads cannot call CloseInputStream() on the same // stream_name at the same time. - mediapipe::Status CloseInputStream(const std::string& stream_name); + absl::Status CloseInputStream(const std::string& stream_name); // Closes all the graph input streams. // TODO: deprecate this function in favor of CloseAllPacketSources. - mediapipe::Status CloseAllInputStreams(); + absl::Status CloseAllInputStreams(); // Closes all the graph input streams and source calculator nodes. - mediapipe::Status CloseAllPacketSources(); + absl::Status CloseAllPacketSources(); // Returns the pointer to the stream with the given name, or dies if none // exists. The result remains owned by the CalculatorGraph. @@ -290,8 +287,7 @@ class CalculatorGraph { // calculator in the graph. May be called at any time after the graph has been // initialized. ABSL_DEPRECATED("Use profiler()->GetCalculatorProfiles() instead") - mediapipe::Status GetCalculatorProfiles( - std::vector*) const; + absl::Status GetCalculatorProfiles(std::vector*) const; // Set the type of counter used in this graph. void SetCounterFactory(CounterFactory* factory) { @@ -301,15 +297,14 @@ class CalculatorGraph { // Callback when an error is encountered. // Adds the error to the vector of errors. - void RecordError(const mediapipe::Status& error) - ABSL_LOCKS_EXCLUDED(error_mutex_); + void RecordError(const absl::Status& error) ABSL_LOCKS_EXCLUDED(error_mutex_); // Combines errors into a status. Returns true if the vector of errors is // non-empty. bool GetCombinedErrors(const std::string& error_prefix, - mediapipe::Status* error_status); + absl::Status* error_status); // Convenience overload which specifies a default error prefix. - bool GetCombinedErrors(mediapipe::Status* error_status); + bool GetCombinedErrors(absl::Status* error_status); // Returns the maximum input stream queue size. int GetMaxInputStreamQueueSize(); @@ -338,8 +333,8 @@ class CalculatorGraph { // Sets the executor that will run the nodes assigned to the executor // named |name|. If |name| is empty, this sets the default executor. Must // be called before the graph is initialized. - mediapipe::Status SetExecutor(const std::string& name, - std::shared_ptr executor); + absl::Status SetExecutor(const std::string& name, + std::shared_ptr executor); // WARNING: the following public methods are exposed to Scheduler only. @@ -365,23 +360,23 @@ class CalculatorGraph { return scheduler_.GetSchedulerTimes(); } -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU // Returns a pointer to the GpuResources in use, if any. // Only meant for internal use. std::shared_ptr<::mediapipe::GpuResources> GetGpuResources() const; - mediapipe::Status SetGpuResources( + absl::Status SetGpuResources( std::shared_ptr<::mediapipe::GpuResources> resources); // Helper for PrepareForRun. If it returns a non-empty map, those packets // must be added to the existing side packets, replacing existing values // that have the same key. - mediapipe::StatusOr> PrepareGpu( + absl::StatusOr> PrepareGpu( const std::map& side_packets); -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU template - mediapipe::Status SetServiceObject(const GraphService& service, - std::shared_ptr object) { + absl::Status SetServiceObject(const GraphService& service, + std::shared_ptr object) { return SetServicePacket(service, MakePacket>(std::move(object))); } @@ -394,7 +389,7 @@ class CalculatorGraph { } // Only the Java API should call this directly. - mediapipe::Status SetServicePacket(const GraphServiceBase& service, Packet p); + absl::Status SetServicePacket(const GraphServiceBase& service, Packet p); private: // GraphRunState is used as a parameter in the function CallStatusHandlers. @@ -418,7 +413,7 @@ class CalculatorGraph { shard_.SetSpec(manager_->Spec()); } - void PrepareForRun(std::function error_callback) { + void PrepareForRun(std::function error_callback) { manager_->PrepareForRun(std::move(error_callback)); } @@ -446,36 +441,35 @@ class CalculatorGraph { }; // Initializes the graph from a ValidatedGraphConfig object. - mediapipe::Status Initialize( - std::unique_ptr validated_graph, - const std::map& side_packets); + absl::Status Initialize(std::unique_ptr validated_graph, + const std::map& side_packets); // AddPacketToInputStreamInternal template is called by either // AddPacketToInputStream(Packet&& packet) or // AddPacketToInputStream(const Packet& packet). template - mediapipe::Status AddPacketToInputStreamInternal( - const std::string& stream_name, T&& packet); + absl::Status AddPacketToInputStreamInternal(const std::string& stream_name, + T&& packet); // Sets the executor that will run the nodes assigned to the executor // named |name|. If |name| is empty, this sets the default executor. // Does not check that the graph is uninitialized and |name| is not a // reserved executor name. - mediapipe::Status SetExecutorInternal(const std::string& name, - std::shared_ptr executor); + absl::Status SetExecutorInternal(const std::string& name, + std::shared_ptr executor); // If the num_threads field in default_executor_options is not specified, // assigns a reasonable value based on system configuration and the graph. // Then, creates the default thread pool if appropriate. // // Only called by InitializeExecutors(). - mediapipe::Status InitializeDefaultExecutor( + absl::Status InitializeDefaultExecutor( const ThreadPoolExecutorOptions* default_executor_options, bool use_application_thread); // Creates a thread pool as the default executor. The num_threads argument // overrides the num_threads field in default_executor_options. - mediapipe::Status CreateDefaultThreadPool( + absl::Status CreateDefaultThreadPool( const ThreadPoolExecutorOptions* default_executor_options, int num_threads); @@ -483,39 +477,38 @@ class CalculatorGraph { static bool IsReservedExecutorName(const std::string& name); // Helper functions for Initialize(). - mediapipe::Status InitializeExecutors(); - mediapipe::Status InitializePacketGeneratorGraph( + absl::Status InitializeExecutors(); + absl::Status InitializePacketGeneratorGraph( const std::map& side_packets); - mediapipe::Status InitializeStreams(); - mediapipe::Status InitializeProfiler(); - mediapipe::Status InitializeCalculatorNodes(); + absl::Status InitializeStreams(); + absl::Status InitializeProfiler(); + absl::Status InitializeCalculatorNodes(); // Iterates through all nodes and schedules any that can be opened. void ScheduleAllOpenableNodes(); // Does the bulk of the work for StartRun but does not start the scheduler. - mediapipe::Status PrepareForRun( + absl::Status PrepareForRun( const std::map& extra_side_packets, const std::map& stream_headers); // Cleans up any remaining state after the run and returns any errors that may // have occurred during the run. Called after the scheduler has terminated. - mediapipe::Status FinishRun(); + absl::Status FinishRun(); // Cleans up any remaining state after the run. All status handlers run here // if their requested input side packets exist. // The original |*status| is passed to all the status handlers. If any status // handler fails, it appends its error to errors_, and CleanupAfterRun sets // |*status| to the new combined errors on return. - void CleanupAfterRun(mediapipe::Status* status) - ABSL_LOCKS_EXCLUDED(error_mutex_); + void CleanupAfterRun(absl::Status* status) ABSL_LOCKS_EXCLUDED(error_mutex_); // Calls HandlePreRunStatus or HandleStatus on the StatusHandlers. Which one // is called depends on the GraphRunState parameter (PRE_RUN or POST_RUN). // current_run_side_packets_ must be set before this function is called. // On error, has_error_ will be set. void CallStatusHandlers(GraphRunState graph_run_state, - const mediapipe::Status& status); + const absl::Status& status); // Callback function to throttle or unthrottle source nodes when a stream // becomes full or non-full. A node is throttled (i.e. prevented being @@ -531,11 +524,11 @@ class CalculatorGraph { void UpdateThrottledNodes(InputStreamManager* stream, bool* stream_was_full); Packet GetServicePacket(const GraphServiceBase& service); -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU // Owns the legacy GpuSharedData if we need to create one for backwards // compatibility. std::unique_ptr<::mediapipe::GpuSharedData> legacy_gpu_shared_; -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU // True if the graph was initialized. bool initialized_ = false; @@ -609,7 +602,7 @@ class CalculatorGraph { // Vector of errors encountered while running graph. Always use RecordError() // to add an error to this vector. - std::vector errors_ ABSL_GUARDED_BY(error_mutex_); + std::vector errors_ ABSL_GUARDED_BY(error_mutex_); // True if the default executor uses the application thread. bool use_application_thread_ = false; diff --git a/mediapipe/framework/calculator_graph_bounds_test.cc b/mediapipe/framework/calculator_graph_bounds_test.cc index 21ee132d9..2c71f8cb3 100644 --- a/mediapipe/framework/calculator_graph_bounds_test.cc +++ b/mediapipe/framework/calculator_graph_bounds_test.cc @@ -30,7 +30,7 @@ namespace { constexpr int kIntTestValue = 33; -typedef std::function +typedef std::function CalculatorContextFunction; // Returns the contents of a set of Packets. @@ -87,26 +87,24 @@ class CountingExecutor : public Executor { // streams and outputs the sum to the output stream. class IntAdderCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).Set(); } cc->Outputs().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int sum = 0; for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { sum += cc->Inputs().Index(i).Get(); } cc->Outputs().Index(0).Add(new int(sum), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(IntAdderCalculator); @@ -114,13 +112,13 @@ REGISTER_CALCULATOR(IntAdderCalculator); template class TypedSinkCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; typedef TypedSinkCalculator StringSinkCalculator; @@ -132,13 +130,13 @@ REGISTER_CALCULATOR(IntSinkCalculator); // integer. class EvenIntFilterCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int value = cc->Inputs().Index(0).Get(); if (value % 2 == 0) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); @@ -146,7 +144,7 @@ class EvenIntFilterCalculator : public CalculatorBase { cc->Outputs().Index(0).SetNextTimestampBound( cc->InputTimestamp().NextAllowedInStream()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(EvenIntFilterCalculator); @@ -156,19 +154,19 @@ REGISTER_CALCULATOR(EvenIntFilterCalculator); // input stream carries the value true. class ValveCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Inputs().Index(1).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { if (cc->Inputs().Index(1).Get()) { cc->GetCounter("PassThrough")->Increment(); cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); @@ -182,7 +180,7 @@ class ValveCalculator : public CalculatorBase { cc->Outputs().Index(0).SetNextTimestampBound( cc->InputTimestamp().NextAllowedInStream()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(ValveCalculator); @@ -191,27 +189,27 @@ REGISTER_CALCULATOR(ValveCalculator); // but shifts the timestamp. class TimeShiftCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->InputSidePackets().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { // Input: arbitrary Packets. // Output: copy of the input. cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header()); shift_ = cc->InputSidePackets().Index(0).Get(); cc->SetOffset(shift_); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->GetCounter("PassThrough")->Increment(); cc->Outputs().Index(0).AddPacket( cc->Inputs().Index(0).Value().At(cc->InputTimestamp() + shift_)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -231,17 +229,17 @@ REGISTER_CALCULATOR(TimeShiftCalculator); // T=2000 Output 100 class OutputAndBoundSourceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { counter_ = 0; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { Timestamp timestamp(counter_); if (counter_ % 20 == 0) { cc->Outputs().Index(0).AddPacket( @@ -253,7 +251,7 @@ class OutputAndBoundSourceCalculator : public CalculatorBase { return tool::StatusStop(); } counter_ += 10; - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -266,42 +264,40 @@ REGISTER_CALCULATOR(OutputAndBoundSourceCalculator); // Process() method. The input stream and output stream have the integer type. class Delay20Calculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(20)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).AddPacket(MakePacket(0).At(Timestamp(0))); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { const Packet& packet = cc->Inputs().Index(0).Value(); Timestamp timestamp = packet.Timestamp() + 20; cc->Outputs().Index(0).AddPacket(packet.At(timestamp)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(Delay20Calculator); class CustomBoundCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->Outputs().Index(0).SetNextTimestampBound(cc->InputTimestamp() + 1); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(CustomBoundCalculator); @@ -613,7 +609,7 @@ TEST(CalculatorGraphBoundsTest, ImmediateHandlerBounds) { MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) { output_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.WaitUntilIdle()); @@ -638,47 +634,41 @@ TEST(CalculatorGraphBoundsTest, ImmediateHandlerBounds) { // A Calculator that only sets timestamp bound by SetOffset(). class OffsetBoundCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); } }; REGISTER_CALCULATOR(OffsetBoundCalculator); // A Calculator that produces a packet for each call to Process. class BoundToPacketCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetAny(); } for (int i = 0; i < cc->Outputs().NumEntries(); ++i) { cc->Outputs().Index(i).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { for (int i = 0; i < cc->Outputs().NumEntries(); ++i) { Timestamp t = cc->Inputs().Index(i).Value().Timestamp(); cc->Outputs().Index(i).AddPacket( mediapipe::MakePacket(t).At(cc->InputTimestamp())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(BoundToPacketCalculator); @@ -688,22 +678,20 @@ class FuturePacketCalculator : public CalculatorBase { public: static constexpr int64 kOutputFutureMicros = 3; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { const Packet& packet = cc->Inputs().Index(0).Value(); Timestamp timestamp = Timestamp(packet.Timestamp().Value() + kOutputFutureMicros); cc->Outputs().Index(0).AddPacket(packet.At(timestamp)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(FuturePacketCalculator); @@ -735,7 +723,7 @@ TEST(CalculatorGraphBoundsTest, OffsetBoundPropagation) { MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) { output_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.WaitUntilIdle()); @@ -786,7 +774,7 @@ TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) { MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) { output_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.WaitUntilIdle()); @@ -860,13 +848,13 @@ TEST(CalculatorGraphBoundsTest, FixedSizeHandlerBounds) { std::vector outputs; MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) { outputs.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); std::vector thinned_outputs; MP_ASSERT_OK( graph.ObserveOutputStream("thinned_output", [&](const Packet& p) { thinned_outputs.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); // The enter_semaphore is used to wait for LambdaCalculator::Process. @@ -875,13 +863,13 @@ TEST(CalculatorGraphBoundsTest, FixedSizeHandlerBounds) { AtomicSemaphore exit_semaphore(0); CalculatorContextFunction open_fn = [&](CalculatorContext* cc) { cc->SetOffset(0); - return mediapipe::OkStatus(); + return absl::OkStatus(); }; CalculatorContextFunction process_fn = [&](CalculatorContext* cc) { enter_semaphore.Release(1); exit_semaphore.Acquire(1); cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); }; MP_ASSERT_OK(graph.StartRun({ {"open_fn", Adopt(new auto(open_fn))}, @@ -935,22 +923,20 @@ TEST(CalculatorGraphBoundsTest, FixedSizeHandlerBounds) { // A Calculator that outputs only the last packet from its input stream. class LastPacketCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetAny(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } + absl::Status Process(CalculatorContext* cc) final { cc->Outputs().Index(0).SetNextTimestampBound(cc->InputTimestamp()); last_packet_ = cc->Inputs().Index(0).Value(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { cc->Outputs().Index(0).AddPacket(last_packet_); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -992,12 +978,12 @@ TEST(CalculatorGraphBoundsTest, LastPacketCheck) { MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) { output_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); std::vector last_output_packets; MP_ASSERT_OK(graph.ObserveOutputStream("last_output", [&](const Packet& p) { last_output_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.WaitUntilIdle()); @@ -1055,11 +1041,11 @@ void TestBoundsForEmptyInputs(std::string input_stream_handler) { MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.ObserveOutputStream("input_ts", [&](const Packet& p) { input_ts_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.ObserveOutputStream("bounds_ts", [&](const Packet& p) { bounds_ts_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.WaitUntilIdle()); @@ -1129,7 +1115,7 @@ TEST(CalculatorGraphBoundsTest, BoundsForEmptyInputs_SyncSets) { // A Calculator that produces a packet for each timestamp bounds update. class ProcessBoundToPacketCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetAny(); } @@ -1138,10 +1124,10 @@ class ProcessBoundToPacketCalculator : public CalculatorBase { } cc->SetInputStreamHandler("ImmediateInputStreamHandler"); cc->SetProcessTimestampBounds(true); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { for (int i = 0; i < cc->Outputs().NumEntries(); ++i) { Timestamp t = cc->Inputs().Index(i).Value().Timestamp(); // Create a new packet for each input stream with a new timestamp bound, @@ -1151,7 +1137,7 @@ class ProcessBoundToPacketCalculator : public CalculatorBase { cc->Outputs().Index(i).Add(new auto(t), t); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(ProcessBoundToPacketCalculator); @@ -1159,7 +1145,7 @@ REGISTER_CALCULATOR(ProcessBoundToPacketCalculator); // A Calculator that passes through each packet and timestamp immediately. class ImmediatePassthroughCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetAny(); } @@ -1168,10 +1154,10 @@ class ImmediatePassthroughCalculator : public CalculatorBase { } cc->SetInputStreamHandler("ImmediateInputStreamHandler"); cc->SetProcessTimestampBounds(true); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { for (int i = 0; i < cc->Outputs().NumEntries(); ++i) { if (!cc->Inputs().Index(i).IsEmpty()) { cc->Outputs().Index(i).AddPacket(cc->Inputs().Index(i).Value()); @@ -1185,7 +1171,7 @@ class ImmediatePassthroughCalculator : public CalculatorBase { } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(ImmediatePassthroughCalculator); @@ -1224,7 +1210,7 @@ void TestProcessForEmptyInputs(const std::string& input_stream_handler) { MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.ObserveOutputStream("bounds_ts", [&](const Packet& p) { bounds_ts_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.WaitUntilIdle()); @@ -1324,11 +1310,11 @@ TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_Passthrough) { MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) { output_0_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.ObserveOutputStream("output_1", [&](const Packet& p) { output_1_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.WaitUntilIdle()); @@ -1378,20 +1364,20 @@ TEST(CalculatorGraphBoundsTest, ProcessTimestampBounds_Passthrough) { // A Calculator that sends a timestamp bound for every other input. class OccasionalBoundCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { absl::SleepFor(absl::Milliseconds(1)); if (cc->InputTimestamp().Value() % 20 == 0) { Timestamp bound = cc->InputTimestamp().NextAllowedInStream(); cc->Outputs().Index(0).SetNextTimestampBound( std::max(bound, cc->Outputs().Index(0).NextTimestampBound())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(OccasionalBoundCalculator); @@ -1419,7 +1405,7 @@ TEST(CalculatorGraphBoundsTest, MaxInFlightWithOccasionalBound) { MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) { output_0_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.WaitUntilIdle()); @@ -1443,20 +1429,18 @@ TEST(CalculatorGraphBoundsTest, MaxInFlightWithOccasionalBound) { // A Calculator that uses both SetTimestampOffset and SetNextTimestampBound. class OffsetAndBoundCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } + absl::Status Process(CalculatorContext* cc) final { if (cc->InputTimestamp().Value() % 20 == 0) { cc->Outputs().Index(0).SetNextTimestampBound(Timestamp(10000)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(OffsetAndBoundCalculator); @@ -1481,7 +1465,7 @@ TEST(CalculatorGraphBoundsTest, OffsetAndBound) { MP_ASSERT_OK(graph.Initialize(config)); MP_ASSERT_OK(graph.ObserveOutputStream("output_0", [&](const Packet& p) { output_0_packets.push_back(p); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); MP_ASSERT_OK(graph.WaitUntilIdle()); diff --git a/mediapipe/framework/calculator_graph_event_loop_test.cc b/mediapipe/framework/calculator_graph_event_loop_test.cc index 4cdae61a4..a8343046d 100644 --- a/mediapipe/framework/calculator_graph_event_loop_test.cc +++ b/mediapipe/framework/calculator_graph_event_loop_test.cc @@ -53,25 +53,25 @@ class CalculatorGraphEventLoopTest : public testing::Test { // testing. class BlockingPassThroughCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->InputSidePackets().Index(0).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { mutex_ = GetFromUniquePtr(cc->InputSidePackets().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { mutex_->Lock(); cc->Outputs().Index(0).AddPacket( cc->Inputs().Index(0).Value().At(cc->InputTimestamp())); mutex_->Unlock(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -87,15 +87,15 @@ struct SimpleHeader { class UsingHeaderCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { if (cc->Inputs().Index(0).Header().IsEmpty()) { - return mediapipe::UnknownError("No stream header present."); + return absl::UnknownError("No stream header present."); } const SimpleHeader& header = @@ -105,13 +105,13 @@ class UsingHeaderCalculator : public CalculatorBase { output_header->height = header.height; cc->Outputs().Index(0).SetHeader(Adopt(output_header.release())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->Outputs().Index(0).AddPacket( cc->Inputs().Index(0).Value().At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(UsingHeaderCalculator); @@ -187,21 +187,20 @@ TEST_F(CalculatorGraphEventLoopTest, WellProvisionedEventLoop) { // Pass-Through calculator that fails upon receiving the 10th packet. class FailingPassThroughCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { Timestamp timestamp = cc->InputTimestamp(); if (timestamp.Value() == 9) { - return mediapipe::UnknownError( - "Meant to fail (magicstringincludedhere)."); + return absl::UnknownError("Meant to fail (magicstringincludedhere)."); } cc->Outputs().Index(0).AddPacket( cc->Inputs().Index(0).Value().At(timestamp)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(FailingPassThroughCalculator); @@ -231,7 +230,7 @@ TEST_F(CalculatorGraphEventLoopTest, FailingEventLoop) { this, std::placeholders::_1))}})); // Insert packets. - mediapipe::Status status; + absl::Status status; for (int i = 0; true; ++i) { status = graph.AddPacketToInputStream("input_numbers", Adopt(new int(i)).At(Timestamp(i))); @@ -315,10 +314,10 @@ TEST_F(CalculatorGraphEventLoopTest, SetStreamHeader) { &CalculatorGraphEventLoopTest::AddThreadSafeVectorSink, this, std::placeholders::_1))}})); - mediapipe::Status status = graph.WaitUntilIdle(); + absl::Status status = graph.WaitUntilIdle(); // Expect to fail if header not set. ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kUnknown); + EXPECT_EQ(status.code(), absl::StatusCode::kUnknown); EXPECT_THAT(status.message(), testing::HasSubstr("No stream header present.")); @@ -387,7 +386,7 @@ TEST_F(CalculatorGraphEventLoopTest, TryToAddPacketToInputStream) { // mechanism could be off by 1 at most due to the order of acquisition of // locks. for (int i = 0; i < kNumInputPackets; ++i) { - mediapipe::Status status = graph.AddPacketToInputStream( + absl::Status status = graph.AddPacketToInputStream( "input_numbers", Adopt(new int(i)).At(Timestamp(i))); if (!status.ok()) { ++fail_count; @@ -472,7 +471,7 @@ TEST_F(CalculatorGraphEventLoopTest, ThrottleGraphInputStreamTwice) { // Lock the mutex so that the BlockingPassThroughCalculator cannot read any // of these packets. mutex->Lock(); - mediapipe::Status status = mediapipe::OkStatus(); + absl::Status status = absl::OkStatus(); for (int i = 0; i < 10; ++i) { status = graph.AddPacketToInputStream("input_numbers", Adopt(new int(i)).At(Timestamp(i))); @@ -482,7 +481,7 @@ TEST_F(CalculatorGraphEventLoopTest, ThrottleGraphInputStreamTwice) { } mutex->Unlock(); ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kUnavailable); + EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable); EXPECT_THAT(status.message(), testing::HasSubstr("Graph is throttled.")); MP_ASSERT_OK(graph.CloseInputStream("input_numbers")); MP_ASSERT_OK(graph.WaitUntilDone()); @@ -523,7 +522,7 @@ TEST_F(CalculatorGraphEventLoopTest, WaitToAddPacketToInputStream) { // All of these packets should be accepted by the graph. int fail_count = 0; for (int i = 0; i < kNumInputPackets; ++i) { - mediapipe::Status status = graph.AddPacketToInputStream( + absl::Status status = graph.AddPacketToInputStream( "input_numbers", Adopt(new int(i)).At(Timestamp(i))); if (!status.ok()) { ++fail_count; @@ -576,7 +575,7 @@ TEST_F(CalculatorGraphEventLoopTest, UnthrottleSources) { CalculatorGraph::GraphInputStreamAddMode::ADD_IF_NOT_FULL); auto poller_status = graph.AddOutputStreamPoller("output_numbers"); MP_ASSERT_OK(poller_status.status()); - mediapipe::OutputStreamPoller& poller = poller_status.ValueOrDie(); + mediapipe::OutputStreamPoller& poller = poller_status.value(); poller.SetMaxQueueSize(kQueueSize); MP_ASSERT_OK(graph.StartRun({})); diff --git a/mediapipe/framework/calculator_graph_side_packet_test.cc b/mediapipe/framework/calculator_graph_side_packet_test.cc index e530238e0..91cf9f3c8 100644 --- a/mediapipe/framework/calculator_graph_side_packet_test.cc +++ b/mediapipe/framework/calculator_graph_side_packet_test.cc @@ -38,16 +38,16 @@ namespace { // output side packet. class OutputSidePacketInProcessCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->OutputSidePackets().Index(0).Set( cc->Inputs().Index(0).Value().At(Timestamp::Unset())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator); @@ -56,22 +56,22 @@ REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator); // receives. Outputs the total number of packets as a side packet in Close. class CountAndOutputSummarySidePacketInCloseCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->OutputSidePackets().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { ++count_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { absl::SleepFor(absl::Milliseconds(300)); // For GetOutputSidePacket test. cc->OutputSidePackets().Index(0).Set( MakePacket(count_).At(Timestamp::Unset())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } int count_ = 0; @@ -82,15 +82,15 @@ REGISTER_CALCULATOR(CountAndOutputSummarySidePacketInCloseCalculator); // output side packet. This triggers an error in the graph. class OutputSidePacketWithTimestampCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->OutputSidePackets().Index(0).Set(cc->Inputs().Index(0).Value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(OutputSidePacketWithTimestampCalculator); @@ -98,19 +98,19 @@ REGISTER_CALCULATOR(OutputSidePacketWithTimestampCalculator); // Generates an output side packet containing the integer 1. class IntegerOutputSidePacketCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->OutputSidePackets().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->OutputSidePackets().Index(0).Set(MakePacket(1)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { LOG(FATAL) << "Not reached."; - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(IntegerOutputSidePacketCalculator); @@ -119,23 +119,23 @@ REGISTER_CALCULATOR(IntegerOutputSidePacketCalculator); // side packets. class SidePacketAdderCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Index(0).Set(); cc->InputSidePackets().Index(1).Set(); cc->OutputSidePackets().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->OutputSidePackets().Index(0).Set( MakePacket(cc->InputSidePackets().Index(1).Get() + cc->InputSidePackets().Index(0).Get())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { LOG(FATAL) << "Not reached."; - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(SidePacketAdderCalculator); @@ -144,20 +144,20 @@ REGISTER_CALCULATOR(SidePacketAdderCalculator); // input side packet. class SidePacketToStreamPacketCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).AddPacket( cc->InputSidePackets().Index(0).At(Timestamp::PostStream())); cc->Outputs().Index(0).Close(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { return mediapipe::tool::StatusStop(); } }; @@ -166,18 +166,18 @@ REGISTER_CALCULATOR(SidePacketToStreamPacketCalculator); // Packet generator for an arbitrary unit64 packet. class Uint64PacketGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { output_side_packets->Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( - const PacketGeneratorOptions& extendable_options, - const PacketSet& input_side_packets, PacketSet* output_side_packets) { + static absl::Status Generate(const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, + PacketSet* output_side_packets) { output_side_packets->Index(0) = Adopt(new uint64(15LL << 32 | 5)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(Uint64PacketGenerator); @@ -204,7 +204,7 @@ TEST(CalculatorGraph, OutputSidePacketInProcess) { MP_ASSERT_OK(graph.ObserveOutputStream( "output", [&output_packets](const Packet& packet) { output_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); // Run the graph twice. @@ -226,11 +226,11 @@ TEST(CalculatorGraph, OutputSidePacketInProcess) { // also be ignored. class PassThroughGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* inputs, PacketTypeSet* outputs) { if (!inputs->TagMap()->SameAs(*outputs->TagMap())) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Input and outputs to PassThroughGenerator must use the same tags " "and indexes."); } @@ -238,17 +238,17 @@ class PassThroughGenerator : public PacketGenerator { inputs->Get(id).SetAny(); outputs->Get(id).SetSameAs(&inputs->Get(id)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( - const PacketGeneratorOptions& extendable_options, - const PacketSet& input_side_packets, PacketSet* output_side_packets) { + static absl::Status Generate(const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, + PacketSet* output_side_packets) { for (CollectionItemId id = input_side_packets.BeginId(); id < input_side_packets.EndId(); ++id) { output_side_packets->Get(id) = input_side_packets.Get(id); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(PassThroughGenerator); @@ -402,8 +402,8 @@ TEST(CalculatorGraph, OutputSidePacketAlreadySet) { "offset", MakePacket(offset).At(Timestamp(1)))); MP_ASSERT_OK(graph.CloseInputStream("offset")); - mediapipe::Status status = graph.WaitUntilDone(); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kAlreadyExists); + absl::Status status = graph.WaitUntilDone(); + EXPECT_EQ(status.code(), absl::StatusCode::kAlreadyExists); EXPECT_THAT(status.message(), testing::HasSubstr("was already set.")); } @@ -428,8 +428,8 @@ TEST(CalculatorGraph, OutputSidePacketWithTimestamp) { MP_ASSERT_OK(graph.AddPacketToInputStream( "offset", MakePacket(offset).At(Timestamp(237)))); MP_ASSERT_OK(graph.CloseInputStream("offset")); - mediapipe::Status status = graph.WaitUntilDone(); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status status = graph.WaitUntilDone(); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::HasSubstr("has a timestamp 237.")); } @@ -460,7 +460,7 @@ TEST(CalculatorGraph, OutputSidePacketConsumedBySourceNode) { MP_ASSERT_OK(graph.ObserveOutputStream( "output", [&output_packets](const Packet& packet) { output_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); // Wait until the graph is idle so that @@ -486,19 +486,19 @@ class FirstPacketFilterCalculator : public CalculatorBase { FirstPacketFilterCalculator() {} ~FirstPacketFilterCalculator() override {} - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (!seen_first_packet_) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); cc->Outputs().Index(0).Close(); seen_first_packet_ = true; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -568,8 +568,8 @@ TEST(CalculatorGraph, SourceLayerInversion) { MP_ASSERT_OK(graph.Initialize( config, {{"max_count", MakePacket(max_count)}, {"initial_value1", MakePacket(initial_value1)}})); - mediapipe::Status status = graph.Run(); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kUnknown); + absl::Status status = graph.Run(); + EXPECT_EQ(status.code(), absl::StatusCode::kUnknown); EXPECT_THAT(status.message(), testing::HasSubstr("deadlock")); } @@ -614,7 +614,7 @@ TEST(CalculatorGraph, PacketGeneratorLikeCalculators) { MP_ASSERT_OK(graph.ObserveOutputStream( "output", [&output_packets](const Packet& packet) { output_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.Run()); ASSERT_EQ(1, output_packets.size()); @@ -643,7 +643,7 @@ TEST(CalculatorGraph, OutputSummarySidePacketInClose) { MP_ASSERT_OK(graph.ObserveOutputStream( "output", [&output_packets](const Packet& packet) { output_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); // Run the graph twice. @@ -686,15 +686,14 @@ TEST(CalculatorGraph, GetOutputSidePacket) { MP_ASSERT_OK(graph.Initialize(config)); // Check a packet generated by the PacketGenerator, which is available after // graph initialization, can be fetched before graph starts. - mediapipe::StatusOr status_or_packet = + absl::StatusOr status_or_packet = graph.GetOutputSidePacket("output_uint64"); MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.value().Timestamp()); // IntSplitterPacketGenerator is missing its input side packet and we // won't be able to get its output side packet now. status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); - EXPECT_EQ(mediapipe::StatusCode::kUnavailable, - status_or_packet.status().code()); + EXPECT_EQ(absl::StatusCode::kUnavailable, status_or_packet.status().code()); // Run the graph twice. int max_count = 100; std::map extra_side_packets; @@ -703,7 +702,7 @@ TEST(CalculatorGraph, GetOutputSidePacket) { MP_ASSERT_OK(graph.StartRun(extra_side_packets)); status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.value().Timestamp()); for (int i = 0; i < max_count; ++i) { MP_ASSERT_OK(graph.AddPacketToInputStream( "input_packets", MakePacket(i).At(Timestamp(i)))); @@ -713,34 +712,32 @@ TEST(CalculatorGraph, GetOutputSidePacket) { // Should return NOT_FOUND for invalid side packets. status_or_packet = graph.GetOutputSidePacket("unknown"); EXPECT_FALSE(status_or_packet.ok()); - EXPECT_EQ(mediapipe::StatusCode::kNotFound, - status_or_packet.status().code()); + EXPECT_EQ(absl::StatusCode::kNotFound, status_or_packet.status().code()); // Should return UNAVAILABLE before graph is done for valid non-base // packets. status_or_packet = graph.GetOutputSidePacket("num_of_packets"); EXPECT_FALSE(status_or_packet.ok()); - EXPECT_EQ(mediapipe::StatusCode::kUnavailable, - status_or_packet.status().code()); + EXPECT_EQ(absl::StatusCode::kUnavailable, status_or_packet.status().code()); // Should stil return a base even before graph is done. status_or_packet = graph.GetOutputSidePacket("output_uint64"); MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.value().Timestamp()); MP_ASSERT_OK(graph.WaitUntilDone()); // Check packets are available after graph is done. status_or_packet = graph.GetOutputSidePacket("num_of_packets"); MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(max_count, status_or_packet.ValueOrDie().Get()); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + EXPECT_EQ(max_count, status_or_packet.value().Get()); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.value().Timestamp()); // Should still return a base packet after graph is done. status_or_packet = graph.GetOutputSidePacket("output_uint64"); MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.value().Timestamp()); // Should still return a non-base packet after graph is done. status_or_packet = graph.GetOutputSidePacket("output_uint32_pair"); MP_ASSERT_OK(status_or_packet); - EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp()); + EXPECT_EQ(Timestamp::Unset(), status_or_packet.value().Timestamp()); } } @@ -749,20 +746,20 @@ typedef std::string HugeModel; // Generates an output-side-packet once for each calculator-graph. class OutputSidePacketCachedCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->OutputSidePackets().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->OutputSidePackets().Index(0).Set(MakePacket( R"(An expensive side-packet created only once per graph)")); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { LOG(FATAL) << "Not reached."; - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(OutputSidePacketCachedCalculator); @@ -791,7 +788,7 @@ TEST(CalculatorGraph, OutputSidePacketCached) { MP_ASSERT_OK(graph.ObserveOutputStream( "output", [&output_packets](const Packet& packet) { output_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); // Run the graph three times. diff --git a/mediapipe/framework/calculator_graph_stopping_test.cc b/mediapipe/framework/calculator_graph_stopping_test.cc index 75262a03f..d31a5e7e9 100644 --- a/mediapipe/framework/calculator_graph_stopping_test.cc +++ b/mediapipe/framework/calculator_graph_stopping_test.cc @@ -49,24 +49,24 @@ using mediapipe::Packet; class InfiniteSequenceCalculator : public mediapipe::CalculatorBase { public: - static mediapipe::Status GetContract(mediapipe::CalculatorContract* cc) { + static absl::Status GetContract(mediapipe::CalculatorContract* cc) { cc->Outputs().Tag("OUT").Set(); cc->Outputs().Tag("EVENT").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->Outputs().Tag("EVENT").AddPacket(MakePacket(1).At(Timestamp(1))); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Tag("OUT").AddPacket( MakePacket(count_).At(Timestamp(count_))); count_++; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { cc->Outputs().Tag("EVENT").AddPacket(MakePacket(2).At(Timestamp(2))); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -76,30 +76,30 @@ REGISTER_CALCULATOR(::testing_ns::InfiniteSequenceCalculator); class StoppingPassThroughCalculator : public mediapipe::CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) { cc->Inputs().Get("", i).SetAny(); cc->Outputs().Get("", i).SetSameAs(&cc->Inputs().Get("", i)); } cc->Outputs().Tag("EVENT").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->Outputs().Tag("EVENT").AddPacket(MakePacket(1).At(Timestamp(1))); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { for (int i = 0; i < cc->Inputs().NumEntries(""); ++i) { if (!cc->Inputs().Get("", i).IsEmpty()) { cc->Outputs().Get("", i).AddPacket(cc->Inputs().Get("", i).Value()); } } - return (++count_ <= max_count_) ? mediapipe::OkStatus() + return (++count_ <= max_count_) ? absl::OkStatus() : mediapipe::tool::StatusStop(); } - mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { cc->Outputs().Tag("EVENT").AddPacket(MakePacket(2).At(Timestamp(2))); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -124,39 +124,39 @@ class AtomicSemaphore { }; // A ProcessFunction that passes through all packets. -mediapipe::Status DoProcess(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { +absl::Status DoProcess(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { for (int i = 0; i < inputs.NumEntries(); ++i) { if (!inputs.Index(i).Value().IsEmpty()) { outputs->Index(i).AddPacket(inputs.Index(i).Value()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -typedef std::function +typedef std::function ProcessFunction; // A Calculator that delegates its Process function to a callback function. class ProcessCallbackCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(0)); } cc->InputSidePackets().Index(0).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { callback_ = *GetFromUniquePtr(cc->InputSidePackets().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { return callback_(cc->Inputs(), &(cc->Outputs())); } @@ -202,22 +202,22 @@ TEST(CalculatorGraphStoppingTest, CloseAllPacketSources) { if (out_packets.size() >= kNumPackets) { MP_EXPECT_OK(graph.CloseAllPacketSources()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.ObserveOutputStream( // "count_out", [&](const Packet& packet) { count_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.ObserveOutputStream( // "event", [&](const Packet& packet) { event_packets.push_back(packet.Get()); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.ObserveOutputStream( // "event_out", [&](const Packet& packet) { event_out_packets.push_back(packet.Get()); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); for (int i = 0; i < kNumPackets; ++i) { @@ -261,7 +261,7 @@ TEST(CalculatorGraphStoppingTest, DeadlockReporting) { MP_ASSERT_OK( graph.ObserveOutputStream("out_1", [&out_packets](const Packet& packet) { out_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); // Lambda that waits for a local semaphore. @@ -289,8 +289,8 @@ TEST(CalculatorGraphStoppingTest, DeadlockReporting) { MP_EXPECT_OK(add_packet("in_1", 2)); EXPECT_FALSE(add_packet("in_1", 3).ok()); - mediapipe::Status status = graph.WaitUntilIdle(); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kUnavailable); + absl::Status status = graph.WaitUntilIdle(); + EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable); EXPECT_THAT( status.message(), testing::HasSubstr("Detected a deadlock due to input throttling")); @@ -326,7 +326,7 @@ TEST(CalculatorGraphStoppingTest, DeadlockResolution) { MP_ASSERT_OK( graph.ObserveOutputStream("out_1", [&out_packets](const Packet& packet) { out_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); // Lambda that waits for a local semaphore. diff --git a/mediapipe/framework/calculator_graph_test.cc b/mediapipe/framework/calculator_graph_test.cc index 38b4bf8d0..9464f5c32 100644 --- a/mediapipe/framework/calculator_graph_test.cc +++ b/mediapipe/framework/calculator_graph_test.cc @@ -72,24 +72,24 @@ using testing::HasSubstr; // instead of Open(). class SetOffsetInProcessCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { // Input: arbitrary Packets. // Output: copy of the input. cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->SetOffset(TimestampDiff(0)); cc->GetCounter("PassThrough")->Increment(); cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(SetOffsetInProcessCalculator); @@ -97,21 +97,19 @@ REGISTER_CALCULATOR(SetOffsetInProcessCalculator); // A Calculator that outputs the square of its input packet (an int). class SquareIntCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int value = cc->Inputs().Index(0).Value().Get(); cc->Outputs().Index(0).Add(new int(value * value), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(SquareIntCalculator); @@ -125,7 +123,7 @@ REGISTER_CALCULATOR(SquareIntCalculator); // the unselected outputs. class DemuxTimedCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 2); cc->Inputs().Tag("SELECT").Set(); PacketType* data_input = &cc->Inputs().Tag("INPUT"); @@ -135,18 +133,18 @@ class DemuxTimedCalculator : public CalculatorBase { cc->Outputs().Get(id).SetSameAs(data_input); } cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { select_input_ = cc->Inputs().GetId("SELECT", 0); data_input_ = cc->Inputs().GetId("INPUT", 0); output_base_ = cc->Outputs().GetId("OUTPUT", 0); num_outputs_ = cc->Outputs().NumEntries("OUTPUT"); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int select = cc->Inputs().Get(select_input_).Get(); RET_CHECK(0 <= select && select < num_outputs_); const Timestamp next_timestamp_bound = @@ -162,7 +160,7 @@ class DemuxTimedCalculator : public CalculatorBase { .SetNextTimestampBound(next_timestamp_bound); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -183,7 +181,7 @@ REGISTER_CALCULATOR(DemuxTimedCalculator); // propagation on the unselected inputs. class MuxTimedCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("SELECT").Set(); CollectionItemId data_input_id = cc->Inputs().BeginId("INPUT"); PacketType* data_input0 = &cc->Inputs().Get(data_input_id); @@ -195,23 +193,23 @@ class MuxTimedCalculator : public CalculatorBase { RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); cc->Outputs().Tag("OUTPUT").SetSameAs(data_input0); cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { select_input_ = cc->Inputs().GetId("SELECT", 0); data_input_base_ = cc->Inputs().GetId("INPUT", 0); num_data_inputs_ = cc->Inputs().NumEntries("INPUT"); output_ = cc->Outputs().GetId("OUTPUT", 0); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int select = cc->Inputs().Get(select_input_).Get(); RET_CHECK(0 <= select && select < num_data_inputs_); cc->Outputs().Get(output_).AddPacket( cc->Inputs().Get(data_input_base_ + select).Value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -227,26 +225,24 @@ REGISTER_CALCULATOR(MuxTimedCalculator); // streams and outputs the sum to the output stream. class IntAdderCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).Set(); } cc->Outputs().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int sum = 0; for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { sum += cc->Inputs().Index(i).Get(); } cc->Outputs().Index(0).Add(new int(sum), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(IntAdderCalculator); @@ -255,26 +251,24 @@ REGISTER_CALCULATOR(IntAdderCalculator); // streams and outputs the sum to the output stream. class FloatAdderCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).Set(); } cc->Outputs().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { float sum = 0.0; for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { sum += cc->Inputs().Index(i).Get(); } cc->Outputs().Index(0).Add(new float(sum), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(FloatAdderCalculator); @@ -283,26 +277,24 @@ REGISTER_CALCULATOR(FloatAdderCalculator); // input streams and outputs the product to the output stream. class IntMultiplierCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).Set(); } cc->Outputs().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int product = 1; for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { product *= cc->Inputs().Index(i).Get(); } cc->Outputs().Index(0).Add(new int(product), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(IntMultiplierCalculator); @@ -312,24 +304,24 @@ REGISTER_CALCULATOR(IntMultiplierCalculator); // output stream. class FloatScalarMultiplierCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); cc->InputSidePackets().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { scalar_ = cc->InputSidePackets().Index(0).Get(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { float value = cc->Inputs().Index(0).Value().Get(); cc->Outputs().Index(0).Add(new float(scalar_ * value), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -340,22 +332,20 @@ REGISTER_CALCULATOR(FloatScalarMultiplierCalculator); // A Calculator that converts an integer to a float. class IntToFloatCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int value = cc->Inputs().Index(0).Value().Get(); cc->Outputs().Index(0).Add(new float(static_cast(value)), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(IntToFloatCalculator); @@ -363,12 +353,12 @@ REGISTER_CALCULATOR(IntToFloatCalculator); template class TypedEmptySourceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).SetAny(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Index(0).Add(new OutputType(), Timestamp::PostStream()); return tool::StatusStop(); } @@ -381,13 +371,13 @@ REGISTER_CALCULATOR(IntEmptySourceCalculator); template class TypedSinkCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } }; typedef TypedSinkCalculator StringSinkCalculator; @@ -403,29 +393,29 @@ class GlobalCountSourceCalculator : public CalculatorBase { public: static const int kNumOutputPackets; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Index(0).Set*>(); if (cc->InputSidePackets().NumEntries() >= 2) { cc->InputSidePackets().Index(1).Set(); } cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { if (cc->InputSidePackets().NumEntries() >= 2 && cc->InputSidePackets().Index(1).Get()) { OutputOne(cc); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { OutputOne(cc); if (local_count_ >= kNumOutputPackets) { return tool::StatusStop(); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -448,19 +438,19 @@ static const int kTestSequenceLength = 15; // Outputs the integers 0, 1, 2, 3, ..., 14, all with timestamp 0. class TestSequence1SourceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Index(0).Add(new int(count_), Timestamp(0)); ++count_; ++num_outputs_; if (num_outputs_ >= kTestSequenceLength) { return tool::StatusStop(); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -474,12 +464,12 @@ REGISTER_CALCULATOR(TestSequence1SourceCalculator); // 100, 99, 98, 97, .... class TestSequence2SourceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Index(0).Add(new int(count_), Timestamp(timestamp_)); ++count_; ++num_outputs_; @@ -487,7 +477,7 @@ class TestSequence2SourceCalculator : public CalculatorBase { if (num_outputs_ >= kTestSequenceLength) { return tool::StatusStop(); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -501,19 +491,19 @@ REGISTER_CALCULATOR(TestSequence2SourceCalculator); // Outputs the integers 0, 1, 2 repeatedly for a total of 15 outputs. class Modulo3SourceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Index(0).Add(new int(count_ % 3), Timestamp(count_ % 3)); ++count_; ++num_outputs_; if (num_outputs_ >= kTestSequenceLength) { return tool::StatusStop(); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -532,12 +522,12 @@ class OutputAllSourceCalculator : public CalculatorBase { public: static constexpr int kNumOutputPackets = 100; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { for (int i = 0; i < kNumOutputPackets; ++i) { cc->Outputs().Index(0).Add(new int(0), Timestamp(i)); } @@ -555,16 +545,16 @@ class OutputOneAtATimeSourceCalculator : public CalculatorBase { public: static constexpr int kNumOutputPackets = 1000; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (index_ < kNumOutputPackets) { cc->Outputs().Index(0).Add(new int(0), Timestamp(index_)); ++index_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } return tool::StatusStop(); } @@ -582,18 +572,18 @@ class DecimatorCalculator : public CalculatorBase { public: static constexpr int kDecimationRatio = 101; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { if (index_ % kDecimationRatio == 0) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); } ++index_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -605,23 +595,23 @@ REGISTER_CALCULATOR(DecimatorCalculator); // this calculator simply passes its input packets through, unchanged. class ErrorOnOpenCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->InputSidePackets().Tag("ERROR_ON_OPEN").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { if (cc->InputSidePackets().Tag("ERROR_ON_OPEN").Get()) { - return mediapipe::NotFoundError("expected error"); + return absl::NotFoundError("expected error"); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(ErrorOnOpenCalculator); @@ -631,64 +621,64 @@ REGISTER_CALCULATOR(ErrorOnOpenCalculator); // Process() method. The input stream and output stream have the integer type. class UnitDelayCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).Add(new int(0), Timestamp(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { const Packet& packet = cc->Inputs().Index(0).Value(); cc->Outputs().Index(0).AddPacket( packet.At(packet.Timestamp().NextAllowedInStream())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(UnitDelayCalculator); class UnitDelayUntimedCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).Add(new int(0), Timestamp::Min()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(UnitDelayUntimedCalculator); class FloatUnitDelayCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).Add(new float(0.0), Timestamp(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { const Packet& packet = cc->Inputs().Index(0).Value(); cc->Outputs().Index(0).AddPacket( packet.At(packet.Timestamp().NextAllowedInStream())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(FloatUnitDelayCalculator); @@ -697,20 +687,18 @@ REGISTER_CALCULATOR(FloatUnitDelayCalculator); // discards input packets in Process(). class AssertEmptyInputInOpenCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { RET_CHECK(cc->Inputs().Index(0).Value().IsEmpty()); RET_CHECK_EQ(cc->Inputs().Index(0).Value().Timestamp(), Timestamp::Unset()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); } }; REGISTER_CALCULATOR(AssertEmptyInputInOpenCalculator); @@ -718,22 +706,22 @@ REGISTER_CALCULATOR(AssertEmptyInputInOpenCalculator); // 0, 1, ..., 9. class SlowCountingSinkCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { absl::SleepFor(absl::Milliseconds(10)); int value = cc->Inputs().Index(0).Get(); CHECK_EQ(value, counter_); ++counter_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { CHECK_EQ(10, counter_); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -745,25 +733,24 @@ template class TypedStatusHandler : public StatusHandler { public: ~TypedStatusHandler() override = 0; - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const MediaPipeOptions& extendable_options, PacketTypeSet* input_side_packets) { input_side_packets->Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status HandlePreRunStatus( + static absl::Status HandlePreRunStatus( const MediaPipeOptions& extendable_options, const PacketSet& input_side_packets, // - const mediapipe::Status& pre_run_status) { - return mediapipe::OkStatus(); + const absl::Status& pre_run_status) { + return absl::OkStatus(); } - static mediapipe::Status HandleStatus( - const MediaPipeOptions& extendable_options, - const PacketSet& input_side_packets, // - const mediapipe::Status& run_status) { - return mediapipe::OkStatus(); + static absl::Status HandleStatus(const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, // + const absl::Status& run_status) { + return absl::OkStatus(); } }; typedef TypedStatusHandler StringStatusHandler; @@ -774,22 +761,22 @@ REGISTER_STATUS_HANDLER(Uint32StatusHandler); // A std::string generator that will succeed. class StaticCounterStringGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { for (int i = 0; i < input_side_packets->NumEntries(); ++i) { input_side_packets->Index(i).SetAny(); } output_side_packets->Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( - const PacketGeneratorOptions& extendable_options, - const PacketSet& input_side_packets, PacketSet* output_side_packets) { + static absl::Status Generate(const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, + PacketSet* output_side_packets) { output_side_packets->Index(0) = MakePacket("fixed_string"); ++num_packets_generated_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } static int NumPacketsGenerated() { return num_packets_generated_; } @@ -806,20 +793,20 @@ REGISTER_PACKET_GENERATOR(StaticCounterStringGenerator); // called. Both claim to output strings but instead always fail. class FailingPacketGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { for (int i = 0; i < input_side_packets->NumEntries(); ++i) { input_side_packets->Index(i).SetAny(); } output_side_packets->Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( - const PacketGeneratorOptions& extendable_options, - const PacketSet& input_side_packets, PacketSet* output_side_packets) { - return mediapipe::UnknownError("this always fails."); + static absl::Status Generate(const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, + PacketSet* output_side_packets) { + return absl::UnknownError("this always fails."); } }; REGISTER_PACKET_GENERATOR(FailingPacketGenerator); @@ -827,28 +814,28 @@ REGISTER_PACKET_GENERATOR(FailingPacketGenerator); // Passes the integer through if it is positive. class EnsurePositivePacketGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { for (int i = 0; i < input_side_packets->NumEntries(); ++i) { input_side_packets->Index(i).Set(); output_side_packets->Index(i).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( - const PacketGeneratorOptions& extendable_options, - const PacketSet& input_side_packets, PacketSet* output_side_packets) { + static absl::Status Generate(const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, + PacketSet* output_side_packets) { for (int i = 0; i < input_side_packets.NumEntries(); ++i) { if (input_side_packets.Index(i).Get() > 0) { output_side_packets->Index(i) = input_side_packets.Index(i); } else { - return mediapipe::UnknownError( + return absl::UnknownError( absl::StrCat("Integer ", i, " was not positive.")); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(EnsurePositivePacketGenerator); @@ -865,32 +852,30 @@ class FailableStatusHandler : public StatusHandler { kFailPostRun = 2, }; - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const MediaPipeOptions& extendable_options, PacketTypeSet* input_side_packets) { input_side_packets->Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status HandlePreRunStatus( + static absl::Status HandlePreRunStatus( const MediaPipeOptions& extendable_options, - const PacketSet& input_side_packets, - const mediapipe::Status& pre_run_status) { + const PacketSet& input_side_packets, const absl::Status& pre_run_status) { if (input_side_packets.Index(0).Get() == kFailPreRun) { - return mediapipe::UnknownError( + return absl::UnknownError( "FailableStatusHandler failing pre run as intended."); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } - static mediapipe::Status HandleStatus( - const MediaPipeOptions& extendable_options, - const PacketSet& input_side_packets, - const mediapipe::Status& run_status) { + static absl::Status HandleStatus(const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, + const absl::Status& run_status) { if (input_side_packets.Index(0).Get() == kFailPostRun) { - return mediapipe::UnknownError( + return absl::UnknownError( "FailableStatusHandler failing post run as intended."); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } }; @@ -898,13 +883,13 @@ REGISTER_STATUS_HANDLER(FailableStatusHandler); class FailingSourceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::UnknownError("this always fails."); + absl::Status Process(CalculatorContext* cc) override { + return absl::UnknownError("this always fails."); } }; REGISTER_CALCULATOR(FailingSourceCalculator); @@ -932,24 +917,22 @@ class SemaphoreCalculator : public CalculatorBase { public: using Semaphore = AtomicSemaphore; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->InputSidePackets().Tag("POST_SEM").Set(); cc->InputSidePackets().Tag("WAIT_SEM").Set(); cc->SetTimestampOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { cc->InputSidePackets().Tag("POST_SEM").Get()->Release(1); cc->InputSidePackets().Tag("WAIT_SEM").Get()->Acquire(1); cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(SemaphoreCalculator); @@ -958,11 +941,11 @@ REGISTER_CALCULATOR(SemaphoreCalculator); // and takes 20 milliseconds to run. class OneShot20MsCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { - return mediapipe::OkStatus(); + static absl::Status GetContract(CalculatorContract* cc) { + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { absl::SleepFor(absl::Milliseconds(20)); return tool::StatusStop(); } @@ -973,12 +956,12 @@ REGISTER_CALCULATOR(OneShot20MsCalculator); // pthread_self() (the pthread id of the current thread). class PthreadSelfSourceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Index(0).AddPacket( MakePacket(pthread_self()).At(Timestamp(0))); return tool::StatusStop(); @@ -990,38 +973,38 @@ REGISTER_CALCULATOR(PthreadSelfSourceCalculator); // It outputs five int packets with timestamps 0, 1, 2, 3, 4. class CheckInputTimestampSourceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source // and non-source nodes. - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // InputTimestamp() always returns Timestamp(0) in Process() for source // nodes. - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(0)); cc->Outputs().Index(0).Add(new int(count_), Timestamp(count_)); ++count_; if (count_ >= 5) { return tool::StatusStop(); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // InputTimestamp() returns Timestamp::Done() in Close() for both source // and non-source nodes. - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { // Must use CHECK instead of RET_CHECK in Close(), because the framework // may call the Close() method of a source node with .IgnoreError(). CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -1033,33 +1016,33 @@ REGISTER_CALCULATOR(CheckInputTimestampSourceCalculator); // It expects to consume the output of a CheckInputTimestampSourceCalculator. class CheckInputTimestampSinkCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source // and non-source nodes. - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // InputTimestamp() returns the timestamp of input packets in Process() for // non-source nodes. - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), cc->Inputs().Index(0).Value().Timestamp()); RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(count_)); ++count_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Done() in Close() for both source // and non-source nodes. - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -1072,34 +1055,34 @@ REGISTER_CALCULATOR(CheckInputTimestampSinkCalculator); // the framework. class CheckInputTimestamp2SourceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source // and non-source nodes. - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // InputTimestamp() always returns Timestamp(0) in Process() for source // nodes. - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(0)); cc->Outputs().Index(0).Add(new int(count_), Timestamp(count_)); ++count_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Done() in Close() for both source // and non-source nodes. - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { // Must use CHECK instead of RET_CHECK in Close(), because the framework // may call the Close() method of a source node with .IgnoreError(). CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -1112,21 +1095,21 @@ REGISTER_CALCULATOR(CheckInputTimestamp2SourceCalculator); // It returns tool::StatusStop() after consuming five input packets. class CheckInputTimestamp2SinkCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // InputTimestamp() returns Timestamp::Unstarted() in Open() for both source // and non-source nodes. - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Unstarted()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // InputTimestamp() returns the timestamp of input packets in Process() for // non-source nodes. - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), cc->Inputs().Index(0).Value().Timestamp()); RET_CHECK_EQ(cc->InputTimestamp(), Timestamp(count_)); @@ -1134,15 +1117,15 @@ class CheckInputTimestamp2SinkCalculator : public CalculatorBase { if (count_ >= 5) { return tool::StatusStop(); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // InputTimestamp() returns Timestamp::Done() in Close() for both source // and non-source nodes. - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { RET_CHECK_EQ(cc->InputTimestamp(), Timestamp::Done()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -1154,16 +1137,16 @@ REGISTER_CALCULATOR(CheckInputTimestamp2SinkCalculator); // output side packet. class OutputSidePacketInProcessCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->OutputSidePackets().Index(0).Set( cc->Inputs().Index(0).Value().At(Timestamp::Unset())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator); @@ -1172,21 +1155,21 @@ REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator); // sends the packet to the single output stream with the same timestamp. class SimpleMuxCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 2); cc->Inputs().Index(0).SetAny(); cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0)); RET_CHECK_EQ(cc->Outputs().NumEntries(), 1); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { data_input_base_ = cc->Inputs().BeginId(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int select_packet_index = -1; if (!cc->Inputs().Index(0).IsEmpty()) { select_packet_index = 0; @@ -1197,7 +1180,7 @@ class SimpleMuxCalculator : public CalculatorBase { cc->Outputs().Index(0).AddPacket( cc->Inputs().Get(data_input_base_ + select_packet_index).Value()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -1209,51 +1192,50 @@ REGISTER_CALCULATOR(SimpleMuxCalculator); // by modifying the int in its input side packet. class IncrementingStatusHandler : public StatusHandler { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const MediaPipeOptions& extendable_options, PacketTypeSet* input_side_packets) { input_side_packets->Tag("EXTRA").SetAny().Optional(); input_side_packets->Tag("COUNTER1").Set>(); input_side_packets->Tag("COUNTER2").Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status HandlePreRunStatus( + static absl::Status HandlePreRunStatus( const MediaPipeOptions& extendable_options, const PacketSet& input_side_packets, // - const mediapipe::Status& pre_run_status) { + const absl::Status& pre_run_status) { int* counter = GetFromUniquePtr(input_side_packets.Tag("COUNTER1")); (*counter)++; return pre_run_status_result_; } - static mediapipe::Status HandleStatus( - const MediaPipeOptions& extendable_options, - const PacketSet& input_side_packets, // - const mediapipe::Status& run_status) { + static absl::Status HandleStatus(const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, // + const absl::Status& run_status) { int* counter = GetFromUniquePtr(input_side_packets.Tag("COUNTER2")); (*counter)++; return post_run_status_result_; } - static void SetPreRunStatusResult(const mediapipe::Status& status) { + static void SetPreRunStatusResult(const absl::Status& status) { pre_run_status_result_ = status; } - static void SetPostRunStatusResult(const mediapipe::Status& status) { + static void SetPostRunStatusResult(const absl::Status& status) { post_run_status_result_ = status; } private: // Return values of HandlePreRunStatus() and HandleSTatus(), respectively. - static mediapipe::Status pre_run_status_result_; - static mediapipe::Status post_run_status_result_; + static absl::Status pre_run_status_result_; + static absl::Status post_run_status_result_; }; -mediapipe::Status IncrementingStatusHandler::pre_run_status_result_ = - mediapipe::OkStatus(); -mediapipe::Status IncrementingStatusHandler::post_run_status_result_ = - mediapipe::OkStatus(); +absl::Status IncrementingStatusHandler::pre_run_status_result_ = + absl::OkStatus(); +absl::Status IncrementingStatusHandler::post_run_status_result_ = + absl::OkStatus(); REGISTER_STATUS_HANDLER(IncrementingStatusHandler); @@ -1601,18 +1583,18 @@ TEST(CalculatorGraph, RunsCorrectlyWithMultipleExecutors) { // Packet generator for an arbitrary unit64 packet. class Uint64PacketGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { output_side_packets->Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( - const PacketGeneratorOptions& extendable_options, - const PacketSet& input_side_packets, PacketSet* output_side_packets) { + static absl::Status Generate(const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, + PacketSet* output_side_packets) { output_side_packets->Index(0) = Adopt(new uint64(15LL << 32 | 5)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(Uint64PacketGenerator); @@ -1806,7 +1788,7 @@ TEST(CalculatorGraph, StatusHandlerInputVerification) { invalid_handler->set_status_handler("Uint32StatusHandler"); invalid_handler->add_input_side_packet("created_by_factory"); graph.reset(new CalculatorGraph()); - mediapipe::Status status = graph->Initialize(config); + absl::Status status = graph->Initialize(config); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("Uint32StatusHandler"), // The problematic input side packet. @@ -2037,17 +2019,17 @@ TEST(CalculatorGraph, HandlersRun) { EXPECT_EQ(1, *GetFromUniquePtr( input_side_packets.at("unavailable_input_counter2"))); - mediapipe::Status run_status; + absl::Status run_status; // Make status handlers fail. The graph should fail. // First, when the PRE_run fails IncrementingStatusHandler::SetPreRunStatusResult( - mediapipe::InternalError("Fail at pre-run")); + absl::InternalError("Fail at pre-run")); graph.reset(new CalculatorGraph()); MP_ASSERT_OK(graph->Initialize(config)); ResetCounters(&input_side_packets); run_status = graph->Run(input_side_packets); - EXPECT_TRUE(run_status.code() == mediapipe::StatusCode::kInternal); + EXPECT_TRUE(run_status.code() == absl::StatusCode::kInternal); EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Fail at pre-run")); EXPECT_EQ(1, *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); @@ -2063,14 +2045,14 @@ TEST(CalculatorGraph, HandlersRun) { input_side_packets.at("unavailable_input_counter2"))); // Second, when the POST_run fails - IncrementingStatusHandler::SetPreRunStatusResult(mediapipe::OkStatus()); + IncrementingStatusHandler::SetPreRunStatusResult(absl::OkStatus()); IncrementingStatusHandler::SetPostRunStatusResult( - mediapipe::InternalError("Fail at post-run")); + absl::InternalError("Fail at post-run")); graph.reset(new CalculatorGraph()); MP_ASSERT_OK(graph->Initialize(config)); ResetCounters(&input_side_packets); run_status = graph->Run(input_side_packets); - EXPECT_TRUE(run_status.code() == mediapipe::StatusCode::kInternal); + EXPECT_TRUE(run_status.code() == absl::StatusCode::kInternal); EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Fail at post-run")); EXPECT_EQ(1, *GetFromUniquePtr(input_side_packets.at("no_input_counter1"))); @@ -2087,7 +2069,7 @@ TEST(CalculatorGraph, HandlersRun) { } // Test that calling SetOffset() in Calculator::Process() results in the -// mediapipe::StatusCode::kFailedPrecondition error. +// absl::StatusCode::kFailedPrecondition error. TEST(CalculatorGraph, SetOffsetInProcess) { CalculatorGraph graph; CalculatorGraphConfig config = @@ -2104,9 +2086,9 @@ TEST(CalculatorGraph, SetOffsetInProcess) { MP_EXPECT_OK(graph.StartRun({})); MP_EXPECT_OK( graph.AddPacketToInputStream("in", MakePacket(0).At(Timestamp(0)))); - mediapipe::Status status = graph.WaitUntilIdle(); + absl::Status status = graph.WaitUntilIdle(); EXPECT_FALSE(status.ok()); - EXPECT_EQ(mediapipe::StatusCode::kFailedPrecondition, status.code()); + EXPECT_EQ(absl::StatusCode::kFailedPrecondition, status.code()); } // Test that MediaPipe releases input packets when it is done with them. @@ -2262,7 +2244,7 @@ TEST(CalculatorGraph, IfThenElse) { // for the unselected outputs. class DemuxUntimedCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK_EQ(cc->Inputs().NumEntries(), 2); cc->Inputs().Tag("INPUT").SetAny(); cc->Inputs().Tag("SELECT").Set(); @@ -2270,9 +2252,9 @@ class DemuxUntimedCalculator : public CalculatorBase { id < cc->Outputs().EndId("OUTPUT"); ++id) { cc->Outputs().Get(id).SetSameAs(&cc->Inputs().Tag("INPUT")); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int index = cc->Inputs().Tag("SELECT").Get(); if (!cc->Inputs().Tag("INPUT").IsEmpty()) { cc->Outputs() @@ -2283,7 +2265,7 @@ class DemuxUntimedCalculator : public CalculatorBase { .Get("OUTPUT", index) .SetNextTimestampBound(cc->InputTimestamp() + 1); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(DemuxUntimedCalculator); @@ -3088,8 +3070,8 @@ TEST(CalculatorGraph, TerminatesOnCancelWithOpenGraphInputStreams) { graph.Cancel(); // This tests that the graph doesn't deadlock on WaitUntilDone (because // the scheduler thread is sleeping). - mediapipe::Status status = graph.WaitUntilDone(); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kCancelled); + absl::Status status = graph.WaitUntilDone(); + EXPECT_EQ(status.code(), absl::StatusCode::kCancelled); } TEST(CalculatorGraph, TerminatesOnCancelAfterPause) { @@ -3117,8 +3099,8 @@ TEST(CalculatorGraph, TerminatesOnCancelAfterPause) { graph.Pause(); // This tests that the graph doesn't deadlock on WaitUntilDone (because // the scheduler thread is sleeping). - mediapipe::Status status = graph.WaitUntilDone(); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kCancelled); + absl::Status status = graph.WaitUntilDone(); + EXPECT_EQ(status.code(), absl::StatusCode::kCancelled); } // A PacketGenerator that simply passes its input Packets through @@ -3127,11 +3109,11 @@ TEST(CalculatorGraph, TerminatesOnCancelAfterPause) { // also be ignored. class PassThroughGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* inputs, PacketTypeSet* outputs) { if (!inputs->TagMap()->SameAs(*outputs->TagMap())) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Input and outputs to PassThroughGenerator must use the same tags " "and indexes."); } @@ -3139,17 +3121,17 @@ class PassThroughGenerator : public PacketGenerator { inputs->Get(id).SetAny(); outputs->Get(id).SetSameAs(&inputs->Get(id)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( - const PacketGeneratorOptions& extendable_options, - const PacketSet& input_side_packets, PacketSet* output_side_packets) { + static absl::Status Generate(const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, + PacketSet* output_side_packets) { for (CollectionItemId id = input_side_packets.BeginId(); id < input_side_packets.EndId(); ++id) { output_side_packets->Get(id) = input_side_packets.Get(id); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(PassThroughGenerator); @@ -3183,7 +3165,7 @@ TEST(CalculatorGraph, RecoverAfterRunError) { MP_ASSERT_OK(graph.ObserveOutputStream("count1", [&packet_count](const Packet& packet) { ++packet_count; - return mediapipe::OkStatus(); + return absl::OkStatus(); })); // Set ERROR_COUNT higher than MAX_COUNT and hence the calculator will // finish successfully. @@ -3328,9 +3310,9 @@ TEST(CalculatorGraph, SetInputStreamMaxQueueSizeWorksSlowCalculator) { MP_EXPECT_OK( graph.AddPacketToInputStream("in", MakePacket(i).At(timestamp))); // We should be prevented from adding another, since the queue is now full. - mediapipe::Status status = graph.AddPacketToInputStream( + absl::Status status = graph.AddPacketToInputStream( "in", MakePacket(i).At(timestamp + 1)); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kUnavailable); + EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable); // Allow calculator to complete its Process call. calc_can_exit_process.Release(1); } @@ -3393,7 +3375,7 @@ TEST(CalculatorGraph, AddPacketNoBusyLoop) { MP_ASSERT_OK( graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { out_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); @@ -3438,29 +3420,29 @@ TEST(CalculatorGraph, AddPacketNoBusyLoop) { namespace nested_ns { -typedef std::function +typedef std::function ProcessFunction; // A Calculator that delegates its Process function to a callback function. class ProcessCallbackCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).SetAny(); cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(0)); } cc->InputSidePackets().Index(0).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { callback_ = *GetFromUniquePtr(cc->InputSidePackets().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { return callback_(cc->Inputs(), &(cc->Outputs())); } @@ -3492,14 +3474,14 @@ TEST(CalculatorGraph, CalculatorInNamepsace) { } // A ProcessFunction that passes through all packets. -mediapipe::Status DoProcess(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { +absl::Status DoProcess(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { for (int i = 0; i < inputs.NumEntries(); ++i) { if (!inputs.Index(i).Value().IsEmpty()) { outputs->Index(i).AddPacket(inputs.Index(i).Value()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } TEST(CalculatorGraph, ObserveOutputStream) { @@ -3532,12 +3514,12 @@ TEST(CalculatorGraph, ObserveOutputStream) { MP_ASSERT_OK(graph.ObserveOutputStream( "count", [&count_packets](const Packet& packet) { count_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK( graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { out_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.Run()); ASSERT_EQ(max_count, count_packets.size()); @@ -3554,7 +3536,7 @@ TEST(CalculatorGraph, ObserveOutputStream) { class PassThroughSubgraph : public Subgraph { public: - mediapipe::StatusOr GetConfig( + absl::StatusOr GetConfig( const SubgraphOptions& options) override { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"( @@ -3594,7 +3576,7 @@ TEST(CalculatorGraph, ObserveOutputStreamSubgraph) { MP_ASSERT_OK( graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { out_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.Run()); ASSERT_EQ(max_count, out_packets.size()); @@ -3636,17 +3618,17 @@ TEST(CalculatorGraph, ObserveOutputStreamError) { "count", [&count_packets](const Packet& packet) { count_packets.push_back(packet); if (count_packets.size() >= fail_count) { - return mediapipe::UnknownError("Expected. MagicString-eatnhuea"); + return absl::UnknownError("Expected. MagicString-eatnhuea"); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } })); MP_ASSERT_OK( graph.ObserveOutputStream("out", [&out_packets](const Packet& packet) { out_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); - mediapipe::Status status = graph.Run(); + absl::Status status = graph.Run(); ASSERT_THAT(status.message(), testing::HasSubstr("MagicString-eatnhuea")); ASSERT_EQ(fail_count, count_packets.size()); for (int i = 0; i < count_packets.size(); ++i) { @@ -3680,12 +3662,12 @@ TEST(CalculatorGraph, ObserveOutputStreamNonexistent) { graph.Initialize(config, {{"max_count", MakePacket(max_count)}})); // Observe the internal output stream "count". std::vector count_packets; // Packets from the output stream "count". - mediapipe::Status status = graph.ObserveOutputStream( + absl::Status status = graph.ObserveOutputStream( "not_found", [&count_packets](const Packet& packet) { count_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); }); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kNotFound); + EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); EXPECT_THAT(status.message(), testing::HasSubstr("not_found")); } @@ -3856,7 +3838,7 @@ TEST(CalculatorGraph, ReuseValidatedGraphConfig) { class TestRangeStdDevSubgraph : public Subgraph { public: - mediapipe::StatusOr GetConfig( + absl::StatusOr GetConfig( const SubgraphOptions& options) override { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"( @@ -3886,7 +3868,7 @@ REGISTER_MEDIAPIPE_GRAPH(TestRangeStdDevSubgraph); class TestMergeSaverSubgraph : public Subgraph { public: - mediapipe::StatusOr GetConfig( + absl::StatusOr GetConfig( const SubgraphOptions& options) override { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"( @@ -3987,18 +3969,18 @@ TEST(CalculatorGraph, SetExecutorTwice) { graph.SetExecutor("xyz", std::make_shared(1))); MP_EXPECT_OK( graph.SetExecutor("abc", std::make_shared(1))); - mediapipe::Status status = + absl::Status status = graph.SetExecutor("xyz", std::make_shared(1)); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kAlreadyExists); + EXPECT_EQ(status.code(), absl::StatusCode::kAlreadyExists); EXPECT_THAT(status.message(), testing::HasSubstr("xyz")); } TEST(CalculatorGraph, ReservedNameSetExecutor) { // A reserved executor name such as "__gpu" must not be used. CalculatorGraph graph; - mediapipe::Status status = + absl::Status status = graph.SetExecutor("__gpu", std::make_shared(1)); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"), testing::HasSubstr("reserved"))); } @@ -4022,8 +4004,8 @@ TEST(CalculatorGraph, ReservedNameExecutorConfig) { output_stream: 'out' } )"); - mediapipe::Status status = graph.Initialize(config); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"), testing::HasSubstr("reserved"))); } @@ -4041,8 +4023,8 @@ TEST(CalculatorGraph, ReservedNameNodeExecutor) { output_stream: 'out' } )"); - mediapipe::Status status = graph.Initialize(config); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("__gpu"), testing::HasSubstr("reserved"))); } @@ -4062,8 +4044,8 @@ TEST(CalculatorGraph, NonExistentExecutor) { output_stream: 'out' } )"); - mediapipe::Status status = graph.Initialize(config); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("xyz"), testing::HasSubstr("not declared"))); @@ -4086,8 +4068,8 @@ TEST(CalculatorGraph, UndeclaredExecutor) { output_stream: 'out' } )"); - mediapipe::Status status = graph.Initialize(config); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("xyz"), testing::HasSubstr("not declared"))); @@ -4108,8 +4090,8 @@ TEST(CalculatorGraph, UntypedExecutorDeclaredButNotSet) { output_stream: 'out' } )"); - mediapipe::Status status = graph.Initialize(config); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("xyz"), testing::HasSubstr("SetExecutor"))); @@ -4132,8 +4114,8 @@ TEST(CalculatorGraph, DuplicateExecutorConfig) { output_stream: 'out' } )"); - mediapipe::Status status = graph.Initialize(config); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("xyz"), testing::HasSubstr("duplicate"))); @@ -4162,8 +4144,8 @@ TEST(CalculatorGraph, TypedExecutorDeclaredAndSet) { output_stream: 'out' } )"); - mediapipe::Status status = graph.Initialize(config); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("xyz"), testing::HasSubstr("SetExecutor"))); @@ -4194,8 +4176,8 @@ TEST(CalculatorGraph, NumThreadsAndDefaultExecutorConfig) { output_stream: 'out' } )"); - mediapipe::Status status = graph.Initialize(config); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status status = graph.Initialize(config); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(status.message(), testing::AllOf(testing::HasSubstr("num_threads"), testing::HasSubstr("default executor"))); @@ -4268,7 +4250,7 @@ TEST(CalculatorGraph, RunWithNumThreadsInExecutorConfig) { MP_ASSERT_OK( graph.ObserveOutputStream("out", [&out_packet](const Packet& packet) { out_packet = packet; - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.Run()); EXPECT_EQ(cases[i].use_app_thread_is_expected, @@ -4380,19 +4362,19 @@ class FirstPacketFilterCalculator : public CalculatorBase { FirstPacketFilterCalculator() {} ~FirstPacketFilterCalculator() override {} - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (!seen_first_packet_) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); cc->Outputs().Index(0).Close(); seen_first_packet_ = true; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -4412,7 +4394,7 @@ TEST(CalculatorGraph, TestPollPacket) { MP_ASSERT_OK(graph.Initialize(config)); auto status_or_poller = graph.AddOutputStreamPoller("output"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); MP_ASSERT_OK( graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); Packet packet; @@ -4439,7 +4421,7 @@ TEST(CalculatorGraph, TestOutputStreamPollerDesiredQueueSize) { MP_ASSERT_OK(graph.Initialize(config)); auto status_or_poller = graph.AddOutputStreamPoller("output"); ASSERT_TRUE(status_or_poller.ok()); - OutputStreamPoller poller = std::move(status_or_poller.ValueOrDie()); + OutputStreamPoller poller = std::move(status_or_poller.value()); poller.SetMaxQueueSize(queue_size); MP_ASSERT_OK( graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); @@ -4471,10 +4453,10 @@ TEST(CalculatorGraph, TestPollPacketsFromMultipleStreams) { MP_ASSERT_OK(graph.Initialize(config)); auto status_or_poller1 = graph.AddOutputStreamPoller("stream1"); ASSERT_TRUE(status_or_poller1.ok()); - OutputStreamPoller poller1 = std::move(status_or_poller1.ValueOrDie()); + OutputStreamPoller poller1 = std::move(status_or_poller1.value()); auto status_or_poller2 = graph.AddOutputStreamPoller("stream2"); ASSERT_TRUE(status_or_poller2.ok()); - OutputStreamPoller poller2 = std::move(status_or_poller2.ValueOrDie()); + OutputStreamPoller poller2 = std::move(status_or_poller2.value()); MP_ASSERT_OK( graph.StartRun({{"max_count", MakePacket(kDefaultMaxCount)}})); Packet packet1; @@ -4550,7 +4532,7 @@ TEST(CalculatorGraph, SimpleMuxCalculatorWithCustomInputStreamHandler) { Timestamp input1_timestamp = Timestamp(0); MP_EXPECT_OK(graph.AddPacketToInputStream( "input1", MakePacket(2).At(input1_timestamp))); - mediapipe::Status run_status = graph.WaitUntilIdle(); + absl::Status run_status = graph.WaitUntilIdle(); EXPECT_THAT( run_status.ToString(), testing::AllOf( diff --git a/mediapipe/framework/calculator_node.cc b/mediapipe/framework/calculator_node.cc index 8d9419127..8d538486b 100644 --- a/mediapipe/framework/calculator_node.cc +++ b/mediapipe/framework/calculator_node.cc @@ -63,6 +63,52 @@ const PacketType* GetPacketType(const PacketTypeSet& packet_type_set, return &packet_type_set.Get(id); } +// Copies a TagMap omitting entries with certain names. +std::shared_ptr RemoveNames(const tool::TagMap& tag_map, + std::set names) { + auto tag_index_names = tag_map.CanonicalEntries(); + for (auto id = tag_map.EndId() - 1; id >= tag_map.BeginId(); --id) { + std::string name = tag_map.Names()[id.value()]; + if (names.count(name) > 0) { + tag_index_names.erase(tag_index_names.begin() + id.value()); + } + } + return tool::TagMap::Create(tag_index_names).value(); +} + +// Copies matching entries from another Collection. +template +void CopyCollection(const CollectionType& other, CollectionType* result) { + auto tag_map = result->TagMap(); + for (auto id = tag_map->BeginId(); id != tag_map->EndId(); ++id) { + auto tag_index = tag_map->TagAndIndexFromId(id); + auto other_id = other.GetId(tag_index.first, tag_index.second); + if (other_id.IsValid()) { + result->Get(id) = other.Get(other_id); + } + } +} + +// Copies packet types omitting entries that are optional and not provided. +std::unique_ptr RemoveOmittedPacketTypes( + const PacketTypeSet& packet_types, + const std::map& all_side_packets, + const ValidatedGraphConfig* validated_graph) { + std::set omitted_names; + for (auto id = packet_types.BeginId(); id != packet_types.EndId(); ++id) { + std::string name = packet_types.TagMap()->Names()[id.value()]; + if (packet_types.Get(id).IsOptional() && + validated_graph->IsExternalSidePacket(name) && + all_side_packets.count(name) == 0) { + omitted_names.insert(name); + } + } + auto tag_map = RemoveNames(*packet_types.TagMap(), omitted_names); + auto result = std::make_unique(tag_map); + CopyCollection(packet_types, result.get()); + return result; +} + } // namespace CalculatorNode::CalculatorNode() {} @@ -72,7 +118,7 @@ Timestamp CalculatorNode::SourceProcessOrder( return calculator_->SourceProcessOrder(cc); } -mediapipe::Status CalculatorNode::Initialize( +absl::Status CalculatorNode::Initialize( const ValidatedGraphConfig* validated_graph, int node_id, InputStreamManager* input_stream_managers, OutputStreamManager* output_stream_managers, @@ -158,7 +204,7 @@ mediapipe::Status CalculatorNode::Initialize( return InitializeInputStreams(input_stream_managers, output_stream_managers); } -mediapipe::Status CalculatorNode::InitializeOutputSidePackets( +absl::Status CalculatorNode::InitializeOutputSidePackets( const PacketTypeSet& output_side_packet_types, OutputSidePacketImpl* output_side_packets) { output_side_packets_ = @@ -172,10 +218,10 @@ mediapipe::Status CalculatorNode::InitializeOutputSidePackets( output_side_packets_->GetPtr(id) = &output_side_packets[base_index + id.value()]; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorNode::InitializeInputSidePackets( +absl::Status CalculatorNode::InitializeInputSidePackets( OutputSidePacketImpl* output_side_packets) { const NodeTypeInfo& node_type_info = validated_graph_->CalculatorInfos()[node_id_]; @@ -200,10 +246,10 @@ mediapipe::Status CalculatorNode::InitializeInputSidePackets( << output_side_packet_index; origin_output_side_packet->AddMirror(&input_side_packet_handler_, id); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorNode::InitializeOutputStreams( +absl::Status CalculatorNode::InitializeOutputStreams( OutputStreamManager* output_stream_managers) { RET_CHECK(output_stream_managers) << "output_stream_managers is NULL"; const NodeTypeInfo& node_type_info = @@ -215,7 +261,7 @@ mediapipe::Status CalculatorNode::InitializeOutputStreams( current_output_stream_managers); } -mediapipe::Status CalculatorNode::InitializeInputStreams( +absl::Status CalculatorNode::InitializeInputStreams( InputStreamManager* input_stream_managers, OutputStreamManager* output_stream_managers) { RET_CHECK(input_stream_managers) << "input_stream_managers is NULL"; @@ -246,10 +292,10 @@ mediapipe::Status CalculatorNode::InitializeInputStreams( << output_stream_index; origin_output_stream_manager->AddMirror(input_stream_handler_.get(), id); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorNode::InitializeInputStreamHandler( +absl::Status CalculatorNode::InitializeInputStreamHandler( const InputStreamHandlerConfig& handler_config, const PacketTypeSet& input_stream_types) { const ProtoString& input_stream_handler_name = @@ -264,10 +310,10 @@ mediapipe::Status CalculatorNode::InitializeInputStreamHandler( _ << "\"" << input_stream_handler_name << "\" is not a registered input stream handler."); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorNode::InitializeOutputStreamHandler( +absl::Status CalculatorNode::InitializeOutputStreamHandler( const OutputStreamHandlerConfig& handler_config, const PacketTypeSet& output_stream_types) { const ProtoString& output_stream_handler_name = @@ -281,10 +327,10 @@ mediapipe::Status CalculatorNode::InitializeOutputStreamHandler( /*calculator_run_in_parallel=*/max_in_flight_ > 1), _ << "\"" << output_stream_handler_name << "\" is not a registered output stream handler."); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorNode::ConnectShardsToStreams( +absl::Status CalculatorNode::ConnectShardsToStreams( CalculatorContext* calculator_context) { RET_CHECK(calculator_context); MP_RETURN_IF_ERROR( @@ -324,13 +370,13 @@ void CalculatorNode::SetMaxInputStreamQueueSize(int max_queue_size) { input_stream_handler_->SetMaxQueueSize(max_queue_size); } -mediapipe::Status CalculatorNode::PrepareForRun( +absl::Status CalculatorNode::PrepareForRun( const std::map& all_side_packets, const std::map& service_packets, std::function ready_for_open_callback, std::function source_node_opened_callback, std::function schedule_callback, - std::function error_callback, + std::function error_callback, CounterFactory* counter_factory) { RET_CHECK(ready_for_open_callback) << "ready_for_open_callback is NULL"; RET_CHECK(schedule_callback) << "schedule_callback is NULL"; @@ -345,10 +391,12 @@ mediapipe::Status CalculatorNode::PrepareForRun( std::move(schedule_callback), error_callback); output_stream_handler_->PrepareForRun(error_callback); - const PacketTypeSet* input_side_packet_types = + const PacketTypeSet* packet_types = &validated_graph_->CalculatorInfos()[node_id_].InputSidePacketTypes(); + input_side_packet_types_ = RemoveOmittedPacketTypes( + *packet_types, all_side_packets, validated_graph_); MP_RETURN_IF_ERROR(input_side_packet_handler_.PrepareForRun( - input_side_packet_types, all_side_packets, + input_side_packet_types_.get(), all_side_packets, [this]() { CalculatorNode::InputSidePacketsReady(); }, std::move(error_callback))); calculator_state_->SetInputSidePackets( @@ -394,7 +442,7 @@ mediapipe::Status CalculatorNode::PrepareForRun( input_side_packets_ready_ = (input_side_packet_handler_.MissingInputSidePacketCount() == 0); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } namespace { @@ -406,7 +454,7 @@ const Packet GetPacket(const OutputSidePacket& out) { } // Resends the output-side-packets from the previous graph run. -mediapipe::Status ResendSidePackets(CalculatorContext* cc) { +absl::Status ResendSidePackets(CalculatorContext* cc) { auto& outs = cc->OutputSidePackets(); for (CollectionItemId id = outs.BeginId(); id < outs.EndId(); ++id) { Packet packet = GetPacket(outs.Get(id)); @@ -415,7 +463,7 @@ mediapipe::Status ResendSidePackets(CalculatorContext* cc) { outs.Get(id).Set(packet); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace @@ -429,7 +477,7 @@ bool CalculatorNode::OutputsAreConstant(CalculatorContext* cc) { return true; } -mediapipe::Status CalculatorNode::OpenNode() { +absl::Status CalculatorNode::OpenNode() { VLOG(2) << "CalculatorNode::OpenNode() for " << DebugName(); CalculatorContext* default_context = @@ -444,7 +492,7 @@ mediapipe::Status CalculatorNode::OpenNode() { calculator_context_manager_.PushInputTimestampToContext( default_context, Timestamp::Unstarted()); - mediapipe::Status result; + absl::Status result; if (OutputsAreConstant(default_context)) { result = ResendSidePackets(default_context); } else { @@ -489,7 +537,7 @@ mediapipe::Status CalculatorNode::OpenNode() { status_ = kStateOpened; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void CalculatorNode::ActivateNode() { @@ -523,8 +571,8 @@ void CalculatorNode::CloseOutputStreams(OutputStreamShardSet* outputs) { output_stream_handler_->Close(outputs); } -mediapipe::Status CalculatorNode::CloseNode( - const mediapipe::Status& graph_status, bool graph_run_ended) { +absl::Status CalculatorNode::CloseNode(const absl::Status& graph_status, + bool graph_run_ended) { { absl::MutexLock status_lock(&status_mutex_); RET_CHECK_NE(status_, kStateClosed) @@ -544,11 +592,11 @@ mediapipe::Status CalculatorNode::CloseNode( calculator_context_manager_.SetGraphStatusInContext(default_context, graph_status); - mediapipe::Status result; + absl::Status result; if (OutputsAreConstant(default_context)) { // Do nothing. - result = mediapipe::OkStatus(); + result = absl::OkStatus(); } else { MEDIAPIPE_PROFILING(CLOSE, default_context); LegacyCalculatorSupport::Scoped s(default_context); @@ -578,10 +626,10 @@ mediapipe::Status CalculatorNode::CloseNode( "Calculator::Close() for node \"$0\" failed: ", DebugName()); VLOG(2) << "Closed node " << DebugName(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -void CalculatorNode::CleanupAfterRun(const mediapipe::Status& graph_status) { +void CalculatorNode::CleanupAfterRun(const absl::Status& graph_status) { if (needs_to_close_) { calculator_context_manager_.PushInputTimestampToContext( calculator_context_manager_.GetDefaultCalculatorContext(), @@ -750,12 +798,12 @@ std::string CalculatorNode::DebugName() const { } // TODO: Split this function. -mediapipe::Status CalculatorNode::ProcessNode( +absl::Status CalculatorNode::ProcessNode( CalculatorContext* calculator_context) { if (IsSource()) { // This is a source Calculator. if (Closed()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const Timestamp input_timestamp = calculator_context->InputTimestamp(); @@ -764,7 +812,7 @@ mediapipe::Status CalculatorNode::ProcessNode( output_stream_handler_->PrepareOutputs(input_timestamp, outputs); VLOG(2) << "Calling Calculator::Process() for node: " << DebugName(); - mediapipe::Status result; + absl::Status result; { MEDIAPIPE_PROFILING(PROCESS, calculator_context); @@ -787,15 +835,15 @@ mediapipe::Status CalculatorNode::ProcessNode( output_stream_handler_->PostProcess(input_timestamp); if (node_stopped) { MP_RETURN_IF_ERROR( - CloseNode(mediapipe::OkStatus(), /*graph_run_ended=*/false)); + CloseNode(absl::OkStatus(), /*graph_run_ended=*/false)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { // This is not a source Calculator. InputStreamShardSet* const inputs = &calculator_context->Inputs(); OutputStreamShardSet* const outputs = &calculator_context->Outputs(); - mediapipe::Status result = - mediapipe::InternalError("Calculator context has no input packets."); + absl::Status result = + absl::InternalError("Calculator context has no input packets."); int num_invocations = calculator_context_manager_.NumberOfContextTimestamps( *calculator_context); @@ -814,7 +862,7 @@ mediapipe::Status CalculatorNode::ProcessNode( if (OutputsAreConstant(calculator_context)) { // Do nothing. - result = mediapipe::OkStatus(); + result = absl::OkStatus(); } else { MEDIAPIPE_PROFILING(PROCESS, calculator_context); LegacyCalculatorSupport::Scoped s( @@ -851,7 +899,7 @@ mediapipe::Status CalculatorNode::ProcessNode( CHECK_EQ(calculator_context_manager_.NumberOfContextTimestamps( *calculator_context), 1); - return CloseNode(mediapipe::OkStatus(), /*graph_run_ended=*/false); + return CloseNode(absl::OkStatus(), /*graph_run_ended=*/false); } else { RET_CHECK_FAIL() << "Invalid input timestamp in ProcessNode(). timestamp: " diff --git a/mediapipe/framework/calculator_node.h b/mediapipe/framework/calculator_node.h index a9e9fd2ff..98e11a399 100644 --- a/mediapipe/framework/calculator_node.h +++ b/mediapipe/framework/calculator_node.h @@ -95,7 +95,7 @@ class CalculatorNode { void SetExecutor(const std::string& executor); // Calls Process() on the Calculator corresponding to this node. - mediapipe::Status ProcessNode(CalculatorContext* calculator_context); + absl::Status ProcessNode(CalculatorContext* calculator_context); // Initializes the node. The buffer_size_hint argument is // set to the value specified in the graph proto for this field. @@ -105,12 +105,13 @@ class CalculatorNode { // output_side_packets is expected to point to a contiguous flat array with // OutputSidePacketImpls corresponding to the output side packet indexes in // validated_graph. - mediapipe::Status Initialize( - const ValidatedGraphConfig* validated_graph, int node_id, - InputStreamManager* input_stream_managers, - OutputStreamManager* output_stream_managers, - OutputSidePacketImpl* output_side_packets, int* buffer_size_hint, - std::shared_ptr profiling_context); + absl::Status Initialize(const ValidatedGraphConfig* validated_graph, + int node_id, + InputStreamManager* input_stream_managers, + OutputStreamManager* output_stream_managers, + OutputSidePacketImpl* output_side_packets, + int* buffer_size_hint, + std::shared_ptr profiling_context); // Sets up the node at the beginning of CalculatorGraph::Run(). This // method is executed before any OpenNode() calls to the nodes @@ -121,22 +122,22 @@ class CalculatorNode { // can be scheduled. source_node_opened_callback is called when a source // node is opened. schedule_callback is passed to the InputStreamHandler // and is called each time a new invocation can be scheduled. - mediapipe::Status PrepareForRun( + absl::Status PrepareForRun( const std::map& all_side_packets, const std::map& service_packets, std::function ready_for_open_callback, std::function source_node_opened_callback, std::function schedule_callback, - std::function error_callback, + std::function error_callback, CounterFactory* counter_factory) ABSL_LOCKS_EXCLUDED(status_mutex_); // Opens the node. - mediapipe::Status OpenNode() ABSL_LOCKS_EXCLUDED(status_mutex_); + absl::Status OpenNode() ABSL_LOCKS_EXCLUDED(status_mutex_); // Called when a source node's layer becomes active. void ActivateNode() ABSL_LOCKS_EXCLUDED(status_mutex_); // Cleans up the node after the CalculatorGraph has been run. Deletes // the Calculator managed by this node. graph_status is the status of // the graph run. - void CleanupAfterRun(const mediapipe::Status& graph_status) + void CleanupAfterRun(const absl::Status& graph_status) ABSL_LOCKS_EXCLUDED(status_mutex_); // Returns true iff PrepareForRun() has been called (and types verified). @@ -218,8 +219,7 @@ class CalculatorNode { // Closes the node's calculator and input and output streams. // graph_status is the current status of the graph run. graph_run_ended // indicates whether the graph run has ended. - mediapipe::Status CloseNode(const mediapipe::Status& graph_status, - bool graph_run_ended) + absl::Status CloseNode(const absl::Status& graph_status, bool graph_run_ended) ABSL_LOCKS_EXCLUDED(status_mutex_); // Returns a pointer to the default calculator context that is used for @@ -235,35 +235,34 @@ class CalculatorNode { private: // Sets up the output side packets from the master flat array. - mediapipe::Status InitializeOutputSidePackets( + absl::Status InitializeOutputSidePackets( const PacketTypeSet& output_side_packet_types, OutputSidePacketImpl* output_side_packets); // Connects the input side packets as mirrors on the output side packets. // Output side packets are looked up in the master flat array which is // provided. - mediapipe::Status InitializeInputSidePackets( + absl::Status InitializeInputSidePackets( OutputSidePacketImpl* output_side_packets); // Sets up the output streams from the master flat array. - mediapipe::Status InitializeOutputStreams( + absl::Status InitializeOutputStreams( OutputStreamManager* output_stream_managers); // Sets up the input streams and connects them as mirrors on the // output streams. Both input streams and output streams are looked // up in the master flat arrays which are provided. - mediapipe::Status InitializeInputStreams( + absl::Status InitializeInputStreams( InputStreamManager* input_stream_managers, OutputStreamManager* output_stream_managers); - mediapipe::Status InitializeInputStreamHandler( + absl::Status InitializeInputStreamHandler( const InputStreamHandlerConfig& handler_config, const PacketTypeSet& input_stream_types); - mediapipe::Status InitializeOutputStreamHandler( + absl::Status InitializeOutputStreamHandler( const OutputStreamHandlerConfig& handler_config, const PacketTypeSet& output_stream_types); // Connects the input/output stream shards in the given calculator context to // the input/output streams of the node. - mediapipe::Status ConnectShardsToStreams( - CalculatorContext* calculator_context); + absl::Status ConnectShardsToStreams(CalculatorContext* calculator_context); // The general scheduling logic shared by EndScheduling() and // CheckIfBecameReady(). @@ -351,6 +350,9 @@ class CalculatorNode { // Mutex for node status. mutable absl::Mutex status_mutex_; + // Describes the input side packets required to run this node. + std::unique_ptr input_side_packet_types_; + // Manages the set of input side packets. InputSidePacketHandler input_side_packet_handler_; diff --git a/mediapipe/framework/calculator_node_test.cc b/mediapipe/framework/calculator_node_test.cc index e72a178ca..325dd8ff7 100644 --- a/mediapipe/framework/calculator_node_test.cc +++ b/mediapipe/framework/calculator_node_test.cc @@ -37,23 +37,23 @@ class CountCalculator : public CalculatorBase { CountCalculator() { ++num_constructed_; } ~CountCalculator() override { ++num_destroyed_; } - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { ++num_fill_expectations_; cc->Inputs().Get(cc->Inputs().BeginId()).Set(); cc->Outputs().Get(cc->Outputs().BeginId()).Set(); cc->InputSidePackets().Get(cc->InputSidePackets().BeginId()).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { ++num_open_; // Simulate doing nontrivial work to ensure that the time spent in the // method will register on streamz each time it is called. usleep(100); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { ++num_process_; int input_stream_int = cc->Inputs().Get(cc->Inputs().BeginId()).Get(); int side_packet_int = @@ -65,15 +65,15 @@ class CountCalculator : public CalculatorBase { // Simulate doing nontrivial work to ensure that the time spent in the // method will register on streamz each time it is called. usleep(100); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { ++num_close_; // Simulate doing nontrivial work to ensure that the time spent in the // method will register on streamz each time it is called. usleep(100); - return mediapipe::OkStatus(); + return absl::OkStatus(); } static int num_constructed_; @@ -94,7 +94,7 @@ int CountCalculator::num_destroyed_ = 0; void SourceNodeOpenedNoOp() {} -void CheckFail(const mediapipe::Status& status) { +void CheckFail(const absl::Status& status) { LOG(FATAL) << "The test triggered the error callback with status: " << status; } @@ -165,7 +165,7 @@ class CalculatorNodeTest : public ::testing::Test { &buffer_size_hint_, graph_profiler_)); } - mediapipe::Status PrepareNodeForRun() { + absl::Status PrepareNodeForRun() { return node_->PrepareForRun( // input_side_packets_, // service_packets_, // @@ -180,7 +180,7 @@ class CalculatorNodeTest : public ::testing::Test { nullptr); } - mediapipe::Status InitializeStreams() { + absl::Status InitializeStreams() { // START OF: code is copied from // CalculatorGraph::InitializePacketGeneratorGraph. // Create and initialize the output side packets. @@ -220,7 +220,7 @@ class CalculatorNodeTest : public ::testing::Test { stream_a_manager_ = &output_stream_managers_[1]; stream_b_manager_ = &output_stream_managers_[2]; - return mediapipe::OkStatus(); + return absl::OkStatus(); } virtual void SimulateParentOpenNode() { stream_a_manager_->LockIntroData(); } @@ -482,7 +482,7 @@ TEST_F(CalculatorNodeTest, CleanupAfterRun) { node_->EndScheduling(); // The max parallelism is already reached. EXPECT_FALSE(node_->TryToBeginScheduling()); - node_->CleanupAfterRun(mediapipe::OkStatus()); + node_->CleanupAfterRun(absl::OkStatus()); EXPECT_FALSE(node_->Prepared()); EXPECT_FALSE(node_->Opened()); @@ -517,7 +517,7 @@ void CalculatorNodeTest::TestCleanupAfterRunTwice() { EXPECT_TRUE(node_->TryToBeginScheduling()); MP_EXPECT_OK(node_->ProcessNode(cc_)); node_->EndScheduling(); - node_->CleanupAfterRun(mediapipe::OkStatus()); + node_->CleanupAfterRun(absl::OkStatus()); stream_a_manager_->PrepareForRun(nullptr); @@ -543,7 +543,7 @@ void CalculatorNodeTest::TestCleanupAfterRunTwice() { node_->EndScheduling(); // The max parallelism is already reached. EXPECT_FALSE(node_->TryToBeginScheduling()); - node_->CleanupAfterRun(mediapipe::OkStatus()); + node_->CleanupAfterRun(absl::OkStatus()); EXPECT_FALSE(node_->Prepared()); EXPECT_FALSE(node_->Opened()); diff --git a/mediapipe/framework/calculator_options.proto b/mediapipe/framework/calculator_options.proto index 5680dd9ed..747e9c4af 100644 --- a/mediapipe/framework/calculator_options.proto +++ b/mediapipe/framework/calculator_options.proto @@ -37,7 +37,6 @@ option java_outer_classname = "CalculatorOptionsProto"; message CalculatorOptions { // If true, this proto specifies a subset of field values, // which should override corresponding field values. - // Deprecated in cl/228195782. optional bool merge_fields = 1 [deprecated = true]; extensions 20000 to max; diff --git a/mediapipe/framework/calculator_parallel_execution_test.cc b/mediapipe/framework/calculator_parallel_execution_test.cc index 887fd4352..2ed385b09 100644 --- a/mediapipe/framework/calculator_parallel_execution_test.cc +++ b/mediapipe/framework/calculator_parallel_execution_test.cc @@ -50,20 +50,20 @@ inline void BusySleep(absl::Duration duration) { class SlowPlusOneCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(mediapipe::TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->InputTimestamp().Value() % 4 == 0) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } RandomEngine random(testing::UnitTest::GetInstance()->random_seed()); @@ -71,7 +71,7 @@ class SlowPlusOneCalculator : public CalculatorBase { BusySleep(absl::Milliseconds(90 + uniform_dist(random))); cc->Outputs().Index(0).Add(new int(cc->Inputs().Index(0).Get() + 1), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; @@ -124,7 +124,7 @@ TEST_F(ParallelExecutionTest, SlowPlusOneCalculatorsTest) { const int kTotalNums = 100; int fail_count = 0; for (int i = 0; i < kTotalNums; ++i) { - mediapipe::Status status = graph.AddPacketToInputStream( + absl::Status status = graph.AddPacketToInputStream( "input", Adopt(new int(i)).At(Timestamp(i))); if (!status.ok()) { ++fail_count; diff --git a/mediapipe/framework/calculator_runner.cc b/mediapipe/framework/calculator_runner.cc index 827ac1c0a..833797483 100644 --- a/mediapipe/framework/calculator_runner.cc +++ b/mediapipe/framework/calculator_runner.cc @@ -36,15 +36,15 @@ namespace { // Input side packets: 1, pointing to CalculatorRunner::StreamContents. class CalculatorRunnerSourceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets() .Index(0) .Set(); cc->Outputs().Index(0).SetAny(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { const auto* contents = cc->InputSidePackets() .Index(0) .Get(); @@ -53,9 +53,9 @@ class CalculatorRunnerSourceCalculator : public CalculatorBase { for (const Packet& packet : contents->packets) { cc->Outputs().Index(0).AddPacket(packet); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { return tool::StatusStop(); } }; @@ -67,23 +67,23 @@ REGISTER_CALCULATOR(CalculatorRunnerSourceCalculator); // Input side packets: 1, pointing to CalculatorRunner::StreamContents. class CalculatorRunnerSinkCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->InputSidePackets().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { contents_ = cc->InputSidePackets() .Index(0) .Get(); contents_->header = cc->Inputs().Index(0).Header(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { contents_->packets.push_back(cc->Inputs().Index(0).Value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -98,7 +98,7 @@ CalculatorRunner::CalculatorRunner( MEDIAPIPE_CHECK_OK(InitializeFromNodeConfig(node_config)); } -mediapipe::Status CalculatorRunner::InitializeFromNodeConfig( +absl::Status CalculatorRunner::InitializeFromNodeConfig( const CalculatorGraphConfig::Node& node_config) { node_config_ = node_config; @@ -126,7 +126,7 @@ mediapipe::Status CalculatorRunner::InitializeFromNodeConfig( tool::TagMap::Create(node_config_.output_side_packet())); output_side_packets_ = absl::make_unique(output_side_map); - return mediapipe::OkStatus(); + return absl::OkStatus(); } CalculatorRunner::CalculatorRunner(const std::string& calculator_type, @@ -220,10 +220,10 @@ std::map CalculatorRunner::GetCountersValues() { return graph_->GetCounterFactory()->GetCounterSet()->GetCountersValues(); } -mediapipe::Status CalculatorRunner::BuildGraph() { +absl::Status CalculatorRunner::BuildGraph() { if (graph_ != nullptr) { // The graph was already built. - return mediapipe::OkStatus(); + return absl::OkStatus(); } RET_CHECK(inputs_) << "The inputs were not initialized."; RET_CHECK(outputs_) << "The outputs were not initialized."; @@ -277,10 +277,10 @@ mediapipe::Status CalculatorRunner::BuildGraph() { graph_ = absl::make_unique(); MP_RETURN_IF_ERROR(graph_->Initialize(config)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CalculatorRunner::Run() { +absl::Status CalculatorRunner::Run() { MP_RETURN_IF_ERROR(BuildGraph()); // Set the input side packets for the sources. std::map input_side_packets; @@ -352,7 +352,7 @@ mediapipe::Status CalculatorRunner::Run() { tag, (index == -1) ? ++positional_index : index); ASSIGN_OR_RETURN(contents, graph_->GetOutputSidePacket(name)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/framework/calculator_runner.h b/mediapipe/framework/calculator_runner.h index 8a3e8ba83..0254d03a0 100644 --- a/mediapipe/framework/calculator_runner.h +++ b/mediapipe/framework/calculator_runner.h @@ -109,11 +109,11 @@ class CalculatorRunner { // Runs the calculator, by calling Open(), Process() with the // inputs provided via mutable_inputs(), and Close(). Returns the - // mediapipe::Status from CalculatorGraph::Run(). Internally, Run() + // absl::Status from CalculatorGraph::Run(). Internally, Run() // constructs a CalculatorGraph in the first call, and calls // CalculatorGraph::Run(). A single instance of CalculatorRunner // uses the same instance of CalculatorGraph for all runs. - mediapipe::Status Run(); + absl::Status Run(); // Returns the vector of contents of the output streams. The .header // field contains the stream header and the .packets field contains @@ -135,11 +135,11 @@ class CalculatorRunner { static const char kSinkPrefix[]; // Initialize using a node config (does the constructor's work). - mediapipe::Status InitializeFromNodeConfig( + absl::Status InitializeFromNodeConfig( const CalculatorGraphConfig::Node& node_config); // Builds the graph if one does not already exist. - mediapipe::Status BuildGraph(); + absl::Status BuildGraph(); CalculatorGraphConfig::Node node_config_; diff --git a/mediapipe/framework/calculator_runner_test.cc b/mediapipe/framework/calculator_runner_test.cc index 51484a4e3..259a7c31f 100644 --- a/mediapipe/framework/calculator_runner_test.cc +++ b/mediapipe/framework/calculator_runner_test.cc @@ -40,7 +40,7 @@ namespace { // at InputTimestamp. The headers are strings. class CalculatorRunnerTestCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Inputs().Index(1).Set(); cc->Outputs().Index(0).Set(); @@ -50,10 +50,10 @@ class CalculatorRunnerTestCalculator : public CalculatorBase { cc->OutputSidePackets() .Tag("SIDE_OUTPUT") .SetSameAs(&cc->InputSidePackets().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { std::string input_header_string = absl::StrCat(cc->Inputs().Index(0).Header().Get(), cc->Inputs().Index(1).Header().Get()); @@ -66,17 +66,17 @@ class CalculatorRunnerTestCalculator : public CalculatorBase { cc->OutputSidePackets() .Tag("SIDE_OUTPUT") .Set(cc->InputSidePackets().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { for (int index = 0; index < 2; ++index) { cc->Outputs().Index(index).Add( new int(-cc->Inputs().Index(index).Get()), cc->InputTimestamp()); } cc->Outputs().Index(2).AddPacket( cc->InputSidePackets().Index(0).At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(CalculatorRunnerTestCalculator); @@ -87,7 +87,7 @@ REGISTER_CALCULATOR(CalculatorRunnerTestCalculator); // with the same tag name (and any index). class CalculatorRunnerMultiTagTestCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (const std::string& tag : cc->Inputs().GetTags()) { for (CollectionItemId item_id = cc->Inputs().BeginId(tag); item_id < cc->Inputs().EndId(tag); ++item_id) { @@ -95,10 +95,10 @@ class CalculatorRunnerMultiTagTestCalculator : public CalculatorBase { } cc->Outputs().Get(tag, 0).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { for (const std::string& tag : cc->Inputs().GetTags()) { auto sum = absl::make_unique(0); for (CollectionItemId item_id = cc->Inputs().BeginId(tag); @@ -109,7 +109,7 @@ class CalculatorRunnerMultiTagTestCalculator : public CalculatorBase { } cc->Outputs().Get(tag, 0).Add(sum.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(CalculatorRunnerMultiTagTestCalculator); diff --git a/mediapipe/framework/calculator_state.cc b/mediapipe/framework/calculator_state.cc index 12e1a6829..fcd20c2a1 100644 --- a/mediapipe/framework/calculator_state.cc +++ b/mediapipe/framework/calculator_state.cc @@ -61,6 +61,11 @@ Counter* CalculatorState::GetCounter(const std::string& name) { return counter_factory_->GetCounter(absl::StrCat(NodeName(), "-", name)); } +CounterSet* CalculatorState::GetCounterSet() { + CHECK(counter_factory_); + return counter_factory_->GetCounterSet(); +} + void CalculatorState::SetServicePacket(const std::string& key, Packet packet) { service_packets_[key] = std::move(packet); } diff --git a/mediapipe/framework/calculator_state.h b/mediapipe/framework/calculator_state.h index 919775c1e..42d1f1d4a 100644 --- a/mediapipe/framework/calculator_state.h +++ b/mediapipe/framework/calculator_state.h @@ -78,6 +78,11 @@ class CalculatorState { // name is the passed-in name, prefixed by the calculator NodeName. Counter* GetCounter(const std::string& name); + // Returns a counter set, which can be passed to other classes, to generate + // counters. NOTE: This differs from GetCounter, in that the counters + // created by this counter set do not have the NodeName prefix. + CounterSet* GetCounterSet(); + std::shared_ptr GetSharedProfilingContext() const { return profiling_context_; } diff --git a/mediapipe/framework/collection.h b/mediapipe/framework/collection.h index 5c1ac199c..a8b63435a 100644 --- a/mediapipe/framework/collection.h +++ b/mediapipe/framework/collection.h @@ -29,6 +29,7 @@ #include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/tool/tag_map.h" +#include "mediapipe/framework/tool/tag_map_helper.h" #include "mediapipe/framework/tool/validate_name.h" #include "mediapipe/framework/type_map.h" @@ -379,39 +380,17 @@ Collection::Collection( template Collection::Collection( - const tool::TagAndNameInfo& info) { - tag_map_ = std::move(tool::TagMap::Create(info).ValueOrDie()); - if (tag_map_->NumEntries() != 0) { - data_ = absl::make_unique(tag_map_->NumEntries()); - } -} + const tool::TagAndNameInfo& info) + : Collection(tool::TagMap::Create(info).value()) {} template -Collection::Collection(const int num_entries) { - proto_ns::RepeatedPtrField fields; - for (int i = 0; i < num_entries; ++i) { - *fields.Add() = absl::StrCat("name", i); - } - tag_map_ = std::move(tool::TagMap::Create(fields).ValueOrDie()); - if (tag_map_->NumEntries() != 0) { - data_ = absl::make_unique(tag_map_->NumEntries()); - } -} +Collection::Collection(const int num_entries) + : Collection(tool::CreateTagMap(num_entries).value()) {} template Collection::Collection( - const std::initializer_list& tag_names) { - proto_ns::RepeatedPtrField fields; - int i = 0; - for (const std::string& name : tag_names) { - *fields.Add() = absl::StrCat(name, ":name", i); - ++i; - } - tag_map_ = std::move(tool::TagMap::Create(fields).ValueOrDie()); - if (tag_map_->NumEntries() != 0) { - data_ = absl::make_unique(tag_map_->NumEntries()); - } -} + const std::initializer_list& tag_names) + : Collection(tool::CreateTagMapFromTags(tag_names).value()) {} template bool Collection::UsesTags() const { diff --git a/mediapipe/framework/collection_test.cc b/mediapipe/framework/collection_test.cc index 704a3f08e..45a819e49 100644 --- a/mediapipe/framework/collection_test.cc +++ b/mediapipe/framework/collection_test.cc @@ -78,7 +78,7 @@ TEST(CollectionTest, MixedTagAndIndexUsage) { "TAG_C:0:e", "TAG_A:1:f"}); MP_ASSERT_OK(tags_statusor); - internal::Collection collection1(std::move(tags_statusor.ValueOrDie())); + internal::Collection collection1(std::move(tags_statusor.value())); collection1.Get("TAG_A", 0) = 100; collection1.Get("TAG_A", 1) = 101; collection1.Get("TAG_A", 2) = 102; @@ -165,16 +165,16 @@ TEST(CollectionTest, StaticEmptyCollectionHeapCheck) { // "new T[0]" returns a non-null pointer which the heap checker has // issues in tracking. Additionally, allocating of empty arrays is // also inefficient as it invokes heap management routines. - static auto* collection1 = new PacketSet(tool::CreateTagMap({}).ValueOrDie()); + static auto* collection1 = new PacketSet(tool::CreateTagMap({}).value()); // Heap check issues are most triggered when zero length and non-zero // length allocations are interleaved. Additionally, this heap check // wasn't triggered by "char", so a more complex type (Packet) is used. static auto* collection2 = - new PacketSet(tool::CreateTagMap({"TAG:name"}).ValueOrDie()); - static auto* collection3 = new PacketSet(tool::CreateTagMap({}).ValueOrDie()); + new PacketSet(tool::CreateTagMap({"TAG:name"}).value()); + static auto* collection3 = new PacketSet(tool::CreateTagMap({}).value()); static auto* collection4 = - new PacketSet(tool::CreateTagMap({"TAG:name"}).ValueOrDie()); - static auto* collection5 = new PacketSet(tool::CreateTagMap({}).ValueOrDie()); + new PacketSet(tool::CreateTagMap({"TAG:name"}).value()); + static auto* collection5 = new PacketSet(tool::CreateTagMap({}).value()); EXPECT_EQ(0, collection1->NumEntries()); EXPECT_EQ(1, collection2->NumEntries()); EXPECT_EQ(0, collection3->NumEntries()); @@ -183,12 +183,12 @@ TEST(CollectionTest, StaticEmptyCollectionHeapCheck) { } template -mediapipe::Status TestCollectionWithPointers( - const std::vector& original_values, const T& inject1, const T& inject2) { +absl::Status TestCollectionWithPointers(const std::vector& original_values, + const T& inject1, const T& inject2) { std::shared_ptr tag_map = tool::CreateTagMap({"TAG_A:a", "TAG_B:1:b", "TAG_A:2:c", "TAG_B:d", "TAG_C:0:e", "TAG_A:1:f"}) - .ValueOrDie(); + .value(); { // Test a regular collection. @@ -451,7 +451,7 @@ mediapipe::Status TestCollectionWithPointers( ++i; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } TEST(CollectionTest, TestCollectionWithPointersIntAndString) { @@ -464,7 +464,7 @@ TEST(CollectionTest, TestIteratorFunctions) { std::shared_ptr tag_map = tool::CreateTagMap({"TAG_A:a", "TAG_B:1:b", "TAG_A:2:c", "TAG_B:d", "TAG_C:0:e", "TAG_A:1:f"}) - .ValueOrDie(); + .value(); std::vector values = {"a0", "a1", "a2", "b0", "b1", "c0"}; internal::Collection diff --git a/mediapipe/framework/deps/BUILD b/mediapipe/framework/deps/BUILD index cf2bf46bd..062551342 100644 --- a/mediapipe/framework/deps/BUILD +++ b/mediapipe/framework/deps/BUILD @@ -94,7 +94,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":file_path", - ":status", + "//mediapipe/framework/port:status", "@com_google_absl//absl/strings", ], ) @@ -257,22 +257,6 @@ cc_library( ], ) -cc_library( - name = "statusor", - srcs = ["statusor.cc"], - hdrs = [ - "statusor.h", - "statusor_internals.h", - ], - # Use this library through "mediapipe/framework/port:statusor". - visibility = ["//mediapipe/framework/port:__pkg__"], - deps = [ - ":status", - "//mediapipe/framework/port:logging", - "@com_google_absl//absl/base:core_headers", - ], -) - cc_library( name = "re2", hdrs = [ @@ -429,18 +413,6 @@ cc_test( ], ) -cc_test( - name = "statusor_test", - size = "small", - srcs = ["statusor_test.cc"], - linkstatic = 1, - deps = [ - ":status", - ":statusor", - "//mediapipe/framework/port:gtest_main", - ], -) - cc_test( name = "topologicalsorter_test", srcs = ["topologicalsorter_test.cc"], diff --git a/mediapipe/framework/deps/canonical_errors.h b/mediapipe/framework/deps/canonical_errors.h index 9f82b2094..e804d9c66 100644 --- a/mediapipe/framework/deps/canonical_errors.h +++ b/mediapipe/framework/deps/canonical_errors.h @@ -22,60 +22,60 @@ namespace mediapipe { // Each of the functions below creates a canonical error with the given // message. The error code of the returned status object matches the name of // the function. -inline mediapipe::Status AlreadyExistsError(absl::string_view message) { - return mediapipe::Status(mediapipe::StatusCode::kAlreadyExists, message); +inline absl::Status AlreadyExistsError(absl::string_view message) { + return absl::Status(absl::StatusCode::kAlreadyExists, message); } -inline mediapipe::Status CancelledError() { - return mediapipe::Status(mediapipe::StatusCode::kCancelled, ""); +inline absl::Status CancelledError() { + return absl::Status(absl::StatusCode::kCancelled, ""); } -inline mediapipe::Status CancelledError(absl::string_view message) { - return mediapipe::Status(mediapipe::StatusCode::kCancelled, message); +inline absl::Status CancelledError(absl::string_view message) { + return absl::Status(absl::StatusCode::kCancelled, message); } -inline mediapipe::Status InternalError(absl::string_view message) { - return mediapipe::Status(mediapipe::StatusCode::kInternal, message); +inline absl::Status InternalError(absl::string_view message) { + return absl::Status(absl::StatusCode::kInternal, message); } -inline mediapipe::Status InvalidArgumentError(absl::string_view message) { - return mediapipe::Status(mediapipe::StatusCode::kInvalidArgument, message); +inline absl::Status InvalidArgumentError(absl::string_view message) { + return absl::Status(absl::StatusCode::kInvalidArgument, message); } -inline mediapipe::Status FailedPreconditionError(absl::string_view message) { - return mediapipe::Status(mediapipe::StatusCode::kFailedPrecondition, message); +inline absl::Status FailedPreconditionError(absl::string_view message) { + return absl::Status(absl::StatusCode::kFailedPrecondition, message); } -inline mediapipe::Status NotFoundError(absl::string_view message) { - return mediapipe::Status(mediapipe::StatusCode::kNotFound, message); +inline absl::Status NotFoundError(absl::string_view message) { + return absl::Status(absl::StatusCode::kNotFound, message); } -inline mediapipe::Status OutOfRangeError(absl::string_view message) { - return mediapipe::Status(mediapipe::StatusCode::kOutOfRange, message); +inline absl::Status OutOfRangeError(absl::string_view message) { + return absl::Status(absl::StatusCode::kOutOfRange, message); } -inline mediapipe::Status PermissionDeniedError(absl::string_view message) { - return mediapipe::Status(mediapipe::StatusCode::kPermissionDenied, message); +inline absl::Status PermissionDeniedError(absl::string_view message) { + return absl::Status(absl::StatusCode::kPermissionDenied, message); } -inline mediapipe::Status UnimplementedError(absl::string_view message) { - return mediapipe::Status(mediapipe::StatusCode::kUnimplemented, message); +inline absl::Status UnimplementedError(absl::string_view message) { + return absl::Status(absl::StatusCode::kUnimplemented, message); } -inline mediapipe::Status UnknownError(absl::string_view message) { - return mediapipe::Status(mediapipe::StatusCode::kUnknown, message); +inline absl::Status UnknownError(absl::string_view message) { + return absl::Status(absl::StatusCode::kUnknown, message); } -inline mediapipe::Status UnavailableError(absl::string_view message) { - return mediapipe::Status(mediapipe::StatusCode::kUnavailable, message); +inline absl::Status UnavailableError(absl::string_view message) { + return absl::Status(absl::StatusCode::kUnavailable, message); } -inline bool IsCancelled(const mediapipe::Status& status) { - return status.code() == mediapipe::StatusCode::kCancelled; +inline bool IsCancelled(const absl::Status& status) { + return status.code() == absl::StatusCode::kCancelled; } -inline bool IsNotFound(const mediapipe::Status& status) { - return status.code() == mediapipe::StatusCode::kNotFound; +inline bool IsNotFound(const absl::Status& status) { + return status.code() == absl::StatusCode::kNotFound; } } // namespace mediapipe diff --git a/mediapipe/framework/deps/file_helpers.cc b/mediapipe/framework/deps/file_helpers.cc index a3807b1d0..23b50310f 100644 --- a/mediapipe/framework/deps/file_helpers.cc +++ b/mediapipe/framework/deps/file_helpers.cc @@ -26,11 +26,11 @@ #include -#include "mediapipe/framework/deps/canonical_errors.h" #include "mediapipe/framework/deps/file_path.h" -#include "mediapipe/framework/deps/status.h" -#include "mediapipe/framework/deps/status_builder.h" -#include "mediapipe/framework/deps/status_macros.h" +#include "mediapipe/framework/port/canonical_errors.h" +#include "mediapipe/framework/port/status.h" +#include "mediapipe/framework/port/status_builder.h" +#include "mediapipe/framework/port/status_macros.h" namespace mediapipe { namespace file { @@ -138,8 +138,8 @@ class DirectoryListing { } // namespace -mediapipe::Status GetContents(absl::string_view file_name, std::string* output, - bool read_as_binary) { +absl::Status GetContents(absl::string_view file_name, std::string* output, + bool read_as_binary) { FILE* fp = fopen(file_name.data(), read_as_binary ? "rb" : "r"); if (fp == NULL) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) @@ -157,11 +157,11 @@ mediapipe::Status GetContents(absl::string_view file_name, std::string* output, output->append(std::string(buf, ret)); } fclose(fp); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SetContents(absl::string_view file_name, - absl::string_view content) { +absl::Status SetContents(absl::string_view file_name, + absl::string_view content) { FILE* fp = fopen(file_name.data(), "w"); if (fp == NULL) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) @@ -175,12 +175,12 @@ mediapipe::Status SetContents(absl::string_view file_name, << "Error while writing file: " << file_name << ". Error message: " << strerror(write_error); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MatchInTopSubdirectories(const std::string& parent_directory, - const std::string& file_name, - std::vector* results) { +absl::Status MatchInTopSubdirectories(const std::string& parent_directory, + const std::string& file_name, + std::vector* results) { DirectoryListing parent_listing(parent_directory); while (parent_listing.HasNextEntry()) { @@ -194,12 +194,12 @@ mediapipe::Status MatchInTopSubdirectories(const std::string& parent_directory, } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MatchFileTypeInDirectory(const std::string& directory, - const std::string& file_suffix, - std::vector* results) { +absl::Status MatchFileTypeInDirectory(const std::string& directory, + const std::string& file_suffix, + std::vector* results) { DirectoryListing directory_listing(directory); while (directory_listing.HasNextEntry()) { @@ -209,21 +209,21 @@ mediapipe::Status MatchFileTypeInDirectory(const std::string& directory, } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Exists(absl::string_view file_name) { +absl::Status Exists(absl::string_view file_name) { struct stat buffer; int status; status = stat(std::string(file_name).c_str(), &buffer); if (status == 0) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } switch (errno) { case EACCES: return mediapipe::PermissionDeniedError("Insufficient permissions."); default: - return mediapipe::NotFoundError("The path does not exist."); + return absl::NotFoundError("The path does not exist."); } } @@ -235,9 +235,9 @@ int mkdir(std::string path) { int mkdir(std::string path) { return _mkdir(path.c_str()); } #endif -mediapipe::Status RecursivelyCreateDir(absl::string_view path) { +absl::Status RecursivelyCreateDir(absl::string_view path) { if (path.empty() || Exists(path).ok()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } auto split_path = file::SplitPath(path); MP_RETURN_IF_ERROR(RecursivelyCreateDir(split_path.first)); @@ -246,10 +246,10 @@ mediapipe::Status RecursivelyCreateDir(absl::string_view path) { case EACCES: return mediapipe::PermissionDeniedError("Insufficient permissions."); default: - return mediapipe::UnavailableError("Failed to create directory."); + return absl::UnavailableError("Failed to create directory."); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace file diff --git a/mediapipe/framework/deps/file_helpers.h b/mediapipe/framework/deps/file_helpers.h index 079fd24a6..278a6bb40 100644 --- a/mediapipe/framework/deps/file_helpers.h +++ b/mediapipe/framework/deps/file_helpers.h @@ -16,27 +16,27 @@ #define MEDIAPIPE_DEPS_FILE_HELPERS_H_ #include "absl/strings/match.h" -#include "mediapipe/framework/deps/status.h" +#include "mediapipe/framework/port/status.h" namespace mediapipe { namespace file { -mediapipe::Status GetContents(absl::string_view file_name, std::string* output, - bool read_as_binary = true); +absl::Status GetContents(absl::string_view file_name, std::string* output, + bool read_as_binary = true); -mediapipe::Status SetContents(absl::string_view file_name, - absl::string_view content); +absl::Status SetContents(absl::string_view file_name, + absl::string_view content); -mediapipe::Status MatchInTopSubdirectories(const std::string& parent_directory, - const std::string& file_name, - std::vector* results); +absl::Status MatchInTopSubdirectories(const std::string& parent_directory, + const std::string& file_name, + std::vector* results); -mediapipe::Status MatchFileTypeInDirectory(const std::string& directory, - const std::string& file_suffix, - std::vector* results); +absl::Status MatchFileTypeInDirectory(const std::string& directory, + const std::string& file_suffix, + std::vector* results); -mediapipe::Status Exists(absl::string_view file_name); +absl::Status Exists(absl::string_view file_name); -mediapipe::Status RecursivelyCreateDir(absl::string_view path); +absl::Status RecursivelyCreateDir(absl::string_view path); } // namespace file } // namespace mediapipe diff --git a/mediapipe/framework/deps/registration.h b/mediapipe/framework/deps/registration.h index 006d9a8a3..9eb0c8ddf 100644 --- a/mediapipe/framework/deps/registration.h +++ b/mediapipe/framework/deps/registration.h @@ -67,7 +67,7 @@ namespace mediapipe { // class Client {}; // // using ClientRegistry = -// GlobalFactoryRegistry>; +// GlobalFactoryRegistry>; // // class MyClient : public Client { // public: @@ -84,7 +84,7 @@ namespace mediapipe { // ::my_ns::MyClient, // []() { // auto backend = absl::make_unique("/path/to/backend"); -// const mediapipe::Status status = backend->Init(); +// const absl::Status status = backend->Init(); // if (!status.ok()) { // return status; // } @@ -95,13 +95,13 @@ namespace mediapipe { // // === Using the registry to create instances ============================== // -// // Registry will return mediapipe::StatusOr -// mediapipe::StatusOr> s_or_widget = +// // Registry will return absl::StatusOr +// absl::StatusOr> s_or_widget = // WidgetRegistry::CreateByName( // "my_ns.MyWidget", std::move(gadget), thing); // // Registry will return NOT_FOUND if the name is unknown. // if (!s_or_widget.ok()) ... // handle error -// DoStuffWithWidget(std::move(s_or_widget).ValueOrDie()); +// DoStuffWithWidget(std::move(s_or_widget).value()); // // // It's also possible to find an instance by name within a source namespace. // auto s_or_widget = WidgetRegistry::CreateByNameInNamespace( @@ -115,7 +115,7 @@ namespace mediapipe { // // This might be useful if clients outside of your codebase are registering // // plugins. // for (const auto& name : WidgetRegistry::GetRegisteredNames()) { -// mediapipe::StatusOr> s_or_widget = +// absl::StatusOr> s_or_widget = // WidgetRegistry::CreateByName(name, std::move(gadget), thing); // ... // } @@ -134,13 +134,13 @@ constexpr char kNameSep[] = "."; template struct WrapStatusOr { - using type = mediapipe::StatusOr; + using type = absl::StatusOr; }; // Specialization to avoid double-wrapping types that are already StatusOrs. template -struct WrapStatusOr> { - using type = mediapipe::StatusOr; +struct WrapStatusOr> { + using type = absl::StatusOr; }; } // namespace registration_internal @@ -196,8 +196,7 @@ class FunctionRegistry { absl::ReaderMutexLock lock(&lock_); auto it = functions_.find(name); if (it == functions_.end()) { - return mediapipe::NotFoundError("No registered object with name: " + - name); + return absl::NotFoundError("No registered object with name: " + name); } function = it->second; } diff --git a/mediapipe/framework/deps/ret_check.cc b/mediapipe/framework/deps/ret_check.cc index 1e4b49e26..8d758a0f9 100644 --- a/mediapipe/framework/deps/ret_check.cc +++ b/mediapipe/framework/deps/ret_check.cc @@ -31,7 +31,7 @@ mediapipe::StatusBuilder RetCheckFailSlowPath( mediapipe::StatusBuilder RetCheckFailSlowPath( mediapipe::source_location location, const char* condition, - const mediapipe::Status& status) { + const absl::Status& status) { return mediapipe::RetCheckFailSlowPath(location) << condition << " returned " << status << " "; } diff --git a/mediapipe/framework/deps/ret_check.h b/mediapipe/framework/deps/ret_check.h index c81baa245..fec7a0318 100644 --- a/mediapipe/framework/deps/ret_check.h +++ b/mediapipe/framework/deps/ret_check.h @@ -31,9 +31,9 @@ mediapipe::StatusBuilder RetCheckFailSlowPath( // Returns a StatusBuilder that corresponds to a `RET_CHECK` failure. mediapipe::StatusBuilder RetCheckFailSlowPath( mediapipe::source_location location, const char* condition, - const mediapipe::Status& status); + const absl::Status& status); -inline StatusBuilder RetCheckImpl(const mediapipe::Status& status, +inline StatusBuilder RetCheckImpl(const absl::Status& status, const char* condition, mediapipe::source_location location) { if (ABSL_PREDICT_TRUE(status.ok())) diff --git a/mediapipe/framework/deps/status.cc b/mediapipe/framework/deps/status.cc index da7f7718e..b51c9f9db 100644 --- a/mediapipe/framework/deps/status.cc +++ b/mediapipe/framework/deps/status.cc @@ -23,7 +23,7 @@ std::ostream& operator<<(std::ostream& os, const Status& x) { return os; } -std::string* MediaPipeCheckOpHelperOutOfLine(const mediapipe::Status& v, +std::string* MediaPipeCheckOpHelperOutOfLine(const absl::Status& v, const char* msg) { std::string r("Non-OK-status: "); r += msg; diff --git a/mediapipe/framework/deps/status.h b/mediapipe/framework/deps/status.h index c1d21ebc9..492e4d434 100644 --- a/mediapipe/framework/deps/status.h +++ b/mediapipe/framework/deps/status.h @@ -20,22 +20,24 @@ #include #include +#include "absl/base/attributes.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "mediapipe/framework/port/logging.h" namespace mediapipe { -using Status = absl::Status; -using StatusCode = absl::StatusCode; +using Status ABSL_DEPRECATED("Use absl::Status directly") = absl::Status; +using StatusCode ABSL_DEPRECATED("Use absl::StatusCode directly") = + absl::StatusCode; -inline mediapipe::Status OkStatus() { return absl::OkStatus(); } +ABSL_DEPRECATED("Use absl::OkStatus directly") +inline absl::Status OkStatus() { return absl::OkStatus(); } -extern std::string* MediaPipeCheckOpHelperOutOfLine(const mediapipe::Status& v, +extern std::string* MediaPipeCheckOpHelperOutOfLine(const absl::Status& v, const char* msg); -inline std::string* MediaPipeCheckOpHelper(mediapipe::Status v, - const char* msg) { +inline std::string* MediaPipeCheckOpHelper(absl::Status v, const char* msg) { if (v.ok()) return nullptr; return MediaPipeCheckOpHelperOutOfLine(v, msg); } @@ -51,7 +53,7 @@ inline std::string* MediaPipeCheckOpHelper(mediapipe::Status v, #define MEDIAPIPE_DCHECK_OK(val) MEDIAPIPE_CHECK_OK(val) #else #define MEDIAPIPE_DCHECK_OK(val) \ - while (false && (mediapipe::OkStatus() == (val))) LOG(FATAL) + while (false && (absl::OkStatus() == (val))) LOG(FATAL) #endif #define CHECK_OK MEDIAPIPE_CHECK_OK diff --git a/mediapipe/framework/deps/status_builder.cc b/mediapipe/framework/deps/status_builder.cc index fef0e4b2e..de46a5d8e 100644 --- a/mediapipe/framework/deps/status_builder.cc +++ b/mediapipe/framework/deps/status_builder.cc @@ -68,7 +68,7 @@ StatusBuilder::operator Status() && { return JoinMessageToStatus(); } -mediapipe::Status StatusBuilder::JoinMessageToStatus() { +absl::Status StatusBuilder::JoinMessageToStatus() { std::string message; if (join_style_ == MessageJoinStyle::kAnnotate) { if (!status_.ok()) { diff --git a/mediapipe/framework/deps/status_builder.h b/mediapipe/framework/deps/status_builder.h index a42bfa939..dad49f11e 100644 --- a/mediapipe/framework/deps/status_builder.h +++ b/mediapipe/framework/deps/status_builder.h @@ -30,14 +30,14 @@ class ABSL_MUST_USE_RESULT StatusBuilder { // Creates a `StatusBuilder` based on an original status. If logging is // enabled, it will use `location` as the location from which the log message // occurs. A typical user will call this with `MEDIAPIPE_LOC`. - StatusBuilder(const mediapipe::Status& original_status, + StatusBuilder(const absl::Status& original_status, mediapipe::source_location location) : status_(original_status), line_(location.line()), file_(location.file_name()), stream_(new std::ostringstream) {} - StatusBuilder(mediapipe::Status&& original_status, + StatusBuilder(absl::Status&& original_status, mediapipe::source_location location) : status_(std::move(original_status)), line_(location.line()), @@ -47,14 +47,13 @@ class ABSL_MUST_USE_RESULT StatusBuilder { // Creates a `StatusBuilder` from a mediapipe status code. If logging is // enabled, it will use `location` as the location from which the log message // occurs. A typical user will call this with `MEDIAPIPE_LOC`. - StatusBuilder(mediapipe::StatusCode code, mediapipe::source_location location) + StatusBuilder(absl::StatusCode code, mediapipe::source_location location) : status_(code, ""), line_(location.line()), file_(location.file_name()), stream_(new std::ostringstream) {} - StatusBuilder(const mediapipe::Status& original_status, const char* file, - int line) + StatusBuilder(const absl::Status& original_status, const char* file, int line) : status_(original_status), line_(line), file_(file), @@ -78,7 +77,7 @@ class ABSL_MUST_USE_RESULT StatusBuilder { operator Status() const&; operator Status() &&; - mediapipe::Status JoinMessageToStatus(); + absl::Status JoinMessageToStatus(); private: // Specifies how to join the error message in the original status and any @@ -90,7 +89,7 @@ class ABSL_MUST_USE_RESULT StatusBuilder { }; // The status that the result will be based on. - mediapipe::Status status_; + absl::Status status_; // The line to record if this file is logged. int line_; // Not-owned: The file to record if this status is logged. @@ -104,39 +103,39 @@ class ABSL_MUST_USE_RESULT StatusBuilder { inline StatusBuilder AlreadyExistsErrorBuilder( mediapipe::source_location location) { - return StatusBuilder(mediapipe::StatusCode::kAlreadyExists, location); + return StatusBuilder(absl::StatusCode::kAlreadyExists, location); } inline StatusBuilder FailedPreconditionErrorBuilder( mediapipe::source_location location) { - return StatusBuilder(mediapipe::StatusCode::kFailedPrecondition, location); + return StatusBuilder(absl::StatusCode::kFailedPrecondition, location); } inline StatusBuilder InternalErrorBuilder(mediapipe::source_location location) { - return StatusBuilder(mediapipe::StatusCode::kInternal, location); + return StatusBuilder(absl::StatusCode::kInternal, location); } inline StatusBuilder InvalidArgumentErrorBuilder( mediapipe::source_location location) { - return StatusBuilder(mediapipe::StatusCode::kInvalidArgument, location); + return StatusBuilder(absl::StatusCode::kInvalidArgument, location); } inline StatusBuilder NotFoundErrorBuilder(mediapipe::source_location location) { - return StatusBuilder(mediapipe::StatusCode::kNotFound, location); + return StatusBuilder(absl::StatusCode::kNotFound, location); } inline StatusBuilder UnavailableErrorBuilder( mediapipe::source_location location) { - return StatusBuilder(mediapipe::StatusCode::kUnavailable, location); + return StatusBuilder(absl::StatusCode::kUnavailable, location); } inline StatusBuilder UnimplementedErrorBuilder( mediapipe::source_location location) { - return StatusBuilder(mediapipe::StatusCode::kUnimplemented, location); + return StatusBuilder(absl::StatusCode::kUnimplemented, location); } inline StatusBuilder UnknownErrorBuilder(mediapipe::source_location location) { - return StatusBuilder(mediapipe::StatusCode::kUnknown, location); + return StatusBuilder(absl::StatusCode::kUnknown, location); } } // namespace mediapipe diff --git a/mediapipe/framework/deps/status_builder_test.cc b/mediapipe/framework/deps/status_builder_test.cc index b7e6a978d..63166a106 100644 --- a/mediapipe/framework/deps/status_builder_test.cc +++ b/mediapipe/framework/deps/status_builder_test.cc @@ -19,54 +19,52 @@ namespace mediapipe { TEST(StatusBuilder, AnnotateMode) { - mediapipe::Status status = - StatusBuilder(mediapipe::Status(mediapipe::StatusCode::kNotFound, - "original message"), - MEDIAPIPE_LOC) - << "annotated message1 " - << "annotated message2"; + absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kNotFound, + "original message"), + MEDIAPIPE_LOC) + << "annotated message1 " + << "annotated message2"; ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kNotFound); + EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); EXPECT_EQ(status.message(), "original message; annotated message1 annotated message2"); } TEST(StatusBuilder, PrependMode) { - mediapipe::Status status = - StatusBuilder(mediapipe::Status(mediapipe::StatusCode::kInvalidArgument, - "original message"), - MEDIAPIPE_LOC) + absl::Status status = + StatusBuilder( + absl::Status(absl::StatusCode::kInvalidArgument, "original message"), + MEDIAPIPE_LOC) .SetPrepend() << "prepended message1 " << "prepended message2 "; ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); EXPECT_EQ(status.message(), "prepended message1 prepended message2 original message"); } TEST(StatusBuilder, AppendMode) { - mediapipe::Status status = - StatusBuilder(mediapipe::Status(mediapipe::StatusCode::kInternal, - "original message"), - MEDIAPIPE_LOC) - .SetAppend() - << " extra message1" - << " extra message2"; + absl::Status status = StatusBuilder(absl::Status(absl::StatusCode::kInternal, + "original message"), + MEDIAPIPE_LOC) + .SetAppend() + << " extra message1" + << " extra message2"; ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInternal); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); EXPECT_EQ(status.message(), "original message extra message1 extra message2"); } TEST(StatusBuilder, NoLoggingMode) { - mediapipe::Status status = - StatusBuilder(mediapipe::Status(mediapipe::StatusCode::kUnavailable, - "original message"), - MEDIAPIPE_LOC) + absl::Status status = + StatusBuilder( + absl::Status(absl::StatusCode::kUnavailable, "original message"), + MEDIAPIPE_LOC) .SetNoLogging() << " extra message"; ASSERT_FALSE(status.ok()); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kUnavailable); + EXPECT_EQ(status.code(), absl::StatusCode::kUnavailable); EXPECT_EQ(status.message(), "original message"); } diff --git a/mediapipe/framework/deps/status_macros.h b/mediapipe/framework/deps/status_macros.h index 229f61b48..c9df245f2 100644 --- a/mediapipe/framework/deps/status_macros.h +++ b/mediapipe/framework/deps/status_macros.h @@ -13,7 +13,7 @@ // limitations under the License. // // Helper macros and methods to return and propagate errors with -// `mediapipe::Status`. +// `absl::Status`. // // The owners of mediapipe do not endorse use of these macros as a good // programming practice, and would prefer that you write the equivalent C++ @@ -26,14 +26,14 @@ #include "mediapipe/framework/deps/status.h" #include "mediapipe/framework/deps/status_builder.h" -// Evaluates an expression that produces a `mediapipe::Status`. If the status +// Evaluates an expression that produces a `absl::Status`. If the status // is not ok, returns it from the current function. // // For example: -// mediapipe::Status MultiStepFunction() { +// absl::Status MultiStepFunction() { // MP_RETURN_IF_ERROR(Function(args...)); // MP_RETURN_IF_ERROR(foo.Method(args...)); -// return mediapipe::OkStatus(); +// return absl::OkStatus(); // } // // The macro ends with a `mediapipe::StatusBuilder` which allows the returned @@ -41,11 +41,11 @@ // macro will not be evaluated unless there is an error. // // For example: -// mediapipe::Status MultiStepFunction() { +// absl::Status MultiStepFunction() { // MP_RETURN_IF_ERROR(Function(args...)) << "in MultiStepFunction"; // MP_RETURN_IF_ERROR(foo.Method(args...)).Log(base_logging::ERROR) // << "while processing query: " << query.DebugString(); -// return mediapipe::OkStatus(); +// return absl::OkStatus(); // } // // `mediapipe::StatusBuilder` supports adapting the builder chain using a @@ -74,12 +74,12 @@ // // If using this macro inside a lambda, you need to annotate the return type // to avoid confusion between a `mediapipe::StatusBuilder` and a -// `mediapipe::Status` type. E.g. +// `absl::Status` type. E.g. // -// []() -> mediapipe::Status { +// []() -> absl::Status { // MP_RETURN_IF_ERROR(Function(args...)); // MP_RETURN_IF_ERROR(foo.Method(args...)); -// return mediapipe::OkStatus(); +// return absl::OkStatus(); // } #define MP_RETURN_IF_ERROR(expr) \ STATUS_MACROS_IMPL_ELSE_BLOCKER_ \ @@ -88,7 +88,7 @@ } else /* NOLINT */ \ return status_macro_internal_adaptor.Consume() -// Executes an expression `rexpr` that returns a `mediapipe::StatusOr`. On +// Executes an expression `rexpr` that returns a `absl::StatusOr`. On // OK, extracts its value into the variable defined by `lhs`, otherwise returns // from the current function. By default the error status is returned // unchanged, but it may be modified by an `error_expression`. If there is an @@ -165,7 +165,7 @@ (void)_; /* error_expression is allowed to not use this variable */ \ return (error_expression); \ } \ - lhs = std::move(statusor).ValueOrDie() + lhs = std::move(statusor).value() // Internal helper for concatenating macro values. #define STATUS_MACROS_IMPL_CONCAT_INNER_(x, y) x##y diff --git a/mediapipe/framework/deps/status_matchers.h b/mediapipe/framework/deps/status_matchers.h index 5e7f34272..62a22afb3 100644 --- a/mediapipe/framework/deps/status_matchers.h +++ b/mediapipe/framework/deps/status_matchers.h @@ -50,8 +50,8 @@ inline IsOkMatcher IsOk() { return IsOkMatcher(); } } // namespace mediapipe -// Macros for testing the results of functions that return mediapipe::Status or -// mediapipe::StatusOr (for any type T). +// Macros for testing the results of functions that return absl::Status or +// absl::StatusOr (for any type T). #define MP_EXPECT_OK(expression) EXPECT_THAT(expression, mediapipe::IsOk()) #define MP_ASSERT_OK(expression) ASSERT_THAT(expression, mediapipe::IsOk()) diff --git a/mediapipe/framework/deps/status_test.cc b/mediapipe/framework/deps/status_test.cc index ea5ad3629..c0292dcf3 100644 --- a/mediapipe/framework/deps/status_test.cc +++ b/mediapipe/framework/deps/status_test.cc @@ -20,7 +20,7 @@ namespace mediapipe { TEST(Status, OK) { - EXPECT_EQ(OkStatus().code(), mediapipe::StatusCode::kOk); + EXPECT_EQ(OkStatus().code(), absl::StatusCode::kOk); EXPECT_EQ(OkStatus().message(), ""); MP_EXPECT_OK(OkStatus()); MP_ASSERT_OK(OkStatus()); @@ -30,25 +30,25 @@ TEST(Status, OK) { } TEST(DeathStatus, CheckOK) { - Status status(mediapipe::StatusCode::kInvalidArgument, "Invalid"); + Status status(absl::StatusCode::kInvalidArgument, "Invalid"); ASSERT_DEATH(MEDIAPIPE_CHECK_OK(status), "Invalid"); } TEST(Status, Set) { Status status; - status = Status(mediapipe::StatusCode::kCancelled, "Error message"); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kCancelled); + status = Status(absl::StatusCode::kCancelled, "Error message"); + EXPECT_EQ(status.code(), absl::StatusCode::kCancelled); EXPECT_EQ(status.message(), "Error message"); } TEST(Status, Copy) { - Status a(mediapipe::StatusCode::kInvalidArgument, "Invalid"); + Status a(absl::StatusCode::kInvalidArgument, "Invalid"); Status b(a); ASSERT_EQ(a.ToString(), b.ToString()); } TEST(Status, Assign) { - Status a(mediapipe::StatusCode::kInvalidArgument, "Invalid"); + Status a(absl::StatusCode::kInvalidArgument, "Invalid"); Status b; b = a; ASSERT_EQ(a.ToString(), b.ToString()); @@ -58,10 +58,10 @@ TEST(Status, Update) { Status s; s.Update(OkStatus()); ASSERT_TRUE(s.ok()); - Status a(mediapipe::StatusCode::kInvalidArgument, "Invalid"); + Status a(absl::StatusCode::kInvalidArgument, "Invalid"); s.Update(a); ASSERT_EQ(s.ToString(), a.ToString()); - Status b(mediapipe::StatusCode::kInternal, "Invalid"); + Status b(absl::StatusCode::kInternal, "Invalid"); s.Update(b); ASSERT_EQ(s.ToString(), a.ToString()); s.Update(OkStatus()); @@ -72,26 +72,26 @@ TEST(Status, Update) { TEST(Status, EqualsOK) { ASSERT_EQ(OkStatus(), Status()); } TEST(Status, EqualsSame) { - Status a(mediapipe::StatusCode::kInvalidArgument, "Invalid"); - Status b(mediapipe::StatusCode::kInvalidArgument, "Invalid"); + Status a(absl::StatusCode::kInvalidArgument, "Invalid"); + Status b(absl::StatusCode::kInvalidArgument, "Invalid"); ASSERT_EQ(a, b); } TEST(Status, EqualsCopy) { - const Status a(mediapipe::StatusCode::kInvalidArgument, "Invalid"); + const Status a(absl::StatusCode::kInvalidArgument, "Invalid"); const Status b = a; ASSERT_EQ(a, b); } TEST(Status, EqualsDifferentCode) { - const Status a(mediapipe::StatusCode::kInvalidArgument, "Invalid"); - const Status b(mediapipe::StatusCode::kInternal, "Internal"); + const Status a(absl::StatusCode::kInvalidArgument, "Invalid"); + const Status b(absl::StatusCode::kInternal, "Internal"); ASSERT_NE(a, b); } TEST(Status, EqualsDifferentMessage) { - const Status a(mediapipe::StatusCode::kInvalidArgument, "message"); - const Status b(mediapipe::StatusCode::kInvalidArgument, "another"); + const Status a(absl::StatusCode::kInvalidArgument, "message"); + const Status b(absl::StatusCode::kInvalidArgument, "another"); ASSERT_NE(a, b); } diff --git a/mediapipe/framework/deps/statusor.cc b/mediapipe/framework/deps/statusor.cc deleted file mode 100644 index 5d7eddec2..000000000 --- a/mediapipe/framework/deps/statusor.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mediapipe/framework/deps/statusor.h" - -#include "absl/base/attributes.h" -#include "mediapipe/framework/deps/canonical_errors.h" -#include "mediapipe/framework/deps/status.h" -#include "mediapipe/framework/port/logging.h" - -namespace mediapipe { -namespace internal_statusor { - -void Helper::HandleInvalidStatusCtorArg(mediapipe::Status* status) { - const char* kMessage = - "An OK status is not a valid constructor argument to StatusOr"; - LOG(ERROR) << kMessage; - *status = mediapipe::InternalError(kMessage); -} - -void Helper::Crash(const mediapipe::Status& status) { - LOG(FATAL) << "Attempting to fetch value instead of handling error " - << status; -} - -} // namespace internal_statusor -} // namespace mediapipe diff --git a/mediapipe/framework/deps/statusor.h b/mediapipe/framework/deps/statusor.h deleted file mode 100644 index 9dc7fc332..000000000 --- a/mediapipe/framework/deps/statusor.h +++ /dev/null @@ -1,331 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// StatusOr is the union of a Status object and a T -// object. StatusOr models the concept of an object that is either a -// usable value, or an error Status explaining why such a value is -// not present. To this end, StatusOr does not allow its Status -// value to be Status::OK. Furthermore, the value of a StatusOr -// must not be null. This is enforced by a debug check in most cases, -// but even when it is not, clients must not set the value to null. -// -// The primary use-case for StatusOr is as the return value of a -// function which may fail. -// -// Example client usage for a StatusOr, where T is not a pointer: -// -// mediapipe::StatusOr result = DoBigCalculationThatCouldFail(); -// if (result.ok()) { -// float answer = result.ValueOrDie(); -// printf("Big calculation yielded: %f", answer); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr: -// -// mediapipe::StatusOr result = FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo(result.ValueOrDie()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example client usage for a StatusOr>: -// -// mediapipe::StatusOr> result = -// FooFactory::MakeNewFoo(arg); -// if (result.ok()) { -// std::unique_ptr foo = std::move(result.ValueOrDie()); -// foo->DoSomethingCool(); -// } else { -// LOG(ERROR) << result.status(); -// } -// -// Example factory implementation returning StatusOr: -// -// mediapipe::StatusOr FooFactory::MakeNewFoo(int arg) { -// if (arg <= 0) { -// return mediapipe::InvalidArgumentError("Arg must be positive"); -// } else { -// return new Foo(arg); -// } -// } -// -// Note that the assignment operators require that destroying the currently -// stored value cannot invalidate the argument; in other words, the argument -// cannot be an alias for the current value, or anything owned by the current -// value. - -#ifndef MEDIAPIPE_DEPS_DEFAULT_STATUSOR_H_ -#define MEDIAPIPE_DEPS_DEFAULT_STATUSOR_H_ - -#include "absl/base/attributes.h" -#include "mediapipe/framework/deps/status.h" -#include "mediapipe/framework/deps/status_builder.h" -#include "mediapipe/framework/deps/statusor_internals.h" - -namespace mediapipe { - -#if defined(__clang__) -// Only clang supports warn_unused_result as a type annotation. -template -class ABSL_MUST_USE_RESULT StatusOr; -#endif - -template -class StatusOr : private internal_statusor::StatusOrData, - private internal_statusor::TraitsBase< - std::is_copy_constructible::value, - std::is_move_constructible::value> { - template - friend class StatusOr; - - typedef internal_statusor::StatusOrData Base; - - public: - typedef T element_type; - - // Constructs a new StatusOr with Status::UNKNOWN status. This is marked - // 'explicit' to try to catch cases like 'return {};', where people think - // StatusOr> will be initialized with an empty vector, - // instead of a Status::UNKNOWN status. - explicit StatusOr(); - - // StatusOr will be copy constructible/assignable if T is copy - // constructible. - StatusOr(const StatusOr&) = default; - StatusOr& operator=(const StatusOr&) = default; - - // StatusOr will be move constructible/assignable if T is move - // constructible. - StatusOr(StatusOr&&) = default; - StatusOr& operator=(StatusOr&&) = default; - - // Conversion copy/move constructor, T must be convertible from U. - // TODO: These should not participate in overload resolution if U - // is not convertible to T. - template - StatusOr(const StatusOr& other); - template - StatusOr(StatusOr&& other); - - // Conversion copy/move assignment operator, T must be convertible from U. - template - StatusOr& operator=(const StatusOr& other); - template - StatusOr& operator=(StatusOr&& other); - - // Constructs a new StatusOr with the given value. After calling this - // constructor, calls to ValueOrDie() will succeed, and calls to status() will - // return OK. - // - // NOTE: Not explicit - we want to use StatusOr as a return type - // so it is convenient and sensible to be able to do 'return T()' - // when the return type is StatusOr. - // - // REQUIRES: T is copy constructible. - StatusOr(const T& value); - - // Constructs a new StatusOr with the given non-ok status. After calling - // this constructor, calls to ValueOrDie() will CHECK-fail. - // - // NOTE: Not explicit - we want to use StatusOr as a return - // value, so it is convenient and sensible to be able to do 'return - // Status()' when the return type is StatusOr. - // - // REQUIRES: !status.ok(). This requirement is DCHECKed. - // In optimized builds, passing Status::OK() here will have the effect - // of passing mediapipe::StatusCode::kInternal as a fallback. - StatusOr(const mediapipe::Status& status); - StatusOr& operator=(const mediapipe::Status& status); - StatusOr(const mediapipe::StatusBuilder& builder); - StatusOr& operator=(const mediapipe::StatusBuilder& builder); - - // TODO: Add operator=(T) overloads. - - // Similar to the `const T&` overload. - // - // REQUIRES: T is move constructible. - StatusOr(T&& value); - - // RValue versions of the operations declared above. - StatusOr(mediapipe::Status&& status); - StatusOr& operator=(mediapipe::Status&& status); - StatusOr(mediapipe::StatusBuilder&& builder); - StatusOr& operator=(mediapipe::StatusBuilder&& builder); - - // Returns this->status().ok() - bool ok() const { return this->status_.ok(); } - - // Returns a reference to mediapipe status. If this contains a T, then - // returns Status::OK(). - const mediapipe::Status& status() const&; - mediapipe::Status status() &&; - - // Returns a reference to our current value, or CHECK-fails if !this->ok(). - // - // Note: for value types that are cheap to copy, prefer simple code: - // - // T value = statusor.ValueOrDie(); - // - // Otherwise, if the value type is expensive to copy, but can be left - // in the StatusOr, simply assign to a reference: - // - // T& value = statusor.ValueOrDie(); // or `const T&` - // - // Otherwise, if the value type supports an efficient move, it can be - // used as follows: - // - // T value = std::move(statusor).ValueOrDie(); - // - // The std::move on statusor instead of on the whole expression enables - // warnings about possible uses of the statusor object after the move. - // C++ style guide waiver for ref-qualified overloads granted in cl/143176389 - // See go/ref-qualifiers for more details on such overloads. - const T& ValueOrDie() const&; - T& ValueOrDie() &; - const T&& ValueOrDie() const&&; - T&& ValueOrDie() &&; - - T ConsumeValueOrDie() { return std::move(ValueOrDie()); } - - // Ignores any errors. This method does nothing except potentially suppress - // complaints from any tools that are checking that errors are not dropped on - // the floor. - void IgnoreError() const; -}; - -//////////////////////////////////////////////////////////////////////////////// -// Implementation details for StatusOr - -template -StatusOr::StatusOr() - : Base(mediapipe::Status(mediapipe::StatusCode::kUnknown, "")) {} - -template -StatusOr::StatusOr(const T& value) : Base(value) {} - -template -StatusOr::StatusOr(const mediapipe::Status& status) : Base(status) {} - -template -StatusOr::StatusOr(const mediapipe::StatusBuilder& builder) - : Base(builder) {} - -template -StatusOr& StatusOr::operator=(const mediapipe::Status& status) { - this->Assign(status); - return *this; -} - -template -StatusOr& StatusOr::operator=(const mediapipe::StatusBuilder& builder) { - return *this = static_cast(builder); -} - -template -StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} - -template -StatusOr::StatusOr(mediapipe::Status&& status) : Base(std::move(status)) {} - -template -StatusOr::StatusOr(mediapipe::StatusBuilder&& builder) - : Base(std::move(builder)) {} - -template -StatusOr& StatusOr::operator=(mediapipe::Status&& status) { - this->Assign(std::move(status)); - return *this; -} - -template -StatusOr& StatusOr::operator=(mediapipe::StatusBuilder&& builder) { - return *this = static_cast(std::move(builder)); -} - -template -template -inline StatusOr::StatusOr(const StatusOr& other) - : Base(static_cast::Base&>(other)) {} - -template -template -inline StatusOr& StatusOr::operator=(const StatusOr& other) { - if (other.ok()) - this->Assign(other.ValueOrDie()); - else - this->Assign(other.status()); - return *this; -} - -template -template -inline StatusOr::StatusOr(StatusOr&& other) - : Base(static_cast::Base&&>(other)) {} - -template -template -inline StatusOr& StatusOr::operator=(StatusOr&& other) { - if (other.ok()) { - this->Assign(std::move(other).ValueOrDie()); - } else { - this->Assign(std::move(other).status()); - } - return *this; -} - -template -const mediapipe::Status& StatusOr::status() const& { - return this->status_; -} -template -mediapipe::Status StatusOr::status() && { - return ok() ? mediapipe::OkStatus() : std::move(this->status_); -} - -template -const T& StatusOr::ValueOrDie() const& { - this->EnsureOk(); - return this->data_; -} - -template -T& StatusOr::ValueOrDie() & { - this->EnsureOk(); - return this->data_; -} - -template -const T&& StatusOr::ValueOrDie() const&& { - this->EnsureOk(); - return std::move(this->data_); -} - -template -T&& StatusOr::ValueOrDie() && { - this->EnsureOk(); - return std::move(this->data_); -} - -template -void StatusOr::IgnoreError() const { - // no-op -} - -} // namespace mediapipe - -#endif // MEDIAPIPE_DEPS_DEFAULT_STATUSOR_H_ diff --git a/mediapipe/framework/deps/statusor_internals.h b/mediapipe/framework/deps/statusor_internals.h deleted file mode 100644 index 8a419ba3b..000000000 --- a/mediapipe/framework/deps/statusor_internals.h +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MEDIAPIPE_DEPS_STATUSOR_INTERNALS_H_ -#define MEDIAPIPE_DEPS_STATUSOR_INTERNALS_H_ - -#include "absl/base/attributes.h" -#include "mediapipe/framework/deps/status.h" - -namespace mediapipe { -namespace internal_statusor { - -class Helper { - public: - // Move type-agnostic error handling to the .cc. - static void HandleInvalidStatusCtorArg(mediapipe::Status*); - ABSL_ATTRIBUTE_NORETURN static void Crash(const mediapipe::Status& status); -}; - -// Construct an instance of T in `p` through placement new, passing Args... to -// the constructor. -// This abstraction is here mostly for the gcc performance fix. -template -void PlacementNew(void* p, Args&&... args) { -#if defined(__GNUC__) && !defined(__clang__) - // Teach gcc that 'p' cannot be null, fixing code size issues. - if (p == nullptr) __builtin_unreachable(); -#endif - new (p) T(std::forward(args)...); -} - -// Helper base class to hold the data and all operations. -// We move all this to a base class to allow mixing with the appropriate -// TraitsBase specialization. -template -class StatusOrData { - template - friend class StatusOrData; - - public: - StatusOrData() = delete; - - StatusOrData(const StatusOrData& other) { - if (other.ok()) { - MakeValue(other.data_); - MakeStatus(); - } else { - MakeStatus(other.status_); - } - } - - StatusOrData(StatusOrData&& other) noexcept { - if (other.ok()) { - MakeValue(std::move(other.data_)); - MakeStatus(); - } else { - MakeStatus(std::move(other.status_)); - } - } - - template - StatusOrData(const StatusOrData& other) { - if (other.ok()) { - MakeValue(other.data_); - MakeStatus(); - } else { - MakeStatus(other.status_); - } - } - - template - StatusOrData(StatusOrData&& other) { - if (other.ok()) { - MakeValue(std::move(other.data_)); - MakeStatus(); - } else { - MakeStatus(std::move(other.status_)); - } - } - - explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); } - explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); } - - explicit StatusOrData(const mediapipe::Status& status) : status_(status) { - EnsureNotOk(); - } - explicit StatusOrData(mediapipe::Status&& status) - : status_(std::move(status)) { - EnsureNotOk(); - } - - StatusOrData& operator=(const StatusOrData& other) { - if (this == &other) return *this; - if (other.ok()) - Assign(other.data_); - else - Assign(other.status_); - return *this; - } - - StatusOrData& operator=(StatusOrData&& other) { - if (this == &other) return *this; - if (other.ok()) - Assign(std::move(other.data_)); - else - Assign(std::move(other.status_)); - return *this; - } - - ~StatusOrData() { - if (ok()) { - status_.~Status(); - data_.~T(); - } else { - status_.~Status(); - } - } - - void Assign(const T& value) { - if (ok()) { - data_.~T(); - MakeValue(value); - } else { - MakeValue(value); - status_ = mediapipe::OkStatus(); - } - } - - void Assign(T&& value) { - if (ok()) { - data_.~T(); - MakeValue(std::move(value)); - } else { - MakeValue(std::move(value)); - status_ = mediapipe::OkStatus(); - } - } - - void Assign(const mediapipe::Status& status) { - Clear(); - status_ = status; - EnsureNotOk(); - } - - void Assign(mediapipe::Status&& status) { - Clear(); - status_ = std::move(status); - EnsureNotOk(); - } - - bool ok() const { return status_.ok(); } - - protected: - // status_ will always be active after the constructor. - // We make it a union to be able to initialize exactly how we need without - // waste. - // Eg. in the copy constructor we use the default constructor of Status in - // the ok() path to avoid an extra Ref call. - union { - mediapipe::Status status_; - }; - - // data_ is active iff status_.ok()==true - struct Dummy {}; - union { - // When T is const, we need some non-const object we can cast to void* for - // the placement new. dummy_ is that object. - Dummy dummy_; - T data_; - }; - - void Clear() { - if (ok()) data_.~T(); - } - - void EnsureOk() const { - if (!ok()) Helper::Crash(status_); - } - - void EnsureNotOk() { - if (ok()) Helper::HandleInvalidStatusCtorArg(&status_); - } - - // Construct the value (ie. data_) through placement new with the passed - // argument. - template - void MakeValue(Arg&& arg) { - internal_statusor::PlacementNew(&dummy_, std::forward(arg)); - } - - // Construct the status (ie. status_) through placement new with the passed - // argument. - template - void MakeStatus(Args&&... args) { - internal_statusor::PlacementNew( - &status_, std::forward(args)...); - } -}; - -// Helper base class to allow implicitly deleted constructors and assignment -// operations in StatusOr. -// TraitsBase will explicitly delete what it can't support and StatusOr will -// inherit that behavior implicitly. -template -struct TraitsBase { - TraitsBase() = default; - TraitsBase(const TraitsBase&) = default; - TraitsBase(TraitsBase&&) = default; - TraitsBase& operator=(const TraitsBase&) = default; - TraitsBase& operator=(TraitsBase&&) = default; -}; - -template <> -struct TraitsBase { - TraitsBase() = default; - TraitsBase(const TraitsBase&) = delete; - TraitsBase(TraitsBase&&) = default; - TraitsBase& operator=(const TraitsBase&) = delete; - TraitsBase& operator=(TraitsBase&&) = default; -}; - -template <> -struct TraitsBase { - TraitsBase() = default; - TraitsBase(const TraitsBase&) = delete; - TraitsBase(TraitsBase&&) = delete; - TraitsBase& operator=(const TraitsBase&) = delete; - TraitsBase& operator=(TraitsBase&&) = delete; -}; - -} // namespace internal_statusor -} // namespace mediapipe - -#endif // MEDIAPIPE_DEPS_STATUSOR_INTERNALS_H_ diff --git a/mediapipe/framework/deps/statusor_test.cc b/mediapipe/framework/deps/statusor_test.cc deleted file mode 100644 index bf2b436d1..000000000 --- a/mediapipe/framework/deps/statusor_test.cc +++ /dev/null @@ -1,437 +0,0 @@ -// Copyright 2019 The MediaPipe Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Unit tests for StatusOr - -#include "mediapipe/framework/deps/statusor.h" - -#include -#include - -#include "mediapipe/framework/deps/canonical_errors.h" -#include "mediapipe/framework/deps/status.h" -#include "mediapipe/framework/port/gtest.h" - -namespace mediapipe { -namespace { - -class Base1 { - public: - virtual ~Base1() {} - int pad_; -}; - -class Base2 { - public: - virtual ~Base2() {} - int yetotherpad_; -}; - -class Derived : public Base1, public Base2 { - public: - ~Derived() override {} - int evenmorepad_; -}; - -class CopyNoAssign { - public: - explicit CopyNoAssign(int value) : foo_(value) {} - CopyNoAssign(const CopyNoAssign& other) : foo_(other.foo_) {} - int foo_; - - private: - const CopyNoAssign& operator=(const CopyNoAssign&); -}; - -class NoDefaultConstructor { - public: - explicit NoDefaultConstructor(int foo); -}; - -static_assert(!std::is_default_constructible(), - "Should not be default-constructible."); - -StatusOr> ReturnUniquePtr() { - // Uses implicit constructor from T&& - return std::unique_ptr(new int(0)); -} - -TEST(StatusOr, ElementType) { - static_assert(std::is_same::element_type, int>(), ""); - static_assert(std::is_same::element_type, char>(), ""); -} - -TEST(StatusOr, TestNoDefaultConstructorInitialization) { - // Explicitly initialize it with an error code. - mediapipe::StatusOr statusor( - mediapipe::CancelledError("")); - EXPECT_FALSE(statusor.ok()); - EXPECT_EQ(statusor.status().code(), mediapipe::StatusCode::kCancelled); - - // Default construction of StatusOr initializes it with an UNKNOWN error code. - mediapipe::StatusOr statusor2; - EXPECT_FALSE(statusor2.ok()); - EXPECT_EQ(statusor2.status().code(), mediapipe::StatusCode::kUnknown); -} - -TEST(StatusOr, TestMoveOnlyInitialization) { - mediapipe::StatusOr> thing(ReturnUniquePtr()); - ASSERT_TRUE(thing.ok()); - EXPECT_EQ(0, *thing.ValueOrDie()); - int* previous = thing.ValueOrDie().get(); - - thing = ReturnUniquePtr(); - EXPECT_TRUE(thing.ok()); - EXPECT_EQ(0, *thing.ValueOrDie()); - EXPECT_NE(previous, thing.ValueOrDie().get()); -} - -TEST(StatusOr, TestMoveOnlyStatusCtr) { - mediapipe::StatusOr> thing( - mediapipe::CancelledError("")); - ASSERT_FALSE(thing.ok()); -} - -TEST(StatusOr, TestMoveOnlyValueExtraction) { - mediapipe::StatusOr> thing(ReturnUniquePtr()); - ASSERT_TRUE(thing.ok()); - std::unique_ptr ptr = thing.ConsumeValueOrDie(); - EXPECT_EQ(0, *ptr); - - thing = std::move(ptr); - ptr = std::move(thing.ValueOrDie()); - EXPECT_EQ(0, *ptr); -} - -TEST(StatusOr, TestMoveOnlyConversion) { - mediapipe::StatusOr> const_thing( - ReturnUniquePtr()); - EXPECT_TRUE(const_thing.ok()); - EXPECT_EQ(0, *const_thing.ValueOrDie()); - - // Test rvalue converting assignment - const int* const_previous = const_thing.ValueOrDie().get(); - const_thing = ReturnUniquePtr(); - EXPECT_TRUE(const_thing.ok()); - EXPECT_EQ(0, *const_thing.ValueOrDie()); - EXPECT_NE(const_previous, const_thing.ValueOrDie().get()); -} - -TEST(StatusOr, TestMoveOnlyVector) { - // Sanity check that mediapipe::StatusOr works in vector. - std::vector>> vec; - vec.push_back(ReturnUniquePtr()); - vec.resize(2); - auto another_vec = std::move(vec); - EXPECT_EQ(0, *another_vec[0].ValueOrDie()); - EXPECT_EQ(mediapipe::StatusCode::kUnknown, another_vec[1].status().code()); -} - -TEST(StatusOr, TestMoveWithValuesAndErrors) { - mediapipe::StatusOr status_or(std::string(1000, '0')); - mediapipe::StatusOr value1(std::string(1000, '1')); - mediapipe::StatusOr value2(std::string(1000, '2')); - mediapipe::StatusOr error1( - Status(mediapipe::StatusCode::kUnknown, "error1")); - mediapipe::StatusOr error2( - Status(mediapipe::StatusCode::kUnknown, "error2")); - - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '0'), status_or.ValueOrDie()); - - // Overwrite the value in status_or with another value. - status_or = std::move(value1); - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '1'), status_or.ValueOrDie()); - - // Overwrite the value in status_or with an error. - status_or = std::move(error1); - ASSERT_FALSE(status_or.ok()); - EXPECT_EQ("error1", status_or.status().message()); - - // Overwrite the error in status_or with another error. - status_or = std::move(error2); - ASSERT_FALSE(status_or.ok()); - EXPECT_EQ("error2", status_or.status().message()); - - // Overwrite the error with a value. - status_or = std::move(value2); - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '2'), status_or.ValueOrDie()); -} - -TEST(StatusOr, TestCopyWithValuesAndErrors) { - mediapipe::StatusOr status_or(std::string(1000, '0')); - mediapipe::StatusOr value1(std::string(1000, '1')); - mediapipe::StatusOr value2(std::string(1000, '2')); - mediapipe::StatusOr error1( - Status(mediapipe::StatusCode::kUnknown, "error1")); - mediapipe::StatusOr error2( - Status(mediapipe::StatusCode::kUnknown, "error2")); - - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '0'), status_or.ValueOrDie()); - - // Overwrite the value in status_or with another value. - status_or = value1; - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '1'), status_or.ValueOrDie()); - - // Overwrite the value in status_or with an error. - status_or = error1; - ASSERT_FALSE(status_or.ok()); - EXPECT_EQ("error1", status_or.status().message()); - - // Overwrite the error in status_or with another error. - status_or = error2; - ASSERT_FALSE(status_or.ok()); - EXPECT_EQ("error2", status_or.status().message()); - - // Overwrite the error with a value. - status_or = value2; - ASSERT_TRUE(status_or.ok()); - EXPECT_EQ(std::string(1000, '2'), status_or.ValueOrDie()); - - // Verify original values unchanged. - EXPECT_EQ(std::string(1000, '1'), value1.ValueOrDie()); - EXPECT_EQ("error1", error1.status().message()); - EXPECT_EQ("error2", error2.status().message()); - EXPECT_EQ(std::string(1000, '2'), value2.ValueOrDie()); -} - -TEST(StatusOr, TestDefaultCtor) { - mediapipe::StatusOr thing; - EXPECT_FALSE(thing.ok()); - EXPECT_EQ(thing.status().code(), mediapipe::StatusCode::kUnknown); -} - -TEST(StatusOrDeathTest, TestDefaultCtorValue) { - mediapipe::StatusOr thing; - EXPECT_DEATH(thing.ValueOrDie(), ""); - - const mediapipe::StatusOr thing2; - EXPECT_DEATH(thing.ValueOrDie(), ""); -} - -TEST(StatusOr, TestStatusCtor) { - mediapipe::StatusOr thing( - mediapipe::Status(mediapipe::StatusCode::kCancelled, "")); - EXPECT_FALSE(thing.ok()); - EXPECT_EQ(thing.status().code(), mediapipe::StatusCode::kCancelled); -} - -TEST(StatusOr, TestValueCtor) { - const int kI = 4; - const mediapipe::StatusOr thing(kI); - EXPECT_TRUE(thing.ok()); - EXPECT_EQ(kI, thing.ValueOrDie()); -} - -TEST(StatusOr, TestCopyCtorStatusOk) { - const int kI = 4; - const mediapipe::StatusOr original(kI); - const mediapipe::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.ValueOrDie(), copy.ValueOrDie()); -} - -TEST(StatusOr, TestCopyCtorStatusNotOk) { - mediapipe::StatusOr original( - Status(mediapipe::StatusCode::kCancelled, "")); - mediapipe::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); -} - -TEST(StatusOr, TestCopyCtorNonAssignable) { - const int kI = 4; - CopyNoAssign value(kI); - mediapipe::StatusOr original(value); - mediapipe::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.ValueOrDie().foo_, copy.ValueOrDie().foo_); -} - -TEST(StatusOr, TestCopyCtorStatusOKConverting) { - const int kI = 4; - mediapipe::StatusOr original(kI); - mediapipe::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); - EXPECT_DOUBLE_EQ(original.ValueOrDie(), copy.ValueOrDie()); -} - -TEST(StatusOr, TestCopyCtorStatusNotOkConverting) { - mediapipe::StatusOr original( - Status(mediapipe::StatusCode::kCancelled, "")); - mediapipe::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); -} - -TEST(StatusOr, TestAssignmentStatusOk) { - const int kI = 4; - mediapipe::StatusOr source(kI); - mediapipe::StatusOr target; - target = source; - EXPECT_EQ(target.status(), source.status()); - EXPECT_EQ(source.ValueOrDie(), target.ValueOrDie()); -} - -TEST(StatusOr, TestAssignmentStatusNotOk) { - mediapipe::StatusOr source( - Status(mediapipe::StatusCode::kCancelled, "")); - mediapipe::StatusOr target; - target = source; - EXPECT_EQ(target.status(), source.status()); -} - -TEST(StatusOr, TestStatus) { - mediapipe::StatusOr good(4); - EXPECT_TRUE(good.ok()); - mediapipe::StatusOr bad(Status(mediapipe::StatusCode::kCancelled, "")); - EXPECT_FALSE(bad.ok()); - EXPECT_EQ(bad.status(), Status(mediapipe::StatusCode::kCancelled, "")); -} - -TEST(StatusOr, TestValue) { - const int kI = 4; - mediapipe::StatusOr thing(kI); - EXPECT_EQ(kI, thing.ValueOrDie()); -} - -TEST(StatusOr, TestValueConst) { - const int kI = 4; - const mediapipe::StatusOr thing(kI); - EXPECT_EQ(kI, thing.ValueOrDie()); -} - -TEST(StatusOrDeathTest, TestValueNotOk) { - mediapipe::StatusOr thing( - mediapipe::Status(mediapipe::StatusCode::kCancelled, "cancelled")); - EXPECT_DEATH(thing.ValueOrDie(), "cancelled"); -} - -TEST(StatusOrDeathTest, TestValueNotOkConst) { - const mediapipe::StatusOr thing( - mediapipe::Status(mediapipe::StatusCode::kUnknown, "")); - EXPECT_DEATH(thing.ValueOrDie(), ""); -} - -TEST(StatusOr, TestPointerDefaultCtor) { - mediapipe::StatusOr thing; - EXPECT_FALSE(thing.ok()); - EXPECT_EQ(thing.status().code(), mediapipe::StatusCode::kUnknown); -} - -TEST(StatusOrDeathTest, TestPointerDefaultCtorValue) { - mediapipe::StatusOr thing; - EXPECT_DEATH(thing.ValueOrDie(), ""); -} - -TEST(StatusOr, TestPointerStatusCtor) { - mediapipe::StatusOr thing( - Status(mediapipe::StatusCode::kCancelled, "")); - EXPECT_FALSE(thing.ok()); - EXPECT_EQ(thing.status(), Status(mediapipe::StatusCode::kCancelled, "")); -} - -TEST(StatusOr, TestPointerValueCtor) { - const int kI = 4; - mediapipe::StatusOr thing(&kI); - EXPECT_TRUE(thing.ok()); - EXPECT_EQ(&kI, thing.ValueOrDie()); -} - -TEST(StatusOr, TestPointerCopyCtorStatusOk) { - const int kI = 0; - mediapipe::StatusOr original(&kI); - mediapipe::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.ValueOrDie(), copy.ValueOrDie()); -} - -TEST(StatusOr, TestPointerCopyCtorStatusNotOk) { - mediapipe::StatusOr original( - Status(mediapipe::StatusCode::kCancelled, "")); - mediapipe::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); -} - -TEST(StatusOr, TestPointerCopyCtorStatusOKConverting) { - Derived derived; - mediapipe::StatusOr original(&derived); - mediapipe::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(static_cast(original.ValueOrDie()), - copy.ValueOrDie()); -} - -TEST(StatusOr, TestPointerCopyCtorStatusNotOkConverting) { - mediapipe::StatusOr original( - mediapipe::Status(mediapipe::StatusCode::kCancelled, "")); - mediapipe::StatusOr copy(original); - EXPECT_EQ(copy.status(), original.status()); -} - -TEST(StatusOr, TestPointerAssignmentStatusOk) { - const int kI = 0; - mediapipe::StatusOr source(&kI); - mediapipe::StatusOr target; - target = source; - EXPECT_EQ(target.status(), source.status()); - EXPECT_EQ(source.ValueOrDie(), target.ValueOrDie()); -} - -TEST(StatusOr, TestPointerAssignmentStatusNotOk) { - mediapipe::StatusOr source( - mediapipe::Status(mediapipe::StatusCode::kCancelled, "")); - mediapipe::StatusOr target; - target = source; - EXPECT_EQ(target.status(), source.status()); -} - -TEST(StatusOr, TestPointerStatus) { - const int kI = 0; - mediapipe::StatusOr good(&kI); - EXPECT_TRUE(good.ok()); - mediapipe::StatusOr bad( - mediapipe::Status(mediapipe::StatusCode::kCancelled, "")); - EXPECT_EQ(bad.status(), - mediapipe::Status(mediapipe::StatusCode::kCancelled, "")); -} - -TEST(StatusOr, TestPointerValue) { - const int kI = 0; - mediapipe::StatusOr thing(&kI); - EXPECT_EQ(&kI, thing.ValueOrDie()); -} - -TEST(StatusOr, TestPointerValueConst) { - const int kI = 0; - const mediapipe::StatusOr thing(&kI); - EXPECT_EQ(&kI, thing.ValueOrDie()); -} - -TEST(StatusOrDeathTest, TestPointerValueNotOk) { - mediapipe::StatusOr thing( - mediapipe::Status(mediapipe::StatusCode::kCancelled, "cancelled")); - EXPECT_DEATH(thing.ValueOrDie(), "cancelled"); -} - -TEST(StatusOrDeathTest, TestPointerValueNotOkConst) { - const mediapipe::StatusOr thing( - mediapipe::Status(mediapipe::StatusCode::kCancelled, "cancelled")); - EXPECT_DEATH(thing.ValueOrDie(), "cancelled"); -} - -} // namespace -} // namespace mediapipe diff --git a/mediapipe/framework/deps/vector.h b/mediapipe/framework/deps/vector.h index ae6fa0068..2d4de82f3 100644 --- a/mediapipe/framework/deps/vector.h +++ b/mediapipe/framework/deps/vector.h @@ -241,7 +241,7 @@ class BasicVector { return out << "]"; } - // These are only public for technical reasons (see cl/121145822). + // These are only public for technical reasons. template D MulScalarInternal(const K& k) const { return Generate([k](const T& x) { return k * x; }, AsD()); diff --git a/mediapipe/framework/executor.h b/mediapipe/framework/executor.h index 50eb9854a..521f9bb81 100644 --- a/mediapipe/framework/executor.h +++ b/mediapipe/framework/executor.h @@ -13,7 +13,6 @@ // limitations under the License. // Executor class for the MediaPipe scheduler. -// Design doc: go/mediapipe-executor #ifndef MEDIAPIPE_FRAMEWORK_EXECUTOR_H_ #define MEDIAPIPE_FRAMEWORK_EXECUTOR_H_ @@ -48,7 +47,7 @@ class Executor { // A registered Executor subclass must implement the static factory method // Create. The Executor subclass cannot be registered without it. // - // static mediapipe::StatusOr Create( + // static absl::StatusOr Create( // const MediaPipeOptions& extendable_options); // // Create validates extendable_options, then calls the constructor, and @@ -65,8 +64,8 @@ class Executor { virtual void Schedule(std::function task) = 0; }; -using ExecutorRegistry = GlobalFactoryRegistry, - const MediaPipeOptions&>; +using ExecutorRegistry = + GlobalFactoryRegistry, const MediaPipeOptions&>; // Macro for registering the executor. #define REGISTER_EXECUTOR(name) \ diff --git a/mediapipe/framework/formats/BUILD b/mediapipe/framework/formats/BUILD index dfe432972..e404e7218 100644 --- a/mediapipe/framework/formats/BUILD +++ b/mediapipe/framework/formats/BUILD @@ -165,6 +165,7 @@ cc_library( }), visibility = ["//visibility:public"], deps = [ + "@com_google_protobuf//:protobuf", "//mediapipe/framework/formats:location_data_cc_proto", "//mediapipe/framework/formats/annotation:locus_cc_proto", "@com_google_absl//absl/base:core_headers", @@ -182,8 +183,6 @@ cc_library( "//mediapipe/framework/port:statusor", "//mediapipe/framework/formats/annotation:rasterization_cc_proto", ] + select({ - "//conditions:default": ["@com_google_protobuf//:protobuf"], - }) + select({ "//conditions:default": [ "//mediapipe/framework/port:opencv_imgproc", ], @@ -277,6 +276,95 @@ filegroup( visibility = ["//mediapipe:__subpackages__"], ) +cc_library( + name = "image", + srcs = ["image.cc"], + hdrs = ["image.h"], + copts = select({ + "//mediapipe:ios": [ + "-x objective-c++", + "-fobjc-arc", # enable reference-counting + ], + "//conditions:default": [], + }), + visibility = ["//visibility:public"], + deps = [ + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_format_cc_proto", + "@com_google_absl//absl/synchronization", + "//mediapipe/framework:port", + "//mediapipe/framework/port:logging", + ] + select({ + "//conditions:default": [ + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_buffer_format", + "//mediapipe/gpu:gl_texture_buffer", + ], + "//mediapipe:ios": [ + "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_buffer_format", + ], + "//mediapipe/gpu:disable_gpu": [], + }) + select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "//mediapipe/objc:CFHolder", + "//mediapipe/objc:util", + ], + }), +) + +cc_library( + name = "image_multi_pool", + srcs = ["image_multi_pool.cc"], + hdrs = ["image_multi_pool.h"], + visibility = ["//visibility:public"], + deps = [ + ":image", + "//mediapipe/framework/formats:image_frame_pool", + "//mediapipe/framework:port", + "//mediapipe/framework/port:logging", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", + ] + select({ + "//conditions:default": [ + "//mediapipe/gpu:gl_texture_buffer", + "//mediapipe/gpu:gl_texture_buffer_pool", + "//mediapipe/gpu:gl_base", + "//mediapipe/gpu:gpu_buffer", + ], + "//mediapipe:ios": [ + "//mediapipe/gpu:gl_base", + "//mediapipe/gpu:gpu_buffer", + ], + "//mediapipe/gpu:disable_gpu": [], + }) + select({ + "//conditions:default": [], + "//mediapipe:apple": [ + "//mediapipe/gpu:pixel_buffer_pool_util", + "//mediapipe/objc:CFHolder", + ], + }), +) + +cc_library( + name = "image_opencv", + srcs = [ + "image_opencv.cc", + ], + hdrs = [ + "image_opencv.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":image", + "//mediapipe/framework/formats:image_format_cc_proto", + "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:statusor", + ], +) + cc_library( name = "image_frame_pool", srcs = ["image_frame_pool.cc"], diff --git a/mediapipe/framework/formats/classification.proto b/mediapipe/framework/formats/classification.proto index 328d9ecb5..8a777f105 100644 --- a/mediapipe/framework/formats/classification.proto +++ b/mediapipe/framework/formats/classification.proto @@ -21,6 +21,8 @@ syntax = "proto2"; package mediapipe; option objc_class_prefix = "MediaPipe"; +option java_package = "com.google.mediapipe.formats.proto"; +option java_outer_classname = "ClassificationProto"; message Classification { // The index of the class in the corresponding label map. diff --git a/mediapipe/framework/formats/detection.proto b/mediapipe/framework/formats/detection.proto index 402408199..3e2d067b3 100644 --- a/mediapipe/framework/formats/detection.proto +++ b/mediapipe/framework/formats/detection.proto @@ -29,7 +29,6 @@ option java_outer_classname = "DetectionProto"; message Detection { // i-th label or label_id has a score encoded by the i-th element in score. - // Either string or integer labels must be used but not both at the same time. repeated string label = 1; repeated int32 label_id = 2 [packed = true]; repeated float score = 3 [packed = true]; diff --git a/mediapipe/framework/formats/image.cc b/mediapipe/framework/formats/image.cc new file mode 100644 index 000000000..b8944593c --- /dev/null +++ b/mediapipe/framework/formats/image.cc @@ -0,0 +1,97 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/formats/image.h" + +namespace mediapipe { + +// TODO Refactor common code from GpuBufferToImageFrameCalculator +bool Image::ConvertToCpu() const { + if (!use_gpu_) return true; // Already on CPU. +#if !MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + image_frame_ = CreateImageFrameForCVPixelBuffer(GetCVPixelBufferRef()); +#else + auto gl_texture = gpu_buffer_.GetGlTextureBufferSharedPtr(); + if (!gl_texture->GetProducerContext()) return false; + gl_texture->GetProducerContext()->Run([this, &gl_texture]() { + gl_texture->WaitOnGpu(); + const auto gpu_buf = mediapipe::GpuBuffer(GetGlTextureBufferSharedPtr()); +#ifdef __ANDROID__ + glBindFramebuffer(GL_FRAMEBUFFER, 0); // b/32091368 +#endif + GLuint fb = 0; + glDisable(GL_DEPTH_TEST); + // TODO Re-use a shared framebuffer. + glGenFramebuffers(1, &fb); + glBindFramebuffer(GL_FRAMEBUFFER, fb); + glViewport(0, 0, gpu_buf.width(), gpu_buf.height()); + glActiveTexture(GL_TEXTURE0); + glBindTexture(gl_texture->target(), gl_texture->name()); + glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, + gl_texture->target(), gl_texture->name(), 0); + auto frame = std::make_shared( + mediapipe::ImageFormatForGpuBufferFormat(gpu_buf.format()), + gpu_buf.width(), gpu_buf.height(), + ImageFrame::kGlDefaultAlignmentBoundary); + const auto info = GlTextureInfoForGpuBufferFormat( + gpu_buf.format(), 0, gl_texture->GetProducerContext()->GetGlVersion()); + glReadPixels(0, 0, gpu_buf.width(), gpu_buf.height(), info.gl_format, + info.gl_type, frame->MutablePixelData()); + glDeleteFramebuffers(1, &fb); + // Cleanup + gl_texture->DidRead(gl_texture->GetProducerContext()->CreateSyncToken()); + image_frame_ = frame; + }); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // !MEDIAPIPE_DISABLE_GPU + use_gpu_ = false; + return true; +} + +// TODO Refactor common code from ImageFrameToGpuBufferCalculator +bool Image::ConvertToGpu() const { +#if MEDIAPIPE_DISABLE_GPU + return false; +#else + if (use_gpu_) return true; // Already on GPU. +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + auto packet = MakePacket(std::move(*image_frame_)); + image_frame_ = nullptr; + CFHolder buffer; + auto status = CreateCVPixelBufferForImageFramePacket(packet, true, &buffer); + CHECK_OK(status); + gpu_buffer_ = mediapipe::GpuBuffer(std::move(buffer)); +#else + // GlCalculatorHelperImpl::MakeGlTextureBuffer (CreateSourceTexture) + auto buffer = mediapipe::GlTextureBuffer::Create( + image_frame_->Width(), image_frame_->Height(), + mediapipe::GpuBufferFormatForImageFormat(image_frame_->Format()), + image_frame_->PixelData()); + glBindTexture(GL_TEXTURE_2D, buffer->name()); + // See GlCalculatorHelperImpl::SetStandardTextureParams + glTexParameteri(buffer->target(), GL_TEXTURE_MIN_FILTER, GL_LINEAR); + glTexParameteri(buffer->target(), GL_TEXTURE_MAG_FILTER, GL_LINEAR); + glTexParameteri(buffer->target(), GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); + glTexParameteri(buffer->target(), GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); + glBindTexture(GL_TEXTURE_2D, 0); + glFlush(); + gpu_buffer_ = mediapipe::GpuBuffer(std::move(buffer)); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + use_gpu_ = true; + return true; +#endif // MEDIAPIPE_DISABLE_GPU +} + +} // namespace mediapipe diff --git a/mediapipe/framework/formats/image.h b/mediapipe/framework/formats/image.h new file mode 100644 index 000000000..9f36471c8 --- /dev/null +++ b/mediapipe/framework/formats/image.h @@ -0,0 +1,318 @@ +// Copyright 2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_H_ + +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/port/logging.h" + +#if !MEDIAPIPE_DISABLE_GPU + +#include "mediapipe/gpu/gpu_buffer.h" +#include "mediapipe/gpu/gpu_buffer_format.h" + +#if defined(__APPLE__) +#include + +#include "mediapipe/objc/CFHolder.h" +#include "mediapipe/objc/util.h" +#if !TARGET_OS_OSX // iOS, use CVPixelBuffer. +#define MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER 1 +#endif // TARGET_OS_OSX +#endif // defined(__APPLE__) + +#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER // OSX, use GL textures. +#include "mediapipe/gpu/gl_texture_buffer.h" +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +using ImageFrameSharedPtr = std::shared_ptr; + +// This class wraps ImageFrame(CPU) & GpuBuffer(GPU) data. +// An instance of Image acts as an opaque reference to the underlying +// data objects. Image also maintains backwards compatability with GpuBuffer. +// +// Accessing GPU storage requires a valid OpenGL context active beforehand. +// i.e.: GetGlTextureBufferSharedPtr() & ConvertToGpu() & GetGpuBuffer() +// should be called inside an active GL context. +// +// Note: 'use_gpu_' flag is used to keep track of where data is (dirty bit). +// +// TODO Refactor Image to use 'Impl' class delegation system. +// +class Image { + public: + // Default constructor creates invalid object. + Image() = default; + + // Copy and move constructors and assignment operators are supported. + Image(const Image& other) = default; + Image(Image&& other) = default; + Image& operator=(const Image& other) = default; + Image& operator=(Image&& other) = default; + + // Creates an Image representing the same image content as the ImageFrame + // the input shared pointer points to, and retaining shared ownership. + explicit Image(ImageFrameSharedPtr frame_buffer) + : image_frame_(std::move(frame_buffer)) { + use_gpu_ = false; + pixel_mutex_ = std::make_shared(); + } + + // Creates an Image representing the same image content as the input GPU + // buffer in platform-specific representations. +#if !MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + explicit Image(CFHolder pixel_buffer) + : Image(mediapipe::GpuBuffer(std::move(pixel_buffer))) {} + explicit Image(CVPixelBufferRef pixel_buffer) + : Image(mediapipe::GpuBuffer(pixel_buffer)) {} +#else + explicit Image(mediapipe::GlTextureBufferSharedPtr texture_buffer) + : Image(mediapipe::GpuBuffer(std::move(texture_buffer))) {} +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + explicit Image(mediapipe::GpuBuffer gpu_buffer) { + use_gpu_ = true; + gpu_buffer_ = gpu_buffer; + pixel_mutex_ = std::make_shared(); + } +#endif // !MEDIAPIPE_DISABLE_GPU + + const ImageFrameSharedPtr& GetImageFrameSharedPtr() const { + if (use_gpu_ == true) ConvertToCpu(); + return image_frame_; + } +#if !MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + CVPixelBufferRef GetCVPixelBufferRef() const { + if (use_gpu_ == false) ConvertToGpu(); + return gpu_buffer_.GetCVPixelBufferRef(); + } +#else + const mediapipe::GlTextureBufferSharedPtr& GetGlTextureBufferSharedPtr() + const { + if (use_gpu_ == false) ConvertToGpu(); + return gpu_buffer_.GetGlTextureBufferSharedPtr(); + } +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + // Get a GPU view. Automatically uploads from CPU if needed. + const mediapipe::GpuBuffer GetGpuBuffer() const { + if (use_gpu_ == false) ConvertToGpu(); + return gpu_buffer_; + } +#endif // !MEDIAPIPE_DISABLE_GPU + + // Returns image properties. + int width() const; + int height() const; + int channels() const; + int step() const; // Row size in bytes. + bool UsesGpu() const { return use_gpu_; } + ImageFormat::Format image_format() const; +#if !MEDIAPIPE_DISABLE_GPU + mediapipe::GpuBufferFormat format() const; +#endif // !MEDIAPIPE_DISABLE_GPU + + // Converts to true iff valid. + explicit operator bool() const { return operator!=(nullptr); } + + bool operator==(const Image& other) const; + bool operator!=(const Image& other) const { return !operator==(other); } + + // Allow comparison with nullptr. + bool operator==(std::nullptr_t other) const; + bool operator!=(std::nullptr_t other) const { return !operator==(other); } + + // Allow assignment from nullptr. + Image& operator=(std::nullptr_t other); + + // Lock/Unlock pixel data. + // Should be used exclusively by the PixelLock helper class. + void LockPixels() const ABSL_EXCLUSIVE_LOCK_FUNCTION(pixel_mutex_); + void UnlockPixels() const ABSL_UNLOCK_FUNCTION(pixel_mutex_); + + // Helper utility for GPU->CPU data transfer. + bool ConvertToCpu() const; + // Helper utility for CPU->GPU data transfer. + // *Requires a valid OpenGL context to be active before calling!* + bool ConvertToGpu() const; + + private: +#if !MEDIAPIPE_DISABLE_GPU + mutable mediapipe::GpuBuffer gpu_buffer_; +#endif // !MEDIAPIPE_DISABLE_GPU + mutable ImageFrameSharedPtr image_frame_; + mutable bool use_gpu_ = false; + mutable std::shared_ptr pixel_mutex_; // ImageFrame only. +}; + +inline int Image::width() const { +#if !MEDIAPIPE_DISABLE_GPU + if (use_gpu_) + return gpu_buffer_.width(); + else +#endif // !MEDIAPIPE_DISABLE_GPU + return image_frame_->Width(); +} + +inline int Image::height() const { +#if !MEDIAPIPE_DISABLE_GPU + if (use_gpu_) + return gpu_buffer_.height(); + else +#endif // !MEDIAPIPE_DISABLE_GPU + return image_frame_->Height(); +} + +inline ImageFormat::Format Image::image_format() const { +#if !MEDIAPIPE_DISABLE_GPU + if (use_gpu_) + return mediapipe::ImageFormatForGpuBufferFormat(gpu_buffer_.format()); + else +#endif // !MEDIAPIPE_DISABLE_GPU + return image_frame_->Format(); +} + +#if !MEDIAPIPE_DISABLE_GPU +inline mediapipe::GpuBufferFormat Image::format() const { + if (use_gpu_) + return gpu_buffer_.format(); + else + return mediapipe::GpuBufferFormatForImageFormat(image_frame_->Format()); +} +#endif // !MEDIAPIPE_DISABLE_GPU + +inline bool Image::operator==(std::nullptr_t other) const { +#if !MEDIAPIPE_DISABLE_GPU + if (use_gpu_) + return gpu_buffer_ == other; + else +#endif // !MEDIAPIPE_DISABLE_GPU + return image_frame_ == other; +} + +inline bool Image::operator==(const Image& other) const { +#if !MEDIAPIPE_DISABLE_GPU + if (use_gpu_) + return gpu_buffer_ == other.gpu_buffer_; + else +#endif // !MEDIAPIPE_DISABLE_GPU + return image_frame_ == other.image_frame_; +} + +inline Image& Image::operator=(std::nullptr_t other) { +#if !MEDIAPIPE_DISABLE_GPU + if (use_gpu_) + gpu_buffer_ = other; + else +#endif // !MEDIAPIPE_DISABLE_GPU + image_frame_ = other; + return *this; +} + +inline int Image::channels() const { + return ImageFrame::NumberOfChannelsForFormat(image_format()); +} + +inline int Image::step() const { + if (use_gpu_) + return width() * ImageFrame::ByteDepthForFormat(image_format()); + else + return image_frame_->WidthStep(); +} + +inline void Image::LockPixels() const { + pixel_mutex_->Lock(); + ConvertToCpu(); // Download data if necessary. +} + +inline void Image::UnlockPixels() const { pixel_mutex_->Unlock(); } + +// Helper class for getting access to Image CPU data, +// and handles automatically locking/unlocking CPU data access. +// +// Returns pointer to first pixel, or nullptr if invaild Image is provided +// +// Example use: +// Image buf = ... +// { +// PixelLock lock(&buf); +// uint8* buf_ptr = lock.Pixels(); +// ... use buf_ptr to access pixel data ... +// ... lock released automatically at end of scope ... +// } +// +// Note: should be used in separate minimal scope where possible; see example^. +// +class PixelReadLock { + public: + explicit PixelReadLock(const Image& image) { + buffer_ = ℑ + if (buffer_) buffer_->LockPixels(); + } + ~PixelReadLock() { + if (buffer_) buffer_->UnlockPixels(); + } + PixelReadLock(const PixelReadLock&) = delete; + + const uint8* Pixels() const { + if (buffer_ && !buffer_->UsesGpu()) { + ImageFrame* frame = buffer_->GetImageFrameSharedPtr().get(); + if (frame) return frame->PixelData(); + } + return nullptr; + } + + PixelReadLock& operator=(const PixelReadLock&) = delete; + + private: + const Image* buffer_ = nullptr; +}; + +class PixelWriteLock { + public: + explicit PixelWriteLock(Image* image) { + buffer_ = image; + if (buffer_) buffer_->LockPixels(); + } + ~PixelWriteLock() { + if (buffer_) buffer_->UnlockPixels(); + } + PixelWriteLock(const PixelWriteLock&) = delete; + + uint8* Pixels() { + if (buffer_ && !buffer_->UsesGpu()) { + ImageFrame* frame = buffer_->GetImageFrameSharedPtr().get(); + if (frame) return frame->MutablePixelData(); + } + return nullptr; + } + + PixelWriteLock& operator=(const PixelWriteLock&) = delete; + + private: + const Image* buffer_ = nullptr; +}; + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_H_ diff --git a/mediapipe/framework/formats/image_multi_pool.cc b/mediapipe/framework/formats/image_multi_pool.cc new file mode 100644 index 000000000..b79c81db8 --- /dev/null +++ b/mediapipe/framework/formats/image_multi_pool.cc @@ -0,0 +1,219 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/formats/image_multi_pool.h" + +#include + +#include "absl/memory/memory.h" +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/port/logging.h" + +#if !MEDIAPIPE_DISABLE_GPU +#ifdef __APPLE__ +#include "mediapipe/objc/CFHolder.h" +#endif // __APPLE__ +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +// Keep this many buffers allocated for a given frame size. +static constexpr int kKeepCount = 2; +// The maximum size of the ImageMultiPool. When the limit is reached, the +// oldest IBufferSpec will be dropped. +static constexpr int kMaxPoolCount = 20; + +#if !MEDIAPIPE_DISABLE_GPU + +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + +ImageMultiPool::SimplePoolGpu ImageMultiPool::MakeSimplePoolGpu( + IBufferSpec spec) { + OSType cv_format = mediapipe::CVPixelFormatForGpuBufferFormat( + GpuBufferFormatForImageFormat(spec.format)); + CHECK_NE(cv_format, -1) << "unsupported pixel format"; + return MakeCFHolderAdopting(mediapipe::CreateCVPixelBufferPool( + spec.width, spec.height, cv_format, kKeepCount, + 0.1 /* max age in seconds */)); +} + +Image ImageMultiPool::GetBufferFromSimplePool( + IBufferSpec spec, const ImageMultiPool::SimplePoolGpu& pool) { +#if TARGET_IPHONE_SIMULATOR + // On the simulator, syncing the texture with the pixelbuffer does not work, + // and we have to use glReadPixels. Since GL_UNPACK_ROW_LENGTH is not + // available in OpenGL ES 2, we should create the buffer so the pixels are + // contiguous. + // + // TODO: verify if we can use kIOSurfaceBytesPerRow to force the + // pool to give us contiguous data. + OSType cv_format = mediapipe::CVPixelFormatForGpuBufferFormat( + mediapipe::GpuBufferFormatForImageFormat(spec.format)); + CHECK_NE(cv_format, -1) << "unsupported pixel format"; + CVPixelBufferRef buffer; + CVReturn err = mediapipe::CreateCVPixelBufferWithoutPool( + spec.width, spec.height, cv_format, &buffer); + CHECK(!err) << "Error creating pixel buffer: " << err; + return Image(MakeCFHolderAdopting(buffer)); +#else + CVPixelBufferRef buffer; + // TODO: allow the keepCount and the allocation threshold to be set + // by the application, and to be set independently. + static CFDictionaryRef auxAttributes = + mediapipe::CreateCVPixelBufferPoolAuxiliaryAttributesForThreshold( + kKeepCount); + CVReturn err = mediapipe::CreateCVPixelBufferWithPool( + *pool, auxAttributes, + [this]() { + absl::MutexLock lock(&mutex_gpu_); + for (const auto& cache : texture_caches_) { +#if TARGET_OS_OSX + CVOpenGLTextureCacheFlush(*cache, 0); +#else + CVOpenGLESTextureCacheFlush(*cache, 0); +#endif // TARGET_OS_OSX + } + }, + &buffer); + CHECK(!err) << "Error creating pixel buffer: " << err; + return Image(MakeCFHolderAdopting(buffer)); +#endif // TARGET_IPHONE_SIMULATOR +} + +#else + +ImageMultiPool::SimplePoolGpu ImageMultiPool::MakeSimplePoolGpu( + IBufferSpec spec) { + return mediapipe::GlTextureBufferPool::Create( + spec.width, spec.height, GpuBufferFormatForImageFormat(spec.format), + kKeepCount); +} + +Image ImageMultiPool::GetBufferFromSimplePool( + IBufferSpec spec, const ImageMultiPool::SimplePoolGpu& pool) { + return Image(pool->GetBuffer()); +} + +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + +#endif // !MEDIAPIPE_DISABLE_GPU + +ImageMultiPool::SimplePoolCpu ImageMultiPool::MakeSimplePoolCpu( + IBufferSpec spec) { + return ImageFramePool::Create(spec.width, spec.height, spec.format, + kKeepCount); +} + +Image ImageMultiPool::GetBufferFromSimplePool( + IBufferSpec spec, const ImageMultiPool::SimplePoolCpu& pool) { + return Image(pool->GetBuffer()); +} + +Image ImageMultiPool::GetBuffer(int width, int height, bool use_gpu, + ImageFormat::Format format) { +#if !MEDIAPIPE_DISABLE_GPU + if (use_gpu) { + absl::MutexLock lock(&mutex_gpu_); + IBufferSpec key(width, height, format); + auto pool_it = pools_gpu_.find(key); + if (pool_it == pools_gpu_.end()) { + // Discard the least recently used pool in LRU cache. + if (pools_gpu_.size() >= kMaxPoolCount) { + auto old_spec = buffer_specs_gpu_.front(); // Front has LRU. + buffer_specs_gpu_.pop_front(); + pools_gpu_.erase(old_spec); + } + buffer_specs_gpu_.push_back(key); // Push new spec to back. + std::tie(pool_it, std::ignore) = pools_gpu_.emplace( + std::piecewise_construct, std::forward_as_tuple(key), + std::forward_as_tuple(MakeSimplePoolGpu(key))); + } else { + // Find and move current 'key' spec to back, keeping others in same order. + auto specs_it = buffer_specs_gpu_.begin(); + while (specs_it != buffer_specs_gpu_.end()) { + if (*specs_it == key) { + buffer_specs_gpu_.erase(specs_it); + break; + } + ++specs_it; + } + buffer_specs_gpu_.push_back(key); + } + return GetBufferFromSimplePool(pool_it->first, pool_it->second); + } else // NOLINT(readability/braces) +#endif // !MEDIAPIPE_DISABLE_GPU + { + absl::MutexLock lock(&mutex_cpu_); + IBufferSpec key(width, height, format); + auto pool_it = pools_cpu_.find(key); + if (pool_it == pools_cpu_.end()) { + // Discard the least recently used pool in LRU cache. + if (pools_cpu_.size() >= kMaxPoolCount) { + auto old_spec = buffer_specs_cpu_.front(); // Front has LRU. + buffer_specs_cpu_.pop_front(); + pools_cpu_.erase(old_spec); + } + buffer_specs_cpu_.push_back(key); // Push new spec to back. + std::tie(pool_it, std::ignore) = pools_cpu_.emplace( + std::piecewise_construct, std::forward_as_tuple(key), + std::forward_as_tuple(MakeSimplePoolCpu(key))); + } else { + // Find and move current 'key' spec to back, keeping others in same order. + auto specs_it = buffer_specs_cpu_.begin(); + while (specs_it != buffer_specs_cpu_.end()) { + if (*specs_it == key) { + buffer_specs_cpu_.erase(specs_it); + break; + } + ++specs_it; + } + buffer_specs_cpu_.push_back(key); + } + return GetBufferFromSimplePool(pool_it->first, pool_it->second); + } +} + +ImageMultiPool::~ImageMultiPool() { +#if !MEDIAPIPE_DISABLE_GPU +#ifdef __APPLE__ + CHECK_EQ(texture_caches_.size(), 0) + << "Failed to unregister texture caches before deleting pool"; +#endif // defined(__APPLE__) +#endif // !MEDIAPIPE_DISABLE_GPU +} + +#if !MEDIAPIPE_DISABLE_GPU +#ifdef __APPLE__ +void ImageMultiPool::RegisterTextureCache(mediapipe::CVTextureCacheType cache) { + absl::MutexLock lock(&mutex_gpu_); + + CHECK(std::find(texture_caches_.begin(), texture_caches_.end(), cache) == + texture_caches_.end()) + << "Attempting to register a texture cache twice"; + texture_caches_.emplace_back(cache); +} + +void ImageMultiPool::UnregisterTextureCache( + mediapipe::CVTextureCacheType cache) { + absl::MutexLock lock(&mutex_gpu_); + + auto it = std::find(texture_caches_.begin(), texture_caches_.end(), cache); + CHECK(it != texture_caches_.end()) + << "Attempting to unregister an unknown texture cache"; + texture_caches_.erase(it); +} +#endif // defined(__APPLE__) +#endif // !MEDIAPIPE_DISABLE_GPU + +} // namespace mediapipe diff --git a/mediapipe/framework/formats/image_multi_pool.h b/mediapipe/framework/formats/image_multi_pool.h new file mode 100644 index 000000000..39593f2af --- /dev/null +++ b/mediapipe/framework/formats/image_multi_pool.h @@ -0,0 +1,154 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This class lets calculators allocate GpuBuffers of various sizes, caching +// and reusing them as needed. It does so by automatically creating and using +// platform-specific buffer pools for the requested sizes. +// +// This class is not meant to be used directly by calculators, but is instead +// used by GlCalculatorHelper to allocate buffers. + +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_MULTI_POOL_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_MULTI_POOL_H_ + +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame_pool.h" + +#if !MEDIAPIPE_DISABLE_GPU +#include "mediapipe/gpu/gpu_buffer.h" + +#ifdef __APPLE__ +#include "mediapipe/gpu/pixel_buffer_pool_util.h" +#endif // __APPLE__ + +#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#include "mediapipe/gpu/gl_texture_buffer_pool.h" +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER +#endif // !MEDIAPIPE_DISABLE_GPU + +namespace mediapipe { + +using ImageFrameSharedPtr = std::shared_ptr; + +// TODO: Update to use new pool eviction policy. +class ImageMultiPool { + public: + ImageMultiPool() {} + explicit ImageMultiPool(void* ignored) {} + ~ImageMultiPool(); + + // Obtains a buffer. May either be reused or created anew. + Image GetBuffer(int width, int height, bool use_gpu, + ImageFormat::Format format /*= ImageFormat::SRGBA*/); + +#if !MEDIAPIPE_DISABLE_GPU +#ifdef __APPLE__ + // TODO: add tests for the texture cache registration. + + // Inform the pool of a cache that should be flushed when it is low on + // reusable buffers. + void RegisterTextureCache(mediapipe::CVTextureCacheType cache); + + // Remove a texture cache from the list of caches to be flushed. + void UnregisterTextureCache(mediapipe::CVTextureCacheType cache); + +#endif // defined(__APPLE__) +#endif // !MEDIAPIPE_DISABLE_GPU + + static std::size_t RotateLeftN(std::size_t x, int n) { + return (x << n) | (x >> (std::numeric_limits::digits - n)); + } + + struct IBufferSpec { + IBufferSpec(int w, int h, mediapipe::ImageFormat::Format f) + : width(w), height(h), format(f) {} + int width; + int height; + mediapipe::ImageFormat::Format format; + // Note: alignment should be added here if ImageFrameBufferPool is changed + // to allow for customizable alignment sizes (currently fixed at 4 for best + // compatability with OpenGL). + }; + + struct IBufferSpecHash { + std::size_t operator()(const IBufferSpec& spec) const { + // Width and height are expected to be smaller than half the width of + // size_t. We can combine them into a single integer using std::hash. + constexpr int kWidth = std::numeric_limits::digits; + return std::hash{}( + spec.width ^ RotateLeftN(spec.height, kWidth / 2) ^ + RotateLeftN(static_cast(spec.format), kWidth / 4)); + // Note: alignment should be added here if ImageFrameBufferPool is changed + // to allow for customizable alignment sizes (currently fixed at 4 for + // best compatability with OpenGL). + } + }; + + private: +#if !MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + typedef CFHolder SimplePoolGpu; +#else + typedef std::shared_ptr SimplePoolGpu; +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + SimplePoolGpu MakeSimplePoolGpu(IBufferSpec spec); + Image GetBufferFromSimplePool(IBufferSpec spec, const SimplePoolGpu& pool); + + absl::Mutex mutex_gpu_; + std::unordered_map pools_gpu_ + ABSL_GUARDED_BY(mutex_gpu_); + // A queue of IBufferSpecs to keep track of the age of each IBufferSpec added + // to the pool. + std::deque buffer_specs_gpu_; +#endif // !MEDIAPIPE_DISABLE_GPU + + typedef std::shared_ptr SimplePoolCpu; + SimplePoolCpu MakeSimplePoolCpu(IBufferSpec spec); + Image GetBufferFromSimplePool(IBufferSpec spec, const SimplePoolCpu& pool); + + absl::Mutex mutex_cpu_; + std::unordered_map pools_cpu_ + ABSL_GUARDED_BY(mutex_cpu_); + // A queue of IBufferSpecs to keep track of the age of each IBufferSpec added + // to the pool. + std::deque buffer_specs_cpu_; + +#if !MEDIAPIPE_DISABLE_GPU +#ifdef __APPLE__ + // Texture caches used with this pool. + std::vector> texture_caches_ + GUARDED_BY(mutex_gpu_); +#endif // defined(__APPLE__) +#endif // !MEDIAPIPE_DISABLE_GPU +}; + +// IBufferSpec equality operators +inline bool operator==(const ImageMultiPool::IBufferSpec& lhs, + const ImageMultiPool::IBufferSpec& rhs) { + return lhs.width == rhs.width && lhs.height == rhs.height && + lhs.format == rhs.format; +} +inline bool operator!=(const ImageMultiPool::IBufferSpec& lhs, + const ImageMultiPool::IBufferSpec& rhs) { + return !operator==(lhs, rhs); +} + +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_MULTI_POOL_H_ diff --git a/mediapipe/framework/formats/image_opencv.cc b/mediapipe/framework/formats/image_opencv.cc new file mode 100644 index 000000000..c03183d14 --- /dev/null +++ b/mediapipe/framework/formats/image_opencv.cc @@ -0,0 +1,103 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/formats/image_opencv.h" + +#include "mediapipe/framework/formats/image_format.pb.h" +#include "mediapipe/framework/port/logging.h" + +namespace { +// Maps Image format to OpenCV Mat type. +// See mediapipe...image_format.proto and cv...opencv2/core/hal/interface.h +// for more details on respective formats. +int GetMatType(const mediapipe::ImageFormat::Format format) { + int type = 0; + switch (format) { + case mediapipe::ImageFormat::UNKNOWN: + // Invalid; Default to uchar. + type = CV_8U; + break; + case mediapipe::ImageFormat::SRGB: + type = CV_8U; + break; + case mediapipe::ImageFormat::SRGBA: + type = CV_8U; + break; + case mediapipe::ImageFormat::GRAY8: + type = CV_8U; + break; + case mediapipe::ImageFormat::GRAY16: + type = CV_16U; + break; + case mediapipe::ImageFormat::YCBCR420P: + // Invalid; Default to uchar. + type = CV_8U; + break; + case mediapipe::ImageFormat::YCBCR420P10: + // Invalid; Default to uint16. + type = CV_16U; + break; + case mediapipe::ImageFormat::SRGB48: + type = CV_16U; + break; + case mediapipe::ImageFormat::SRGBA64: + type = CV_16U; + break; + case mediapipe::ImageFormat::VEC32F1: + type = CV_32F; + break; + case mediapipe::ImageFormat::VEC32F2: + type = CV_32FC2; + break; + case mediapipe::ImageFormat::LAB8: + type = CV_8U; + break; + case mediapipe::ImageFormat::SBGRA: + type = CV_8U; + break; + default: + // Invalid or unknown; Default to uchar. + type = CV_8U; + break; + } + return type; +} +} // namespace +namespace mediapipe { + +namespace formats { + +cv::Mat MatView(const mediapipe::Image* image) { + const int dims = 2; + const int sizes[] = {image->height(), image->width()}; + const int type = + CV_MAKETYPE(GetMatType(image->image_format()), image->channels()); + const size_t steps[] = {static_cast(image->step()), + static_cast(ImageFrame::ByteDepthForFormat( + image->image_format()))}; + mediapipe::PixelWriteLock dst_lock(const_cast(image)); + uint8* data_ptr = dst_lock.Pixels(); + CHECK(data_ptr != nullptr); + // Use Image to initialize in-place. Image still owns memory. + if (steps[0] == sizes[1] * image->channels() * + ImageFrame::ByteDepthForFormat(image->image_format())) { + // Contiguous memory optimization. See b/78570764 + return cv::Mat(dims, sizes, type, data_ptr); + } else { + // Custom width step. + return cv::Mat(dims, sizes, type, data_ptr, steps); + } +} +} // namespace formats +} // namespace mediapipe diff --git a/mediapipe/framework/formats/image_opencv.h b/mediapipe/framework/formats/image_opencv.h new file mode 100644 index 000000000..48824a4dd --- /dev/null +++ b/mediapipe/framework/formats/image_opencv.h @@ -0,0 +1,37 @@ +// Copyright 2019-2020 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Helper functions for working with ImageFrame and OpenCV. +#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_OPENCV_H_ +#define MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_OPENCV_H_ + +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/port/opencv_core_inc.h" + +namespace mediapipe { +namespace formats { + +// Image to OpenCV helper conversion function. +// A view into existing data is created (zero copy). +// The pixel data remains owned and maintained by mediapipe::Image. +// When converting a const Image into a cv::Mat, +// the const modifier is lost. The caller must be careful +// not to use the returned object to modify the data in a const Image, +// even though the returned data is mutable. +cv::Mat MatView(const mediapipe::Image* image); + +} // namespace formats +} // namespace mediapipe + +#endif // MEDIAPIPE_FRAMEWORK_FORMATS_IMAGE_FRAME_OPENCV_H_ diff --git a/mediapipe/framework/formats/location.cc b/mediapipe/framework/formats/location.cc index 9a9b83233..66d46bf76 100644 --- a/mediapipe/framework/formats/location.cc +++ b/mediapipe/framework/formats/location.cc @@ -72,13 +72,13 @@ std::unique_ptr MaskToMat(const LocationData::BinaryMask& mask) { } return image; } -mediapipe::StatusOr> RectangleToMat( +absl::StatusOr> RectangleToMat( int image_width, int image_height, const Rectangle_i& rect) { // These checks prevent undefined behavior caused when setting memory for // rectangles whose edges lie outside image edges. if (rect.ymin() < 0 || rect.xmin() < 0 || rect.xmax() > image_width || rect.ymax() > image_height) { - return mediapipe::InvalidArgumentError(absl::Substitute( + return absl::InvalidArgumentError(absl::Substitute( "Rectangle must be bounded by image boundaries.\nImage Width: " "$0\nImage Height: $1\nRectangle: [($2, $3), ($4, $5)]", image_width, image_height, rect.xmin(), rect.ymin(), rect.xmax(), @@ -643,7 +643,7 @@ std::unique_ptr Location::ConvertToCvMask(int image_width, LOG(ERROR) << status_or_mat.status().message(); return nullptr; } - return std::move(status_or_mat).ValueOrDie(); + return std::move(status_or_mat).value(); } case LocationData::MASK: { return MaskToMat(location_data_.mask()); diff --git a/mediapipe/framework/formats/motion/BUILD b/mediapipe/framework/formats/motion/BUILD index a0422f555..5beeb5703 100644 --- a/mediapipe/framework/formats/motion/BUILD +++ b/mediapipe/framework/formats/motion/BUILD @@ -63,9 +63,11 @@ cc_library( cc_test( name = "optical_flow_field_test", srcs = ["optical_flow_field_test.cc"], + linkstatic = 1, deps = [ ":optical_flow_field", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", diff --git a/mediapipe/framework/formats/motion/optical_flow_field_test.cc b/mediapipe/framework/formats/motion/optical_flow_field_test.cc index 50a453111..44474120f 100644 --- a/mediapipe/framework/formats/motion/optical_flow_field_test.cc +++ b/mediapipe/framework/formats/motion/optical_flow_field_test.cc @@ -19,6 +19,7 @@ #include #include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/integral_types.h" diff --git a/mediapipe/framework/formats/tensor.cc b/mediapipe/framework/formats/tensor.cc index 9c66c242a..a5964567a 100644 --- a/mediapipe/framework/formats/tensor.cc +++ b/mediapipe/framework/formats/tensor.cc @@ -140,34 +140,29 @@ Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dReadView() const { auto lock = absl::make_unique(&view_mutex_); AllocateOpenGlTexture2d(); if (!(valid_ & kValidOpenGlTexture2d)) { - uint8_t* buffer; - std::unique_ptr temp_buffer; - if (BhwcDepthFromShape(shape_) % 4 == 0) { - // No padding exists because number of channels are multiple of 4. - buffer = reinterpret_cast(cpu_buffer_); - } else { - const int padded_depth = (BhwcDepthFromShape(shape_) + 3) / 4 * 4; - const int padded_depth_size = padded_depth * element_size(); - const int padded_size = BhwcBatchFromShape(shape_) * - BhwcHeightFromShape(shape_) * - BhwcWidthFromShape(shape_) * padded_depth_size; - temp_buffer = absl::make_unique(padded_size); - buffer = temp_buffer.get(); - uint8_t* src_buffer = reinterpret_cast(cpu_buffer_); - const int actual_depth_size = BhwcDepthFromShape(shape_) * element_size(); - for (int e = 0; - e < BhwcBatchFromShape(shape_) * BhwcHeightFromShape(shape_) * - BhwcWidthFromShape(shape_); - e++) { - std::memcpy(buffer, src_buffer, actual_depth_size); - src_buffer += actual_depth_size; - buffer += padded_depth_size; - } + const int padded_size = + texture_height_ * texture_width_ * 4 * element_size(); + auto temp_buffer = absl::make_unique(padded_size); + uint8_t* dest_buffer = temp_buffer.get(); + uint8_t* src_buffer = reinterpret_cast(cpu_buffer_); + const int num_elements = BhwcWidthFromShape(shape_) * + BhwcHeightFromShape(shape_) * + BhwcBatchFromShape(shape_); + const int actual_depth_size = BhwcDepthFromShape(shape_) * element_size(); + const int padded_depth_size = + (BhwcDepthFromShape(shape_) + 3) / 4 * 4 * element_size(); + for (int e = 0; e < num_elements; e++) { + std::memcpy(dest_buffer, src_buffer, actual_depth_size); + src_buffer += actual_depth_size; + dest_buffer += padded_depth_size; } // Transfer from CPU memory into GPU memory. glBindTexture(GL_TEXTURE_2D, opengl_texture2d_); - glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, BhwcWidthFromShape(shape_), - BhwcHeightFromShape(shape_), GL_RGBA, GL_FLOAT, buffer); + // Set alignment for the proper value (default) to avoid address sanitizer + // error "out of boundary reading". + glPixelStorei(GL_UNPACK_ALIGNMENT, 4); + glTexSubImage2D(GL_TEXTURE_2D, 0, 0, 0, texture_width_, texture_height_, + GL_RGBA, GL_FLOAT, temp_buffer.get()); glBindTexture(GL_TEXTURE_2D, 0); valid_ |= kValidOpenGlTexture2d; } @@ -181,6 +176,48 @@ Tensor::OpenGlTexture2dView Tensor::GetOpenGlTexture2dWriteView() const { return {opengl_texture2d_, std::move(lock)}; } +Tensor::OpenGlTexture2dView::Layout +Tensor::OpenGlTexture2dView::GetLayoutDimensions(const Tensor::Shape& shape, + int* width, int* height) { + static int max_size = 0; + if (max_size == 0) { + int max_texture_size; + glGetIntegerv(GL_MAX_TEXTURE_SIZE, &max_texture_size); + int max_renderbuffer_size; + glGetIntegerv(GL_MAX_RENDERBUFFER_SIZE, &max_renderbuffer_size); + int max_viewport_dims[2]; + glGetIntegerv(GL_MAX_VIEWPORT_DIMS, max_viewport_dims); + max_size = std::min(std::min(max_texture_size, max_renderbuffer_size), + std::min(max_viewport_dims[0], max_viewport_dims[1])); + } + const int num_slices = (BhwcDepthFromShape(shape) + 3) / 4; + const int num_elements = BhwcBatchFromShape(shape) * + BhwcHeightFromShape(shape) * + BhwcWidthFromShape(shape); + const int num_pixels = num_slices * num_elements; + int w = BhwcWidthFromShape(shape) * num_slices; + if (w <= max_size) { + int h = (num_pixels + w - 1) / w; + if (h <= max_size) { + *width = w; + *height = h; + return Tensor::OpenGlTexture2dView::Layout::kAligned; + } + } + // The best performance of a compute shader can be achived with textures' + // width multiple of 256. Making minimum fixed width of 256 waste memory for + // small tensors. The optimal balance memory-vs-performance is power of 2. + // The texture width and height are choosen to be closer to square. + float power = std::log2(std::sqrt(static_cast(num_pixels))); + w = 1 << static_cast(power); + int h = (num_pixels + w - 1) / w; + LOG_IF(FATAL, w > max_size || h > max_size) + << "The tensor can't fit into OpenGL Texture2D View."; + *width = w; + *height = h; + return Tensor::OpenGlTexture2dView::Layout::kLinearized; +} + void Tensor::AllocateOpenGlTexture2d() const { if (opengl_texture2d_ == GL_INVALID_INDEX) { gl_context_ = mediapipe::GlContext::GetCurrent(); @@ -192,12 +229,26 @@ void Tensor::AllocateOpenGlTexture2d() const { // supported from floating point textures. glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST); glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST); - glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_BASE_LEVEL, 0); - glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAX_LEVEL, 0); - const int pixels_per_depth = (BhwcDepthFromShape(shape_) + 3) / 4; - const int width = BhwcWidthFromShape(shape_) * pixels_per_depth; - glTexStorage2D(GL_TEXTURE_2D, 1, GL_RGBA32F, width, - BhwcHeightFromShape(shape_)); + OpenGlTexture2dView::GetLayoutDimensions(shape_, &texture_width_, + &texture_height_); + if (gl_context_->GetGlVersion() != mediapipe::GlVersion::kGLES2) { + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_BASE_LEVEL, 0); + glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAX_LEVEL, 0); + glTexStorage2D(GL_TEXTURE_2D, 1, GL_RGBA32F, texture_width_, + texture_height_); + } else { + // We assume all contexts will have the same extensions, so we only check + // once for OES_texture_float extension, to save time. + static bool has_oes_extension = + gl_context_->HasGlExtension("OES_texture_float"); + LOG_IF(FATAL, !has_oes_extension) + << "OES_texture_float extension required in order to use MP tensor " + << "with GLES 2.0"; + // Allocate the image data; note that it's no longer RGBA32F, so will be + // lower precision. + glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, texture_width_, texture_height_, + 0, GL_RGBA, GL_FLOAT, 0 /* data */); + } glBindTexture(GL_TEXTURE_2D, 0); glGenFramebuffers(1, &frame_buffer_); } @@ -272,6 +323,8 @@ void Tensor::Move(Tensor* src) { src->frame_buffer_ = GL_INVALID_INDEX; opengl_texture2d_ = src->opengl_texture2d_; src->opengl_texture2d_ = GL_INVALID_INDEX; + texture_width_ = src->texture_width_; + texture_height_ = src->texture_height_; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 opengl_buffer_ = src->opengl_buffer_; src->opengl_buffer_ = GL_INVALID_INDEX; @@ -283,42 +336,54 @@ Tensor::Tensor(ElementType element_type, const Shape& shape) : element_type_(element_type), shape_(shape) {} void Tensor::Invalidate() { - absl::MutexLock lock(&view_mutex_); -#if MEDIAPIPE_METAL_ENABLED - // If memory is allocated and not owned by the metal buffer. - // TODO: Re-design cpu buffer memory management. - if (cpu_buffer_ && !metal_buffer_) { - DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes())); - } - metal_buffer_ = nil; -#else - if (cpu_buffer_) { - free(cpu_buffer_); - } -#endif // MEDIAPIPE_METAL_ENABLED - cpu_buffer_ = nullptr; - - // Don't need to wait for the resource to be deleted bacause if will be - // released on last reference deletion inside the OpenGL driver. #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 - if (opengl_texture2d_ != GL_INVALID_INDEX) { - GLuint opengl_texture2d = opengl_texture2d_; - GLuint frame_buffer = frame_buffer_; - gl_context_->RunWithoutWaiting([opengl_texture2d, frame_buffer]() { - glDeleteTextures(1, &opengl_texture2d); - glDeleteFramebuffers(1, &frame_buffer); - }); - opengl_texture2d_ = GL_INVALID_INDEX; - frame_buffer_ = GL_INVALID_INDEX; - } + GLuint cleanup_gl_tex = GL_INVALID_INDEX; + GLuint cleanup_gl_fb = GL_INVALID_INDEX; + GLuint cleanup_gl_buf = GL_INVALID_INDEX; +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + { + absl::MutexLock lock(&view_mutex_); +#if MEDIAPIPE_METAL_ENABLED + // If memory is allocated and not owned by the metal buffer. + // TODO: Re-design cpu buffer memory management. + if (cpu_buffer_ && !metal_buffer_) { + DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes())); + } + metal_buffer_ = nil; +#else + if (cpu_buffer_) { + free(cpu_buffer_); + } +#endif // MEDIAPIPE_METAL_ENABLED + cpu_buffer_ = nullptr; + + // Don't need to wait for the resource to be deleted bacause if will be + // released on last reference deletion inside the OpenGL driver. +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + std::swap(cleanup_gl_tex, opengl_texture2d_); + std::swap(cleanup_gl_fb, frame_buffer_); #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 - if (opengl_buffer_ != GL_INVALID_INDEX) { - GLuint opengl_buffer = opengl_buffer_; - gl_context_->RunWithoutWaiting( - [opengl_buffer]() { glDeleteBuffers(1, &opengl_buffer); }); - opengl_buffer_ = GL_INVALID_INDEX; - } + std::swap(cleanup_gl_buf, opengl_buffer_); #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + } + // Do not hold the view mutex while invoking GlContext::RunWithoutWaiting, + // since that method may acquire the context's own lock. +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 + if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX || + cleanup_gl_buf != GL_INVALID_INDEX) + gl_context_->RunWithoutWaiting([cleanup_gl_tex, cleanup_gl_fb +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + , + cleanup_gl_buf +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + ]() { + glDeleteTextures(1, &cleanup_gl_tex); + glDeleteFramebuffers(1, &cleanup_gl_fb); +#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + glDeleteBuffers(1, &cleanup_gl_buf); +#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + }); #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 } @@ -341,6 +406,8 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 + // TODO: we cannot just grab the GL context's lock while holding + // the view mutex here. if (valid_ & kValidOpenGlBuffer) { gl_context_->Run([this]() { glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); @@ -356,42 +423,30 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const { // yet. if (valid_ & kValidOpenGlTexture2d) { gl_context_->Run([this]() { - const int pixels_per_depth = (BhwcDepthFromShape(shape_) + 3) / 4; - const int width = BhwcWidthFromShape(shape_) * pixels_per_depth; - - uint8_t* buffer; - std::unique_ptr temp_buffer; - if (BhwcDepthFromShape(shape_) % 4 == 0) { - buffer = reinterpret_cast(cpu_buffer_); - } else { - const int padded_size = BhwcBatchFromShape(shape_) * - BhwcHeightFromShape(shape_) * width * - pixels_per_depth * 4 * element_size(); - temp_buffer = absl::make_unique(padded_size); - buffer = temp_buffer.get(); - } + const int padded_size = + texture_height_ * texture_width_ * 4 * element_size(); + auto temp_buffer = absl::make_unique(padded_size); + uint8_t* buffer = temp_buffer.get(); glBindFramebuffer(GL_FRAMEBUFFER, frame_buffer_); glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, opengl_texture2d_, 0); - glPixelStorei(GL_PACK_ROW_LENGTH, width); - glPixelStorei(GL_PACK_ALIGNMENT, 1); - glReadPixels(0, 0, width, BhwcHeightFromShape(shape_), GL_RGBA, - GL_FLOAT, buffer); + glPixelStorei(GL_PACK_ALIGNMENT, 4); + glReadPixels(0, 0, texture_width_, texture_height_, GL_RGBA, GL_FLOAT, + buffer); - if (BhwcDepthFromShape(shape_) % 4) { - uint8_t* dest_buffer = reinterpret_cast(cpu_buffer_); - const int actual_depth_size = - BhwcDepthFromShape(shape_) * element_size(); - const int padded_depth_size = pixels_per_depth * 4 * element_size(); - for (int e = 0; - e < BhwcBatchFromShape(shape_) * BhwcHeightFromShape(shape_) * - BhwcWidthFromShape(shape_); - e++) { - std::memcpy(dest_buffer, buffer, actual_depth_size); - dest_buffer += actual_depth_size; - buffer += padded_depth_size; - } + uint8_t* dest_buffer = reinterpret_cast(cpu_buffer_); + const int actual_depth_size = + BhwcDepthFromShape(shape_) * element_size(); + const int num_slices = (BhwcDepthFromShape(shape_) + 3) / 4; + const int padded_depth_size = num_slices * 4 * element_size(); + const int num_elements = BhwcWidthFromShape(shape_) * + BhwcHeightFromShape(shape_) * + BhwcBatchFromShape(shape_); + for (int e = 0; e < num_elements; e++) { + std::memcpy(dest_buffer, buffer, actual_depth_size); + dest_buffer += actual_depth_size; + buffer += padded_depth_size; } }); } diff --git a/mediapipe/framework/formats/tensor.h b/mediapipe/framework/formats/tensor.h index b3cfa5de3..66eb7e604 100644 --- a/mediapipe/framework/formats/tensor.h +++ b/mediapipe/framework/formats/tensor.h @@ -161,6 +161,16 @@ class Tensor { : View(std::move(src)), name_(src.name_) { src.name_ = GL_INVALID_INDEX; } + // To fit a tensor into a texture two layouts are used: + // 1. Aligned. Width of the texture = tensor_width * num_slices, where slice + // is a group of 4 depth values. Tensor depth is padded to 4. + // 2. Linearized. If texture width or height with the layout 1. is greater + // than the GPU supports then all tensor values are packed into a texture + // with fixed width calculated by this method. + // Must be called with the valid GL context bound to the current thread. + enum class Layout { kAligned, kLinearized }; + static Layout GetLayoutDimensions(const Tensor::Shape& shape, int* width, + int* height); protected: friend class Tensor; @@ -254,6 +264,8 @@ class Tensor { mutable std::shared_ptr gl_context_; mutable GLuint opengl_texture2d_ = GL_INVALID_INDEX; mutable GLuint frame_buffer_ = GL_INVALID_INDEX; + mutable int texture_width_; + mutable int texture_height_; void AllocateOpenGlTexture2d() const; #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 mutable GLuint opengl_buffer_ = GL_INVALID_INDEX; diff --git a/mediapipe/framework/formats/tensor_test.cc b/mediapipe/framework/formats/tensor_test.cc index 42c86fd4c..c8c990a2b 100644 --- a/mediapipe/framework/formats/tensor_test.cc +++ b/mediapipe/framework/formats/tensor_test.cc @@ -2,7 +2,7 @@ #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" -#if !defined(MEDIAPIPE_DISABLE_GPU) +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gpu_buffer_format.h" #endif diff --git a/mediapipe/framework/graph_output_stream.cc b/mediapipe/framework/graph_output_stream.cc index 8e99a050f..0a2bc4c18 100644 --- a/mediapipe/framework/graph_output_stream.cc +++ b/mediapipe/framework/graph_output_stream.cc @@ -18,7 +18,7 @@ namespace mediapipe { namespace internal { -mediapipe::Status GraphOutputStream::Initialize( +absl::Status GraphOutputStream::Initialize( const std::string& stream_name, const PacketType* packet_type, OutputStreamManager* output_stream_manager) { RET_CHECK(output_stream_manager); @@ -27,7 +27,7 @@ mediapipe::Status GraphOutputStream::Initialize( proto_ns::RepeatedPtrField input_stream_field; input_stream_field.Add()->assign(stream_name); std::shared_ptr tag_map = - tool::TagMap::Create(input_stream_field).ValueOrDie(); + tool::TagMap::Create(input_stream_field).value(); input_stream_handler_ = absl::make_unique( tag_map, /*cc_manager=*/nullptr, MediaPipeOptions(), /*calculator_run_in_parallel=*/false); @@ -38,20 +38,20 @@ mediapipe::Status GraphOutputStream::Initialize( MP_RETURN_IF_ERROR(input_stream_handler_->InitializeInputStreamManagers( input_stream_.get())); output_stream_manager->AddMirror(input_stream_handler_.get(), id); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void GraphOutputStream::PrepareForRun( std::function notification_callback, - std::function error_callback) { + std::function error_callback) { input_stream_handler_->PrepareForRun( /*headers_ready_callback=*/[] {}, std::move(notification_callback), /*schedule_callback=*/nullptr, std::move(error_callback)); } -mediapipe::Status OutputStreamObserver::Initialize( +absl::Status OutputStreamObserver::Initialize( const std::string& stream_name, const PacketType* packet_type, - std::function packet_callback, + std::function packet_callback, OutputStreamManager* output_stream_manager) { RET_CHECK(output_stream_manager); @@ -60,7 +60,7 @@ mediapipe::Status OutputStreamObserver::Initialize( output_stream_manager); } -mediapipe::Status OutputStreamObserver::Notify() { +absl::Status OutputStreamObserver::Notify() { while (true) { bool empty; Timestamp min_timestamp = input_stream_->MinTimestampOrBound(&empty); @@ -76,10 +76,10 @@ mediapipe::Status OutputStreamObserver::Notify() { num_packets_dropped, input_stream_->Name()); MP_RETURN_IF_ERROR(packet_callback_(packet)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OutputStreamPollerImpl::Initialize( +absl::Status OutputStreamPollerImpl::Initialize( const std::string& stream_name, const PacketType* packet_type, std::function queue_size_callback, OutputStreamManager* output_stream_manager) { @@ -87,12 +87,12 @@ mediapipe::Status OutputStreamPollerImpl::Initialize( output_stream_manager)); input_stream_handler_->SetQueueSizeCallbacks(queue_size_callback, queue_size_callback); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void OutputStreamPollerImpl::PrepareForRun( std::function notification_callback, - std::function error_callback) { + std::function error_callback) { input_stream_handler_->PrepareForRun( /*headers_ready_callback=*/[] {}, std::move(notification_callback), /*schedule_callback=*/nullptr, std::move(error_callback)); @@ -116,11 +116,11 @@ void OutputStreamPollerImpl::SetMaxQueueSize(int queue_size) { int OutputStreamPollerImpl::QueueSize() { return input_stream_->QueueSize(); } -mediapipe::Status OutputStreamPollerImpl::Notify() { +absl::Status OutputStreamPollerImpl::Notify() { mutex_.Lock(); handler_condvar_.Signal(); mutex_.Unlock(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void OutputStreamPollerImpl::NotifyError() { diff --git a/mediapipe/framework/graph_output_stream.h b/mediapipe/framework/graph_output_stream.h index 06f26fe7b..9a60fa4bd 100644 --- a/mediapipe/framework/graph_output_stream.h +++ b/mediapipe/framework/graph_output_stream.h @@ -50,18 +50,17 @@ class GraphOutputStream { // input stream and attaches the input stream to an output stream as // the mirror for observation/polling. Ownership of output_stream_manager // is not transferred to the graph output stream object. - mediapipe::Status Initialize(const std::string& stream_name, - const PacketType* packet_type, - OutputStreamManager* output_stream_manager); + absl::Status Initialize(const std::string& stream_name, + const PacketType* packet_type, + OutputStreamManager* output_stream_manager); // Installs callbacks into its GraphOutputStreamHandler. - virtual void PrepareForRun( - std::function notification_callback, - std::function error_callback); + virtual void PrepareForRun(std::function notification_callback, + std::function error_callback); // Notifies the graph output stream of new packets emitted by the output // stream. - virtual mediapipe::Status Notify() = 0; + virtual absl::Status Notify() = 0; // Notifies the graph output stream of the errors in the calculator graph. virtual void NotifyError() = 0; @@ -110,21 +109,21 @@ class OutputStreamObserver : public GraphOutputStream { public: virtual ~OutputStreamObserver() {} - mediapipe::Status Initialize( + absl::Status Initialize( const std::string& stream_name, const PacketType* packet_type, - std::function packet_callback, + std::function packet_callback, OutputStreamManager* output_stream_manager); // Notifies the observer of new packets emitted by the observed // output stream. - mediapipe::Status Notify() override; + absl::Status Notify() override; // Notifies the observer of the errors in the calculator graph. void NotifyError() override {} private: // Invoked on every packet emitted by the observed output stream. - std::function packet_callback_; + std::function packet_callback_; }; // OutputStreamPollerImpl that returns packets to the caller via @@ -134,14 +133,13 @@ class OutputStreamPollerImpl : public GraphOutputStream { virtual ~OutputStreamPollerImpl() {} // Initializes an OutputStreamPollerImpl. - mediapipe::Status Initialize( + absl::Status Initialize( const std::string& stream_name, const PacketType* packet_type, std::function queue_size_callback, OutputStreamManager* output_stream_manager); - void PrepareForRun( - std::function notification_callback, - std::function error_callback) override; + void PrepareForRun(std::function notification_callback, + std::function error_callback) override; // Resets graph_has_error_ and cleans the internal packet queue. void Reset(); @@ -152,7 +150,7 @@ class OutputStreamPollerImpl : public GraphOutputStream { int QueueSize(); // Notifies the poller of new packets emitted by the output stream. - mediapipe::Status Notify() override; + absl::Status Notify() override; // Notifies the poller of the errors in the calculator graph. void NotifyError() override; diff --git a/mediapipe/framework/graph_service_test.cc b/mediapipe/framework/graph_service_test.cc index 31cd2aa77..6ccf59c31 100644 --- a/mediapipe/framework/graph_service_test.cc +++ b/mediapipe/framework/graph_service_test.cc @@ -60,7 +60,7 @@ class GraphServiceTest : public ::testing::Test { MP_ASSERT_OK( graph_.ObserveOutputStream("out", [this](const Packet& packet) { output_packets_.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); } diff --git a/mediapipe/framework/graph_validation.h b/mediapipe/framework/graph_validation.h index 63ab02bd0..10869eda6 100644 --- a/mediapipe/framework/graph_validation.h +++ b/mediapipe/framework/graph_validation.h @@ -28,7 +28,7 @@ namespace mediapipe { class GraphValidation { public: // Validates the specified CalculatorGraphConfig. - mediapipe::Status Validate( + absl::Status Validate( const CalculatorGraphConfig& config, const std::map& side_packets = {}) { return graph_.Initialize(config, side_packets); @@ -40,12 +40,11 @@ class GraphValidation { // CalclatorGraphConfig.type. A subgraph can be validated directly by // specifying its type in |graph_type|. A template graph can be validated // directly by specifying its template arguments in |arguments|. - mediapipe::Status Validate( - const std::vector& configs, - const std::vector& templates, - const std::map& side_packets = {}, - const std::string& graph_type = "", - const Subgraph::SubgraphOptions* options = nullptr) { + absl::Status Validate(const std::vector& configs, + const std::vector& templates, + const std::map& side_packets = {}, + const std::string& graph_type = "", + const Subgraph::SubgraphOptions* options = nullptr) { return graph_.Initialize(configs, templates, side_packets, graph_type, options); } diff --git a/mediapipe/framework/graph_validation_test.cc b/mediapipe/framework/graph_validation_test.cc index 57af7cdf3..aa75016d2 100644 --- a/mediapipe/framework/graph_validation_test.cc +++ b/mediapipe/framework/graph_validation_test.cc @@ -106,9 +106,8 @@ TEST(GraphValidationTest, InitializeGraphFromProtos) { TEST(GraphValidationTest, InitializeGraphFromLinker) { EXPECT_FALSE(SubgraphRegistry::IsRegistered("DubQuadTestSubgraph")); ValidatedGraphConfig builder_1; - mediapipe::Status status_1 = - builder_1.Initialize({}, {}, "DubQuadTestSubgraph"); - EXPECT_EQ(status_1.code(), mediapipe::StatusCode::kNotFound); + absl::Status status_1 = builder_1.Initialize({}, {}, "DubQuadTestSubgraph"); + EXPECT_EQ(status_1.code(), absl::StatusCode::kNotFound); EXPECT_THAT(status_1.message(), testing::HasSubstr( R"(No registered object with name: DubQuadTestSubgraph)")); @@ -313,8 +312,8 @@ TEST(GraphValidationTest, OptionalSubgraphStreamsMismatched) { )"); GraphValidation validation_1; - mediapipe::Status status = validation_1.Validate({config_1, config_2}, {}); - ASSERT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status status = validation_1.Validate({config_1, config_2}, {}); + ASSERT_EQ(status.code(), absl::StatusCode::kInvalidArgument); ASSERT_THAT(status.ToString(), testing::HasSubstr( "PassThroughCalculator must use matching tags and indexes")); @@ -323,22 +322,22 @@ TEST(GraphValidationTest, OptionalSubgraphStreamsMismatched) { // A calculator that optionally accepts an input-side-packet. class OptionalSideInputTestCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag("SIDEINPUT").Set().Optional(); cc->Inputs().Tag("SELECT").Set().Optional(); cc->Inputs().Tag("ENABLE").Set().Optional(); cc->Outputs().Tag("OUTPUT").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { std::string value("default"); if (cc->InputSidePackets().HasTag("SIDEINPUT")) { value = cc->InputSidePackets().Tag("SIDEINPUT").Get(); } cc->Outputs().Tag("OUTPUT").Add(new std::string(value), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(OptionalSideInputTestCalculator); @@ -451,5 +450,54 @@ TEST(GraphValidationTest, MultipleOptionalInputsForSubgraph) { MP_EXPECT_OK(graph_1.WaitUntilDone()); } +// Shows a calculator graph running with and without one optional side packet. +TEST(GraphValidationTest, OptionalInputsForGraph) { + // A subgraph defining one optional input-side-packet. + auto config_1 = ParseTextProtoOrDie(R"( + type: "PassThroughGraph" + input_side_packet: "side_input_0" + input_stream: "stream_input_0" + input_stream: "stream_input_1" + output_stream: "OUTPUT:output_0" + node { + calculator: "OptionalSideInputTestCalculator" + input_side_packet: "SIDEINPUT:side_input_0" + input_stream: "SELECT:stream_input_0" + input_stream: "ENABLE:stream_input_1" + output_stream: "OUTPUT:output_0" + } + )"); + GraphValidation validation_1; + MP_EXPECT_OK(validation_1.Validate({config_1}, {})); + CalculatorGraph graph_1; + MP_EXPECT_OK(graph_1.Initialize({config_1}, {})); + auto out_poller = graph_1.AddOutputStreamPoller("output_0"); + MP_ASSERT_OK(out_poller); + + // Run the graph specifying the optional side packet. + std::map side_packets; + side_packets.insert({"side_input_0", MakePacket("side_in")}); + MP_EXPECT_OK(graph_1.StartRun(side_packets)); + MP_EXPECT_OK(graph_1.AddPacketToInputStream( + "stream_input_0", MakePacket(22).At(Timestamp(3000)))); + MP_EXPECT_OK(graph_1.AddPacketToInputStream( + "stream_input_1", MakePacket(true).At(Timestamp(3000)))); + Packet out_packet, options_packet; + EXPECT_TRUE(out_poller->Next(&out_packet)); + EXPECT_EQ(out_packet.Get(), "side_in"); + MP_EXPECT_OK(graph_1.CloseAllPacketSources()); + MP_EXPECT_OK(graph_1.WaitUntilDone()); + + // Run the graph omitting the optional inputs. + MP_EXPECT_OK(graph_1.StartRun({})); + MP_EXPECT_OK(graph_1.CloseInputStream("stream_input_1")); + MP_EXPECT_OK(graph_1.AddPacketToInputStream( + "stream_input_0", MakePacket(22).At(Timestamp(3000)))); + EXPECT_TRUE(out_poller->Next(&out_packet)); + EXPECT_EQ(out_packet.Get(), "default"); + MP_EXPECT_OK(graph_1.CloseAllPacketSources()); + MP_EXPECT_OK(graph_1.WaitUntilDone()); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/input_side_packet_handler.cc b/mediapipe/framework/input_side_packet_handler.cc index 87517ee5d..9b01cc31a 100644 --- a/mediapipe/framework/input_side_packet_handler.cc +++ b/mediapipe/framework/input_side_packet_handler.cc @@ -21,11 +21,11 @@ namespace mediapipe { -mediapipe::Status InputSidePacketHandler::PrepareForRun( +absl::Status InputSidePacketHandler::PrepareForRun( const PacketTypeSet* input_side_packet_types, const std::map& all_side_packets, std::function input_side_packets_ready_callback, - std::function error_callback) { + std::function error_callback) { int missing_input_side_packet_count; prev_input_side_packets_ = std::move(input_side_packets_); ASSIGN_OR_RETURN( @@ -39,7 +39,7 @@ mediapipe::Status InputSidePacketHandler::PrepareForRun( input_side_packets_ready_callback_ = std::move(input_side_packets_ready_callback); error_callback_ = std::move(error_callback); - return mediapipe::OkStatus(); + return absl::OkStatus(); } bool InputSidePacketHandler::InputSidePacketsChanged() { @@ -49,14 +49,14 @@ bool InputSidePacketHandler::InputSidePacketsChanged() { } void InputSidePacketHandler::Set(CollectionItemId id, const Packet& packet) { - mediapipe::Status status = SetInternal(id, packet); + absl::Status status = SetInternal(id, packet); if (!status.ok()) { TriggerErrorCallback(status); } } -mediapipe::Status InputSidePacketHandler::SetInternal(CollectionItemId id, - const Packet& packet) { +absl::Status InputSidePacketHandler::SetInternal(CollectionItemId id, + const Packet& packet) { RET_CHECK_GT(missing_input_side_packet_count_, 0); Packet& side_packet = input_side_packets_->Get(id); @@ -64,7 +64,7 @@ mediapipe::Status InputSidePacketHandler::SetInternal(CollectionItemId id, return mediapipe::AlreadyExistsErrorBuilder(MEDIAPIPE_LOC) << "Input side packet with id " << id << " was already set."; } - mediapipe::Status result = input_side_packet_types_->Get(id).Validate(packet); + absl::Status result = input_side_packet_types_->Get(id).Validate(packet); if (!result.ok()) { return mediapipe::StatusBuilder(result, MEDIAPIPE_LOC).SetPrepend() << absl::StrCat( @@ -77,11 +77,11 @@ mediapipe::Status InputSidePacketHandler::SetInternal(CollectionItemId id, 1, std::memory_order_acq_rel) == 1) { input_side_packets_ready_callback_(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void InputSidePacketHandler::TriggerErrorCallback( - const mediapipe::Status& status) const { + const absl::Status& status) const { CHECK(error_callback_); error_callback_(status); } diff --git a/mediapipe/framework/input_side_packet_handler.h b/mediapipe/framework/input_side_packet_handler.h index 022432414..c0a3249a6 100644 --- a/mediapipe/framework/input_side_packet_handler.h +++ b/mediapipe/framework/input_side_packet_handler.h @@ -41,11 +41,11 @@ class InputSidePacketHandler { // Resets the input side packet handler and its underlying input side packets // for another run of the graph. - mediapipe::Status PrepareForRun( + absl::Status PrepareForRun( const PacketTypeSet* input_side_packet_types, const std::map& all_side_packets, std::function input_side_packets_ready_callback, - std::function error_callback); + std::function error_callback); // Sets a particular input side packet. void Set(CollectionItemId id, const Packet& packet); @@ -63,11 +63,11 @@ class InputSidePacketHandler { private: // Called by Set(). - mediapipe::Status SetInternal(CollectionItemId id, const Packet& packet); + absl::Status SetInternal(CollectionItemId id, const Packet& packet); - // Triggers the error callback with mediapipe::Status info when an error + // Triggers the error callback with absl::Status info when an error // occurs. - void TriggerErrorCallback(const mediapipe::Status& status) const; + void TriggerErrorCallback(const absl::Status& status) const; const PacketTypeSet* input_side_packet_types_; @@ -77,7 +77,7 @@ class InputSidePacketHandler { std::atomic missing_input_side_packet_count_{0}; std::function input_side_packets_ready_callback_; - std::function error_callback_; + std::function error_callback_; }; } // namespace mediapipe diff --git a/mediapipe/framework/input_stream_handler.cc b/mediapipe/framework/input_stream_handler.cc index 541001d52..66d9bfede 100644 --- a/mediapipe/framework/input_stream_handler.cc +++ b/mediapipe/framework/input_stream_handler.cc @@ -24,13 +24,13 @@ namespace mediapipe { using SyncSet = InputStreamHandler::SyncSet; -mediapipe::Status InputStreamHandler::InitializeInputStreamManagers( +absl::Status InputStreamHandler::InitializeInputStreamManagers( InputStreamManager* flat_input_stream_managers) { for (CollectionItemId id = input_stream_managers_.BeginId(); id < input_stream_managers_.EndId(); ++id) { input_stream_managers_.Get(id) = &flat_input_stream_managers[id.value()]; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } InputStreamManager* InputStreamHandler::GetInputStreamManager( @@ -38,7 +38,7 @@ InputStreamManager* InputStreamHandler::GetInputStreamManager( return input_stream_managers_.Get(id); } -mediapipe::Status InputStreamHandler::SetupInputShards( +absl::Status InputStreamHandler::SetupInputShards( InputStreamShardSet* input_shards) { RET_CHECK(input_shards); for (CollectionItemId id = input_stream_managers_.BeginId(); @@ -48,7 +48,7 @@ mediapipe::Status InputStreamHandler::SetupInputShards( input_shards->Get(id).SetName(&manager->Name()); input_shards->Get(id).SetHeader(manager->Header()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } std::vector> @@ -68,7 +68,7 @@ void InputStreamHandler::PrepareForRun( std::function headers_ready_callback, std::function notification_callback, std::function schedule_callback, - std::function error_callback) { + std::function error_callback) { headers_ready_callback_ = std::move(headers_ready_callback); notification_ = std::move(notification_callback); schedule_callback_ = std::move(schedule_callback); @@ -94,7 +94,7 @@ void InputStreamHandler::SetQueueSizeCallbacks( } void InputStreamHandler::SetHeader(CollectionItemId id, const Packet& header) { - mediapipe::Status result = input_stream_managers_.Get(id)->SetHeader(header); + absl::Status result = input_stream_managers_.Get(id)->SetHeader(header); if (!result.ok()) { error_callback_(result); return; @@ -260,7 +260,7 @@ void InputStreamHandler::AddPackets(CollectionItemId id, LogQueuedPackets(GetCalculatorContext(calculator_context_manager_), input_stream_managers_.Get(id), packets.back()); bool notify = false; - mediapipe::Status result = + absl::Status result = input_stream_managers_.Get(id)->AddPackets(packets, ¬ify); if (!result.ok()) { error_callback_(result); @@ -275,7 +275,7 @@ void InputStreamHandler::MovePackets(CollectionItemId id, LogQueuedPackets(GetCalculatorContext(calculator_context_manager_), input_stream_managers_.Get(id), packets->back()); bool notify = false; - mediapipe::Status result = + absl::Status result = input_stream_managers_.Get(id)->MovePackets(packets, ¬ify); if (!result.ok()) { error_callback_(result); @@ -288,7 +288,7 @@ void InputStreamHandler::MovePackets(CollectionItemId id, void InputStreamHandler::SetNextTimestampBound(CollectionItemId id, Timestamp bound) { bool notify = false; - mediapipe::Status result = + absl::Status result = input_stream_managers_.Get(id)->SetNextTimestampBound(bound, ¬ify); if (!result.ok()) { error_callback_(result); diff --git a/mediapipe/framework/input_stream_handler.h b/mediapipe/framework/input_stream_handler.h index 8e7f44b5f..940826932 100644 --- a/mediapipe/framework/input_stream_handler.h +++ b/mediapipe/framework/input_stream_handler.h @@ -84,13 +84,13 @@ class InputStreamHandler { // InputStreamHandler::input_stream_managers_ (meaning it should point // to somewhere in the middle of the master flat array of all input // stream managers). - mediapipe::Status InitializeInputStreamManagers( + absl::Status InitializeInputStreamManagers( InputStreamManager* flat_input_stream_managers); InputStreamManager* GetInputStreamManager(CollectionItemId id); // Sets up the InputStreamShardSet by propagating data from the managers. - mediapipe::Status SetupInputShards(InputStreamShardSet* input_shards); + absl::Status SetupInputShards(InputStreamShardSet* input_shards); // Returns a vector of pairs of stream name and queue size for monitoring // purpose. @@ -106,7 +106,7 @@ class InputStreamHandler { std::function headers_ready_callback, std::function notification_callback, std::function schedule_callback, - std::function error_callback); + std::function error_callback); int NumInputStreams() const { return input_stream_managers_.NumEntries(); } @@ -286,7 +286,7 @@ class InputStreamHandler { std::function notification_; // A callback to schedule the node with the prepared calculator context. std::function schedule_callback_; - std::function error_callback_; + std::function error_callback_; private: // Indicates when to fill the input set. If true, every input set will be diff --git a/mediapipe/framework/input_stream_manager.cc b/mediapipe/framework/input_stream_manager.cc index f5ce10d70..5b1917138 100644 --- a/mediapipe/framework/input_stream_manager.cc +++ b/mediapipe/framework/input_stream_manager.cc @@ -27,14 +27,14 @@ namespace mediapipe { -mediapipe::Status InputStreamManager::Initialize(const std::string& name, - const PacketType* packet_type, - bool back_edge) { +absl::Status InputStreamManager::Initialize(const std::string& name, + const PacketType* packet_type, + bool back_edge) { name_ = name; packet_type_ = packet_type; back_edge_ = back_edge; PrepareForRun(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } const std::string& InputStreamManager::Name() const { return name_; } @@ -70,29 +70,29 @@ Packet InputStreamManager::QueueHead() const { return queue_.front(); } -mediapipe::Status InputStreamManager::SetHeader(const Packet& header) { +absl::Status InputStreamManager::SetHeader(const Packet& header) { if (header.Timestamp() != Timestamp::Unset()) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Headers must not have a timestamp. Stream: \"" << name_ << "\"."; } header_ = header; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status InputStreamManager::AddPackets( - const std::list& container, bool* notify) { +absl::Status InputStreamManager::AddPackets(const std::list& container, + bool* notify) { return AddOrMovePacketsInternal&>(container, notify); } -mediapipe::Status InputStreamManager::MovePackets(std::list* container, - bool* notify) { +absl::Status InputStreamManager::MovePackets(std::list* container, + bool* notify) { return AddOrMovePacketsInternal&>(*container, notify); } template -mediapipe::Status InputStreamManager::AddOrMovePacketsInternal( - Container container, bool* notify) { +absl::Status InputStreamManager::AddOrMovePacketsInternal(Container container, + bool* notify) { *notify = false; bool queue_became_non_empty = false; bool queue_became_full = false; @@ -100,7 +100,7 @@ mediapipe::Status InputStreamManager::AddOrMovePacketsInternal( // Scope to prevent locking the stream when notification is called. absl::MutexLock stream_lock(&stream_mutex_); if (closed_) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Check if the queue was full before packets came in. bool was_queue_full = @@ -108,7 +108,7 @@ mediapipe::Status InputStreamManager::AddOrMovePacketsInternal( // Check if the queue becomes non-empty. queue_became_non_empty = queue_.empty() && !container.empty(); for (auto& packet : container) { - mediapipe::Status result = packet_type_->Validate(packet); + absl::Status result = packet_type_->Validate(packet); if (!result.ok()) { return tool::AddStatusPrefix( absl::StrCat( @@ -177,17 +177,17 @@ mediapipe::Status InputStreamManager::AddOrMovePacketsInternal( becomes_full_callback_(this, &last_reported_stream_full_); } *notify = queue_became_non_empty; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status InputStreamManager::SetNextTimestampBound( - const Timestamp bound, bool* notify) { +absl::Status InputStreamManager::SetNextTimestampBound(const Timestamp bound, + bool* notify) { *notify = false; { // Scope to prevent locking the stream when notification is called. absl::MutexLock stream_lock(&stream_mutex_); if (closed_) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (enable_timestamps_ && bound < next_timestamp_bound_) { @@ -211,7 +211,7 @@ mediapipe::Status InputStreamManager::SetNextTimestampBound( } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void InputStreamManager::DisableTimestamps() { enable_timestamps_ = false; } diff --git a/mediapipe/framework/input_stream_manager.h b/mediapipe/framework/input_stream_manager.h index 0541247ff..042ef8d83 100644 --- a/mediapipe/framework/input_stream_manager.h +++ b/mediapipe/framework/input_stream_manager.h @@ -57,8 +57,8 @@ class InputStreamManager { InputStreamManager() = default; // Initializes the InputStreamManager. - mediapipe::Status Initialize(const std::string& name, - const PacketType* packet_type, bool back_edge); + absl::Status Initialize(const std::string& name, + const PacketType* packet_type, bool back_edge); // Returns the stream name. const std::string& Name() const; @@ -67,7 +67,7 @@ class InputStreamManager { bool BackEdge() const { return back_edge_; } // Sets the header Packet. - mediapipe::Status SetHeader(const Packet& header); + absl::Status SetHeader(const Packet& header); const Packet& Header() const { return header_; } @@ -87,13 +87,12 @@ class InputStreamManager { // Timestamp::PostStream(), the packet must be the only packet in the // stream. // Violation of any of these conditions causes an error status. - mediapipe::Status AddPackets(const std::list& container, - bool* notify); + absl::Status AddPackets(const std::list& container, bool* notify); // Move a list of timestamped packets. Sets "notify" to true if the queue // becomes non-empty. Does nothing if the input stream is closed. After the // move, all packets in the container must be empty. - mediapipe::Status MovePackets(std::list* container, bool* notify); + absl::Status MovePackets(std::list* container, bool* notify); // Closes the input stream. This function can be called multiple times. void Close() ABSL_LOCKS_EXCLUDED(stream_mutex_); @@ -103,7 +102,7 @@ class InputStreamManager { // empty. Returns an error status if this decreases the bound, unless // DisableTimestamps() is called. Does nothing if the input stream is // closed. - mediapipe::Status SetNextTimestampBound(Timestamp bound, bool* notify) + absl::Status SetNextTimestampBound(Timestamp bound, bool* notify) ABSL_LOCKS_EXCLUDED(stream_mutex_); // Returns the smallest timestamp at which we might see an input in @@ -182,7 +181,7 @@ class InputStreamManager { // Otherwise, the caller must be MovePackets() and Container should be // non-const reference. template - mediapipe::Status AddOrMovePacketsInternal(Container container, bool* notify) + absl::Status AddOrMovePacketsInternal(Container container, bool* notify) ABSL_LOCKS_EXCLUDED(stream_mutex_); // Returns true if the next timestamp bound reaches Timestamp::Done(). diff --git a/mediapipe/framework/input_stream_manager_test.cc b/mediapipe/framework/input_stream_manager_test.cc index 6b158928b..f1c1185f1 100644 --- a/mediapipe/framework/input_stream_manager_test.cc +++ b/mediapipe/framework/input_stream_manager_test.cc @@ -133,7 +133,7 @@ TEST_F(InputStreamManagerTest, AddPacketUnset) { packets.push_back(MakePacket("packet 1").At(Timestamp::Unset())); EXPECT_TRUE(input_stream_manager_->IsEmpty()); - mediapipe::Status result = + absl::Status result = input_stream_manager_->AddPackets(packets, ¬ify_); // No notification ASSERT_THAT(result.message(), testing::HasSubstr("Timestamp::Unset()")); EXPECT_FALSE(notify_); @@ -145,7 +145,7 @@ TEST_F(InputStreamManagerTest, AddPacketUnstarted) { MakePacket("packet 1").At(Timestamp::Unstarted())); EXPECT_TRUE(input_stream_manager_->IsEmpty()); - mediapipe::Status result = + absl::Status result = input_stream_manager_->AddPackets(packets, ¬ify_); // No notification ASSERT_THAT(result.message(), testing::HasSubstr("Timestamp::Unstarted()")); EXPECT_FALSE(notify_); @@ -157,7 +157,7 @@ TEST_F(InputStreamManagerTest, AddPacketOneOverPostStream) { MakePacket("packet 1").At(Timestamp::OneOverPostStream())); EXPECT_TRUE(input_stream_manager_->IsEmpty()); - mediapipe::Status result = + absl::Status result = input_stream_manager_->AddPackets(packets, ¬ify_); // No notification ASSERT_THAT(result.message(), testing::HasSubstr("Timestamp::OneOverPostStream()")); @@ -169,7 +169,7 @@ TEST_F(InputStreamManagerTest, AddPacketDone) { packets.push_back(MakePacket("packet 1").At(Timestamp::Done())); EXPECT_TRUE(input_stream_manager_->IsEmpty()); - mediapipe::Status result = + absl::Status result = input_stream_manager_->AddPackets(packets, ¬ify_); // No notification ASSERT_THAT(result.message(), testing::HasSubstr("Timestamp::Done()")); EXPECT_FALSE(notify_); @@ -196,7 +196,7 @@ TEST_F(InputStreamManagerTest, AddPacketsAfterPreStream) { packets.push_back(MakePacket("packet 2").At(Timestamp(10))); EXPECT_TRUE(input_stream_manager_->IsEmpty()); - mediapipe::Status result = + absl::Status result = input_stream_manager_->AddPackets(packets, ¬ify_); // No notification ASSERT_THAT(result.message(), testing::HasSubstr("Timestamp::OneOverPostStream()")); @@ -224,7 +224,7 @@ TEST_F(InputStreamManagerTest, AddPacketsBeforePostStream) { MakePacket("packet 2").At(Timestamp::PostStream())); EXPECT_TRUE(input_stream_manager_->IsEmpty()); - mediapipe::Status result = + absl::Status result = input_stream_manager_->AddPackets(packets, ¬ify_); // No notification ASSERT_THAT(result.message(), testing::HasSubstr("Timestamp::PostStream()")); EXPECT_FALSE(notify_); @@ -237,7 +237,7 @@ TEST_F(InputStreamManagerTest, AddPacketsReverseTimestamps) { packets.push_back(MakePacket("packet 3").At(Timestamp(30))); EXPECT_TRUE(input_stream_manager_->IsEmpty()); - mediapipe::Status result = + absl::Status result = input_stream_manager_->AddPackets(packets, ¬ify_); // No notification ASSERT_THAT(result.message(), testing::HasSubstr( @@ -398,7 +398,7 @@ TEST_F(InputStreamManagerTest, BadPacketType) { packets.push_back(MakePacket(10).At(Timestamp(10))); EXPECT_TRUE(input_stream_manager_->IsEmpty()); - mediapipe::Status result = + absl::Status result = input_stream_manager_->AddPackets(packets, ¬ify_); // No notification ASSERT_THAT(result.message(), testing::HasSubstr("Packet type mismatch")); EXPECT_FALSE(notify_); @@ -543,7 +543,7 @@ TEST_F(InputStreamManagerTest, BackwardsInTime) { EXPECT_FALSE(notify_); notify_ = false; - mediapipe::Status result = input_stream_manager_->SetNextTimestampBound( + absl::Status result = input_stream_manager_->SetNextTimestampBound( Timestamp(40), ¬ify_); // Set Timestamp bound backwards in time. ASSERT_THAT(result.message(), testing::HasSubstr("40")); ASSERT_THAT(result.message(), testing::HasSubstr("50")); @@ -554,7 +554,7 @@ TEST_F(InputStreamManagerTest, BackwardsInTime) { packets.clear(); packets.push_back(MakePacket("packet 3") .At(Timestamp(30))); // Backwards in time - mediapipe::Status result2 = + absl::Status result2 = input_stream_manager_->AddPackets(packets, ¬ify_); // No notification ASSERT_THAT(result2.message(), testing::HasSubstr("50")); ASSERT_THAT(result2.message(), testing::HasSubstr("30")); @@ -585,7 +585,7 @@ TEST_F(InputStreamManagerTest, BackwardsInTime) { packets.clear(); packets.push_back(MakePacket("packet 5") .At(Timestamp(130))); // Backwards in time. - mediapipe::Status result3 = + absl::Status result3 = input_stream_manager_->AddPackets(packets, ¬ify_); // No notification ASSERT_THAT(result3.message(), testing::HasSubstr("151")); ASSERT_THAT(result3.message(), testing::HasSubstr("130")); diff --git a/mediapipe/framework/mediapipe_cc_test.bzl b/mediapipe/framework/mediapipe_cc_test.bzl new file mode 100644 index 000000000..4991992dd --- /dev/null +++ b/mediapipe/framework/mediapipe_cc_test.bzl @@ -0,0 +1,30 @@ +"""Macro for multi-platform C++ tests.""" + +DEFAULT_ADDITIONAL_TEST_DEPS = [] + +def mediapipe_cc_test( + name, + srcs = [], + data = [], + deps = [], + size = None, + timeout = None, + additional_deps = DEFAULT_ADDITIONAL_TEST_DEPS, + **kwargs): + # Note: additional_deps are MediaPipe-specific test support deps added by default. + # They are provided as a default argument so they can be disabled if desired. + native.cc_library( + name = name + "_lib", + testonly = 1, + srcs = srcs, + data = data, + deps = deps + additional_deps, + alwayslink = 1, + ) + + native.cc_test( + name = name, + size = size, + timeout = timeout, + deps = [":{}_lib".format(name)], + ) diff --git a/mediapipe/framework/output_side_packet_impl.cc b/mediapipe/framework/output_side_packet_impl.cc index 0cb8be047..94bc518f8 100644 --- a/mediapipe/framework/output_side_packet_impl.cc +++ b/mediapipe/framework/output_side_packet_impl.cc @@ -20,21 +20,21 @@ namespace mediapipe { -mediapipe::Status OutputSidePacketImpl::Initialize( - const std::string& name, const PacketType* packet_type) { +absl::Status OutputSidePacketImpl::Initialize(const std::string& name, + const PacketType* packet_type) { name_ = name; packet_type_ = packet_type; - return mediapipe::OkStatus(); + return absl::OkStatus(); } void OutputSidePacketImpl::PrepareForRun( - std::function error_callback) { + std::function error_callback) { error_callback_ = std::move(error_callback); initialized_ = false; } void OutputSidePacketImpl::Set(const Packet& packet) { - mediapipe::Status status = SetInternal(packet); + absl::Status status = SetInternal(packet); if (!status.ok()) { TriggerErrorCallback(status); } @@ -46,7 +46,7 @@ void OutputSidePacketImpl::AddMirror( mirrors_.emplace_back(input_side_packet_handler, id); } -mediapipe::Status OutputSidePacketImpl::SetInternal(const Packet& packet) { +absl::Status OutputSidePacketImpl::SetInternal(const Packet& packet) { if (initialized_) { return mediapipe::AlreadyExistsErrorBuilder(MEDIAPIPE_LOC) << "Output side packet \"" << name_ << "\" was already set."; @@ -63,7 +63,7 @@ mediapipe::Status OutputSidePacketImpl::SetInternal(const Packet& packet) { << packet.Timestamp().DebugString() << "."; } - mediapipe::Status result = packet_type_->Validate(packet); + absl::Status result = packet_type_->Validate(packet); if (!result.ok()) { return mediapipe::StatusBuilder(result, MEDIAPIPE_LOC).SetPrepend() << absl::StrCat( @@ -76,11 +76,11 @@ mediapipe::Status OutputSidePacketImpl::SetInternal(const Packet& packet) { for (const auto& mirror : mirrors_) { mirror.input_side_packet_handler->Set(mirror.id, packet_); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void OutputSidePacketImpl::TriggerErrorCallback( - const mediapipe::Status& status) const { + const absl::Status& status) const { CHECK(error_callback_); error_callback_(status); } diff --git a/mediapipe/framework/output_side_packet_impl.h b/mediapipe/framework/output_side_packet_impl.h index c38c65912..7e7d639cd 100644 --- a/mediapipe/framework/output_side_packet_impl.h +++ b/mediapipe/framework/output_side_packet_impl.h @@ -35,13 +35,13 @@ class OutputSidePacketImpl : public OutputSidePacket { ~OutputSidePacketImpl() override = default; // Initializes the OutputSidePacketImpl. - mediapipe::Status Initialize(const std::string& name, - const PacketType* packet_type); + absl::Status Initialize(const std::string& name, + const PacketType* packet_type); // Prepares this for processing. If an error occurs in a user called function // (such as Set()) then error_callback will be called before returning // control to the user. - void PrepareForRun(std::function error_callback); + void PrepareForRun(std::function error_callback); // Gets the output side packet. Packet GetPacket() const { return packet_; } @@ -70,15 +70,15 @@ class OutputSidePacketImpl : public OutputSidePacket { }; // Called by Set(). - mediapipe::Status SetInternal(const Packet& packet); + absl::Status SetInternal(const Packet& packet); - // Triggers the error callback with mediapipe::Status info when an error + // Triggers the error callback with absl::Status info when an error // occurs. - void TriggerErrorCallback(const mediapipe::Status& status) const; + void TriggerErrorCallback(const absl::Status& status) const; std::string name_; const PacketType* packet_type_; - std::function error_callback_; + std::function error_callback_; Packet packet_; bool initialized_ = false; diff --git a/mediapipe/framework/output_stream_handler.cc b/mediapipe/framework/output_stream_handler.cc index 86990d343..ba8f46718 100644 --- a/mediapipe/framework/output_stream_handler.cc +++ b/mediapipe/framework/output_stream_handler.cc @@ -20,16 +20,16 @@ namespace mediapipe { -mediapipe::Status OutputStreamHandler::InitializeOutputStreamManagers( +absl::Status OutputStreamHandler::InitializeOutputStreamManagers( OutputStreamManager* flat_output_stream_managers) { for (CollectionItemId id = output_stream_managers_.BeginId(); id < output_stream_managers_.EndId(); ++id) { output_stream_managers_.Get(id) = &flat_output_stream_managers[id.value()]; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status OutputStreamHandler::SetupOutputShards( +absl::Status OutputStreamHandler::SetupOutputShards( OutputStreamShardSet* output_shards) { CHECK(output_shards); for (CollectionItemId id = output_stream_managers_.BeginId(); @@ -37,11 +37,11 @@ mediapipe::Status OutputStreamHandler::SetupOutputShards( OutputStreamManager* manager = output_stream_managers_.Get(id); output_shards->Get(id).SetSpec(manager->Spec()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void OutputStreamHandler::PrepareForRun( - const std::function& error_callback) { + const std::function& error_callback) { for (auto& manager : output_stream_managers_) { manager->PrepareForRun(error_callback); } diff --git a/mediapipe/framework/output_stream_handler.h b/mediapipe/framework/output_stream_handler.h index 994f356e8..f134139c6 100644 --- a/mediapipe/framework/output_stream_handler.h +++ b/mediapipe/framework/output_stream_handler.h @@ -76,11 +76,11 @@ class OutputStreamHandler { // OutputStreamHandler::output_stream_managers_ (meaning it should // point to somewhere in the middle of the master flat array of all // output stream managers). - mediapipe::Status InitializeOutputStreamManagers( + absl::Status InitializeOutputStreamManagers( OutputStreamManager* flat_output_stream_managers); // Sets up output shards by connecting to the managers. - mediapipe::Status SetupOutputShards(OutputStreamShardSet* output_shards); + absl::Status SetupOutputShards(OutputStreamShardSet* output_shards); int NumOutputStreams() const { return output_stream_managers_.NumEntries(); } @@ -91,8 +91,7 @@ class OutputStreamHandler { // Calls OutputStreamManager::PrepareForRun(error_callback) per stream, and // resets data memebers. - void PrepareForRun( - const std::function& error_callback) + void PrepareForRun(const std::function& error_callback) ABSL_LOCKS_EXCLUDED(timestamp_mutex_); // Marks the output streams as started and propagates any changes made in diff --git a/mediapipe/framework/output_stream_manager.cc b/mediapipe/framework/output_stream_manager.cc index 572245bf6..0784bdccc 100644 --- a/mediapipe/framework/output_stream_manager.cc +++ b/mediapipe/framework/output_stream_manager.cc @@ -20,17 +20,17 @@ namespace mediapipe { -mediapipe::Status OutputStreamManager::Initialize( - const std::string& name, const PacketType* packet_type) { +absl::Status OutputStreamManager::Initialize(const std::string& name, + const PacketType* packet_type) { output_stream_spec_.name = name; output_stream_spec_.packet_type = packet_type; output_stream_spec_.offset_enabled = false; PrepareForRun(nullptr); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void OutputStreamManager::PrepareForRun( - std::function error_callback) { + std::function error_callback) { output_stream_spec_.error_callback = std::move(error_callback); output_stream_spec_.locked_intro_data = false; @@ -117,7 +117,6 @@ Timestamp OutputStreamManager::ComputeOutputTimestampBound( // MaxOutputTimestamp(completed_timestamp) + 1) // Note that "MaxOutputTimestamp()" must consider both output packet // timetstamp and SetNextTimestampBound values. - // See the timestamp mapping section in go/mediapipe-bounds for details. Timestamp input_bound; if (output_stream_spec_.offset_enabled && input_timestamp != Timestamp::Unstarted()) { diff --git a/mediapipe/framework/output_stream_manager.h b/mediapipe/framework/output_stream_manager.h index 6fec67d08..b0d97a99e 100644 --- a/mediapipe/framework/output_stream_manager.h +++ b/mediapipe/framework/output_stream_manager.h @@ -40,13 +40,13 @@ class OutputStreamManager { OutputStreamManager() = default; // Initializes the OutputStreamManager. - mediapipe::Status Initialize(const std::string& name, - const PacketType* packet_type); + absl::Status Initialize(const std::string& name, + const PacketType* packet_type); // Prepares this for processing. If an error occurs in a user called function // (such as AddPacket()) then error_callback will be called before returning // control to the user. - void PrepareForRun(std::function error_callback); + void PrepareForRun(std::function error_callback); // Gets the stream name. const std::string& Name() const { return output_stream_spec_.name; } @@ -85,8 +85,7 @@ class OutputStreamManager { // Computes the output timestamp bound based on the input timestamp, the // timestamp of the last added packet, and the next timestamp bound from - // the OutputStreamShard. See the timestamp mapping section in - // go/mediapipe-bounds for details. + // the OutputStreamShard. // The function is invoked by OutputStreamHandler after the calculator node // finishes a call to Calculator::Process(). Timestamp ComputeOutputTimestampBound( diff --git a/mediapipe/framework/output_stream_manager_test.cc b/mediapipe/framework/output_stream_manager_test.cc index 77b15e341..c8aaec011 100644 --- a/mediapipe/framework/output_stream_manager_test.cc +++ b/mediapipe/framework/output_stream_manager_test.cc @@ -58,13 +58,13 @@ class OutputStreamManagerTest : public ::testing::Test { output_stream_shard_.SetSpec(output_stream_manager_->Spec()); output_stream_manager_->ResetShard(&output_stream_shard_); - std::shared_ptr tag_map = tool::CreateTagMap(1).ValueOrDie(); - mediapipe::StatusOr> + std::shared_ptr tag_map = tool::CreateTagMap(1).value(); + absl::StatusOr> status_or_handler = InputStreamHandlerRegistry::CreateByName( "DefaultInputStreamHandler", tag_map, /*cc_manager=*/nullptr, MediaPipeOptions(), /*calculator_run_in_parallel=*/false); ASSERT_TRUE(status_or_handler.ok()); - input_stream_handler_ = std::move(status_or_handler.ValueOrDie()); + input_stream_handler_ = std::move(status_or_handler.value()); const CollectionItemId& id = tag_map->BeginId(); MP_ASSERT_OK(input_stream_manager_.Initialize("a_test", &packet_type_, @@ -85,7 +85,7 @@ class OutputStreamManagerTest : public ::testing::Test { void ScheduleNoOp(CalculatorContext* cc) {} - void RecordError(const mediapipe::Status& error) { errors_.push_back(error); } + void RecordError(const absl::Status& error) { errors_.push_back(error); } void ReportQueueNoOp(InputStreamManager* stream, bool* stream_was_full) {} @@ -104,7 +104,7 @@ class OutputStreamManagerTest : public ::testing::Test { std::function headers_ready_callback_; std::function notification_callback_; std::function schedule_callback_; - std::function error_callback_; + std::function error_callback_; InputStreamManager::QueueSizeCallback queue_full_callback_; InputStreamManager::QueueSizeCallback queue_not_full_callback_; @@ -114,7 +114,7 @@ class OutputStreamManagerTest : public ::testing::Test { InputStreamManager input_stream_manager_; // Vector of errors encountered while using the stream. - std::vector errors_; + std::vector errors_; }; TEST_F(OutputStreamManagerTest, Init) {} diff --git a/mediapipe/framework/output_stream_shard.cc b/mediapipe/framework/output_stream_shard.cc index 0a600be86..704a18d8c 100644 --- a/mediapipe/framework/output_stream_shard.cc +++ b/mediapipe/framework/output_stream_shard.cc @@ -128,7 +128,7 @@ Status OutputStreamShard::AddPacketInternal(T&& packet) { // TODO debug log? - return mediapipe::OkStatus(); + return absl::OkStatus(); } void OutputStreamShard::AddPacket(const Packet& packet) { diff --git a/mediapipe/framework/output_stream_shard.h b/mediapipe/framework/output_stream_shard.h index ad8ac5995..fdc5fe077 100644 --- a/mediapipe/framework/output_stream_shard.h +++ b/mediapipe/framework/output_stream_shard.h @@ -31,16 +31,16 @@ class OutputStreamManager; // The output stream spec shared across all output stream shards and their // output stream manager. struct OutputStreamSpec { - // Triggers the error callback with mediapipe::Status info when an error + // Triggers the error callback with absl::Status info when an error // occurs. - void TriggerErrorCallback(const mediapipe::Status& status) const { + void TriggerErrorCallback(const absl::Status& status) const { CHECK(error_callback); error_callback(status); } std::string name; const PacketType* packet_type; - std::function error_callback; + std::function error_callback; bool locked_intro_data; // Those three variables are the intro data protected by locked_intro_data. bool offset_enabled; @@ -102,7 +102,7 @@ class OutputStreamShard : public OutputStream { // AddPacketInternal template is called by either AddPacket(Packet&& packet) // or AddPacket(const Packet& packet). template - mediapipe::Status AddPacketInternal(T&& packet); + absl::Status AddPacketInternal(T&& packet); // Returns a pointer to the output queue. std::list* OutputQueue() { return &output_queue_; } diff --git a/mediapipe/framework/packet.cc b/mediapipe/framework/packet.cc index 3d25a4d0f..d1810871e 100644 --- a/mediapipe/framework/packet.cc +++ b/mediapipe/framework/packet.cc @@ -51,8 +51,8 @@ const HolderBase* GetHolder(const Packet& packet) { return packet.holder_.get(); } -mediapipe::StatusOr PacketFromDynamicProto( - const std::string& type_name, const std::string& serialized) { +absl::StatusOr PacketFromDynamicProto(const std::string& type_name, + const std::string& serialized) { ASSIGN_OR_RETURN( auto message_holder, packet_internal::MessageHolderRegistry::CreateByName(type_name)); @@ -105,16 +105,16 @@ std::string Packet::DebugString() const { return result; } -mediapipe::Status Packet::ValidateAsProtoMessageLite() const { +absl::Status Packet::ValidateAsProtoMessageLite() const { if (ABSL_PREDICT_FALSE(IsEmpty())) { - return mediapipe::InternalError("Packet is empty."); + return absl::InternalError("Packet is empty."); } if (ABSL_PREDICT_FALSE(holder_->GetProtoMessageLite() == nullptr)) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("The Packet stores \"", holder_->DebugTypeName(), "\"", "which is not convertible to proto_ns::MessageLite.")); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -129,7 +129,7 @@ const proto_ns::MessageLite& Packet::GetProtoMessageLite() const { StatusOr> Packet::GetVectorOfProtoMessageLitePtrs() { if (holder_ == nullptr) { - return mediapipe::InternalError("Packet is empty."); + return absl::InternalError("Packet is empty."); } return holder_->GetVectorOfProtoMessageLite(); } diff --git a/mediapipe/framework/packet.h b/mediapipe/framework/packet.h index f079afb89..03a5e8109 100644 --- a/mediapipe/framework/packet.h +++ b/mediapipe/framework/packet.h @@ -53,8 +53,9 @@ Packet Create(HolderBase* holder, Timestamp timestamp); Packet Create(std::shared_ptr holder, Timestamp timestamp); const HolderBase* GetHolder(const Packet& packet); const std::shared_ptr& GetHolderShared(const Packet& packet); -mediapipe::StatusOr PacketFromDynamicProto( - const std::string& type_name, const std::string& serialized); +std::shared_ptr GetHolderShared(Packet&& packet); +absl::StatusOr PacketFromDynamicProto(const std::string& type_name, + const std::string& serialized); } // namespace packet_internal // A generic container class which can hold data of any type. The type of @@ -111,7 +112,7 @@ class Packet { // holder. Otherwise, returns error when the packet can't be consumed. // See ConsumeOrCopy for threading requirements and example usage. template - mediapipe::StatusOr> Consume(); + absl::StatusOr> Consume(); // Consumes the packet and transfers the ownership of the data to a // unique pointer if the packet is the sole owner of a non-foreign @@ -130,28 +131,28 @@ class Packet { // // The unique_ptr type can be omitted with auto. // ASSIGN_OR_RETURN(auto detection, p.ConsumeOrCopy()); // If you would like to crash on failure (prefer ASSIGN_OR_RETURN): - // auto detection = p.ConsumeOrCopy().ValueOrDie(); - // // In functions which do not return mediapipe::Status use an adaptor + // auto detection = p.ConsumeOrCopy().value(); + // // In functions which do not return absl::Status use an adaptor // // function as the third argument to ASSIGN_OR_RETURN. In tests, // // use an adaptor which returns void. // ASSIGN_OR_RETURN(auto detection, p.ConsumeOrCopy(), - // _.With([](const mediapipe::Status& status) { + // _.With([](const absl::Status& status) { // MP_EXPECT_OK(status); // // Use CHECK_OK to crash and report a usable line - // // number (which the ValueOrDie alternative does not). + // // number (which the value() alternative does not). // // Include a return statement if the return value is // // non-void. For example: return 1; // })); // // Version for non-arrays. template - mediapipe::StatusOr> ConsumeOrCopy( + absl::StatusOr> ConsumeOrCopy( bool* was_copied = nullptr, typename std::enable_if::value>::type* = nullptr); // Version for bounded array. template - mediapipe::StatusOr> ConsumeOrCopy( + absl::StatusOr> ConsumeOrCopy( bool* was_copied = nullptr, typename std::enable_if::value && std::extent::value != 0>::type* = nullptr); @@ -160,7 +161,7 @@ class Packet { // delete helper. // Version for unbounded array. template - mediapipe::StatusOr> ConsumeOrCopy( + absl::StatusOr> ConsumeOrCopy( bool* was_copied = nullptr, typename std::enable_if::value && std::extent::value == 0>::type* = nullptr); @@ -178,11 +179,11 @@ class Packet { // Returns an error if the packet does not contain data of type T. template - mediapipe::Status ValidateAsType() const; + absl::Status ValidateAsType() const; // Returns an error if the packet is not an instance of // a protocol buffer message. - mediapipe::Status ValidateAsProtoMessageLite() const; + absl::Status ValidateAsProtoMessageLite() const; // Get the type id for the underlying type stored in the Packet. // Crashes if IsEmpty() == true. @@ -214,6 +215,8 @@ class Packet { const Packet& packet); friend const std::shared_ptr& packet_internal::GetHolderShared(const Packet& packet); + friend std::shared_ptr + packet_internal::GetHolderShared(Packet&& packet); std::shared_ptr holder_; class Timestamp timestamp_; @@ -326,6 +329,21 @@ T* GetFromUniquePtr(const Packet& packet) { return packet.Get>().get(); } +// Returns a shared_ptr to the payload of the packet which retains its object +// through a copy of the packet. +// Use std::const_pointer_cast if you need a shared_ptr, but remember that +// you must not change the payload if the packet has other owners. Use Consume +// if you want to try and modify the payload directly. +template +std::shared_ptr SharedPtrWithPacket(Packet packet) { + // This needs to be a separate statement because the evaluation order of + // function arguments is unspecified, and if the lambda is created first it + // moves the packet. + const T* ptr = &packet.Get(); + return std::shared_ptr( + ptr, [packet = std::move(packet)](const T* ptr) mutable { packet = {}; }); +} + //// Implementation details. namespace packet_internal { @@ -406,7 +424,7 @@ template StatusOr> ConvertToVectorOfProtoMessageLitePtrs(const T* data, /*is_proto_vector=*/std::false_type) { - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "The Packet stores \"", tool::TypeId().name(), "\"", "which is not convertible to vector.")); } @@ -496,7 +514,7 @@ class Holder : public HolderBase { // This method is dangerous and is only used by Packet::Consume() if the // packet is the only owner of the holder. template - mediapipe::StatusOr> Release( + absl::StatusOr> Release( typename std::enable_if::value || std::extent::value != 0>::type* = 0) { // Since C++ doesn't allow virtual, templated functions, check holder @@ -513,10 +531,10 @@ class Holder : public HolderBase { // TODO: support unbounded array after fixing the bug in holder's // delete helper. template - mediapipe::StatusOr> Release( + absl::StatusOr> Release( typename std::enable_if::value && std::extent::value == 0>::type* = 0) { - return mediapipe::InternalError("Release T[] isn't supported."); + return absl::InternalError("Release T[] isn't supported."); } const std::string DebugTypeName() const final { return MediaPipeTypeStringOrDemangled(); @@ -580,8 +598,8 @@ class ForeignHolder : public Holder { this->ptr_ = nullptr; } // Foreign holder can't release data pointer without ownership. - mediapipe::StatusOr> Release() { - return mediapipe::InternalError( + absl::StatusOr> Release() { + return absl::InternalError( "Foreign holder can't release data ptr without ownership."); } }; @@ -621,14 +639,14 @@ inline Packet& Packet::operator=(const Packet& packet) { } template -inline mediapipe::StatusOr> Packet::Consume() { +inline absl::StatusOr> Packet::Consume() { // If type validation fails, returns error. MP_RETURN_IF_ERROR(ValidateAsType()); // Clients who use this function are responsible for ensuring that no // other thread is doing anything with this Packet. if (holder_.unique()) { VLOG(2) << "Consuming the data of " << DebugString(); - mediapipe::StatusOr> release_result = + absl::StatusOr> release_result = holder_->As()->Release(); if (release_result.ok()) { VLOG(2) << "Setting " << DebugString() << " to empty."; @@ -638,12 +656,12 @@ inline mediapipe::StatusOr> Packet::Consume() { } // If packet isn't the sole owner of the holder, returns kFailedPrecondition // error with message. - return mediapipe::Status(mediapipe::StatusCode::kFailedPrecondition, - "Packet isn't the sole owner of the holder."); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Packet isn't the sole owner of the holder."); } template -inline mediapipe::StatusOr> Packet::ConsumeOrCopy( +inline absl::StatusOr> Packet::ConsumeOrCopy( bool* was_copied, typename std::enable_if::value>::type*) { MP_RETURN_IF_ERROR(ValidateAsType()); @@ -651,7 +669,7 @@ inline mediapipe::StatusOr> Packet::ConsumeOrCopy( if (!holder_->HolderIsOfType>() && holder_.unique()) { VLOG(2) << "Consuming the data of " << DebugString(); - mediapipe::StatusOr> release_result = + absl::StatusOr> release_result = holder_->As()->Release(); if (release_result.ok()) { VLOG(2) << "Setting " << DebugString() << " to empty."; @@ -673,7 +691,7 @@ inline mediapipe::StatusOr> Packet::ConsumeOrCopy( } template -inline mediapipe::StatusOr> Packet::ConsumeOrCopy( +inline absl::StatusOr> Packet::ConsumeOrCopy( bool* was_copied, typename std::enable_if::value && std::extent::value != 0>::type*) { @@ -682,7 +700,7 @@ inline mediapipe::StatusOr> Packet::ConsumeOrCopy( if (!holder_->HolderIsOfType>() && holder_.unique()) { VLOG(2) << "Consuming the data of " << DebugString(); - mediapipe::StatusOr> release_result = + absl::StatusOr> release_result = holder_->As()->Release(); if (release_result.ok()) { VLOG(2) << "Setting " << DebugString() << " to empty."; @@ -710,11 +728,11 @@ inline mediapipe::StatusOr> Packet::ConsumeOrCopy( } template -inline mediapipe::StatusOr> Packet::ConsumeOrCopy( +inline absl::StatusOr> Packet::ConsumeOrCopy( bool* was_copied, typename std::enable_if::value && std::extent::value == 0>::type*) { - return mediapipe::InternalError("Unbounded array isn't supported."); + return absl::InternalError("Unbounded array isn't supported."); } inline Packet::Packet(Packet&& packet) { @@ -746,25 +764,25 @@ inline const T& Packet::Get() const { packet_internal::Holder* holder = IsEmpty() ? nullptr : holder_->As(); if (holder == nullptr) { // Produce a good error message. - mediapipe::Status status = ValidateAsType(); + absl::Status status = ValidateAsType(); LOG(FATAL) << "Packet::Get() failed: " << status.message(); } return holder->data(); } template -mediapipe::Status Packet::ValidateAsType() const { +absl::Status Packet::ValidateAsType() const { if (ABSL_PREDICT_FALSE(IsEmpty())) { - return mediapipe::InternalError(absl::StrCat( + return absl::InternalError(absl::StrCat( "Expected a Packet of type: ", MediaPipeTypeStringOrDemangled(), ", but received an empty Packet.")); } if (ABSL_PREDICT_FALSE(holder_->As() == nullptr)) { - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "The Packet stores \"", holder_->DebugTypeName(), "\", but \"", MediaPipeTypeStringOrDemangled(), "\" was requested.")); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } inline Timestamp Packet::Timestamp() const { return timestamp_; } @@ -796,6 +814,10 @@ inline const std::shared_ptr& GetHolderShared( return packet.holder_; } +inline std::shared_ptr GetHolderShared(Packet&& packet) { + return std::move(packet.holder_); +} + } // namespace packet_internal } // namespace mediapipe diff --git a/mediapipe/framework/packet_generator.h b/mediapipe/framework/packet_generator.h index 8c8c1185a..dbc1f07e0 100644 --- a/mediapipe/framework/packet_generator.h +++ b/mediapipe/framework/packet_generator.h @@ -49,12 +49,12 @@ class PacketGenerator { // and // produce output side packets. // - // static mediapipe::Status FillExpectations( + // static absl::Status FillExpectations( // const PacketGeneratorOptions& extendable_options, // PacketTypeSet* input_side_packets, // PacketTypeSet* output_side_packets); // - // static mediapipe::Status Generate( + // static absl::Status Generate( // const PacketGeneratorOptions& extendable_options, // const PacketSet& input_side_packets, // PacketSet* output_side_packets); @@ -69,11 +69,11 @@ namespace internal { class StaticAccessToGenerator { public: virtual ~StaticAccessToGenerator() {} - virtual mediapipe::Status FillExpectations( + virtual absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, // PacketTypeSet* input_side_packets, // PacketTypeSet* output_side_packets) = 0; - virtual mediapipe::Status Generate( + virtual absl::Status Generate( const PacketGeneratorOptions& extendable_options, // const PacketSet& input_side_packets, // PacketSet* output_side_packets) = 0; @@ -87,7 +87,7 @@ using StaticAccessToGeneratorRegistry = template constexpr bool PacketGeneratorHasFillExpectations( decltype(&T::FillExpectations) /*unused*/) { - typedef mediapipe::Status (*FillExpectationsType)( + typedef absl::Status (*FillExpectationsType)( const PacketGeneratorOptions& extendable_options, // PacketTypeSet* input_side_packets, // PacketTypeSet* output_side_packets); @@ -100,7 +100,7 @@ constexpr bool PacketGeneratorHasFillExpectations(...) { } template constexpr bool PacketGeneratorHasGenerate(decltype(&T::Generate) /*unused*/) { - typedef mediapipe::Status (*GenerateType)( + typedef absl::Status (*GenerateType)( const PacketGeneratorOptions& extendable_options, // const PacketSet& input_side_packets, // PacketSet* output_side_packets); @@ -129,7 +129,7 @@ class StaticAccessToGeneratorTyped : public StaticAccessToGenerator { "Generate() must be defined with the correct signature in " "every PacketGenerator."); - mediapipe::Status FillExpectations( + absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, // PacketTypeSet* input_side_packets, // PacketTypeSet* output_side_packets) final { @@ -137,10 +137,9 @@ class StaticAccessToGeneratorTyped : public StaticAccessToGenerator { extendable_options, input_side_packets, output_side_packets); } - mediapipe::Status Generate( - const PacketGeneratorOptions& extendable_options, // - const PacketSet& input_side_packets, // - PacketSet* output_side_packets) final { + absl::Status Generate(const PacketGeneratorOptions& extendable_options, // + const PacketSet& input_side_packets, // + PacketSet* output_side_packets) final { return PacketGeneratorSubclass::Generate( extendable_options, input_side_packets, output_side_packets); } diff --git a/mediapipe/framework/packet_generator_graph.cc b/mediapipe/framework/packet_generator_graph.cc index ee7472d7b..8f93e8383 100644 --- a/mediapipe/framework/packet_generator_graph.cc +++ b/mediapipe/framework/packet_generator_graph.cc @@ -44,7 +44,7 @@ namespace { // generator cannot be run given the currently available side packets // (and false otherwise). If an error occurs then unrunnable and // input_side_packet_set are undefined. -mediapipe::Status CreateInputsForGenerator( +absl::Status CreateInputsForGenerator( const ValidatedGraphConfig& validated_graph, int generator_index, const std::map& side_packets, PacketSet* input_side_packet_set, bool* unrunnable) { @@ -55,7 +55,7 @@ mediapipe::Status CreateInputsForGenerator( .packet_generator(); // Fill the PacketSet (if possible). *unrunnable = false; - std::vector statuses; + std::vector statuses; for (CollectionItemId id = node_type_info.InputSidePacketTypes().BeginId(); id < node_type_info.InputSidePacketTypes().EndId(); ++id) { const std::string& name = @@ -67,7 +67,7 @@ mediapipe::Status CreateInputsForGenerator( continue; } input_side_packet_set->Get(id) = it->second; - mediapipe::Status status = + absl::Status status = node_type_info.InputSidePacketTypes().Get(id).Validate( input_side_packet_set->Get(id)); if (!status.ok()) { @@ -82,15 +82,15 @@ mediapipe::Status CreateInputsForGenerator( return tool::CombinedStatus( absl::StrCat(generator_name, " had invalid configuration."), statuses); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Generate the packets from a PacketGenerator, place them in // output_side_packet_set, and validate their types. -mediapipe::Status Generate(const ValidatedGraphConfig& validated_graph, - int generator_index, - const PacketSet& input_side_packet_set, - PacketSet* output_side_packet_set) { +absl::Status Generate(const ValidatedGraphConfig& validated_graph, + int generator_index, + const PacketSet& input_side_packet_set, + PacketSet* output_side_packet_set) { const NodeTypeInfo& node_type_info = validated_graph.GeneratorInfos()[generator_index]; const PacketGeneratorConfig& generator_config = @@ -113,7 +113,7 @@ mediapipe::Status Generate(const ValidatedGraphConfig& validated_graph, .SetPrepend() << generator_name << "::Generate() output packets were of incorrect type: "; - return mediapipe::OkStatus(); + return absl::OkStatus(); } // GeneratorScheduler schedules the packet generators in a validated graph for @@ -149,7 +149,7 @@ class GeneratorScheduler { // rather, not executed) in non_scheduled_generators. Returns the combined // error status if there were errors while running the packet generators. // NOTE: This method should only be called when there are no pending tasks. - mediapipe::Status GetNonScheduledGenerators( + absl::Status GetNonScheduledGenerators( std::vector* non_scheduled_generators) const; private: @@ -169,7 +169,7 @@ class GeneratorScheduler { // This condition variable is signaled when num_tasks_ becomes 0. absl::CondVar idle_condvar_; // Accumulates the error statuses while running the packet generators. - std::vector statuses_ ABSL_GUARDED_BY(mutex_); + std::vector statuses_ ABSL_GUARDED_BY(mutex_); // scheduled_generators_[i] is true if the packet generator with index i was // scheduled (or rather, executed). std::vector scheduled_generators_ ABSL_GUARDED_BY(mutex_); @@ -219,7 +219,7 @@ void GeneratorScheduler::GenerateAndScheduleNext( .OutputSidePacketTypes() .TagMap()); VLOG(1) << "Running generator " << generator_index; - mediapipe::Status status = + absl::Status status = Generate(*validated_graph_, generator_index, *input_side_packet_set, &output_side_packet_set); @@ -235,7 +235,7 @@ void GeneratorScheduler::GenerateAndScheduleNext( const auto& name = output_side_packet_set.TagMap()->Names()[id.value()]; auto item = side_packets->emplace(name, output_side_packet_set.Get(id)); if (!item.second) { - statuses_.push_back(mediapipe::AlreadyExistsError( + statuses_.push_back(absl::AlreadyExistsError( absl::StrCat("Side packet \"", name, "\" was defined twice."))); } } @@ -266,7 +266,7 @@ void GeneratorScheduler::ScheduleAllRunnableGenerators( .InputSidePacketTypes() .TagMap()); - mediapipe::Status status = + absl::Status status = CreateInputsForGenerator(*validated_graph_, index, *side_packets, input_side_packet_set.get(), &is_unrunnable); if (!status.ok()) { @@ -313,7 +313,7 @@ void GeneratorScheduler::WaitUntilIdle() { } } -mediapipe::Status GeneratorScheduler::GetNonScheduledGenerators( +absl::Status GeneratorScheduler::GetNonScheduledGenerators( std::vector* non_scheduled_generators) const { non_scheduled_generators->clear(); @@ -326,7 +326,7 @@ mediapipe::Status GeneratorScheduler::GetNonScheduledGenerators( non_scheduled_generators->push_back(i); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void GeneratorScheduler::AddApplicationThreadTask(std::function task) { @@ -356,7 +356,7 @@ void GeneratorScheduler::RunApplicationThreadTasks() { PacketGeneratorGraph::~PacketGeneratorGraph() {} -mediapipe::Status PacketGeneratorGraph::Initialize( +absl::Status PacketGeneratorGraph::Initialize( const ValidatedGraphConfig* validated_graph, mediapipe::Executor* executor, const std::map& input_side_packets) { validated_graph_ = validated_graph; @@ -368,14 +368,14 @@ mediapipe::Status PacketGeneratorGraph::Initialize( /*initial=*/true); } -mediapipe::Status PacketGeneratorGraph::RunGraphSetup( +absl::Status PacketGeneratorGraph::RunGraphSetup( const std::map& input_side_packets, std::map* output_side_packets) const { *output_side_packets = base_packets_; for (const std::pair& item : input_side_packets) { auto iter = output_side_packets->find(item.first); if (iter != output_side_packets->end()) { - return mediapipe::AlreadyExistsError( + return absl::AlreadyExistsError( absl::StrCat("Side packet \"", iter->first, "\" was defined twice.")); } output_side_packets->insert(iter, item); @@ -394,10 +394,10 @@ mediapipe::Status PacketGeneratorGraph::RunGraphSetup( << "Some Generators were unrunnable (validation should have failed).\n" "Generator indexes: " << absl::StrJoin(non_scheduled_generators, ", "); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PacketGeneratorGraph::ExecuteGenerators( +absl::Status PacketGeneratorGraph::ExecuteGenerators( std::map* output_side_packets, std::vector* non_scheduled_generators, bool initial) const { VLOG(1) << "ExecuteGenerators initial == " << initial; diff --git a/mediapipe/framework/packet_generator_graph.h b/mediapipe/framework/packet_generator_graph.h index a1c493ca9..533f55d08 100644 --- a/mediapipe/framework/packet_generator_graph.h +++ b/mediapipe/framework/packet_generator_graph.h @@ -67,14 +67,14 @@ class PacketGeneratorGraph { // stage and will be common to all calls to CalculatorGraph::Run(). // Any generators which are runnable at this stage (that only depend on // things in the graph or input_side_packets) will be run at this time. - virtual mediapipe::Status Initialize( + virtual absl::Status Initialize( const ValidatedGraphConfig* validated_graph, mediapipe::Executor* executor, const std::map& input_side_packets); // Add the input_side_packets and run any remaining generators (which // must now be runnable) to produce output_side_packets. - virtual mediapipe::Status RunGraphSetup( + virtual absl::Status RunGraphSetup( const std::map& input_side_packets, std::map* output_side_packets) const; @@ -96,7 +96,7 @@ class PacketGeneratorGraph { // packets and unrunnable generators. "initial" must be set to true for // the first pass and false for subsequent passes. output_side_packets // must be set to include the input side packets before calling. - mediapipe::Status ExecuteGenerators( + absl::Status ExecuteGenerators( std::map* output_side_packets, std::vector* non_scheduled_generators, bool initial) const; diff --git a/mediapipe/framework/packet_generator_test.cc b/mediapipe/framework/packet_generator_test.cc index 05197bc9e..c34949a82 100644 --- a/mediapipe/framework/packet_generator_test.cc +++ b/mediapipe/framework/packet_generator_test.cc @@ -27,7 +27,7 @@ namespace mediapipe { namespace { class DoNothingGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) { for (CollectionItemId id = input_side_packets->BeginId(); @@ -38,17 +38,17 @@ class DoNothingGenerator : public PacketGenerator { id < output_side_packets->EndId(); ++id) { output_side_packets->Get(id).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( - const PacketGeneratorOptions& extendable_options, - const PacketSet& input_side_packets, PacketSet* output_side_packets) { + static absl::Status Generate(const PacketGeneratorOptions& extendable_options, + const PacketSet& input_side_packets, + PacketSet* output_side_packets) { for (CollectionItemId id = output_side_packets->BeginId(); id < output_side_packets->EndId(); ++id) { output_side_packets->Get(id) = MakePacket(true); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; diff --git a/mediapipe/framework/packet_registration_test.cc b/mediapipe/framework/packet_registration_test.cc index 0860adb1d..ed3260f97 100644 --- a/mediapipe/framework/packet_registration_test.cc +++ b/mediapipe/framework/packet_registration_test.cc @@ -28,17 +28,17 @@ namespace test_ns { class TestSinkCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("IN").Set(); cc->Outputs().Tag("OUT").Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { int x = cc->Inputs().Tag("IN").Get().x(); cc->Outputs().Tag("OUT").AddPacket( MakePacket(x).At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(TestSinkCalculator); diff --git a/mediapipe/framework/packet_test.cc b/mediapipe/framework/packet_test.cc index db863248d..d5b8d0c58 100644 --- a/mediapipe/framework/packet_test.cc +++ b/mediapipe/framework/packet_test.cc @@ -210,8 +210,8 @@ TEST(PacketTest, ValidateAsProtoMessageLite) { Packet packet = Adopt(proto_ptr.release()); MP_EXPECT_OK(packet.ValidateAsProtoMessageLite()); Packet packet2 = MakePacket(3); - mediapipe::Status status = packet2.ValidateAsProtoMessageLite(); - EXPECT_EQ(status.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status status = packet2.ValidateAsProtoMessageLite(); + EXPECT_EQ(status.code(), absl::StatusCode::kInvalidArgument); } TEST(PacketTest, SyncedPacket) { @@ -283,11 +283,10 @@ TEST(PacketTest, TestPacketMoveConstructor) { TEST(PacketTest, TestPacketConsume) { Packet packet1 = MakePacket(33); Packet packet_copy = packet1; - mediapipe::StatusOr> result1 = - packet_copy.Consume(); + absl::StatusOr> result1 = packet_copy.Consume(); // Both packet1 and packet_copy own the data, Consume() should return error. - mediapipe::Status status1 = result1.status(); - EXPECT_EQ(status1.code(), mediapipe::StatusCode::kFailedPrecondition); + absl::Status status1 = result1.status(); + EXPECT_EQ(status1.code(), absl::StatusCode::kFailedPrecondition); EXPECT_THAT(status1.message(), testing::HasSubstr("isn't the sole owner of the holder")); ASSERT_FALSE(packet1.IsEmpty()); @@ -297,8 +296,7 @@ TEST(PacketTest, TestPacketConsume) { Packet packet2 = MakePacket(33); // Types don't match (int vs float). - mediapipe::StatusOr> result2 = - packet2.Consume(); + absl::StatusOr> result2 = packet2.Consume(); EXPECT_THAT( result2.status().message(), testing::AllOf(testing::HasSubstr("int"), testing::HasSubstr("float"))); @@ -307,11 +305,11 @@ TEST(PacketTest, TestPacketConsume) { // packet3 is the sole owner of the data. Packet packet3 = MakePacket(42); - mediapipe::StatusOr> result3 = packet3.Consume(); + absl::StatusOr> result3 = packet3.Consume(); // After Consume(), packet3 should be empty and result3 owns the data. EXPECT_TRUE(result3.ok()); - ASSERT_NE(nullptr, result3.ValueOrDie()); - EXPECT_EQ(42, *result3.ValueOrDie()); + ASSERT_NE(nullptr, result3.value()); + EXPECT_EQ(42, *result3.value()); EXPECT_TRUE(packet3.IsEmpty()); } @@ -319,14 +317,14 @@ TEST(PacketTest, TestPacketConsumeOrCopy) { Packet packet1 = MakePacket(33); Packet packet_copy = packet1; bool was_copied1 = false; - mediapipe::StatusOr> result1 = + absl::StatusOr> result1 = packet_copy.ConsumeOrCopy(&was_copied1); // Both packet1 and packet_copy own the data, ConsumeOrCopy() returns a copy // of the data and sets packet_copy to empty. EXPECT_TRUE(result1.ok()); EXPECT_TRUE(was_copied1); - ASSERT_NE(nullptr, result1.ValueOrDie()); - EXPECT_EQ(33, *result1.ValueOrDie()); + ASSERT_NE(nullptr, result1.value()); + EXPECT_EQ(33, *result1.value()); EXPECT_TRUE(packet_copy.IsEmpty()); // ConsumeOrCopy() doesn't affect packet1. ASSERT_FALSE(packet1.IsEmpty()); @@ -334,7 +332,7 @@ TEST(PacketTest, TestPacketConsumeOrCopy) { Packet packet2 = MakePacket(33); // Types don't match (int vs float). - mediapipe::StatusOr> result2 = + absl::StatusOr> result2 = packet2.ConsumeOrCopy(); EXPECT_THAT( result2.status().message(), @@ -346,21 +344,21 @@ TEST(PacketTest, TestPacketConsumeOrCopy) { bool was_copied3 = false; // packet3 is the sole owner of the data. ConsumeOrCopy() transfers the // ownership to result3 and makes packet3 empty. - mediapipe::StatusOr> result3 = + absl::StatusOr> result3 = packet3.ConsumeOrCopy(&was_copied3); EXPECT_FALSE(was_copied3); EXPECT_TRUE(result3.ok()); - ASSERT_NE(nullptr, result3.ValueOrDie()); - EXPECT_EQ(42, *result3.ValueOrDie()); + ASSERT_NE(nullptr, result3.value()); + EXPECT_EQ(42, *result3.value()); EXPECT_TRUE(packet3.IsEmpty()); } TEST(PacketTest, TestConsumeForeignHolder) { std::unique_ptr data(new int(33)); Packet packet = PointToForeign(data.get()); - mediapipe::StatusOr> result = packet.Consume(); + absl::StatusOr> result = packet.Consume(); EXPECT_FALSE(result.ok()); - EXPECT_EQ(result.status().code(), mediapipe::StatusCode::kInternal); + EXPECT_EQ(result.status().code(), absl::StatusCode::kInternal); EXPECT_EQ(result.status().message(), "Foreign holder can't release data ptr without ownership."); ASSERT_FALSE(packet.IsEmpty()); @@ -372,15 +370,15 @@ TEST(PacketTest, TestForeignHolderConsumeOrCopy) { Packet packet1 = PointToForeign(data1.get()); Packet packet_copy = packet1; bool was_copied1 = false; - mediapipe::StatusOr> result1 = + absl::StatusOr> result1 = packet_copy.ConsumeOrCopy(&was_copied1); // After ConsumeOrCopy(), result1 gets the copy of packet_copy's data and // packet_copy is set to empty. EXPECT_TRUE(packet_copy.IsEmpty()); EXPECT_TRUE(was_copied1); EXPECT_TRUE(result1.ok()); - ASSERT_NE(nullptr, result1.ValueOrDie()); - EXPECT_EQ(42, *result1.ValueOrDie()); + ASSERT_NE(nullptr, result1.value()); + EXPECT_EQ(42, *result1.value()); // ConsumeOrCopy() doesn't affect packet1. ASSERT_FALSE(packet1.IsEmpty()); EXPECT_EQ(42, packet1.Get()); @@ -388,25 +386,25 @@ TEST(PacketTest, TestForeignHolderConsumeOrCopy) { std::unique_ptr data2(new int(33)); Packet packet2 = PointToForeign(data2.get()); bool was_copied2 = false; - mediapipe::StatusOr> result2 = + absl::StatusOr> result2 = packet2.ConsumeOrCopy(&was_copied2); // After ConsumeOrCopy(), result2 gets the copy of packet2's data and packet2 // is set to empty. EXPECT_TRUE(packet2.IsEmpty()); EXPECT_TRUE(was_copied2); EXPECT_TRUE(result2.ok()); - ASSERT_NE(nullptr, result2.ValueOrDie()); - EXPECT_EQ(33, *result2.ValueOrDie()); + ASSERT_NE(nullptr, result2.value()); + EXPECT_EQ(33, *result2.value()); } TEST(PacketTest, TestConsumeBoundedArray) { Packet packet1 = MakePacket(10, 20, 30); Packet packet_copy = packet1; - mediapipe::StatusOr> result1 = + absl::StatusOr> result1 = packet_copy.Consume(); // Both packet1 and packet_copy own the data, Consume() should return error. - mediapipe::Status status1 = result1.status(); - EXPECT_EQ(status1.code(), mediapipe::StatusCode::kFailedPrecondition); + absl::Status status1 = result1.status(); + EXPECT_EQ(status1.code(), absl::StatusCode::kFailedPrecondition); EXPECT_THAT(status1.message(), testing::HasSubstr("isn't the sole owner of the holder")); ASSERT_FALSE(packet1.IsEmpty()); @@ -422,10 +420,9 @@ TEST(PacketTest, TestConsumeBoundedArray) { Packet packet2 = MakePacket(40, 50, 60); // After Consume(), packet2 should be empty and result2 owns the data. - mediapipe::StatusOr> result2 = - packet2.Consume(); - ASSERT_NE(nullptr, result2.ValueOrDie()); - auto value3 = result2.ValueOrDie().get(); + absl::StatusOr> result2 = packet2.Consume(); + ASSERT_NE(nullptr, result2.value()); + auto value3 = result2.value().get(); EXPECT_EQ(40, (*value3)[0]); EXPECT_EQ(50, (*value3)[1]); EXPECT_EQ(60, (*value3)[2]); @@ -436,14 +433,14 @@ TEST(PacketTest, TestConsumeOrCopyBoundedArray) { Packet packet1 = MakePacket(10, 20, 30); Packet packet_copy = packet1; bool was_copied1 = false; - mediapipe::StatusOr> result1 = + absl::StatusOr> result1 = packet_copy.ConsumeOrCopy(&was_copied1); // Both packet1 and packet_copy own the data, ConsumeOrCopy() returns a copy // of the data and sets packet_copy to empty. EXPECT_TRUE(result1.ok()); EXPECT_TRUE(was_copied1); - ASSERT_NE(nullptr, result1.ValueOrDie()); - auto value1 = result1.ValueOrDie().get(); + ASSERT_NE(nullptr, result1.value()); + auto value1 = result1.value().get(); EXPECT_EQ(10, (*value1)[0]); EXPECT_EQ(20, (*value1)[1]); EXPECT_EQ(30, (*value1)[2]); @@ -459,12 +456,12 @@ TEST(PacketTest, TestConsumeOrCopyBoundedArray) { bool was_copied2 = false; // packet2 is the sole owner of the data. ConsumeOrCopy() transfers the // ownership to result2 and makes packet2 empty. - mediapipe::StatusOr> result2 = + absl::StatusOr> result2 = packet2.ConsumeOrCopy(&was_copied2); EXPECT_TRUE(result2.ok()); EXPECT_FALSE(was_copied2); - ASSERT_NE(nullptr, result2.ValueOrDie()); - auto value3 = result2.ValueOrDie().get(); + ASSERT_NE(nullptr, result2.value()); + auto value3 = result2.value().get(); EXPECT_EQ(40, (*value3)[0]); EXPECT_EQ(50, (*value3)[1]); EXPECT_EQ(60, (*value3)[2]); @@ -487,10 +484,23 @@ TEST(PacketTest, PacketFromSerializedProto) { StatusOr maybe_packet = packet_internal::PacketFromDynamicProto( "mediapipe.SimpleProto", serialized); MP_ASSERT_OK(maybe_packet); - Packet packet = maybe_packet.ValueOrDie(); + Packet packet = maybe_packet.value(); MP_EXPECT_OK(packet.ValidateAsType<::mediapipe::SimpleProto>()); EXPECT_FALSE(packet.ValidateAsType<::mediapipe::PacketTestProto>().ok()); } +TEST(PacketTest, SharedPtrWithPacketOwnership) { + bool exist; + Packet packet = MakePacket(&exist); + ASSERT_EQ(exist, true); + std::shared_ptr ptr = SharedPtrWithPacket(packet); + packet = {}; + // The shared_ptr should still be retaining the object. + EXPECT_EQ(exist, true); + ptr = nullptr; + // Now it should be released. + EXPECT_EQ(exist, false); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/framework/packet_type.cc b/mediapipe/framework/packet_type.cc index ecf9f8041..bbcd84d80 100644 --- a/mediapipe/framework/packet_type.cc +++ b/mediapipe/framework/packet_type.cc @@ -125,9 +125,9 @@ const std::string PacketType::DebugTypeName() const { return type_name_; } -mediapipe::Status PacketType::Validate(const Packet& packet) const { +absl::Status PacketType::Validate(const Packet& packet) const { if (!initialized_) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Uninitialized PacketType was used for validation."); } if (same_as_) { @@ -147,7 +147,7 @@ mediapipe::Status PacketType::Validate(const Packet& packet) const { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Empty packets are not allowed for type: " << type_name_; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } bool PacketType::IsConsistentWith(const PacketType& other) const { @@ -176,7 +176,7 @@ bool PacketType::IsConsistentWith(const PacketType& other) const { return type1->validate_method_ == type2->validate_method_; } -mediapipe::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set) { +absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set) { std::vector errors; if (packet_type_set.GetErrorHandler().HasError()) { errors = packet_type_set.GetErrorHandler().ErrorMessages(); @@ -190,25 +190,24 @@ mediapipe::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set) { } } if (!errors.empty()) { - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "ValidatePacketTypeSet failed:\n", absl::StrJoin(errors, "\n"))); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatePacketSet(const PacketTypeSet& packet_type_set, - const PacketSet& packet_set) { - std::vector errors; +absl::Status ValidatePacketSet(const PacketTypeSet& packet_type_set, + const PacketSet& packet_set) { + std::vector errors; if (!packet_type_set.TagMap()->SameAs(*packet_set.TagMap())) { - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "TagMaps do not match. PacketTypeSet TagMap:\n", packet_type_set.TagMap()->DebugString(), "\n\nPacketSet TagMap:\n", packet_set.TagMap()->DebugString())); } for (CollectionItemId id = packet_type_set.BeginId(); id < packet_type_set.EndId(); ++id) { - mediapipe::Status status = - packet_type_set.Get(id).Validate(packet_set.Get(id)); + absl::Status status = packet_type_set.Get(id).Validate(packet_set.Get(id)); if (!status.ok()) { std::pair tag_index = packet_type_set.TagAndIndexFromId(id); @@ -222,7 +221,7 @@ mediapipe::Status ValidatePacketSet(const PacketTypeSet& packet_type_set, if (!errors.empty()) { return tool::CombinedStatus("ValidatePacketSet failed:", errors); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/framework/packet_type.h b/mediapipe/framework/packet_type.h index 1a3ee81e1..868d75243 100644 --- a/mediapipe/framework/packet_type.h +++ b/mediapipe/framework/packet_type.h @@ -41,8 +41,10 @@ class PacketType { public: // Creates an uninitialized PacketType. PacketType(); - PacketType(const PacketType&) = delete; - PacketType& operator=(const PacketType&) = delete; + + // PacketType can be passed by value. + PacketType(const PacketType&) = default; + PacketType& operator=(const PacketType&) = default; // False for a PacketType that has not had any Set*() function called. bool IsInitialized() const; @@ -85,7 +87,7 @@ class PacketType { bool IsConsistentWith(const PacketType& other) const; // Returns OK if the packet contains an object of the appropriate type. - mediapipe::Status Validate(const Packet& packet) const; + absl::Status Validate(const Packet& packet) const; // Returns a pointer to the Registered type name, or nullptr if the type // is not registered. Do not use this for validation, use Validate() @@ -98,7 +100,7 @@ class PacketType { private: // Typedef for the ValidateAsType() method in Packet that is used for // type validation and identification. - typedef mediapipe::Status (Packet::*ValidateMethodType)() const; + typedef absl::Status (Packet::*ValidateMethodType)() const; // Records whether the packet type was set in any way. bool initialized_; @@ -213,15 +215,15 @@ using PacketTypeSet = // Returns OK if the packets in the PacketSet are of the appropriate type. // packet_type_set must be valid before this is called (but packet_set // may be in any state). -mediapipe::Status ValidatePacketSet(const PacketTypeSet& packet_type_set, - const PacketSet& packet_set); +absl::Status ValidatePacketSet(const PacketTypeSet& packet_type_set, + const PacketSet& packet_set); // Validates that the PacketTypeSet was initialized properly. // An error is returned if // 1) Tag() or Index() is called with an invalid argument (however, // a valid PacketType is still returned by the function). // 2) Any PacketType is not initialized. -mediapipe::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set); +absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set); // Templated function definitions. diff --git a/mediapipe/framework/port.h b/mediapipe/framework/port.h index 521620520..8994c652e 100644 --- a/mediapipe/framework/port.h +++ b/mediapipe/framework/port.h @@ -44,9 +44,9 @@ // These platforms do not support OpenGL ES Compute Shaders (v3.1 and up), // but may or may not still be able to run other OpenGL code. -#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) && \ - (defined(__APPLE__) || defined(__EMSCRIPTEN__) || \ - defined(MEDIAPIPE_DISABLE_GPU) || MEDIAPIPE_USING_SWIFTSHADER) +#if !defined(MEDIAPIPE_DISABLE_GL_COMPUTE) && \ + (defined(__APPLE__) || defined(__EMSCRIPTEN__) || MEDIAPIPE_DISABLE_GPU || \ + MEDIAPIPE_USING_SWIFTSHADER) #define MEDIAPIPE_DISABLE_GL_COMPUTE #endif @@ -56,7 +56,7 @@ #define MEDIAPIPE_OPENGL_ES_30 300 #define MEDIAPIPE_OPENGL_ES_31 310 -#if defined(MEDIAPIPE_DISABLE_GPU) +#if MEDIAPIPE_DISABLE_GPU #define MEDIAPIPE_OPENGL_ES_VERSION 0 #define MEDIAPIPE_METAL_ENABLED 0 #else diff --git a/mediapipe/framework/port/BUILD b/mediapipe/framework/port/BUILD index 78fa44739..cc15572d6 100644 --- a/mediapipe/framework/port/BUILD +++ b/mediapipe/framework/port/BUILD @@ -62,9 +62,8 @@ cc_library( ":advanced_proto_lite", ":core_proto", "//mediapipe/framework:port", - ] + select({ - "//conditions:default": ["@com_google_protobuf//:protobuf"], - }), + "@com_google_protobuf//:protobuf", + ], ) cc_library( @@ -77,9 +76,8 @@ cc_library( deps = [ ":core_proto", "//mediapipe/framework:port", - ] + select({ - "//conditions:default": ["@com_google_protobuf//:protobuf"], - }), + "@com_google_protobuf//:protobuf", + ], ) cc_library( @@ -98,8 +96,8 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework:port", "//third_party:glog", + "@com_google_absl//absl/flags:flag", ], ) @@ -112,9 +110,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", - ] + select({ - "//conditions:default": ["@com_google_protobuf//:protobuf"], - }), + "@com_google_protobuf//:protobuf", + ], ) cc_library( @@ -395,7 +392,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:port", - "//mediapipe/framework/deps:statusor", + "@com_google_absl//absl/status:statusor", ], ) diff --git a/mediapipe/framework/port/advanced_proto_lite_inc.h b/mediapipe/framework/port/advanced_proto_lite_inc.h index a6627cbb9..13b72967b 100644 --- a/mediapipe/framework/port/advanced_proto_lite_inc.h +++ b/mediapipe/framework/port/advanced_proto_lite_inc.h @@ -18,6 +18,7 @@ #ifndef MEDIAPIPE_PORT_ADVANCED_PROTO_LITE_INC_H_ #define MEDIAPIPE_PORT_ADVANCED_PROTO_LITE_INC_H_ +#include "google/protobuf/io/zero_copy_stream.h" #include "google/protobuf/io/zero_copy_stream_impl_lite.h" #include "google/protobuf/wire_format_lite.h" #include "mediapipe/framework/port.h" diff --git a/mediapipe/framework/port/build_config.bzl b/mediapipe/framework/port/build_config.bzl index f07c96c16..3a571b219 100644 --- a/mediapipe/framework/port/build_config.bzl +++ b/mediapipe/framework/port/build_config.bzl @@ -58,7 +58,7 @@ def mediapipe_proto_library( portable_deps: the portable_proto_library targets for all referenced protobufs. visibility: visibility of this target. testonly: true means the proto can be used for testing only. - compatible_with: see go/target-constraints. + compatible_with: a list of environments the rule is compatible with. def_proto: define the proto_library target def_cc_proto: define the cc_proto_library target def_py_proto: define the py_proto_library target @@ -111,7 +111,6 @@ def mediapipe_proto_library( native.java_lite_proto_library(**provided_args( name = replace_suffix(name, "_proto", "_java_proto_lite"), deps = proto_deps, - strict_deps = 0, visibility = visibility, testonly = testonly, compatible_with = compatible_with, diff --git a/mediapipe/framework/port/commandlineflags.h b/mediapipe/framework/port/commandlineflags.h index 9c54576ff..a3d17c71e 100644 --- a/mediapipe/framework/port/commandlineflags.h +++ b/mediapipe/framework/port/commandlineflags.h @@ -16,5 +16,15 @@ #define MEDIAPIPE_PORT_COMMANDLINEFLAGS_H_ #include "gflags/gflags.h" +namespace absl { +template +T GetFlag(const T& f) { + return f; +} +template +void SetFlag(T* f, const U& u) { + *f = u; +} +} // namespace absl #endif // MEDIAPIPE_PORT_COMMANDLINEFLAGS_H_ diff --git a/mediapipe/framework/port/core_proto_inc.h b/mediapipe/framework/port/core_proto_inc.h index f56a44e36..20f50097b 100644 --- a/mediapipe/framework/port/core_proto_inc.h +++ b/mediapipe/framework/port/core_proto_inc.h @@ -22,6 +22,7 @@ #include "google/protobuf/repeated_field.h" #include "mediapipe/framework/port.h" #include "mediapipe/framework/port/proto_ns.h" + #if !defined(MEDIAPIPE_PROTO_LITE) #include "google/protobuf/text_format.h" #endif // !defined(MEDIAPIPE_PROTO_LITE) diff --git a/mediapipe/framework/port/file_helpers.h b/mediapipe/framework/port/file_helpers.h index 68ccb4605..b276852c9 100644 --- a/mediapipe/framework/port/file_helpers.h +++ b/mediapipe/framework/port/file_helpers.h @@ -16,5 +16,6 @@ #define MEDIAPIPE_PORT_FILE_HELPERS_H_ #include "mediapipe/framework/deps/file_helpers.h" +#include "mediapipe/framework/deps/file_path.h" #endif // MEDIAPIPE_PORT_FILE_HELPERS_H_ diff --git a/mediapipe/framework/port/proto_ns.h b/mediapipe/framework/port/proto_ns.h index 6e13ca45e..83aecdf49 100644 --- a/mediapipe/framework/port/proto_ns.h +++ b/mediapipe/framework/port/proto_ns.h @@ -29,4 +29,10 @@ namespace proto_ns = ::google::protobuf; typedef ::std::string ProtoString; } // namespace mediapipe. +// Legacy namespace support. +namespace mediapipe { +namespace proto_ns = mediapipe::proto_ns; +typedef ::std::string ProtoString; +} // namespace mediapipe + #endif // MEDIAPIPE_PORT_PROTO_NS_H_ diff --git a/mediapipe/framework/port/statusor.h b/mediapipe/framework/port/statusor.h index 9373dc65a..c38ce0f87 100644 --- a/mediapipe/framework/port/statusor.h +++ b/mediapipe/framework/port/statusor.h @@ -15,6 +15,14 @@ #ifndef MEDIAPIPE_PORT_STATUSOR_H_ #define MEDIAPIPE_PORT_STATUSOR_H_ -#include "mediapipe/framework/deps/statusor.h" +#include "absl/status/statusor.h" + +namespace mediapipe { + +template +using StatusOr ABSL_DEPRECATED("Use absl::StatusOr directly") = + absl::StatusOr; + +} // namespace mediapipe #endif // MEDIAPIPE_PORT_STATUSOR_H_ diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 3ef48ed94..cabc980d2 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -235,6 +235,7 @@ cc_test( "//mediapipe/framework/deps:clock", "//mediapipe/framework/deps:message_matchers", "//mediapipe/framework/port:advanced_proto", + "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:integral_types", @@ -246,7 +247,6 @@ cc_test( "//mediapipe/framework/tool:simulation_clock", "//mediapipe/framework/tool:simulation_clock_executor", "//mediapipe/framework/tool:status_util", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/time", ], ) @@ -285,12 +285,11 @@ cc_library( "//mediapipe/framework:mediapipe_internal", ], deps = [ - "@com_google_absl//absl/strings", + "@com_google_absl//absl/flags:flag", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:statusor", "//mediapipe/framework/port:status", - "@com_google_absl//absl/flags:flag", "//mediapipe/framework/deps:file_path", ] + select({ "//conditions:default": [ @@ -298,7 +297,7 @@ cc_library( ], "//mediapipe:android": [ "//mediapipe/java/com/google/mediapipe/framework/jni:jni_util", - "//mediapipe/util/android/file/base", + "//mediapipe/framework/port:file_helpers", ], "//mediapipe:apple": [ "//mediapipe/framework/port:file_helpers", diff --git a/mediapipe/framework/profiler/graph_profiler.cc b/mediapipe/framework/profiler/graph_profiler.cc index 10a2e742e..eb7d80c62 100644 --- a/mediapipe/framework/profiler/graph_profiler.cc +++ b/mediapipe/framework/profiler/graph_profiler.cc @@ -199,7 +199,7 @@ void GraphProfiler::Reset() { } // Begins profiling for a single graph run. -mediapipe::Status GraphProfiler::Start(mediapipe::Executor* executor) { +absl::Status GraphProfiler::Start(mediapipe::Executor* executor) { // If specified, start periodic profile output while the graph runs. Resume(); if (is_tracing_ && IsTraceIntervalEnabled(profiler_config_, tracer()) && @@ -220,18 +220,18 @@ mediapipe::Status GraphProfiler::Start(mediapipe::Executor* executor) { } }); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Ends profiling for a single graph run. -mediapipe::Status GraphProfiler::Stop() { +absl::Status GraphProfiler::Stop() { is_running_ = false; Pause(); // If specified, write a final profile. if (IsTraceLogEnabled(profiler_config_)) { MP_RETURN_IF_ERROR(WriteProfile()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void GraphProfiler::LogEvent(const TraceEvent& event) { @@ -281,7 +281,7 @@ void GraphProfiler::AddPacketInfo(const TraceEvent& packet_info) { production_time_usec, production_time_usec); } -mediapipe::Status GraphProfiler::GetCalculatorProfiles( +absl::Status GraphProfiler::GetCalculatorProfiles( std::vector* profiles) const { absl::ReaderMutexLock lock(&profiler_mutex_); RET_CHECK(is_initialized_) @@ -289,7 +289,7 @@ mediapipe::Status GraphProfiler::GetCalculatorProfiles( for (auto& entry : calculator_profiles_) { profiles->push_back(entry.second); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void GraphProfiler::InitializeTimeHistogram(int64 interval_size_usec, @@ -308,7 +308,7 @@ void GraphProfiler::InitializeInputStreams( const CalculatorGraphConfig::Node& node_config, int64 interval_size_usec, int64 num_intervals, CalculatorProfile* calculator_profile) { std::shared_ptr input_tag_map = - TagMap::Create(node_config.input_stream()).ValueOrDie(); + TagMap::Create(node_config.input_stream()).value(); std::set back_edge_ids = GetBackEdgeIds(node_config, *input_tag_map); auto input_tag_map_names = input_tag_map->Names(); for (int i = 0; i < input_tag_map_names.size(); ++i) { @@ -566,9 +566,9 @@ void AssignNodeNames(GraphProfile* profile) { } } -mediapipe::StatusOr GraphProfiler::GetTraceLogPath() { +absl::StatusOr GraphProfiler::GetTraceLogPath() { if (!IsTraceLogEnabled(profiler_config_)) { - return mediapipe::InternalError( + return absl::InternalError( "Trace log writing is disabled, unable to get trace_log_path."); } if (profiler_config_.trace_log_path().empty()) { @@ -581,42 +581,49 @@ mediapipe::StatusOr GraphProfiler::GetTraceLogPath() { } } -mediapipe::Status GraphProfiler::WriteProfile() { - if (profiler_config_.trace_log_disabled()) { - // Logging is disabled, so we can exit writing without error. - return mediapipe::OkStatus(); - } - ASSIGN_OR_RETURN(std::string trace_log_path, GetTraceLogPath()); - int log_interval_count = GetLogIntervalCount(profiler_config_); - int log_file_count = GetLogFileCount(profiler_config_); - +absl::Status GraphProfiler::CaptureProfile(GraphProfile* result) { // Record the GraphTrace events since the previous WriteProfile. // The end_time is chosen to be trace_log_margin_usec in the past, // providing time for events to be appended to the TraceBuffer. absl::Time end_time = clock_->TimeNow() - absl::Microseconds(profiler_config_.trace_log_margin_usec()); - GraphProfile profile; - GraphTrace* trace = profile.add_graph_trace(); + GraphTrace* trace = result->add_graph_trace(); if (!profiler_config_.trace_log_instant_events()) { tracer()->GetTrace(previous_log_end_time_, end_time, trace); } else { tracer()->GetLog(previous_log_end_time_, end_time, trace); } previous_log_end_time_ = end_time; - // If there are no trace events, skip log writing. - if (is_tracing_ && trace->calculator_trace().empty()) { - return mediapipe::OkStatus(); - } // Record the latest CalculatorProfiles. Status status; std::vector profiles; status.Update(GetCalculatorProfiles(&profiles)); for (CalculatorProfile& p : profiles) { - *profile.mutable_calculator_profiles()->Add() = std::move(p); + *result->mutable_calculator_profiles()->Add() = std::move(p); } this->Reset(); + AssignNodeNames(result); + return status; +} + +absl::Status GraphProfiler::WriteProfile() { + if (profiler_config_.trace_log_disabled()) { + // Logging is disabled, so we can exit writing without error. + return absl::OkStatus(); + } + ASSIGN_OR_RETURN(std::string trace_log_path, GetTraceLogPath()); + int log_interval_count = GetLogIntervalCount(profiler_config_); + int log_file_count = GetLogFileCount(profiler_config_); + GraphProfile profile; + MP_RETURN_IF_ERROR(CaptureProfile(&profile)); + + // If there are no trace events, skip log writing. + const GraphTrace& trace = *profile.graph_trace().rbegin(); + if (is_tracing_ && trace.calculator_trace().empty()) { + return absl::OkStatus(); + } // Record the CalculatorGraphConfig, once per log file. ++previous_log_index_; @@ -638,7 +645,7 @@ mediapipe::Status GraphProfiler::WriteProfile() { OstreamStream out(&ofs); RET_CHECK(profile.SerializeToZeroCopyStream(&out)) << "Could not write binary GraphProfile to: " << log_path; - return status; + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/framework/profiler/graph_profiler.h b/mediapipe/framework/profiler/graph_profiler.h index 94fc9fcc1..946208ceb 100644 --- a/mediapipe/framework/profiler/graph_profiler.h +++ b/mediapipe/framework/profiler/graph_profiler.h @@ -140,9 +140,9 @@ class GraphProfiler : public std::enable_shared_from_this { // Process() and does NOT affect information for Open() and Close() methods. void Reset() ABSL_LOCKS_EXCLUDED(profiler_mutex_); // Begins profiling for a single graph run. - mediapipe::Status Start(mediapipe::Executor* executor); + absl::Status Start(mediapipe::Executor* executor); // Ends profiling for a single graph run. - mediapipe::Status Stop(); + absl::Status Stop(); // Record a tracing event. void LogEvent(const TraceEvent& event); @@ -150,12 +150,16 @@ class GraphProfiler : public std::enable_shared_from_this { // Collects the runtime profile for Open(), Process(), and Close() of each // calculator in the graph. May be called at any time after the graph has been // initialized. - mediapipe::Status GetCalculatorProfiles(std::vector*) const + absl::Status GetCalculatorProfiles(std::vector*) const ABSL_LOCKS_EXCLUDED(profiler_mutex_); + // Records recent profiling and tracing data. Includes events since the + // previous call to CaptureProfile. + absl::Status CaptureProfile(GraphProfile* result); + // Writes recent profiling and tracing data to a file specified in the // ProfilerConfig. Includes events since the previous call to WriteProfile. - mediapipe::Status WriteProfile(); + absl::Status WriteProfile(); // Returns the trace event buffer. GraphTracer* tracer() { return packet_tracer_.get(); } @@ -294,7 +298,7 @@ class GraphProfiler : public std::enable_shared_from_this { // Helper method to get trace_log_path. If the trace_log_path is empty and // tracing is enabled, this function returns a default platform dependent // trace_log_path. - mediapipe::StatusOr GetTraceLogPath(); + absl::StatusOr GetTraceLogPath(); // Helper method to get the clock time in microsecond. int64 TimeNowUsec() { return ToUnixMicros(clock_->TimeNow()); } diff --git a/mediapipe/framework/profiler/graph_profiler_ios_test.mm b/mediapipe/framework/profiler/graph_profiler_ios_test.mm index 99e480c3c..ad95159ad 100644 --- a/mediapipe/framework/profiler/graph_profiler_ios_test.mm +++ b/mediapipe/framework/profiler/graph_profiler_ios_test.mm @@ -61,7 +61,7 @@ static const char* kOutputStream = "counter"; success = [graph waitUntilDoneWithError:&error]; XCTAssertTrue(success, @"%@", error.localizedDescription); - mediapipe::StatusOr getTraceLogDir = mediapipe::GetDefaultTraceLogDirectory(); + absl::StatusOr getTraceLogDir = mediapipe::GetDefaultTraceLogDirectory(); XCTAssertTrue(getTraceLogDir.ok(), "GetDefaultTraceLogDirectory failed."); NSString* directoryPath = [NSString stringWithCString:(*getTraceLogDir).c_str() diff --git a/mediapipe/framework/profiler/graph_profiler_stub.h b/mediapipe/framework/profiler/graph_profiler_stub.h index 16a12abf0..6621c0192 100644 --- a/mediapipe/framework/profiler/graph_profiler_stub.h +++ b/mediapipe/framework/profiler/graph_profiler_stub.h @@ -81,17 +81,17 @@ class GraphProfilerStub { inline void Initialize(const ValidatedGraphConfig& validated_graph_config) {} inline void SetClock(const std::shared_ptr& clock) {} inline void LogEvent(const TraceEvent& event) {} - inline mediapipe::Status GetCalculatorProfiles( + inline absl::Status GetCalculatorProfiles( std::vector*) const { - return mediapipe::OkStatus(); + return absl::OkStatus(); } inline void Pause() {} inline void Resume() {} inline void Reset() {} - inline mediapipe::Status Start(mediapipe::Executor* executor) { - return mediapipe::OkStatus(); + inline absl::Status Start(mediapipe::Executor* executor) { + return absl::OkStatus(); } - inline mediapipe::Status Stop() { return mediapipe::OkStatus(); } + inline absl::Status Stop() { return absl::OkStatus(); } inline GraphTracer* tracer() { return nullptr; } inline std::unique_ptr CreateGlProfilingHelper() { return nullptr; diff --git a/mediapipe/framework/profiler/graph_profiler_test.cc b/mediapipe/framework/profiler/graph_profiler_test.cc index d06b0eb6c..0d7c51eb5 100644 --- a/mediapipe/framework/profiler/graph_profiler_test.cc +++ b/mediapipe/framework/profiler/graph_profiler_test.cc @@ -328,8 +328,6 @@ TEST_F(GraphProfilerTestPeer, InitializeConfigWithoutStreamLatency) { // Tests that Initialize() reads all the configs defined in the graph // definition. -// The best way to understand this test in case of debugging is to visualize -// the graph using go/mediapipe-vis. TEST_F(GraphProfilerTestPeer, Initialize) { InitializeProfilerWithGraphConfig(R"( profiler_config { @@ -1176,7 +1174,7 @@ TEST(GraphProfilerTest, ParallelReads) { MP_ASSERT_OK(graph.ObserveOutputStream("out_1", [&](const Packet& packet) { absl::MutexLock lock(&out_1_mutex); out_1_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_EXPECT_OK(graph.StartRun( {{"range_step", MakePacket>(1000, 1)}})); diff --git a/mediapipe/framework/profiler/graph_tracer_test.cc b/mediapipe/framework/profiler/graph_tracer_test.cc index 1d91b0ab1..4c50a6c91 100644 --- a/mediapipe/framework/profiler/graph_tracer_test.cc +++ b/mediapipe/framework/profiler/graph_tracer_test.cc @@ -28,6 +28,7 @@ #include "mediapipe/framework/deps/clock.h" #include "mediapipe/framework/deps/message_matchers.h" #include "mediapipe/framework/port/advanced_proto_inc.h" +#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gtest.h" @@ -475,19 +476,19 @@ class GraphTracerE2ETest : public ::testing::Test { } // A Calculator::Process callback function. - typedef std::function + typedef std::function ProcessFunction; // A testing callback function that passes through all packets. - mediapipe::Status PassThrough(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { + absl::Status PassThrough(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { for (int i = 0; i < inputs.NumEntries(); ++i) { if (!inputs.Index(i).Value().IsEmpty()) { outputs->Index(i).AddPacket(inputs.Index(i).Value()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void RunPassThroughGraph() { @@ -511,7 +512,7 @@ class GraphTracerE2ETest : public ::testing::Test { MP_ASSERT_OK( graph_.ObserveOutputStream("output_0", [&](const Packet& packet) { out_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); simulation_clock_->ThreadStart(); MP_ASSERT_OK(graph_.StartRun({})); @@ -557,7 +558,7 @@ class GraphTracerE2ETest : public ::testing::Test { clock_->Sleep(absl::Microseconds(packets.front().first)); outputs->Index(0).AddPacket(packets.front().second); packets.erase(packets.begin()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } return tool::StatusStop(); }; @@ -580,7 +581,7 @@ class GraphTracerE2ETest : public ::testing::Test { MP_ASSERT_OK(graph_.ObserveOutputStream("output_packets_0", [&](const Packet& packet) { out_packets.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); simulation_clock_->ThreadStart(); MP_ASSERT_OK(graph_.StartRun({})); @@ -676,6 +677,7 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLog) { calculator_trace { node_id: 2 input_timestamp: 10000 } calculator_trace { node_id: 3 input_timestamp: 10000 } calculator_trace { node_id: 3 input_timestamp: 10000 } + calculator_trace { node_id: 4 input_timestamp: 10000 } calculator_trace { node_id: 2 input_timestamp: 10000 } calculator_trace { node_id: 3 input_timestamp: 10000 } calculator_trace { node_id: 0 input_timestamp: 20000 } @@ -988,14 +990,13 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLog) { } // Read a GraphProfile from a file path. -mediapipe::Status ReadGraphProfile(const std::string& path, - GraphProfile* profile) { +absl::Status ReadGraphProfile(const std::string& path, GraphProfile* profile) { std::ifstream ifs; ifs.open(path); proto_ns::io::IstreamInputStream in_stream(&ifs); profile->ParseFromZeroCopyStream(&in_stream); - return ifs.is_open() ? mediapipe::OkStatus() - : mediapipe::UnavailableError("Cannot open"); + return ifs.is_open() ? absl::OkStatus() + : absl::UnavailableError("Cannot open"); } TEST_F(GraphTracerE2ETest, DemuxGraphLogFile) { @@ -1007,7 +1008,7 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFile) { GraphProfile profile; MP_EXPECT_OK( ReadGraphProfile(absl::StrCat(log_path, 0, ".binarypb"), &profile)); - EXPECT_EQ(112, profile.graph_trace(0).calculator_trace().size()); + EXPECT_EQ(113, profile.graph_trace(0).calculator_trace().size()); } TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { @@ -1036,7 +1037,7 @@ TEST_F(GraphTracerE2ETest, DemuxGraphLogFiles) { // The expected counts of calculator_trace records in each of the log files. // The processing spans three 12.5ms log files, because // RunDemuxInFlightGraph adds packets over 30ms. - std::vector expected = {49, 64, 12}; + std::vector expected = {50, 64, 12}; EXPECT_EQ(event_counts, expected); GraphProfile& profile_2 = graph_profiles[2]; profile_2.clear_calculator_profiles(); @@ -1242,7 +1243,7 @@ TEST_F(GraphTracerE2ETest, DisableLoggingToDisk) { graph_config_.mutable_profiler_config()->set_trace_log_path(log_path); graph_config_.mutable_profiler_config()->set_trace_log_disabled(true); RunDemuxInFlightGraph(); - EXPECT_TRUE(mediapipe::IsNotFound( + EXPECT_TRUE(absl::IsNotFound( mediapipe::file::Exists(absl::StrCat(log_path, 0, ".binarypb")))); } diff --git a/mediapipe/framework/profiler/profiler_resource_util_android.cc b/mediapipe/framework/profiler/profiler_resource_util_android.cc index 0bdf51db7..09f656f4b 100644 --- a/mediapipe/framework/profiler/profiler_resource_util_android.cc +++ b/mediapipe/framework/profiler/profiler_resource_util_android.cc @@ -27,13 +27,13 @@ StatusOr GetDefaultTraceLogDirectory() { StatusOr* result = new StatusOr(); bool has_jvm = java::HasJavaVM(); if (!has_jvm) { - *result = mediapipe::InternalError("JavaVM not available."); + *result = absl::InternalError("JavaVM not available."); return result; } JNIEnv* env = java::GetJNIEnv(); if (!env) { - *result = mediapipe::InternalError("JNIEnv not available."); + *result = absl::InternalError("JNIEnv not available."); return result; } diff --git a/mediapipe/framework/profiler/profiler_resource_util_common.cc b/mediapipe/framework/profiler/profiler_resource_util_common.cc index bd205f8b7..d75ea972e 100644 --- a/mediapipe/framework/profiler/profiler_resource_util_common.cc +++ b/mediapipe/framework/profiler/profiler_resource_util_common.cc @@ -14,42 +14,25 @@ #include "absl/flags/flag.h" #include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/profiler/profiler_resource_util.h" -// TODO: Move this Android include to port/file_helpers. -// Also move this from resource_util.cc. -#ifdef __ANDROID__ -#include "mediapipe/util/android/file/base/filesystem.h" -#else -#include "mediapipe/framework/port/file_helpers.h" -#endif - ABSL_FLAG(std::string, log_root_dir, "", "The absolute path to the logging output directory. If specified, " "log_root_dir will be prepended to each specified log file path."); -#ifdef __ANDROID__ -namespace mediapipe { -namespace file { -mediapipe::Status RecursivelyCreateDir(absl::string_view path) { - return RecursivelyCreateDir(path, file::Options()); -} -} // namespace file -} // namespace mediapipe -#endif - namespace mediapipe { -mediapipe::StatusOr GetLogDirectory() { - if (!FLAGS_log_root_dir.CurrentValue().empty()) { - return FLAGS_log_root_dir.CurrentValue(); +absl::StatusOr GetLogDirectory() { + if (!absl::GetFlag(FLAGS_log_root_dir).empty()) { + return absl::GetFlag(FLAGS_log_root_dir); } return GetDefaultTraceLogDirectory(); } -mediapipe::StatusOr PathToLogFile(const std::string& path) { +absl::StatusOr PathToLogFile(const std::string& path) { ASSIGN_OR_RETURN(std::string log_dir, GetLogDirectory()); std::string result = file::JoinPath(log_dir, path); MP_RETURN_IF_ERROR( diff --git a/mediapipe/framework/profiler/profiler_resource_util_ios.cc b/mediapipe/framework/profiler/profiler_resource_util_ios.cc index b0f72f9db..b878e2fb0 100644 --- a/mediapipe/framework/profiler/profiler_resource_util_ios.cc +++ b/mediapipe/framework/profiler/profiler_resource_util_ios.cc @@ -37,7 +37,7 @@ StatusOr GetDefaultTraceLogDirectory() { error:&error]; if (!success) { // TODO: Use NSError+util_status to get status from NSError. - return mediapipe::InternalError([[error localizedDescription] UTF8String]); + return absl::InternalError([[error localizedDescription] UTF8String]); } std::string trace_log_directory = [ns_documents_directory UTF8String]; diff --git a/mediapipe/framework/profiler/reporter/BUILD b/mediapipe/framework/profiler/reporter/BUILD index 6e97a5f3f..3d92efd8d 100644 --- a/mediapipe/framework/profiler/reporter/BUILD +++ b/mediapipe/framework/profiler/reporter/BUILD @@ -51,7 +51,6 @@ cc_binary( deps = [ ":reporter_lib", "//mediapipe/framework/port:advanced_proto", - "//mediapipe/framework/port:commandlineflags", "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:status", "@com_google_absl//absl/container:btree", diff --git a/mediapipe/framework/profiler/reporter/README.md b/mediapipe/framework/profiler/reporter/README.md index 12fa4bb06..b7cfef6df 100644 --- a/mediapipe/framework/profiler/reporter/README.md +++ b/mediapipe/framework/profiler/reporter/README.md @@ -3,9 +3,6 @@ Allows a user to analyze trace files generated by MediaPipe from the command line. If you would prefer to see this information visually (or if you can't build the tool), you can see the same information within viz.mediapipe.dev. -For more information on this, see [Profile Visualization](https://docs.google.com/document/d/1inBoRzKDyKEjtVws8Seceoa0xRU3mqpPgM_myDBilS4/edit#heading=h.bnft45odm046) at go/mediapipe-profiler-guide - -See go/mediapipe-profiler-guide for a detailed user's guide. --- diff --git a/mediapipe/framework/profiler/reporter/print_profile.cc b/mediapipe/framework/profiler/reporter/print_profile.cc index a15403872..dd6e0846d 100644 --- a/mediapipe/framework/profiler/reporter/print_profile.cc +++ b/mediapipe/framework/profiler/reporter/print_profile.cc @@ -24,7 +24,6 @@ #include "absl/flags/parse.h" #include "absl/flags/usage.h" #include "mediapipe/framework/port/advanced_proto_inc.h" -#include "mediapipe/framework/port/commandlineflags.h" #include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/profiler/reporter/reporter.h" diff --git a/mediapipe/framework/profiler/reporter/reporter.cc b/mediapipe/framework/profiler/reporter/reporter.cc index 8a754ddc5..b61afa363 100644 --- a/mediapipe/framework/profiler/reporter/reporter.cc +++ b/mediapipe/framework/profiler/reporter/reporter.cc @@ -297,8 +297,7 @@ void Reporter::Accumulate(const mediapipe::GraphProfile& profile) { } } -mediapipe::Status Reporter::set_columns( - const std::vector& columns) { +absl::Status Reporter::set_columns(const std::vector& columns) { bool error = false; std::stringstream warnings; std::vector new_columns({"calculator"}); @@ -337,9 +336,9 @@ mediapipe::Status Reporter::set_columns( columns_.swap(new_columns); } if (!error) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } - return mediapipe::InvalidArgumentError(warnings.str()); + return absl::InvalidArgumentError(warnings.str()); } class ReportImpl : public Report { diff --git a/mediapipe/framework/profiler/reporter/reporter.h b/mediapipe/framework/profiler/reporter/reporter.h index 07f50e227..47716c920 100644 --- a/mediapipe/framework/profiler/reporter/reporter.h +++ b/mediapipe/framework/profiler/reporter/reporter.h @@ -106,7 +106,7 @@ class Reporter { // Accepts names of of columns or wildcard patterns (* or ?) to // select which statistics columns will be included in a generated // report. - mediapipe::Status set_columns(const std::vector& columns); + absl::Status set_columns(const std::vector& columns); // Generates a report based on the current accumulated statistics. std::unique_ptr Report(); diff --git a/mediapipe/framework/profiler/sharded_map_test.cc b/mediapipe/framework/profiler/sharded_map_test.cc index b9981a80b..e551b25c8 100644 --- a/mediapipe/framework/profiler/sharded_map_test.cc +++ b/mediapipe/framework/profiler/sharded_map_test.cc @@ -57,7 +57,7 @@ void TestWriteAndRead(Map& time_map) { // Tests writing, reading and erasing in a ShardedMap. TEST(ShardedMapTest, TestWriteAndRead) { - std::unordered_map simple_map; + absl::node_hash_map simple_map; TestWriteAndRead(simple_map); ShardedMap safe_map(4999, 1); TestWriteAndRead(safe_map); diff --git a/mediapipe/framework/profiler/test_context_builder.h b/mediapipe/framework/profiler/test_context_builder.h index 986a1ad8f..abf9ee749 100644 --- a/mediapipe/framework/profiler/test_context_builder.h +++ b/mediapipe/framework/profiler/test_context_builder.h @@ -74,8 +74,8 @@ class TestContextBuilder { state_ = absl::make_unique( node_name, node_id, "PCalculator", CalculatorGraphConfig::Node(), nullptr); - input_map_ = tool::CreateTagMap(inputs).ValueOrDie(); - output_map_ = tool::CreateTagMap(outputs).ValueOrDie(); + input_map_ = tool::CreateTagMap(inputs).value(); + output_map_ = tool::CreateTagMap(outputs).value(); input_handler_ = absl::make_unique( input_map_, nullptr, MediaPipeOptions(), false); input_managers_.reset(new InputStreamManager[input_map_->NumEntries()]); @@ -91,7 +91,7 @@ class TestContextBuilder { OutputStreamSpec spec; spec.name = output_map_->Names()[id.value()]; spec.packet_type = packet_type; - spec.error_callback = [](const mediapipe::Status& status) { + spec.error_callback = [](const absl::Status& status) { LOG(ERROR) << status; }; output_specs_[spec.name] = spec; diff --git a/mediapipe/framework/profiler/testing/simple_calculator.cc b/mediapipe/framework/profiler/testing/simple_calculator.cc index 2126b052d..18ba67b9b 100644 --- a/mediapipe/framework/profiler/testing/simple_calculator.cc +++ b/mediapipe/framework/profiler/testing/simple_calculator.cc @@ -19,15 +19,15 @@ namespace mediapipe { class SimpleCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); if (cc->InputSidePackets().HasTag("MAX_COUNT")) { cc->InputSidePackets().Tag("MAX_COUNT").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { LOG(WARNING) << "Simple Calculator Process called, count_: " << count_; int max_count = 1; if (cc->InputSidePackets().HasTag("MAX_COUNT")) { @@ -38,7 +38,7 @@ class SimpleCalculator : public CalculatorBase { } cc->Outputs().Index(0).Add(new int(count_), Timestamp(count_)); ++count_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/framework/profiler/trace_builder.cc b/mediapipe/framework/profiler/trace_builder.cc index ce6c6c57c..6797cd0d9 100644 --- a/mediapipe/framework/profiler/trace_builder.cc +++ b/mediapipe/framework/profiler/trace_builder.cc @@ -112,11 +112,13 @@ class StringIdMap { return string_id->second; } void clear() { pointer_id_map_.clear(), string_id_map_.clear(); } - const std::unordered_map& map() { return string_id_map_; } + const absl::node_hash_map& map() { + return string_id_map_; + } private: std::unordered_map pointer_id_map_; - std::unordered_map string_id_map_; + absl::node_hash_map string_id_map_; int32 next_id = 0; }; diff --git a/mediapipe/framework/scheduler.cc b/mediapipe/framework/scheduler.cc index 30d0b355d..afef4f383 100644 --- a/mediapipe/framework/scheduler.cc +++ b/mediapipe/framework/scheduler.cc @@ -83,8 +83,8 @@ void Scheduler::SetExecutor(Executor* executor) { } // TODO: Consider renaming this method CreateNonDefaultQueue. -mediapipe::Status Scheduler::SetNonDefaultExecutor(const std::string& name, - Executor* executor) { +absl::Status Scheduler::SetNonDefaultExecutor(const std::string& name, + Executor* executor) { RET_CHECK_EQ(state_, STATE_NOT_STARTED) << "SetNonDefaultExecutor must not " "be called after the scheduler " "has started"; @@ -99,7 +99,7 @@ mediapipe::Status Scheduler::SetNonDefaultExecutor(const std::string& name, std::placeholders::_1)); queue->SetExecutor(executor); scheduler_queues_.push_back(queue); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void Scheduler::SetQueuesRunning(bool running) { @@ -252,7 +252,7 @@ void Scheduler::EmittedObservedOutput() { } } -mediapipe::Status Scheduler::WaitForObservedOutput() { +absl::Status Scheduler::WaitForObservedOutput() { bool observed = false; ApplicationThreadAwait( [this, &observed]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mutex_) { @@ -262,8 +262,7 @@ mediapipe::Status Scheduler::WaitForObservedOutput() { // Wait until the member waiting_for_observed_output_ becomes false. return !waiting_for_observed_output_; }); - return observed ? mediapipe::OkStatus() - : mediapipe::OutOfRangeError("Graph is done."); + return observed ? absl::OkStatus() : absl::OutOfRangeError("Graph is done."); } // Idleness requires: @@ -273,18 +272,18 @@ mediapipe::Status Scheduler::WaitForObservedOutput() { // no source nodes. (This is enforced by CalculatorGraph::WaitUntilIdle().) // The application must ensure no other threads are adding packets to graph // input streams while a WaitUntilIdle() call is in progress. -mediapipe::Status Scheduler::WaitUntilIdle() { +absl::Status Scheduler::WaitUntilIdle() { RET_CHECK_NE(state_, STATE_NOT_STARTED); ApplicationThreadAwait(std::bind(&Scheduler::IsIdle, this)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Scheduler::WaitUntilDone() { +absl::Status Scheduler::WaitUntilDone() { RET_CHECK_NE(state_, STATE_NOT_STARTED); ApplicationThreadAwait([this]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mutex_) { return state_ == STATE_TERMINATED; }); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void Scheduler::ApplicationThreadAwait( @@ -379,7 +378,7 @@ bool Scheduler::TryToScheduleNextSourceLayer() { // If no graph input streams are open, then there are no packet sources in // the graph. It's a deadlock. if (graph_input_streams_closed_) { - graph_->RecordError(mediapipe::UnknownError( + graph_->RecordError(absl::UnknownError( "Detected a deadlock because source nodes cannot be activated when a " "source node at a lower layer is still not opened.")); } @@ -496,7 +495,7 @@ void Scheduler::Cancel() { if (state_ != STATE_RUNNING && state_ != STATE_PAUSED) { return; } - graph_->RecordError(mediapipe::CancelledError()); + graph_->RecordError(absl::CancelledError()); if (state_ == STATE_PAUSED) { // Keep the scheduler queue running, since we need to exhaust it. SetQueuesRunning(true); diff --git a/mediapipe/framework/scheduler.h b/mediapipe/framework/scheduler.h index 6c3050a11..dd1572d99 100644 --- a/mediapipe/framework/scheduler.h +++ b/mediapipe/framework/scheduler.h @@ -56,8 +56,8 @@ class Scheduler { // Sets the executor that will run the nodes assigned to the executor // named |name|. Must be called before the scheduler is started. - mediapipe::Status SetNonDefaultExecutor(const std::string& name, - Executor* executor); + absl::Status SetNonDefaultExecutor(const std::string& name, + Executor* executor); // Resets the data members at the beginning of each graph run. void Reset(); @@ -70,13 +70,13 @@ class Scheduler { // have been closed, and no more calculators can be run). // This function can be called only after Start(). // Runs application thread tasks while waiting. - mediapipe::Status WaitUntilDone() ABSL_LOCKS_EXCLUDED(state_mutex_); + absl::Status WaitUntilDone() ABSL_LOCKS_EXCLUDED(state_mutex_); // Wait until the running graph is in the idle mode, which is when nothing can // be scheduled and nothing is running in the worker threads. This function // can be called only after Start(). // Runs application thread tasks while waiting. - mediapipe::Status WaitUntilIdle() ABSL_LOCKS_EXCLUDED(state_mutex_); + absl::Status WaitUntilIdle() ABSL_LOCKS_EXCLUDED(state_mutex_); // Wait until any graph input stream has been unthrottled. // This is meant to be used by CalculatorGraph::AddPacketToInputStream, which @@ -93,8 +93,8 @@ class Scheduler { // this function returns immediately if an observed packet has already been // emitted since the previous call. This relies on the fact that the calls are // in sequence. Runs application thread tasks while waiting. - // Returns mediapipe::OutOfRangeError if the graph terminated. - mediapipe::Status WaitForObservedOutput() ABSL_LOCKS_EXCLUDED(state_mutex_); + // Returns absl::OutOfRangeError if the graph terminated. + absl::Status WaitForObservedOutput() ABSL_LOCKS_EXCLUDED(state_mutex_); // Callback that is invoked by a node when it wants to be scheduled. // If the node is throttled, the call is ignored. diff --git a/mediapipe/framework/scheduler_queue.cc b/mediapipe/framework/scheduler_queue.cc index f6f92abd8..efad97282 100644 --- a/mediapipe/framework/scheduler_queue.cc +++ b/mediapipe/framework/scheduler_queue.cc @@ -245,8 +245,8 @@ void SchedulerQueue::RunCalculatorNode(CalculatorNode* node, // source node always reuses the same CalculatorContext and Close() doesn't // access any inputs. // TODO: Should we pass tool::StatusStop() in this case? - const mediapipe::Status result = - node->CloseNode(mediapipe::OkStatus(), /*graph_run_ended=*/false); + const absl::Status result = + node->CloseNode(absl::OkStatus(), /*graph_run_ended=*/false); shared_->timer.EndNode(start_time); if (!result.ok()) { VLOG(3) << node->DebugName() @@ -257,7 +257,7 @@ void SchedulerQueue::RunCalculatorNode(CalculatorNode* node, // Note that we don't need a lock because only one thread can execute this // due to the lock on running_nodes. int64 start_time = shared_->timer.StartNode(); - const mediapipe::Status result = node->ProcessNode(cc); + const absl::Status result = node->ProcessNode(cc); shared_->timer.EndNode(start_time); if (!result.ok()) { @@ -284,7 +284,7 @@ void SchedulerQueue::RunCalculatorNode(CalculatorNode* node, void SchedulerQueue::OpenCalculatorNode(CalculatorNode* node) { VLOG(3) << "Opening " << node->DebugName(); int64 start_time = shared_->timer.StartNode(); - const mediapipe::Status result = node->OpenNode(); + const absl::Status result = node->OpenNode(); shared_->timer.EndNode(start_time); if (!result.ok()) { VLOG(3) << node->DebugName() << " had an error!"; diff --git a/mediapipe/framework/scheduler_shared.h b/mediapipe/framework/scheduler_shared.h index adbd6801a..e62d785db 100644 --- a/mediapipe/framework/scheduler_shared.h +++ b/mediapipe/framework/scheduler_shared.h @@ -105,7 +105,7 @@ struct SchedulerShared { // flag indicates that the graph is in that mode. std::atomic stopping; std::atomic has_error; - std::function error_callback; + std::function error_callback; // Collects timing information for measuring overhead. internal::SchedulerTimer timer; }; diff --git a/mediapipe/framework/status_handler.h b/mediapipe/framework/status_handler.h index 225561db2..f1bc29b13 100644 --- a/mediapipe/framework/status_handler.h +++ b/mediapipe/framework/status_handler.h @@ -48,19 +48,19 @@ class StatusHandler { // All subclasses of StatusHandler must implement these static functions with // the following signatures: // - // static mediapipe::Status FillExpectations( + // static absl::Status FillExpectations( // const MediaPipeOptions& extendable_options, // PacketTypeSet* input_side_packets); // - // static mediapipe::Status HandlePreRunStatus( + // static absl::Status HandlePreRunStatus( // const MediaPipeOptions& extendable_options, // const PacketSet& input_side_packets, - // const mediapipe::Status& pre_run_status); + // const absl::Status& pre_run_status); // - // static mediapipe::Status HandleStatus( + // static absl::Status HandleStatus( // const MediaPipeOptions& extendable_options, // const PacketSet& input_side_packets, - // const mediapipe::Status& run_status); + // const absl::Status& run_status); // // FillExpectations() is used to validate the graph and it is analogous to the // function in calculator.h, packet_generator.h, and packet_factory.h. @@ -90,17 +90,16 @@ namespace internal { class StaticAccessToStatusHandler { public: virtual ~StaticAccessToStatusHandler() {} - virtual mediapipe::Status FillExpectations( + virtual absl::Status FillExpectations( const MediaPipeOptions& extendable_options, PacketTypeSet* input_side_packets) = 0; - virtual mediapipe::Status HandlePreRunStatus( + virtual absl::Status HandlePreRunStatus( const MediaPipeOptions& extendable_options, const PacketSet& input_side_packets, - const mediapipe::Status& pre_run_status) = 0; - virtual mediapipe::Status HandleStatus( - const MediaPipeOptions& extendable_options, - const PacketSet& input_side_packets, // - const mediapipe::Status& run_status) = 0; + const absl::Status& pre_run_status) = 0; + virtual absl::Status HandleStatus(const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, // + const absl::Status& run_status) = 0; }; using StaticAccessToStatusHandlerRegistry = @@ -111,7 +110,7 @@ using StaticAccessToStatusHandlerRegistry = template constexpr bool StatusHandlerHasFillExpectations( decltype(&T::FillExpectations) /* unused */) { - typedef mediapipe::Status (*FillExpectationsType)( + typedef absl::Status (*FillExpectationsType)( const MediaPipeOptions& extendable_options, PacketTypeSet* input_side_packets); return std::is_same constexpr bool StatusHandlerHasHandlePreRunStatus( decltype(&T::HandlePreRunStatus) /* unused */) { - typedef mediapipe::Status (*HandlePreRunStatusType)( + typedef absl::Status (*HandlePreRunStatusType)( const MediaPipeOptions& extendable_options, - const PacketSet& input_side_packets, - const mediapipe::Status& pre_run_status); + const PacketSet& input_side_packets, const absl::Status& pre_run_status); return std::is_same::value; } template constexpr bool StatusHandlerHasHandleStatus( decltype(&T::HandleStatus) /* unused */) { - typedef mediapipe::Status (*HandleStatusType)( + typedef absl::Status (*HandleStatusType)( const MediaPipeOptions& extendable_options, - const PacketSet& input_side_packets, const mediapipe::Status& run_status); + const PacketSet& input_side_packets, const absl::Status& run_status); return std::is_same::value; } template @@ -166,23 +164,22 @@ class StaticAccessToStatusHandlerTyped : public StaticAccessToStatusHandler { "HandleStatus() must be defined with the correct signature in " "every StatusHandler."); - mediapipe::Status FillExpectations(const MediaPipeOptions& extendable_options, - PacketTypeSet* input_side_packets) final { + absl::Status FillExpectations(const MediaPipeOptions& extendable_options, + PacketTypeSet* input_side_packets) final { return StatusHandlerSubclass::FillExpectations(extendable_options, input_side_packets); } - mediapipe::Status HandlePreRunStatus( - const MediaPipeOptions& extendable_options, - const PacketSet& input_side_packets, - const mediapipe::Status& pre_run_status) final { + absl::Status HandlePreRunStatus(const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, + const absl::Status& pre_run_status) final { return StatusHandlerSubclass::HandlePreRunStatus( extendable_options, input_side_packets, pre_run_status); } - mediapipe::Status HandleStatus(const MediaPipeOptions& extendable_options, - const PacketSet& input_side_packets, - const mediapipe::Status& run_status) final { + absl::Status HandleStatus(const MediaPipeOptions& extendable_options, + const PacketSet& input_side_packets, + const absl::Status& run_status) final { return StatusHandlerSubclass::HandleStatus(extendable_options, input_side_packets, run_status); } diff --git a/mediapipe/framework/stream_handler/barrier_input_stream_handler.cc b/mediapipe/framework/stream_handler/barrier_input_stream_handler.cc index 5a2f30296..ece873b1e 100644 --- a/mediapipe/framework/stream_handler/barrier_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/barrier_input_stream_handler.cc @@ -37,7 +37,7 @@ class BarrierInputStreamHandler : public InputStreamHandler { std::function headers_ready_callback, std::function notification_callback, std::function schedule_callback, - std::function error_callback) override { + std::function error_callback) override { InputStreamHandler::PrepareForRun( std::move(headers_ready_callback), std::move(notification_callback), std::move(schedule_callback), std::move(error_callback)); diff --git a/mediapipe/framework/stream_handler/barrier_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/barrier_input_stream_handler_test.cc index e60ba9133..9f341ba54 100644 --- a/mediapipe/framework/stream_handler/barrier_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/barrier_input_stream_handler_test.cc @@ -57,7 +57,7 @@ class BarrierInputStreamHandlerTest : public ::testing::Test { std::placeholders::_1, std::placeholders::_2); std::shared_ptr input_tag_map = - tool::CreateTagMap({"input_a", "input_b", "input_c"}).ValueOrDie(); + tool::CreateTagMap({"input_a", "input_b", "input_c"}).value(); input_stream_managers_.reset( new InputStreamManager[input_tag_map->NumEntries()]); @@ -79,16 +79,16 @@ class BarrierInputStreamHandlerTest : public ::testing::Test { nullptr); calculator_context_manager_.Initialize( calculator_state_.get(), input_tag_map, - /*output_tag_map=*/tool::CreateTagMap({"output_a"}).ValueOrDie(), + /*output_tag_map=*/tool::CreateTagMap({"output_a"}).value(), /*calculator_run_in_parallel=*/false); - mediapipe::StatusOr> + absl::StatusOr> status_or_handler = InputStreamHandlerRegistry::CreateByName( "BarrierInputStreamHandler", input_tag_map, &calculator_context_manager_, MediaPipeOptions(), /*calculator_run_in_parallel=*/false); ASSERT_TRUE(status_or_handler.ok()); - input_stream_handler_ = std::move(status_or_handler.ValueOrDie()); + input_stream_handler_ = std::move(status_or_handler.value()); MP_ASSERT_OK(input_stream_handler_->InitializeInputStreamManagers( input_stream_managers_.get())); MP_ASSERT_OK( @@ -109,10 +109,10 @@ class BarrierInputStreamHandlerTest : public ::testing::Test { calculator_context_ = calculator_context; } - void RecordError(const mediapipe::Status& error) { errors_.push_back(error); } + void RecordError(const absl::Status& error) { errors_.push_back(error); } - mediapipe::Status SetupShardsNoOp(CalculatorContext* calculator_context) { - return mediapipe::OkStatus(); + absl::Status SetupShardsNoOp(CalculatorContext* calculator_context) { + return absl::OkStatus(); } void ReportQueueNoOp(InputStreamManager* stream, bool* stream_was_full) {} @@ -121,13 +121,13 @@ class BarrierInputStreamHandlerTest : public ::testing::Test { std::function headers_ready_callback_; std::function notification_callback_; std::function schedule_callback_; - std::function error_callback_; - std::function setup_shards_callback_; + std::function error_callback_; + std::function setup_shards_callback_; InputStreamManager::QueueSizeCallback queue_full_callback_; InputStreamManager::QueueSizeCallback queue_not_full_callback_; // Vector of errors encountered while using the stream. - std::vector errors_; + std::vector errors_; std::unique_ptr calculator_state_; CalculatorContextManager calculator_context_manager_; diff --git a/mediapipe/framework/stream_handler/default_input_stream_handler.cc b/mediapipe/framework/stream_handler/default_input_stream_handler.cc index 4c95241a9..b30ed5bbd 100644 --- a/mediapipe/framework/stream_handler/default_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/default_input_stream_handler.cc @@ -49,7 +49,7 @@ void DefaultInputStreamHandler::PrepareForRun( std::function headers_ready_callback, std::function notification_callback, std::function schedule_callback, - std::function error_callback) { + std::function error_callback) { sync_set_.PrepareForRun(); InputStreamHandler::PrepareForRun( std::move(headers_ready_callback), std::move(notification_callback), diff --git a/mediapipe/framework/stream_handler/default_input_stream_handler.h b/mediapipe/framework/stream_handler/default_input_stream_handler.h index edf41fbea..f3b78f85d 100644 --- a/mediapipe/framework/stream_handler/default_input_stream_handler.h +++ b/mediapipe/framework/stream_handler/default_input_stream_handler.h @@ -36,11 +36,10 @@ class DefaultInputStreamHandler : public InputStreamHandler { protected: // Reinitializes this InputStreamHandler before each CalculatorGraph run. - void PrepareForRun( - std::function headers_ready_callback, - std::function notification_callback, - std::function schedule_callback, - std::function error_callback) override; + void PrepareForRun(std::function headers_ready_callback, + std::function notification_callback, + std::function schedule_callback, + std::function error_callback) override; // In DefaultInputStreamHandler, a node is "ready" if: // - all streams are done (need to call Close() in this case), or diff --git a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc index 075bbef31..903fcb776 100644 --- a/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/fixed_size_input_stream_handler_test.cc @@ -46,17 +46,17 @@ bool g_source_done ABSL_GUARDED_BY(g_source_mutex); class TestSourceCalculator : public CalculatorBase { public: TestSourceCalculator() : current_packet_id_(0) {} - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { absl::MutexLock lock(&g_source_mutex); g_source_counter = 0; g_source_done = false; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (current_packet_id_ == kMaxPacketId) { absl::MutexLock lock(&g_source_mutex); g_source_done = true; @@ -70,7 +70,7 @@ class TestSourceCalculator : public CalculatorBase { g_source_mutex.Await( absl::Condition(this, &TestSourceCalculator::CanProceed)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -86,17 +86,17 @@ REGISTER_CALCULATOR(TestSourceCalculator); class TestSlowCalculator : public CalculatorBase { public: TestSlowCalculator() = default; - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { absl::MutexLock lock(&g_source_mutex); g_slow_counter = 0; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { cc->Outputs().Index(0).Add(new int64(0), cc->Inputs().Index(0).Value().Timestamp()); { @@ -105,7 +105,7 @@ class TestSlowCalculator : public CalculatorBase { g_source_mutex.Await( absl::Condition(this, &TestSlowCalculator::CanProceed)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc index 4bcaff59c..c34fc96b3 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler.cc @@ -41,11 +41,10 @@ class ImmediateInputStreamHandler : public InputStreamHandler { protected: // Reinitializes this InputStreamHandler before each CalculatorGraph run. - void PrepareForRun( - std::function headers_ready_callback, - std::function notification_callback, - std::function schedule_callback, - std::function error_callback) override; + void PrepareForRun(std::function headers_ready_callback, + std::function notification_callback, + std::function schedule_callback, + std::function error_callback) override; // Returns kReadyForProcess whenever a Packet is available at any of // the input streams, or any input stream becomes done. @@ -83,7 +82,7 @@ void ImmediateInputStreamHandler::PrepareForRun( std::function headers_ready_callback, std::function notification_callback, std::function schedule_callback, - std::function error_callback) { + std::function error_callback) { { absl::MutexLock lock(&mutex_); for (int i = 0; i < sync_sets_.size(); ++i) { diff --git a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc index 3bcc1358d..e721afb02 100644 --- a/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/immediate_input_stream_handler_test.cc @@ -57,7 +57,7 @@ class ImmediateInputStreamHandlerTest : public ::testing::Test { std::placeholders::_1, std::placeholders::_2); std::shared_ptr input_tag_map = - tool::CreateTagMap({"input_a", "input_b", "input_c"}).ValueOrDie(); + tool::CreateTagMap({"input_a", "input_b", "input_c"}).value(); input_stream_managers_.reset( new InputStreamManager[input_tag_map->NumEntries()]); @@ -79,16 +79,16 @@ class ImmediateInputStreamHandlerTest : public ::testing::Test { nullptr); cc_manager_.Initialize( calculator_state_.get(), input_tag_map, - /*output_tag_map=*/tool::CreateTagMap({"output_a"}).ValueOrDie(), + /*output_tag_map=*/tool::CreateTagMap({"output_a"}).value(), /*calculator_run_in_parallel=*/false); - mediapipe::StatusOr> + absl::StatusOr> status_or_handler = InputStreamHandlerRegistry::CreateByName( "ImmediateInputStreamHandler", input_tag_map, &cc_manager_, MediaPipeOptions(), /*calculator_run_in_parallel=*/false); ASSERT_TRUE(status_or_handler.ok()); - input_stream_handler_ = std::move(status_or_handler.ValueOrDie()); + input_stream_handler_ = std::move(status_or_handler.value()); MP_ASSERT_OK(input_stream_handler_->InitializeInputStreamManagers( input_stream_managers_.get())); MP_ASSERT_OK(cc_manager_.PrepareForRun(setup_shards_callback_)); @@ -108,10 +108,10 @@ class ImmediateInputStreamHandlerTest : public ::testing::Test { cc_ = cc; } - void RecordError(const mediapipe::Status& error) { errors_.push_back(error); } + void RecordError(const absl::Status& error) { errors_.push_back(error); } - mediapipe::Status SetupShardsNoOp(CalculatorContext* calculator_context) { - return mediapipe::OkStatus(); + absl::Status SetupShardsNoOp(CalculatorContext* calculator_context) { + return absl::OkStatus(); } void ReportQueueNoOp(InputStreamManager* stream, bool* stream_was_full) {} @@ -140,13 +140,13 @@ class ImmediateInputStreamHandlerTest : public ::testing::Test { std::function headers_ready_callback_; std::function notification_callback_; std::function schedule_callback_; - std::function error_callback_; - std::function setup_shards_callback_; + std::function error_callback_; + std::function setup_shards_callback_; InputStreamManager::QueueSizeCallback queue_full_callback_; InputStreamManager::QueueSizeCallback queue_not_full_callback_; // Vector of errors encountered while using the stream. - std::vector errors_; + std::vector errors_; std::unique_ptr calculator_state_; CalculatorContextManager cc_manager_; diff --git a/mediapipe/framework/stream_handler/set_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/set_input_stream_handler_test.cc index ff01bc92e..c6866c294 100644 --- a/mediapipe/framework/stream_handler/set_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/set_input_stream_handler_test.cc @@ -108,9 +108,9 @@ TEST(MuxInputStreamHandlerTest, AtomicAccessToControlAndDataStreams) { // ignored. class FixedPassThroughCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { if (!cc->Inputs().TagMap()->SameAs(*cc->Outputs().TagMap())) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Input and output streams to PassThroughCalculator must use " "matching tags and indexes."); } @@ -126,7 +126,7 @@ class FixedPassThroughCalculator : public CalculatorBase { if (cc->OutputSidePackets().NumEntries() != 0) { if (!cc->InputSidePackets().TagMap()->SameAs( *cc->OutputSidePackets().TagMap())) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Input and output side packets to PassThroughCalculator must use " "matching tags and indexes."); } @@ -148,10 +148,10 @@ class FixedPassThroughCalculator : public CalculatorBase { ->set_target_queue_size(2); cc->SetInputStreamHandlerOptions(options); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); ++id) { if (!cc->Inputs().Get(id).Header().IsEmpty()) { @@ -165,10 +165,10 @@ class FixedPassThroughCalculator : public CalculatorBase { } } cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { cc->GetCounter("PassThrough")->Increment(); if (cc->Inputs().NumEntries() == 0) { return tool::StatusStop(); @@ -182,7 +182,7 @@ class FixedPassThroughCalculator : public CalculatorBase { cc->Outputs().Get(id).AddPacket(cc->Inputs().Get(id).Value()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(FixedPassThroughCalculator); diff --git a/mediapipe/framework/stream_handler/sync_set_input_stream_handler.cc b/mediapipe/framework/stream_handler/sync_set_input_stream_handler.cc index d597ab5db..1001d64f7 100644 --- a/mediapipe/framework/stream_handler/sync_set_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/sync_set_input_stream_handler.cc @@ -44,11 +44,10 @@ class SyncSetInputStreamHandler : public InputStreamHandler { const MediaPipeOptions& extendable_options, bool calculator_run_in_parallel); - void PrepareForRun( - std::function headers_ready_callback, - std::function notification_callback, - std::function schedule_callback, - std::function error_callback) override; + void PrepareForRun(std::function headers_ready_callback, + std::function notification_callback, + std::function schedule_callback, + std::function error_callback) override; protected: // In SyncSetInputStreamHandler, a node is "ready" if any @@ -94,7 +93,7 @@ void SyncSetInputStreamHandler::PrepareForRun( std::function headers_ready_callback, std::function notification_callback, std::function schedule_callback, - std::function error_callback) { + std::function error_callback) { const auto& handler_options = options_.GetExtension(SyncSetInputStreamHandlerOptions::ext); { diff --git a/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc b/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc index cd3379e6a..8c1f9b2f5 100644 --- a/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc +++ b/mediapipe/framework/stream_handler/sync_set_input_stream_handler_test.cc @@ -36,8 +36,8 @@ namespace mediapipe { namespace { // The type LambdaCalculator takes. -typedef std::function +typedef std::function ProcessFunction; // Helper function to create a tuple (inside an initializer list). @@ -50,8 +50,8 @@ std::tuple> CommandTuple( // Function to take the inputs and produce a diagnostic output std::string // and output a packet with a diagnostic output std::string which includes // the input timestamp and the ids of each input which is present. -mediapipe::Status InputsToDebugString(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { +absl::Status InputsToDebugString(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { std::string output; Timestamp output_timestamp; for (CollectionItemId id = inputs.BeginId(); id < inputs.EndId(); ++id) { @@ -79,7 +79,7 @@ mediapipe::Status InputsToDebugString(const InputStreamShardSet& inputs, // TODO Output at output_timestamp once unordered output stream // handlers are allowed. outputs->Index(0).AddPacket(output_packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); } TEST(SyncSetInputStreamHandlerTest, OrdinaryOperation) { @@ -273,7 +273,7 @@ TEST(SyncSetInputStreamHandlerTest, OrdinaryOperation) { MP_ASSERT_OK( graph.ObserveOutputStream("output", [&outputs](const Packet& packet) { outputs.push_back(packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); MP_ASSERT_OK(graph.StartRun({})); for (int command_index = 0; command_index < shuffled_commands.size(); diff --git a/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.cc b/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.cc index bcc180d6c..ae075d788 100644 --- a/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.cc +++ b/mediapipe/framework/stream_handler/timestamp_align_input_stream_handler.cc @@ -52,11 +52,10 @@ class TimestampAlignInputStreamHandler : public InputStreamHandler { const MediaPipeOptions& options, bool calculator_run_in_parallel); - void PrepareForRun( - std::function headers_ready_callback, - std::function notification_callback, - std::function schedule_callback, - std::function error_callback) override; + void PrepareForRun(std::function headers_ready_callback, + std::function notification_callback, + std::function schedule_callback, + std::function error_callback) override; protected: // In TimestampAlignInputStreamHandler, a node is "ready" if: @@ -107,7 +106,7 @@ void TimestampAlignInputStreamHandler::PrepareForRun( std::function headers_ready_callback, std::function notification_callback, std::function schedule_callback, - std::function error_callback) { + std::function error_callback) { { absl::MutexLock lock(&mutex_); offsets_initialized_ = (input_stream_managers_.NumEntries() == 1); diff --git a/mediapipe/framework/subgraph.cc b/mediapipe/framework/subgraph.cc index 231aeb6bc..5ca99aec3 100644 --- a/mediapipe/framework/subgraph.cc +++ b/mediapipe/framework/subgraph.cc @@ -32,7 +32,7 @@ ProtoSubgraph::ProtoSubgraph(const CalculatorGraphConfig& config) ProtoSubgraph::~ProtoSubgraph() {} -mediapipe::StatusOr ProtoSubgraph::GetConfig( +absl::StatusOr ProtoSubgraph::GetConfig( const Subgraph::SubgraphOptions& options) { return config_; } @@ -42,7 +42,7 @@ TemplateSubgraph::TemplateSubgraph(const CalculatorGraphTemplate& templ) TemplateSubgraph::~TemplateSubgraph() {} -mediapipe::StatusOr TemplateSubgraph::GetConfig( +absl::StatusOr TemplateSubgraph::GetConfig( const Subgraph::SubgraphOptions& options) { TemplateDict arguments = Subgraph::GetOptions(options).dict(); @@ -91,19 +91,19 @@ bool GraphRegistry::IsRegistered(const std::string& ns, global_factories_->IsRegistered(ns, type_name); } -mediapipe::StatusOr GraphRegistry::CreateByName( +absl::StatusOr GraphRegistry::CreateByName( const std::string& ns, const std::string& type_name, const Subgraph::SubgraphOptions* options) const { Subgraph::SubgraphOptions graph_options; if (options) { graph_options = *options; } - mediapipe::StatusOr> maker = + absl::StatusOr> maker = local_factories_.IsRegistered(ns, type_name) ? local_factories_.Invoke(ns, type_name) : global_factories_->Invoke(ns, type_name); MP_RETURN_IF_ERROR(maker.status()); - return maker.ValueOrDie()->GetConfig(graph_options); + return maker.value()->GetConfig(graph_options); } } // namespace mediapipe diff --git a/mediapipe/framework/subgraph.h b/mediapipe/framework/subgraph.h index 8529cbad3..64ebc313c 100644 --- a/mediapipe/framework/subgraph.h +++ b/mediapipe/framework/subgraph.h @@ -41,7 +41,7 @@ class Subgraph { // the parent graph. // Subclasses may use the options argument to parameterize the config. // TODO: make this static? - virtual mediapipe::StatusOr GetConfig( + virtual absl::StatusOr GetConfig( const SubgraphOptions& options) = 0; // Returns options of a specific type. @@ -71,7 +71,7 @@ class ProtoSubgraph : public Subgraph { public: ProtoSubgraph(const CalculatorGraphConfig& config); virtual ~ProtoSubgraph(); - virtual mediapipe::StatusOr GetConfig( + virtual absl::StatusOr GetConfig( const Subgraph::SubgraphOptions& options); private: @@ -83,7 +83,7 @@ class TemplateSubgraph : public Subgraph { public: TemplateSubgraph(const CalculatorGraphTemplate& templ); virtual ~TemplateSubgraph(); - virtual mediapipe::StatusOr GetConfig( + virtual absl::StatusOr GetConfig( const Subgraph::SubgraphOptions& options); private: @@ -118,7 +118,7 @@ class GraphRegistry { bool IsRegistered(const std::string& ns, const std::string& type_name) const; // Returns the specified graph config. - mediapipe::StatusOr CreateByName( + absl::StatusOr CreateByName( const std::string& ns, const std::string& type_name, const Subgraph::SubgraphOptions* options = nullptr) const; diff --git a/mediapipe/framework/test_calculators.cc b/mediapipe/framework/test_calculators.cc index 4addce93b..92f15e3a5 100644 --- a/mediapipe/framework/test_calculators.cc +++ b/mediapipe/framework/test_calculators.cc @@ -39,21 +39,21 @@ using RandomEngine = std::mt19937_64; // A Calculator that outputs twice the value of its input packet (an int). class DoubleIntCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int value = cc->Inputs().Index(0).Value().Get(); cc->Outputs().Index(0).Add(new int(2 * value), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(DoubleIntCalculator); @@ -62,16 +62,16 @@ REGISTER_CALCULATOR(DoubleIntCalculator); // holds the high order bits and the second the low order ones. class IntSplitterPacketGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, // PacketTypeSet* input_side_packets, // PacketTypeSet* output_side_packets) { input_side_packets->Index(0).Set(); output_side_packets->Index(0).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( + static absl::Status Generate( const PacketGeneratorOptions& extendable_options, // const PacketSet& input_side_packets, // PacketSet* output_side_packets) { @@ -80,7 +80,7 @@ class IntSplitterPacketGenerator : public PacketGenerator { uint32 low = value & 0xFFFFFFFF; output_side_packets->Index(0) = Adopt(new std::pair(high, low)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(IntSplitterPacketGenerator); @@ -90,7 +90,7 @@ REGISTER_PACKET_GENERATOR(IntSplitterPacketGenerator); // with both the high and low order bits. class TaggedIntSplitterPacketGenerator : public PacketGenerator { public: - static mediapipe::Status FillExpectations( + static absl::Status FillExpectations( const PacketGeneratorOptions& extendable_options, // PacketTypeSet* input_side_packets, // PacketTypeSet* output_side_packets) { @@ -98,10 +98,10 @@ class TaggedIntSplitterPacketGenerator : public PacketGenerator { output_side_packets->Tag("HIGH").Set(); output_side_packets->Tag("LOW").Set(); output_side_packets->Tag("PAIR").Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status Generate( + static absl::Status Generate( const PacketGeneratorOptions& extendable_options, // const PacketSet& input_side_packets, // PacketSet* output_side_packets) { @@ -112,7 +112,7 @@ class TaggedIntSplitterPacketGenerator : public PacketGenerator { output_side_packets->Tag("LOW") = Adopt(new uint32(low)); output_side_packets->Tag("PAIR") = Adopt(new std::pair(high, low)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_PACKET_GENERATOR(TaggedIntSplitterPacketGenerator); @@ -129,22 +129,22 @@ class RangeCalculator : public CalculatorBase { public: RangeCalculator() : initialized_(false) {} - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); cc->Outputs().Index(1).Set(); cc->Outputs().Index(2).Set(); cc->InputSidePackets().Index(0).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { Initialize(cc); // Fail if requested, without setting any stream headers. This tests that // the downstream Calculators will not try to access the headers in case // this one failed. if (k_ == 0) { - return mediapipe::Status(mediapipe::StatusCode::kCancelled, "k_ == 0"); + return absl::Status(absl::StatusCode::kCancelled, "k_ == 0"); } cc->Outputs().Index(0).SetHeader( Adopt(new std::string(absl::StrCat(cc->CalculatorType(), k_)))); @@ -154,21 +154,21 @@ class RangeCalculator : public CalculatorBase { cc->Outputs().Index(1).SetNextTimestampBound(Timestamp::PostStream()); cc->Outputs().Index(2).SetNextTimestampBound(Timestamp::PreStream()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { // Output at timestamps 1:N-1 that are divisible by K. index_ += k_; if (index_ < n_) { cc->Outputs().Index(0).AddPacket(GetNextPacket().At(Timestamp(index_))); - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { return tool::StatusStop(); } } - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { // Output at timestamp N. cc->Outputs().Index(0).AddPacket(GetNextPacket().At(Timestamp(n_))); // Output: ints from a range specified in the input side packet. @@ -177,7 +177,7 @@ class RangeCalculator : public CalculatorBase { new double(static_cast(total_) / static_cast(count_)), Timestamp::PreStream()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -220,19 +220,19 @@ class StdDevCalculator : public CalculatorBase { public: StdDevCalculator() {} - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag("DATA").Set(); cc->Inputs().Tag("MEAN").Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).SetNextTimestampBound(Timestamp::PostStream()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { if (cc->InputTimestamp() == Timestamp::PreStream()) { RET_CHECK(cc->Inputs().Tag("DATA").Value().IsEmpty()); RET_CHECK(!cc->Inputs().Tag("MEAN").Value().IsEmpty()); @@ -246,15 +246,15 @@ class StdDevCalculator : public CalculatorBase { cummulative_variance_ += diff * diff; ++count_; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { cc->Outputs().Index(0).Add( new int(mediapipe::MathUtil::SafeRound( sqrt(cummulative_variance_ / count_) * 100.0)), Timestamp::PostStream()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -272,15 +272,15 @@ REGISTER_CALCULATOR(StdDevCalculator); // concatenation of the input stream headers. class MergeCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { cc->Inputs().Index(i).Set(); } cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { auto header = absl::make_unique(); for (auto& input : cc->Inputs()) { if (!input.Header().IsEmpty()) { @@ -291,10 +291,10 @@ class MergeCalculator : public CalculatorBase { } } cc->Outputs().Index(0).SetHeader(Adopt(header.release())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { std::string result; if (cc->InputTimestamp().IsSpecialValue()) { absl::StrAppend(&result, cc->InputTimestamp().DebugString()); @@ -312,7 +312,7 @@ class MergeCalculator : public CalculatorBase { } } cc->Outputs().Index(0).Add(new std::string(result), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(MergeCalculator); @@ -324,28 +324,28 @@ class SaverCalculator : public CalculatorBase { public: SaverCalculator() : result_(new std::string) {} - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).SetNextTimestampBound(Timestamp::PostStream()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { if (!result_->empty()) { result_->append("/"); } result_->append(cc->Inputs().Index(0).Get()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { cc->Outputs().Index(0).Add(result_.release(), Timestamp::PostStream()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -359,13 +359,13 @@ REGISTER_CALCULATOR(SaverCalculator); // as an input side packet. class RandomMatrixCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Outputs().Index(0).Set(); cc->InputSidePackets().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { auto& options = cc->Options(); CHECK_LT(0, options.timestamp_step()); CHECK_LT(0, options.rows()); @@ -380,10 +380,10 @@ class RandomMatrixCalculator : public CalculatorBase { std::vector seed(1); seq.generate(seed.begin(), seed.end()); random_ = absl::make_unique(seed[0]); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { auto& options = cc->Options(); Matrix* matrix = new Matrix(options.rows(), options.cols()); @@ -398,7 +398,7 @@ class RandomMatrixCalculator : public CalculatorBase { if (current_timestamp_ >= Timestamp(options.limit_timestamp())) { return tool::StatusStop(); } else { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } @@ -417,21 +417,21 @@ REGISTER_CALCULATOR(RandomMatrixCalculator); // effect of round off error). class MeanAndCovarianceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->Outputs().Index(0).SetNextTimestampBound(Timestamp::PostStream()); rows_ = -1; num_samples_ = 0; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { const Eigen::MatrixXd sample = cc->Inputs().Index(0).Get().cast(); CHECK_EQ(1, sample.cols()); @@ -446,10 +446,10 @@ class MeanAndCovarianceCalculator : public CalculatorBase { outer_product_sum_ += sample * sample.transpose(); ++num_samples_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) override { + absl::Status Close(CalculatorContext* cc) override { Eigen::VectorXd mean_vector = sum_vector_ / num_samples_; Eigen::MatrixXd covariance_matrix(rows_, rows_); @@ -469,7 +469,7 @@ class MeanAndCovarianceCalculator : public CalculatorBase { new std::pair( mean_vector.cast(), covariance_matrix.cast()), Timestamp::PostStream()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -485,27 +485,27 @@ REGISTER_CALCULATOR(MeanAndCovarianceCalculator); // increases by 1 for each packet. class SidePacketToOutputPacketCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Index(0).SetAny(); for (int i = 1; i < cc->InputSidePackets().NumEntries(); ++i) { cc->InputSidePackets().Index(i).SetSameAs( &cc->InputSidePackets().Index(0)); } cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { int current_timestamp = 0; for (const Packet& packet : cc->InputSidePackets()) { cc->Outputs().Index(0).AddPacket(packet.At(Timestamp(current_timestamp))); ++current_timestamp; } cc->Outputs().Index(0).Close(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { return tool::StatusStop(); } }; @@ -516,46 +516,46 @@ REGISTER_CALCULATOR(SidePacketToOutputPacketCalculator); class ABSL_DEPRECATED("Use SidePacketToOutputPacketCalculator instead") ExternalInputToOutputPacketCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Index(0).SetAny(); for (int i = 1; i < cc->InputSidePackets().NumEntries(); ++i) { cc->InputSidePackets().Index(i).SetSameAs( &cc->InputSidePackets().Index(0)); } cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { int current_timestamp = 0; for (const Packet& packet : cc->InputSidePackets()) { cc->Outputs().Index(0).AddPacket(packet.At(Timestamp(current_timestamp))); ++current_timestamp; } cc->Outputs().Index(0).Close(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { return tool::StatusStop(); } }; REGISTER_CALCULATOR(ExternalInputToOutputPacketCalculator); // A Calculator::Process callback function. -typedef std::function +typedef std::function ProcessFunction; // A callback function for Calculator::Open, Process, or Close. -typedef std::function +typedef std::function CalculatorContextFunction; // A Calculator that runs a testing callback function in Process, // Open, or Close, which is specified as an input side packet. class LambdaCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); ++id) { cc->Inputs().Get(id).SetAny(); @@ -572,31 +572,31 @@ class LambdaCalculator : public CalculatorBase { cc->InputSidePackets().Tag(tag).Set(); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { if (cc->InputSidePackets().HasTag("OPEN")) { return GetContextFn(cc, "OPEN")(cc); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { if (cc->InputSidePackets().HasTag("PROCESS")) { return GetContextFn(cc, "PROCESS")(cc); } if (cc->InputSidePackets().HasTag("") > 0) { return GetProcessFn(cc, "")(cc->Inputs(), &cc->Outputs()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { if (cc->InputSidePackets().HasTag("CLOSE")) { return GetContextFn(cc, "CLOSE")(cc); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -618,7 +618,7 @@ REGISTER_CALCULATOR(LambdaCalculator); // stream connections. class DummyTestCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (CollectionItemId id = cc->Inputs().BeginId(); id < cc->Inputs().EndId(); ++id) { cc->Inputs().Get(id).SetAny(); @@ -631,12 +631,10 @@ class DummyTestCalculator : public CalculatorBase { id < cc->InputSidePackets().EndId(); ++id) { cc->InputSidePackets().Get(id).SetAny(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { - return mediapipe::OkStatus(); - } + absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); } }; REGISTER_CALCULATOR(DummyTestCalculator); @@ -644,27 +642,27 @@ REGISTER_CALCULATOR(DummyTestCalculator); // a set number of microseconds. class PassThroughWithSleepCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); cc->InputSidePackets().Tag("SLEEP_MICROS").Set(); cc->InputSidePackets().Tag("CLOCK").Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->SetOffset(TimestampDiff(0)); sleep_micros_ = cc->InputSidePackets().Tag("SLEEP_MICROS").Get(); if (sleep_micros_ < 0) { - return mediapipe::InternalError("SLEEP_MICROS should be >= 0"); + return absl::InternalError("SLEEP_MICROS should be >= 0"); } clock_ = cc->InputSidePackets().Tag("CLOCK").Get>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { clock_->Sleep(absl::Microseconds(sleep_micros_)); int value = cc->Inputs().Index(0).Value().Get(); cc->Outputs().Index(0).Add(new int(value), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -676,25 +674,74 @@ REGISTER_CALCULATOR(PassThroughWithSleepCalculator); // A Calculator that multiples two input values. class MultiplyIntCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Inputs().Index(1).SetSameAs(&cc->Inputs().Index(0)); // cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); RET_CHECK(cc->Outputs().HasTag("OUT")); cc->Outputs().Tag("OUT").SetSameAs(&cc->Inputs().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int x = cc->Inputs().Index(0).Value().Get(); int y = cc->Inputs().Index(1).Value().Get(); cc->Outputs().Tag("OUT").Add(new int(x * y), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(MultiplyIntCalculator); +// A calculator that forwards nested input packets to the output stream if they +// are not empty, otherwise it transforms them into timestamp bound updates. +class ForwardNestedPacketOrEmitBoundUpdateCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).SetAny(); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + const auto& nested_packet = cc->Inputs().Index(0).Get(); + if (!nested_packet.IsEmpty()) { + cc->Outputs().Index(0).AddPacket(nested_packet); + } else { + // Add 1 so that Process() of the downstream calculator is called with + // exactly this timestamp. + cc->Outputs().Index(0).SetNextTimestampBound(nested_packet.Timestamp() + + 1); + } + return absl::OkStatus(); + } +}; +REGISTER_CALCULATOR(ForwardNestedPacketOrEmitBoundUpdateCalculator); + +// A calculator that outputs timestamp bound updates emitted by the upstream +// calculator. +class TimestampBoundReceiverCalculator : public CalculatorBase { + public: + static absl::Status GetContract(CalculatorContract* cc) { + cc->Inputs().Index(0).SetAny(); + cc->Outputs().Index(0).Set(); + cc->SetProcessTimestampBounds(true); + return absl::OkStatus(); + } + + absl::Status Process(CalculatorContext* cc) override { + if (cc->Inputs().Index(0).IsEmpty()) { + // Add 1 to get the exact value that was passed to SetNextTimestampBound() + // in the upstream calculator. + const Timestamp bound = cc->InputTimestamp() + 1; + cc->Outputs().Index(0).AddPacket( + mediapipe::MakePacket(bound).At(bound)); + } + return absl::OkStatus(); + } +}; +REGISTER_CALCULATOR(TimestampBoundReceiverCalculator); + } // namespace mediapipe diff --git a/mediapipe/framework/test_service.cc b/mediapipe/framework/test_service.cc index a5139e8b2..79bbc4340 100644 --- a/mediapipe/framework/test_service.cc +++ b/mediapipe/framework/test_service.cc @@ -19,26 +19,26 @@ namespace mediapipe { const GraphService kTestService("test_service"); const GraphService kAnotherService("another_service"); -mediapipe::Status TestServiceCalculator::GetContract(CalculatorContract* cc) { +absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)); // This service will be required. The graph won't start without it. cc->UseService(kTestService); // This service is optional for this calculator. cc->UseService(kAnotherService).Optional(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TestServiceCalculator::Open(CalculatorContext* cc) { +absl::Status TestServiceCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); // For an optional service, check whether it's available. if (cc->Service(kAnotherService).IsAvailable()) { optional_bias_ = cc->Service(kAnotherService).GetObject(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TestServiceCalculator::Process(CalculatorContext* cc) { +absl::Status TestServiceCalculator::Process(CalculatorContext* cc) { int value = cc->Inputs().Index(0).Value().Get(); // A required service is sure to be available, so we can just GetObject. TestServiceObject& service_object = cc->Service(kTestService).GetObject(); @@ -46,7 +46,7 @@ mediapipe::Status TestServiceCalculator::Process(CalculatorContext* cc) { service_object["count"] += 1; int x = value + delta + optional_bias_; cc->Outputs().Index(0).Add(new int(x), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(TestServiceCalculator); diff --git a/mediapipe/framework/test_service.h b/mediapipe/framework/test_service.h index 4d76f84e5..e726f7c15 100644 --- a/mediapipe/framework/test_service.h +++ b/mediapipe/framework/test_service.h @@ -27,9 +27,9 @@ extern const GraphService kAnotherService; // Use a service. class TestServiceCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) final; - mediapipe::Status Process(CalculatorContext* cc) final; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) final; + absl::Status Process(CalculatorContext* cc) final; private: int optional_bias_ = 0; diff --git a/mediapipe/framework/thread_pool_executor.cc b/mediapipe/framework/thread_pool_executor.cc index d7c3a7886..27dce0aeb 100644 --- a/mediapipe/framework/thread_pool_executor.cc +++ b/mediapipe/framework/thread_pool_executor.cc @@ -25,12 +25,12 @@ namespace mediapipe { // static -mediapipe::StatusOr ThreadPoolExecutor::Create( +absl::StatusOr ThreadPoolExecutor::Create( const MediaPipeOptions& extendable_options) { auto& options = extendable_options.GetExtension(ThreadPoolExecutorOptions::ext); if (!options.has_num_threads()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "num_threads is not specified in ThreadPoolExecutorOptions."); } if (options.num_threads() <= 0) { diff --git a/mediapipe/framework/thread_pool_executor.h b/mediapipe/framework/thread_pool_executor.h index a7700c4bb..7f9b48ea1 100644 --- a/mediapipe/framework/thread_pool_executor.h +++ b/mediapipe/framework/thread_pool_executor.h @@ -25,7 +25,7 @@ namespace mediapipe { // A multithreaded executor based on a thread pool. class ThreadPoolExecutor : public Executor { public: - static mediapipe::StatusOr Create( + static absl::StatusOr Create( const MediaPipeOptions& extendable_options); explicit ThreadPoolExecutor(int num_threads); diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index 4aee33f5e..991814515 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -96,6 +96,7 @@ cc_library( ":validate_name", "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/port:map_util", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -128,16 +129,24 @@ cc_test( cc_library( name = "options_util", + srcs = ["options_util.cc"], hdrs = ["options_util.h"], visibility = ["//mediapipe/framework:mediapipe_internal"], deps = [ + ":proto_util_lite", "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:collection", "//mediapipe/framework:input_stream_shard", + "//mediapipe/framework:output_side_packet", "//mediapipe/framework:packet", - "//mediapipe/framework:packet_generator_cc_proto", "//mediapipe/framework:packet_set", + "//mediapipe/framework:packet_type", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:any_proto", + "//mediapipe/framework/port:status", "//mediapipe/framework/tool:type_util", + "@com_google_absl//absl/strings", ], ) @@ -280,7 +289,6 @@ cc_library( cc_library( name = "tag_map_helper", - testonly = 1, srcs = ["tag_map_helper.cc"], hdrs = ["tag_map_helper.h"], visibility = ["//visibility:public"], @@ -578,26 +586,16 @@ cc_library( deps = [ "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:advanced_proto", + "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:logging", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "//mediapipe/framework/deps:no_destructor", - "//mediapipe/framework/port:logging", - ] + select({ - "//conditions:default": [ - "//mediapipe/framework/port:file_helpers", - ], - "//mediapipe:android": [ - "//mediapipe/util/android/file/base", - ], - "//mediapipe:ios": [], - "//mediapipe:macos": [ - "//mediapipe/framework/port:file_helpers", - ], - }), + ], ) cc_library( diff --git a/mediapipe/framework/tool/container_util.cc b/mediapipe/framework/tool/container_util.cc index 9939ef277..a40db96e9 100644 --- a/mediapipe/framework/tool/container_util.cc +++ b/mediapipe/framework/tool/container_util.cc @@ -12,7 +12,7 @@ std::string ChannelTag(const std::string& tag, int channel) { // Parses a tag name starting with a channel prefix, like "C2__". bool ParseChannelTag(const std::string& channel_name, std::string* name, std::string* num) { - int pos = channel_name.find("C"); + int pos = channel_name.find('C'); int sep = channel_name.find("__"); if (pos != 0 || sep == std::string::npos) { return false; diff --git a/mediapipe/framework/tool/fill_packet_set.cc b/mediapipe/framework/tool/fill_packet_set.cc index 365f238c4..05f5b4e5d 100644 --- a/mediapipe/framework/tool/fill_packet_set.cc +++ b/mediapipe/framework/tool/fill_packet_set.cc @@ -25,14 +25,14 @@ namespace mediapipe { namespace tool { -mediapipe::StatusOr> FillPacketSet( +absl::StatusOr> FillPacketSet( const PacketTypeSet& input_side_packet_types, const std::map& input_side_packets, int* missing_packet_count_ptr) { if (missing_packet_count_ptr != nullptr) { *missing_packet_count_ptr = 0; } - std::vector errors; + std::vector errors; auto packet_set = absl::make_unique(input_side_packet_types.TagMap()); const auto& names = input_side_packet_types.TagMap()->Names(); @@ -51,7 +51,7 @@ mediapipe::StatusOr> FillPacketSet( } packet_set->Get(id) = iter->second; // Check the type. - mediapipe::Status status = + absl::Status status = input_side_packet_types.Get(id).Validate(iter->second); if (!status.ok()) { std::pair tag_index = diff --git a/mediapipe/framework/tool/fill_packet_set.h b/mediapipe/framework/tool/fill_packet_set.h index a2cd72130..a7dbf0698 100644 --- a/mediapipe/framework/tool/fill_packet_set.h +++ b/mediapipe/framework/tool/fill_packet_set.h @@ -32,7 +32,7 @@ namespace tool { // missing_packet_count_ptr is not null, the number of missing packets // is returned in *missing_packet_count_ptr. Otherwise, an error is // returned if any packets are missing. -mediapipe::StatusOr> FillPacketSet( +absl::StatusOr> FillPacketSet( const PacketTypeSet& input_side_packet_types, const std::map& input_side_packets, int* missing_packet_count_ptr); diff --git a/mediapipe/framework/tool/fill_packet_set_test.cc b/mediapipe/framework/tool/fill_packet_set_test.cc index ae399831e..a5707fabb 100644 --- a/mediapipe/framework/tool/fill_packet_set_test.cc +++ b/mediapipe/framework/tool/fill_packet_set_test.cc @@ -31,7 +31,7 @@ TEST(FillPacketSetTest, Success) { node.add_input_side_packet("DOUBLE:1:side_packet4"); PacketTypeSet input_side_packet_types( - tool::TagMap::Create(node.input_side_packet()).ValueOrDie()); + tool::TagMap::Create(node.input_side_packet()).value()); input_side_packet_types.Index(0).Set( // An age ); @@ -57,7 +57,7 @@ TEST(FillPacketSetTest, Success) { std::unique_ptr input_side_packets = tool::FillPacketSet(input_side_packet_types, all_side_packets, nullptr) - .ValueOrDie(); + .value(); ASSERT_EQ(4, input_side_packets->NumEntries()); EXPECT_EQ(input_side_packets->Index(0).Get(), 70); EXPECT_EQ(input_side_packets->Index(1).Get(), "Dennis Ritchie"); @@ -73,7 +73,7 @@ TEST(FillPacketSetTest, MissingSidePacketError) { node.add_input_side_packet("DOUBLE:1:side_packet4"); PacketTypeSet input_side_packet_types( - tool::TagMap::Create(node.input_side_packet()).ValueOrDie()); + tool::TagMap::Create(node.input_side_packet()).value()); input_side_packet_types.Index(0).Set( // An age ); @@ -111,7 +111,7 @@ TEST(FillPacketSetTest, MissingSidePacketOk) { node.add_input_side_packet("DOUBLE:1:side_packet4"); PacketTypeSet input_side_packet_types( - tool::TagMap::Create(node.input_side_packet()).ValueOrDie()); + tool::TagMap::Create(node.input_side_packet()).value()); input_side_packet_types.Index(0).Set( // An age ); @@ -138,7 +138,7 @@ TEST(FillPacketSetTest, MissingSidePacketOk) { std::unique_ptr input_side_packets = tool::FillPacketSet(input_side_packet_types, all_side_packets, &missing_packet_count) - .ValueOrDie(); + .value(); ASSERT_EQ(4, input_side_packets->NumEntries()); EXPECT_EQ(1, missing_packet_count); EXPECT_EQ(input_side_packets->Index(0).Get(), 70); @@ -155,7 +155,7 @@ TEST(FillPacketSetTest, WrongSidePacketType) { node.add_input_side_packet("DOUBLE:1:side_packet4"); PacketTypeSet input_side_packet_types( - tool::TagMap::Create(node.input_side_packet()).ValueOrDie()); + tool::TagMap::Create(node.input_side_packet()).value()); input_side_packet_types.Index(0).Set( // An age ); diff --git a/mediapipe/framework/tool/name_util.cc b/mediapipe/framework/tool/name_util.cc index 695b5b2e4..4784441d7 100644 --- a/mediapipe/framework/tool/name_util.cc +++ b/mediapipe/framework/tool/name_util.cc @@ -17,6 +17,7 @@ #include #include +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "mediapipe/framework/port/map_util.h" @@ -47,16 +48,14 @@ std::string GetUnusedNodeName(const CalculatorGraphConfig& config, std::string GetUnusedSidePacketName( const CalculatorGraphConfig& config, const std::string& input_side_packet_name_base) { - std::unordered_map> - input_side_packets; + absl::flat_hash_set input_side_packets; for (const ::mediapipe::CalculatorGraphConfig::Node& node : config.node()) { for (const auto& tag_and_name : node.input_side_packet()) { std::string tag; std::string name; int index; MEDIAPIPE_CHECK_OK(ParseTagIndexName(tag_and_name, &tag, &index, &name)); - input_side_packets[name].push_back(node); + input_side_packets.insert(name); } } std::string candidate = input_side_packet_name_base; diff --git a/mediapipe/framework/tool/options_util.cc b/mediapipe/framework/tool/options_util.cc new file mode 100644 index 000000000..20734b953 --- /dev/null +++ b/mediapipe/framework/tool/options_util.cc @@ -0,0 +1,17 @@ + +#include "mediapipe/framework/tool/options_util.h" + +#include "mediapipe/framework/port/proto_ns.h" + +namespace mediapipe { +namespace tool { + +// TODO: Return registered protobuf Descriptors when available. +const proto_ns::Descriptor* GetProtobufDescriptor( + const std::string& type_name) { + return proto_ns::DescriptorPool::generated_pool()->FindMessageTypeByName( + type_name); +} + +} // namespace tool +} // namespace mediapipe diff --git a/mediapipe/framework/tool/options_util.h b/mediapipe/framework/tool/options_util.h index 51d5bed4b..da943a121 100644 --- a/mediapipe/framework/tool/options_util.h +++ b/mediapipe/framework/tool/options_util.h @@ -18,7 +18,6 @@ #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/input_stream_shard.h" #include "mediapipe/framework/packet.h" -#include "mediapipe/framework/packet_generator.pb.h" #include "mediapipe/framework/packet_set.h" #include "mediapipe/framework/port/any_proto.h" #include "mediapipe/framework/tool/type_util.h" @@ -162,6 +161,9 @@ class OptionsMap { TypeMap options_; }; +// Finds the descriptor for a protobuf. +const proto_ns::Descriptor* GetProtobufDescriptor(const std::string& type_name); + } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/options_util_test.cc b/mediapipe/framework/tool/options_util_test.cc new file mode 100644 index 000000000..7935b165e --- /dev/null +++ b/mediapipe/framework/tool/options_util_test.cc @@ -0,0 +1,46 @@ +// Copyright 2019 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace { + +// Tests for calculator and graph options. +// +class OptionsUtilTest : public ::testing::Test { + protected: + void SetUp() override {} + void TearDown() override {} +}; + +// Retrieves the description of a protobuf. +TEST_F(OptionsUtilTest, GetProtobufDescriptor) { + const proto_ns::Descriptor* descriptor = + tool::GetProtobufDescriptor("mediapipe.CalculatorGraphConfig"); +#ifndef MEDIAPIPE_MOBILE + EXPECT_NE(nullptr, descriptor); +#else + EXPECT_EQ(nullptr, descriptor); +#endif +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/framework/tool/proto_util_lite.cc b/mediapipe/framework/tool/proto_util_lite.cc index 1eb2c812d..e4844d5cd 100644 --- a/mediapipe/framework/tool/proto_util_lite.cc +++ b/mediapipe/framework/tool/proto_util_lite.cc @@ -42,8 +42,8 @@ bool IsLengthDelimited(WireFormatLite::WireType wire_type) { } // Reads a single data value for a wire type. -mediapipe::Status ReadFieldValue(uint32 tag, CodedInputStream* in, - std::string* result) { +absl::Status ReadFieldValue(uint32 tag, CodedInputStream* in, + std::string* result) { WireFormatLite::WireType wire_type = WireFormatLite::GetTagWireType(tag); if (IsLengthDelimited(wire_type)) { uint32 length; @@ -59,13 +59,13 @@ mediapipe::Status ReadFieldValue(uint32 tag, CodedInputStream* in, cos.Trim(); result->assign(field_data, tag_size, std::string::npos); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Reads the packed sequence of data values for a wire type. -mediapipe::Status ReadPackedValues(WireFormatLite::WireType wire_type, - CodedInputStream* in, - std::vector* field_values) { +absl::Status ReadPackedValues(WireFormatLite::WireType wire_type, + CodedInputStream* in, + std::vector* field_values) { uint32 data_size; RET_CHECK(in->ReadVarint32(&data_size)); // fake_tag encodes the wire-type for calls to WireFormatLite::SkipField. @@ -77,15 +77,14 @@ mediapipe::Status ReadPackedValues(WireFormatLite::WireType wire_type, field_values->push_back(number); data_size -= number.size(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Extracts the data value(s) for one field from a serialized message. // The message with these field values removed is written to |out|. -mediapipe::Status GetFieldValues(uint32 field_id, - WireFormatLite::WireType wire_type, - CodedInputStream* in, CodedOutputStream* out, - std::vector* field_values) { +absl::Status GetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, + CodedInputStream* in, CodedOutputStream* out, + std::vector* field_values) { uint32 tag; while ((tag = in->ReadTag()) != 0) { int field_number = WireFormatLite::GetTagFieldNumber(tag); @@ -102,7 +101,7 @@ mediapipe::Status GetFieldValues(uint32 field_id, RET_CHECK(WireFormatLite::SkipField(in, tag, out)); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Injects the data value(s) for one field into a serialized message. @@ -122,7 +121,7 @@ void SetFieldValues(uint32 field_id, WireFormatLite::WireType wire_type, FieldAccess::FieldAccess(uint32 field_id, FieldType field_type) : field_id_(field_id), field_type_(field_type) {} -mediapipe::Status FieldAccess::SetMessage(const std::string& message) { +absl::Status FieldAccess::SetMessage(const std::string& message) { ArrayInputStream ais(message.data(), message.size()); CodedInputStream in(&ais); StringOutputStream sos(&message_); @@ -146,7 +145,7 @@ std::vector* FieldAccess::mutable_field_values() { } // Replaces a range of field values for one field nested within a protobuf. -mediapipe::Status ProtoUtilLite::ReplaceFieldRange( +absl::Status ProtoUtilLite::ReplaceFieldRange( FieldValue* message, ProtoPath proto_path, int length, FieldType field_type, const std::vector& field_values) { int field_id, index; @@ -169,11 +168,11 @@ mediapipe::Status ProtoUtilLite::ReplaceFieldRange( } message->clear(); access.GetMessage(message); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Returns a range of field values from one field nested within a protobuf. -mediapipe::Status ProtoUtilLite::GetFieldRange( +absl::Status ProtoUtilLite::GetFieldRange( const FieldValue& message, ProtoPath proto_path, int length, FieldType field_type, std::vector* field_values) { int field_id, index; @@ -194,40 +193,40 @@ mediapipe::Status ProtoUtilLite::GetFieldRange( field_values->insert(field_values->begin(), v.begin() + index, v.begin() + index + length); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // If ok, returns OkStatus, otherwise returns InvalidArgumentError. template -mediapipe::Status SyntaxStatus(bool ok, const std::string& text, T* result) { - return ok ? mediapipe::OkStatus() - : mediapipe::InvalidArgumentError(absl::StrCat( +absl::Status SyntaxStatus(bool ok, const std::string& text, T* result) { + return ok ? absl::OkStatus() + : absl::InvalidArgumentError(absl::StrCat( "Syntax error: \"", text, "\"", " for type: ", MediaPipeTypeStringOrDemangled(), ".")); } // Templated parsing of a std::string value. template -mediapipe::Status ParseValue(const std::string& text, T* result) { +absl::Status ParseValue(const std::string& text, T* result) { return SyntaxStatus(absl::SimpleAtoi(text, result), text, result); } template <> -mediapipe::Status ParseValue(const std::string& text, double* result) { +absl::Status ParseValue(const std::string& text, double* result) { return SyntaxStatus(absl::SimpleAtod(text, result), text, result); } template <> -mediapipe::Status ParseValue(const std::string& text, float* result) { +absl::Status ParseValue(const std::string& text, float* result) { return SyntaxStatus(absl::SimpleAtof(text, result), text, result); } template <> -mediapipe::Status ParseValue(const std::string& text, bool* result) { +absl::Status ParseValue(const std::string& text, bool* result) { return SyntaxStatus(absl::SimpleAtob(text, result), text, result); } template <> -mediapipe::Status ParseValue(const std::string& text, - std::string* result) { +absl::Status ParseValue(const std::string& text, + std::string* result) { *result = text; - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Templated formatting of a primitive value. @@ -238,20 +237,19 @@ std::string FormatValue(T v) { // A helper function to parse and serialize one primtive value. template -mediapipe::Status WritePrimitive( - void (*writer)(T, proto_ns::io::CodedOutputStream*), - const std::string& text, CodedOutputStream* out) { +absl::Status WritePrimitive(void (*writer)(T, proto_ns::io::CodedOutputStream*), + const std::string& text, CodedOutputStream* out) { T value; MP_RETURN_IF_ERROR(ParseValue(text, &value)); (*writer)(value, out); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Serializes a protobuf FieldValue. -static mediapipe::Status SerializeValue(const std::string& text, - FieldType field_type, - FieldValue* field_value) { - mediapipe::Status status; +static absl::Status SerializeValue(const std::string& text, + FieldType field_type, + FieldValue* field_value) { + absl::Status status; StringOutputStream sos(field_value); CodedOutputStream out(&sos); @@ -277,11 +275,11 @@ static mediapipe::Status SerializeValue(const std::string& text, case W::TYPE_BYTES: case W::TYPE_STRING: { out.WriteRaw(text.data(), text.size()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } case W::TYPE_GROUP: case W::TYPE_MESSAGE: - return mediapipe::UnimplementedError( + return absl::UnimplementedError( "SerializeValue cannot serialize a Message."); case W::TYPE_UINT32: return WritePrimitive(W::WriteUInt32NoTag, text, &out); @@ -296,27 +294,27 @@ static mediapipe::Status SerializeValue(const std::string& text, case W::TYPE_SINT64: return WritePrimitive(W::WriteSInt64NoTag, text, &out); } - return mediapipe::UnimplementedError("SerializeValue unimplemented type."); + return absl::UnimplementedError("SerializeValue unimplemented type."); } // A helper function for deserializing one text value. template -static mediapipe::Status ReadPrimitive(CodedInputStream* input, - std::string* result) { +static absl::Status ReadPrimitive(CodedInputStream* input, + std::string* result) { CType value; if (!WireFormatLite::ReadPrimitive(input, &value)) { - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Bad serialized value: ", MediaPipeTypeStringOrDemangled(), ".")); } *result = FormatValue(value); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Deserializes a protobuf FieldValue. -static mediapipe::Status DeserializeValue(const FieldValue& bytes, - FieldType field_type, - std::string* result) { +static absl::Status DeserializeValue(const FieldValue& bytes, + FieldType field_type, + std::string* result) { ArrayInputStream ais(bytes.data(), bytes.size()); CodedInputStream input(&ais); typedef WireFormatLite W; @@ -340,7 +338,7 @@ static mediapipe::Status DeserializeValue(const FieldValue& bytes, case W::TYPE_BYTES: case W::TYPE_STRING: { *result = bytes; - return mediapipe::OkStatus(); + return absl::OkStatus(); } case W::TYPE_GROUP: case W::TYPE_MESSAGE: @@ -358,10 +356,10 @@ static mediapipe::Status DeserializeValue(const FieldValue& bytes, case W::TYPE_SINT64: return ReadPrimitive(&input, result); } - return mediapipe::UnimplementedError("DeserializeValue unimplemented type."); + return absl::UnimplementedError("DeserializeValue unimplemented type."); } -mediapipe::Status ProtoUtilLite::Serialize( +absl::Status ProtoUtilLite::Serialize( const std::vector& text_values, FieldType field_type, std::vector* result) { result->clear(); @@ -371,10 +369,10 @@ mediapipe::Status ProtoUtilLite::Serialize( MP_RETURN_IF_ERROR(SerializeValue(text_value, field_type, &field_value)); result->push_back(field_value); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ProtoUtilLite::Deserialize( +absl::Status ProtoUtilLite::Deserialize( const std::vector& field_values, FieldType field_type, std::vector* result) { result->clear(); @@ -384,7 +382,7 @@ mediapipe::Status ProtoUtilLite::Deserialize( MP_RETURN_IF_ERROR(DeserializeValue(field_value, field_type, &text_value)); result->push_back(text_value); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace tool diff --git a/mediapipe/framework/tool/proto_util_lite.h b/mediapipe/framework/tool/proto_util_lite.h index f39fbd963..fcbcd7469 100644 --- a/mediapipe/framework/tool/proto_util_lite.h +++ b/mediapipe/framework/tool/proto_util_lite.h @@ -47,7 +47,7 @@ class ProtoUtilLite { FieldAccess(uint32 field_id, FieldType field_type); // Specifies the original serialized protobuf message. - mediapipe::Status SetMessage(const FieldValue& message); + absl::Status SetMessage(const FieldValue& message); // Returns the serialized protobuf message with updated field values. void GetMessage(FieldValue* result); @@ -64,26 +64,26 @@ class ProtoUtilLite { // Replace a range of field values nested within a protobuf. // Starting at the proto_path index, "length" values are replaced. - static mediapipe::Status ReplaceFieldRange( + static absl::Status ReplaceFieldRange( FieldValue* message, ProtoPath proto_path, int length, FieldType field_type, const std::vector& field_values); // Retrieve a range of field values nested within a protobuf. // Starting at the proto_path index, "length" values are retrieved. - static mediapipe::Status GetFieldRange(const FieldValue& message, - ProtoPath proto_path, int length, - FieldType field_type, - std::vector* field_values); + static absl::Status GetFieldRange(const FieldValue& message, + ProtoPath proto_path, int length, + FieldType field_type, + std::vector* field_values); // Serialize one or more protobuf field values from text. - static mediapipe::Status Serialize( - const std::vector& text_values, FieldType field_type, - std::vector* result); + static absl::Status Serialize(const std::vector& text_values, + FieldType field_type, + std::vector* result); // Deserialize one or more protobuf field values to text. - static mediapipe::Status Deserialize( - const std::vector& field_values, FieldType field_type, - std::vector* result); + static absl::Status Deserialize(const std::vector& field_values, + FieldType field_type, + std::vector* result); }; } // namespace tool diff --git a/mediapipe/framework/tool/simple_subgraph_template.cc b/mediapipe/framework/tool/simple_subgraph_template.cc index 1978a2955..606694632 100644 --- a/mediapipe/framework/tool/simple_subgraph_template.cc +++ b/mediapipe/framework/tool/simple_subgraph_template.cc @@ -20,24 +20,26 @@ namespace mediapipe { +// clang-format off static const char binary_graph[] = #include "{{SUBGRAPH_INC_FILE_PATH}}" ; // NOLINT(whitespace/semicolon) class {{SUBGRAPH_CLASS_NAME}} : public Subgraph { public: - ::mediapipe::StatusOr GetConfig( - const SubgraphOptions& /*options*/) { + absl::StatusOr GetConfig( + const SubgraphOptions& /*options*/) { CalculatorGraphConfig config; // Note: this is a binary protobuf serialization, and may include NUL // bytes. The trailing NUL added to the std::string literal should be excluded. if (config.ParseFromArray(binary_graph, sizeof(binary_graph) - 1)) { return config; } else { - return ::mediapipe::InternalError("Could not parse subgraph."); + return absl::InternalError("Could not parse subgraph."); } } }; REGISTER_MEDIAPIPE_GRAPH({{SUBGRAPH_CLASS_NAME}}); +// clang-format on } // namespace mediapipe diff --git a/mediapipe/framework/tool/simulation_clock_test.cc b/mediapipe/framework/tool/simulation_clock_test.cc index fca9ebbd0..85b4ea58e 100644 --- a/mediapipe/framework/tool/simulation_clock_test.cc +++ b/mediapipe/framework/tool/simulation_clock_test.cc @@ -177,19 +177,19 @@ TEST_F(SimulationClockTest, DuplicateWakeTimes) { } // A Calculator::Process callback function. -typedef std::function +typedef std::function ProcessFunction; // A testing callback function that passes through all packets. -mediapipe::Status PassThrough(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { +absl::Status PassThrough(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { for (int i = 0; i < inputs.NumEntries(); ++i) { if (!inputs.Index(i).Value().IsEmpty()) { outputs->Index(i).AddPacket(inputs.Index(i).Value()); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // This test shows sim clock synchronizing a bunch of parallel tasks. @@ -267,7 +267,7 @@ TEST_F(SimulationClockTest, DestroyClock) { if (++input_count < 4) { outputs->Index(0).AddPacket( MakePacket(input_count).At(Timestamp(input_count))); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } else { return tool::StatusStop(); } @@ -279,7 +279,7 @@ TEST_F(SimulationClockTest, DestroyClock) { }; std::vector out_packets; - ::mediapipe::Status status; + absl::Status status; { CalculatorGraph graph; auto executor = std::make_shared(4); diff --git a/mediapipe/framework/tool/sink.cc b/mediapipe/framework/tool/sink.cc index 5382975b7..2ac17f8e1 100644 --- a/mediapipe/framework/tool/sink.cc +++ b/mediapipe/framework/tool/sink.cc @@ -45,20 +45,20 @@ namespace { class MediaPipeInternalSidePacketToPacketStreamCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Index(0).SetAny(); cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->Outputs().Index(0).AddPacket( cc->InputSidePackets().Index(0).At(Timestamp::PostStream())); cc->Outputs().Index(0).Close(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { // The framework treats this calculator as a source calculator. return mediapipe::tool::StatusStop(); } @@ -222,7 +222,7 @@ void AddCallbackWithHeaderCalculator(const std::string& stream_name, // CallbackCalculator // static -mediapipe::Status CallbackCalculator::GetContract(CalculatorContract* cc) { +absl::Status CallbackCalculator::GetContract(CalculatorContract* cc) { bool allow_multiple_streams = false; // If the input side packet is specified using tag "CALLBACK" it must contain // a std::function, which may be generated by CallbackPacketCalculator. @@ -246,10 +246,10 @@ mediapipe::Status CallbackCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Index(i).SetAny(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CallbackCalculator::Open(CalculatorContext* cc) { +absl::Status CallbackCalculator::Open(CalculatorContext* cc) { if (cc->InputSidePackets().HasTag("CALLBACK")) { callback_ = cc->InputSidePackets() .Tag("CALLBACK") @@ -266,10 +266,10 @@ mediapipe::Status CallbackCalculator::Open(CalculatorContext* cc) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "missing callback."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CallbackCalculator::Process(CalculatorContext* cc) { +absl::Status CallbackCalculator::Process(CalculatorContext* cc) { if (callback_) { callback_(cc->Inputs().Index(0).Value()); } else if (vector_callback_) { @@ -281,7 +281,7 @@ mediapipe::Status CallbackCalculator::Process(CalculatorContext* cc) { } vector_callback_(packets); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(CallbackCalculator); @@ -289,8 +289,7 @@ REGISTER_CALCULATOR(CallbackCalculator); // CallbackWithHeaderCalculator // static -mediapipe::Status CallbackWithHeaderCalculator::GetContract( - CalculatorContract* cc) { +absl::Status CallbackWithHeaderCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag("INPUT").SetAny(); cc->Inputs().Tag("HEADER").SetAny(); @@ -303,10 +302,10 @@ mediapipe::Status CallbackWithHeaderCalculator::GetContract( return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "InputSidePackets must use tags."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CallbackWithHeaderCalculator::Open(CalculatorContext* cc) { +absl::Status CallbackWithHeaderCalculator::Open(CalculatorContext* cc) { if (cc->InputSidePackets().UsesTags()) { callback_ = cc->InputSidePackets() .Tag("CALLBACK") @@ -333,10 +332,10 @@ mediapipe::Status CallbackWithHeaderCalculator::Open(CalculatorContext* cc) { if (!cc->Inputs().Tag("INPUT").Header().IsEmpty()) { header_packet_ = cc->Inputs().Tag("INPUT").Header(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CallbackWithHeaderCalculator::Process(CalculatorContext* cc) { +absl::Status CallbackWithHeaderCalculator::Process(CalculatorContext* cc) { if (!cc->Inputs().Tag("INPUT").Value().IsEmpty() && header_packet_.IsEmpty()) { // Header packet should be available before we receive any normal input @@ -351,7 +350,7 @@ mediapipe::Status CallbackWithHeaderCalculator::Process(CalculatorContext* cc) { if (!cc->Inputs().Tag("INPUT").Value().IsEmpty()) { callback_(cc->Inputs().Tag("INPUT").Value(), header_packet_); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } REGISTER_CALCULATOR(CallbackWithHeaderCalculator); diff --git a/mediapipe/framework/tool/sink.h b/mediapipe/framework/tool/sink.h index f563f603f..8f09269fc 100644 --- a/mediapipe/framework/tool/sink.h +++ b/mediapipe/framework/tool/sink.h @@ -166,10 +166,10 @@ class CallbackCalculator : public CalculatorBase { ~CallbackCalculator() override {} - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: std::function callback_; @@ -185,10 +185,10 @@ class CallbackWithHeaderCalculator : public CalculatorBase { ~CallbackWithHeaderCalculator() override {} - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: std::function callback_; diff --git a/mediapipe/framework/tool/sink_test.cc b/mediapipe/framework/tool/sink_test.cc index 4e64a5e8d..16a8a7208 100644 --- a/mediapipe/framework/tool/sink_test.cc +++ b/mediapipe/framework/tool/sink_test.cc @@ -31,21 +31,21 @@ namespace mediapipe { namespace { class CountAndOutputSummarySidePacketInCloseCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->OutputSidePackets().Index(0).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { ++count_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) final { + absl::Status Close(CalculatorContext* cc) final { cc->OutputSidePackets().Index(0).Set( MakePacket(count_).At(Timestamp::Unset())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } int count_ = 0; diff --git a/mediapipe/framework/tool/source.cc b/mediapipe/framework/tool/source.cc index 3bdc0faa2..9ea1fd2a1 100644 --- a/mediapipe/framework/tool/source.cc +++ b/mediapipe/framework/tool/source.cc @@ -43,19 +43,19 @@ class SidePacketsToStreamsCalculator : public CalculatorBase { const SidePacketsToStreamsCalculator&) = delete; ~SidePacketsToStreamsCalculator() override {} - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { auto& options = cc->Options(); if (options.has_num_inputs() && (options.num_inputs() != cc->InputSidePackets().NumEntries() || options.num_inputs() != cc->Outputs().NumEntries())) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "If num_inputs is specified it must be equal to the number of " "input side packets and output streams."); } if (!options.vectors_of_packets() && options.set_timestamp() == SidePacketsToStreamsCalculatorOptions::NONE) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "If set_timestamp is NONE, vectors_of_packets must not be false."); } for (int i = 0; i < cc->InputSidePackets().NumEntries(); ++i) { @@ -72,10 +72,10 @@ class SidePacketsToStreamsCalculator : public CalculatorBase { cc->Outputs().Index(i).SetSameAs(&cc->InputSidePackets().Index(i)); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { const auto& options = cc->Options(); // The i-th input side packet contains a vector of packets corresponding // to the values of this input for all batch elements. @@ -87,7 +87,7 @@ class SidePacketsToStreamsCalculator : public CalculatorBase { const auto& packets = input_side_packet.Get>(); if (batch_size >= 0) { if (packets.size() != batch_size) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "The specified input side packets contain vectors of different " "sizes."); } diff --git a/mediapipe/framework/tool/status_util.cc b/mediapipe/framework/tool/status_util.cc index ad5a69a2c..57faa3899 100644 --- a/mediapipe/framework/tool/status_util.cc +++ b/mediapipe/framework/tool/status_util.cc @@ -22,46 +22,44 @@ namespace mediapipe { namespace tool { -mediapipe::Status StatusInvalid(const std::string& message) { - return mediapipe::Status(mediapipe::StatusCode::kInvalidArgument, message); +absl::Status StatusInvalid(const std::string& message) { + return absl::Status(absl::StatusCode::kInvalidArgument, message); } -mediapipe::Status StatusFail(const std::string& message) { - return mediapipe::Status(mediapipe::StatusCode::kUnknown, message); +absl::Status StatusFail(const std::string& message) { + return absl::Status(absl::StatusCode::kUnknown, message); } -mediapipe::Status StatusStop() { - return mediapipe::Status(mediapipe::StatusCode::kOutOfRange, - "mediapipe::tool::StatusStop()"); +absl::Status StatusStop() { + return absl::Status(absl::StatusCode::kOutOfRange, + "mediapipe::tool::StatusStop()"); } -mediapipe::Status AddStatusPrefix(const std::string& prefix, - const mediapipe::Status& status) { - return mediapipe::Status(status.code(), - absl::StrCat(prefix, status.message())); +absl::Status AddStatusPrefix(const std::string& prefix, + const absl::Status& status) { + return absl::Status(status.code(), absl::StrCat(prefix, status.message())); } -mediapipe::Status CombinedStatus( - const std::string& general_comment, - const std::vector& statuses) { - // The final error code is mediapipe::StatusCode::kUnknown if not all +absl::Status CombinedStatus(const std::string& general_comment, + const std::vector& statuses) { + // The final error code is absl::StatusCode::kUnknown if not all // the error codes are the same. Otherwise it is the same error code // as all of the (non-OK) statuses. If statuses is empty or they are - // all OK, then mediapipe::OkStatus() is returned. - mediapipe::StatusCode error_code = mediapipe::StatusCode::kOk; + // all OK, then absl::OkStatus() is returned. + absl::StatusCode error_code = absl::StatusCode::kOk; std::vector errors; - for (const mediapipe::Status& status : statuses) { + for (const absl::Status& status : statuses) { if (!status.ok()) { errors.emplace_back(status.message()); - if (error_code == mediapipe::StatusCode::kOk) { + if (error_code == absl::StatusCode::kOk) { error_code = status.code(); } else if (error_code != status.code()) { - error_code = mediapipe::StatusCode::kUnknown; + error_code = absl::StatusCode::kUnknown; } } } if (error_code == StatusCode::kOk) return OkStatus(); - Status combined = mediapipe::Status( + Status combined = absl::Status( error_code, absl::StrCat(general_comment, "\n", absl::StrJoin(errors, "\n"))); return combined; diff --git a/mediapipe/framework/tool/status_util.h b/mediapipe/framework/tool/status_util.h index 92e0c9fab..039f55609 100644 --- a/mediapipe/framework/tool/status_util.h +++ b/mediapipe/framework/tool/status_util.h @@ -29,31 +29,30 @@ namespace tool { // be called on it again). When returned from a non-source Calculator // it signals that the graph should be cancelled (which is handled by // closing all source Calculators and waiting for the graph to finish). -mediapipe::Status StatusStop(); +absl::Status StatusStop(); // Return a status which signals an invalid initial condition (for // example an InputSidePacket does not include all necessary fields). -ABSL_DEPRECATED("Use mediapipe::InvalidArgumentError(error_message) instead.") -mediapipe::Status StatusInvalid(const std::string& error_message); +ABSL_DEPRECATED("Use absl::InvalidArgumentError(error_message) instead.") +absl::Status StatusInvalid(const std::string& error_message); // Return a status which signals that something unexpectedly failed. -ABSL_DEPRECATED("Use mediapipe::UnknownError(error_message) instead.") -mediapipe::Status StatusFail(const std::string& error_message); +ABSL_DEPRECATED("Use absl::UnknownError(error_message) instead.") +absl::Status StatusFail(const std::string& error_message); // Prefixes the given std::string to the error message in status. // This function should be considered internal to the framework. // TODO Replace usage of AddStatusPrefix with util::Annotate(). -mediapipe::Status AddStatusPrefix(const std::string& prefix, - const mediapipe::Status& status); +absl::Status AddStatusPrefix(const std::string& prefix, + const absl::Status& status); -// Combine a vector of mediapipe::Status into a single composite status. -// If statuses is empty or all statuses are OK then mediapipe::OkStatus() +// Combine a vector of absl::Status into a single composite status. +// If statuses is empty or all statuses are OK then absl::OkStatus() // will be returned. // This function should be considered internal to the framework. // TODO Move this function to somewhere with less visibility. -mediapipe::Status CombinedStatus( - const std::string& general_comment, - const std::vector& statuses); +absl::Status CombinedStatus(const std::string& general_comment, + const std::vector& statuses); } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/status_util_test.cc b/mediapipe/framework/tool/status_util_test.cc index 711cb4148..c7e845aa6 100644 --- a/mediapipe/framework/tool/status_util_test.cc +++ b/mediapipe/framework/tool/status_util_test.cc @@ -36,24 +36,24 @@ TEST(StatusTest, StatusStopIsNotOk) { EXPECT_FALSE(tool::StatusStop().ok()); } TEST(StatusTest, Prefix) { const std::string base_error_message("error_with_this_string"); const std::string prefix_error_message("error_with_prefix: "); - mediapipe::Status base_status = mediapipe::Status( - mediapipe::StatusCode::kInvalidArgument, base_error_message); - mediapipe::Status status = + absl::Status base_status = + absl::Status(absl::StatusCode::kInvalidArgument, base_error_message); + absl::Status status = tool::AddStatusPrefix(prefix_error_message, base_status); EXPECT_THAT(status.ToString(), HasSubstr(base_error_message)); EXPECT_THAT(status.ToString(), HasSubstr(prefix_error_message)); - EXPECT_EQ(mediapipe::StatusCode::kInvalidArgument, status.code()); + EXPECT_EQ(absl::StatusCode::kInvalidArgument, status.code()); } TEST(StatusTest, CombinedStatus) { - std::vector errors; + std::vector errors; const std::string prefix_error_message("error_with_prefix: "); - mediapipe::Status status; + absl::Status status; errors.clear(); - errors.emplace_back(mediapipe::StatusCode::kInvalidArgument, + errors.emplace_back(absl::StatusCode::kInvalidArgument, "error_with_this_string"); - errors.emplace_back(mediapipe::StatusCode::kInvalidArgument, + errors.emplace_back(absl::StatusCode::kInvalidArgument, "error_with_that_string"); errors.back().SetPayload("test payload type", absl::Cord(absl::string_view("hello"))); @@ -61,30 +61,29 @@ TEST(StatusTest, CombinedStatus) { EXPECT_THAT(status.ToString(), HasSubstr(std::string(errors[0].message()))); EXPECT_THAT(status.ToString(), HasSubstr(std::string(errors[1].message()))); EXPECT_THAT(status.ToString(), HasSubstr(prefix_error_message)); - EXPECT_EQ(mediapipe::StatusCode::kInvalidArgument, status.code()); + EXPECT_EQ(absl::StatusCode::kInvalidArgument, status.code()); errors.clear(); - errors.emplace_back(mediapipe::StatusCode::kNotFound, - "error_with_this_string"); - errors.emplace_back(mediapipe::StatusCode::kInvalidArgument, + errors.emplace_back(absl::StatusCode::kNotFound, "error_with_this_string"); + errors.emplace_back(absl::StatusCode::kInvalidArgument, "error_with_that_string"); status = tool::CombinedStatus(prefix_error_message, errors); EXPECT_THAT(status.ToString(), HasSubstr(std::string(errors[0].message()))); EXPECT_THAT(status.ToString(), HasSubstr(std::string(errors[1].message()))); EXPECT_THAT(status.ToString(), HasSubstr(prefix_error_message)); - EXPECT_EQ(mediapipe::StatusCode::kUnknown, status.code()); + EXPECT_EQ(absl::StatusCode::kUnknown, status.code()); errors.clear(); - errors.emplace_back(mediapipe::StatusCode::kOk, "error_with_this_string"); - errors.emplace_back(mediapipe::StatusCode::kInvalidArgument, + errors.emplace_back(absl::StatusCode::kOk, "error_with_this_string"); + errors.emplace_back(absl::StatusCode::kInvalidArgument, "error_with_that_string"); status = tool::CombinedStatus(prefix_error_message, errors); EXPECT_THAT(status.ToString(), HasSubstr(std::string(errors[1].message()))); EXPECT_THAT(status.ToString(), HasSubstr(prefix_error_message)); - EXPECT_EQ(mediapipe::StatusCode::kInvalidArgument, status.code()); + EXPECT_EQ(absl::StatusCode::kInvalidArgument, status.code()); errors.clear(); - errors.emplace_back(mediapipe::StatusCode::kOk, "error_with_this_string"); - errors.emplace_back(mediapipe::StatusCode::kOk, "error_with_that_string"); + errors.emplace_back(absl::StatusCode::kOk, "error_with_this_string"); + errors.emplace_back(absl::StatusCode::kOk, "error_with_that_string"); MP_EXPECT_OK(tool::CombinedStatus(prefix_error_message, errors)); errors.clear(); @@ -93,13 +92,13 @@ TEST(StatusTest, CombinedStatus) { // Verify tool::StatusInvalid() and tool::StatusFail() and the alternatives // recommended by their ABSL_DEPRECATED messages return the same -// mediapipe::Status objects. +// absl::Status objects. TEST(StatusTest, Deprecated) { const std::string error_message = "an error message"; EXPECT_EQ(tool::StatusInvalid(error_message), // NOLINT - mediapipe::InvalidArgumentError(error_message)); + absl::InvalidArgumentError(error_message)); EXPECT_EQ(tool::StatusFail(error_message), // NOLINT - mediapipe::UnknownError(error_message)); + absl::UnknownError(error_message)); } } // namespace diff --git a/mediapipe/framework/tool/subgraph_expansion.cc b/mediapipe/framework/tool/subgraph_expansion.cc index 9483048fa..ab5a7c464 100644 --- a/mediapipe/framework/tool/subgraph_expansion.cc +++ b/mediapipe/framework/tool/subgraph_expansion.cc @@ -42,22 +42,22 @@ namespace mediapipe { namespace tool { -mediapipe::Status TransformStreamNames( +absl::Status TransformStreamNames( proto_ns::RepeatedPtrField* streams, const std::function& transform) { for (auto& stream : *streams) { absl::string_view port_and_name(stream); - auto colon_pos = port_and_name.find_last_of(":"); + auto colon_pos = port_and_name.find_last_of(':'); auto name_pos = colon_pos == absl::string_view::npos ? 0 : colon_pos + 1; stream = absl::StrCat(port_and_name.substr(0, name_pos), transform(absl::ClippedSubstr(port_and_name, name_pos))); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Returns subgraph streams not requested by a subgraph-node. -mediapipe::Status FindIgnoredStreams( +absl::Status FindIgnoredStreams( const proto_ns::RepeatedPtrField& src_streams, const proto_ns::RepeatedPtrField& dst_streams, std::set* result) { @@ -69,11 +69,11 @@ mediapipe::Status FindIgnoredStreams( result->insert(src_map->Names()[id.value()]); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Removes subgraph streams not requested by a subgraph-node. -mediapipe::Status RemoveIgnoredStreams( +absl::Status RemoveIgnoredStreams( proto_ns::RepeatedPtrField* streams, const std::set& missing_streams) { for (int i = streams->size() - 1; i >= 0; --i) { @@ -84,10 +84,10 @@ mediapipe::Status RemoveIgnoredStreams( streams->DeleteSubrange(i, 1); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TransformNames( +absl::Status TransformNames( CalculatorGraphConfig* config, const std::function& transform) { RET_CHECK_EQ(config->packet_factory().size(), 0); @@ -122,7 +122,7 @@ mediapipe::Status TransformNames( MP_RETURN_IF_ERROR(TransformStreamNames( status_handler.mutable_input_side_packet(), transform)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Adds a prefix to the name of each stream, side packet and node in the @@ -131,8 +131,8 @@ mediapipe::Status TransformNames( // 2, { foo, bar } --PrefixNames-> { rsg__foo, rsg__bar } // This means that two copies of the same subgraph will not interfere with // each other. -static mediapipe::Status PrefixNames(std::string prefix, - CalculatorGraphConfig* config) { +static absl::Status PrefixNames(std::string prefix, + CalculatorGraphConfig* config) { std::transform(prefix.begin(), prefix.end(), prefix.begin(), ::tolower); std::replace(prefix.begin(), prefix.end(), '.', '_'); std::replace(prefix.begin(), prefix.end(), ' ', '_'); @@ -144,7 +144,7 @@ static mediapipe::Status PrefixNames(std::string prefix, return TransformNames(config, add_prefix); } -mediapipe::Status FindCorrespondingStreams( +absl::Status FindCorrespondingStreams( std::map* stream_map, const proto_ns::RepeatedPtrField& src_streams, const proto_ns::RepeatedPtrField& dst_streams) { @@ -175,14 +175,14 @@ mediapipe::Status FindCorrespondingStreams( (*stream_map)[src_name] = dst_name; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // The following fields can be used in a Node message for a subgraph: // name, calculator, input_stream, output_stream, input_side_packet, // output_side_packet, options. // All other fields are only applicable to calculators. -mediapipe::Status ValidateSubgraphFields( +absl::Status ValidateSubgraphFields( const CalculatorGraphConfig::Node& subgraph_node) { if (subgraph_node.source_layer() || subgraph_node.buffer_size_hint() || subgraph_node.has_input_stream_handler() || @@ -193,10 +193,10 @@ mediapipe::Status ValidateSubgraphFields( << "Subgraph \"" << subgraph_node.name() << "\" has a field that is only applicable to calculators."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ConnectSubgraphStreams( +absl::Status ConnectSubgraphStreams( const CalculatorGraphConfig::Node& subgraph_node, CalculatorGraphConfig* subgraph_config) { std::map stream_map; @@ -269,11 +269,11 @@ mediapipe::Status ConnectSubgraphStreams( MP_RETURN_IF_ERROR(RemoveIgnoredStreams( generator.mutable_input_side_packet(), ignored_input_side_packets)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ExpandSubgraphs(CalculatorGraphConfig* config, - const GraphRegistry* graph_registry) { +absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, + const GraphRegistry* graph_registry) { graph_registry = graph_registry ? graph_registry : &GraphRegistry::global_graph_registry; RET_CHECK(config); @@ -313,7 +313,7 @@ mediapipe::Status ExpandSubgraphs(CalculatorGraphConfig* config, config->mutable_status_handler())); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } CalculatorGraphConfig MakeSingleNodeGraph(CalculatorGraphConfig::Node node) { diff --git a/mediapipe/framework/tool/subgraph_expansion.h b/mediapipe/framework/tool/subgraph_expansion.h index 6c422a72a..5c4e1c5cf 100644 --- a/mediapipe/framework/tool/subgraph_expansion.h +++ b/mediapipe/framework/tool/subgraph_expansion.h @@ -29,13 +29,13 @@ namespace tool { // Apply the given transformation function to the names of streams and // side packets. -mediapipe::Status TransformStreamNames( +absl::Status TransformStreamNames( proto_ns::RepeatedPtrField* streams, const std::function& transform); // Apply the given transformation function to the names of streams, // side packets, and nodes. -mediapipe::Status TransformNames( +absl::Status TransformNames( CalculatorGraphConfig* config, const std::function& transform); @@ -48,7 +48,7 @@ mediapipe::Status TransformNames( // src: FOO:abc dst: FOO:bob // BAR:def // The entry 'abc' -> 'bob' is added to the map. -mediapipe::Status FindCorrespondingStreams( +absl::Status FindCorrespondingStreams( std::map* stream_map, const proto_ns::RepeatedPtrField& src_streams, const proto_ns::RepeatedPtrField& dst_streams); @@ -56,21 +56,20 @@ mediapipe::Status FindCorrespondingStreams( // Validates the fields in the given Node message that specifies a subgraph. // Returns an error status if the Node message contains any field that is only // applicable to calculators. -mediapipe::Status ValidateSubgraphFields( +absl::Status ValidateSubgraphFields( const CalculatorGraphConfig::Node& subgraph_node); // Renames the streams in a subgraph config to match the connections on the // wrapping node. -mediapipe::Status ConnectSubgraphStreams( +absl::Status ConnectSubgraphStreams( const CalculatorGraphConfig::Node& subgraph_node, CalculatorGraphConfig* subgraph_config); // Replaces subgraph nodes in the given config with the contents of the // corresponding subgraphs. Nested subgraphs are retrieved from the // graph registry and expanded recursively. -mediapipe::Status ExpandSubgraphs( - CalculatorGraphConfig* config, - const GraphRegistry* graph_registry = nullptr); +absl::Status ExpandSubgraphs(CalculatorGraphConfig* config, + const GraphRegistry* graph_registry = nullptr); // Creates a graph wrapping the provided node and exposing all of its // connections diff --git a/mediapipe/framework/tool/subgraph_expansion_test.cc b/mediapipe/framework/tool/subgraph_expansion_test.cc index 7de3f80f6..07e0b512d 100644 --- a/mediapipe/framework/tool/subgraph_expansion_test.cc +++ b/mediapipe/framework/tool/subgraph_expansion_test.cc @@ -38,10 +38,10 @@ namespace { class SimpleTestCalculator : public CalculatorBase { public: - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { for (PacketType& type : cc->Inputs()) { type.Set(); } @@ -51,7 +51,7 @@ class SimpleTestCalculator : public CalculatorBase { for (PacketType& type : cc->InputSidePackets()) { type.Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(SimpleTestCalculator); @@ -66,7 +66,7 @@ REGISTER_CALCULATOR(SomeAggregator); class TestSubgraph : public Subgraph { public: - mediapipe::StatusOr GetConfig( + absl::StatusOr GetConfig( const SubgraphOptions& /*options*/) override { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"( @@ -95,7 +95,7 @@ REGISTER_MEDIAPIPE_GRAPH(TestSubgraph); class PacketFactoryTestSubgraph : public Subgraph { public: - mediapipe::StatusOr GetConfig( + absl::StatusOr GetConfig( const SubgraphOptions& /*options*/) override { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"( @@ -126,7 +126,7 @@ REGISTER_MEDIAPIPE_GRAPH(PacketFactoryTestSubgraph); // and the number of copies of the node are specified in subgraph options. class NodeChainSubgraph : public Subgraph { public: - mediapipe::StatusOr GetConfig( + absl::StatusOr GetConfig( const SubgraphOptions& options) override { auto opts = Subgraph::GetOptions(options); @@ -152,7 +152,7 @@ REGISTER_MEDIAPIPE_GRAPH(NodeChainSubgraph); // subgraph contains a node with the executor field "custom_thread_pool". class NodeWithExecutorSubgraph : public Subgraph { public: - mediapipe::StatusOr GetConfig( + absl::StatusOr GetConfig( const SubgraphOptions& options) override { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"( @@ -174,7 +174,7 @@ REGISTER_MEDIAPIPE_GRAPH(NodeWithExecutorSubgraph); // subgraph contains a NodeWithExecutorSubgraph. class EnclosingSubgraph : public Subgraph { public: - mediapipe::StatusOr GetConfig( + absl::StatusOr GetConfig( const SubgraphOptions& options) override { CalculatorGraphConfig config = mediapipe::ParseTextProtoOrDie(R"( @@ -474,12 +474,12 @@ TEST(SubgraphExpansionTest, ValidateSubgraphFields) { buffer_size_hint: -1 # This field is only applicable to calculators. } )"); - mediapipe::Status s1 = tool::ValidateSubgraphFields(supergraph.node(1)); - EXPECT_EQ(s1.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status s1 = tool::ValidateSubgraphFields(supergraph.node(1)); + EXPECT_EQ(s1.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(s1.message(), testing::HasSubstr("foo_subgraph")); - mediapipe::Status s2 = tool::ExpandSubgraphs(&supergraph); - EXPECT_EQ(s2.code(), mediapipe::StatusCode::kInvalidArgument); + absl::Status s2 = tool::ExpandSubgraphs(&supergraph); + EXPECT_EQ(s2.code(), absl::StatusCode::kInvalidArgument); EXPECT_THAT(s2.message(), testing::HasSubstr("foo_subgraph")); } diff --git a/mediapipe/framework/tool/switch_container.cc b/mediapipe/framework/tool/switch_container.cc index b923460a3..f91275be7 100644 --- a/mediapipe/framework/tool/switch_container.cc +++ b/mediapipe/framework/tool/switch_container.cc @@ -62,7 +62,7 @@ using mediapipe::SwitchContainerOptions; class SwitchContainer : public Subgraph { public: SwitchContainer() = default; - mediapipe::StatusOr GetConfig( + absl::StatusOr GetConfig( const Subgraph::SubgraphOptions& options) override; }; REGISTER_MEDIAPIPE_GRAPH(SwitchContainer); @@ -157,7 +157,7 @@ void GetContainerNodeStreams(const CalculatorGraphConfig::Node& node, } // Validate all subgraph inputs and outputs. -mediapipe::Status ValidateContract( +absl::Status ValidateContract( const CalculatorGraphConfig::Node& subgraph_node, const Subgraph::SubgraphOptions& subgraph_options) { auto options = @@ -166,20 +166,20 @@ mediapipe::Status ValidateContract( ParseTags(subgraph_node.input_stream(), &input_tags); ParseTags(subgraph_node.input_side_packet(), &side_tags); if (options.has_select() && options.has_enable()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Only one of SwitchContainer options 'enable' and 'select' can be " "specified"); } if (side_tags.count({"SELECT", 0}) + side_tags.count({"ENABLE", 0}) > 1 || input_tags.count({"SELECT", 0}) + input_tags.count({"ENABLE", 0}) > 1) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Only one of SwitchContainer inputs 'ENABLE' and 'SELECT' can be " "specified"); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::StatusOr SwitchContainer::GetConfig( +absl::StatusOr SwitchContainer::GetConfig( const Subgraph::SubgraphOptions& options) { CalculatorGraphConfig config; std::vector subnodes; diff --git a/mediapipe/framework/tool/switch_container_test.cc b/mediapipe/framework/tool/switch_container_test.cc index b214dc105..402b1991e 100644 --- a/mediapipe/framework/tool/switch_container_test.cc +++ b/mediapipe/framework/tool/switch_container_test.cc @@ -34,7 +34,7 @@ namespace { // It also accepts a side packet tagged "TIMEZONE", but doesn't use it. class TripleIntCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Index(0).Set().Optional(); cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0)).Optional(); cc->InputSidePackets().Index(0).Set().Optional(); @@ -43,22 +43,22 @@ class TripleIntCalculator : public CalculatorBase { .SetSameAs(&cc->InputSidePackets().Index(0)) .Optional(); cc->InputSidePackets().Tag("TIMEZONE").Set().Optional(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) final { + absl::Status Open(CalculatorContext* cc) final { cc->SetOffset(TimestampDiff(0)); if (cc->OutputSidePackets().HasTag("")) { cc->OutputSidePackets().Index(0).Set( MakePacket(cc->InputSidePackets().Index(0).Get() * 3)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) final { + absl::Status Process(CalculatorContext* cc) final { int value = cc->Inputs().Index(0).Value().Get(); cc->Outputs().Index(0).Add(new int(3 * value), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(TripleIntCalculator); @@ -188,7 +188,7 @@ void RunTestSideContainer(CalculatorGraphConfig supergraph) { })); MP_ASSERT_OK(graph.CloseAllInputStreams()); MP_ASSERT_OK(graph.WaitUntilDone()); - Packet side_output = graph.GetOutputSidePacket("output_bar").ValueOrDie(); + Packet side_output = graph.GetOutputSidePacket("output_bar").value(); EXPECT_EQ(side_output.Get(), 12); MP_ASSERT_OK(graph.StartRun({ @@ -197,7 +197,7 @@ void RunTestSideContainer(CalculatorGraphConfig supergraph) { })); MP_ASSERT_OK(graph.CloseAllInputStreams()); MP_ASSERT_OK(graph.WaitUntilDone()); - side_output = graph.GetOutputSidePacket("output_bar").ValueOrDie(); + side_output = graph.GetOutputSidePacket("output_bar").value(); EXPECT_EQ(side_output.Get(), 4); } @@ -359,7 +359,7 @@ TEST(SwitchContainerTest, ValidateSideInputs) { )"); auto status = tool::ExpandSubgraphs(&supergraph); EXPECT_EQ(std::pair(status.code(), std::string(status.message())), - std::pair(mediapipe::StatusCode::kInvalidArgument, + std::pair(absl::StatusCode::kInvalidArgument, std::string("Only one of SwitchContainer inputs " "'ENABLE' and 'SELECT' can be specified"))); } diff --git a/mediapipe/framework/tool/switch_demux_calculator.cc b/mediapipe/framework/tool/switch_demux_calculator.cc index 0a16b1f3e..35f9cc0a0 100644 --- a/mediapipe/framework/tool/switch_demux_calculator.cc +++ b/mediapipe/framework/tool/switch_demux_calculator.cc @@ -57,10 +57,10 @@ class SwitchDemuxCalculator : public CalculatorBase { static constexpr char kEnableTag[] = "ENABLE"; public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: int channel_index_; @@ -68,7 +68,7 @@ class SwitchDemuxCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(SwitchDemuxCalculator); -mediapipe::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) { +absl::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) { // Allow any one of kSelectTag, kEnableTag. if (cc->Inputs().HasTag(kSelectTag)) { cc->Inputs().Tag(kSelectTag).Set(); @@ -121,10 +121,10 @@ mediapipe::Status SwitchDemuxCalculator::GetContract(CalculatorContract* cc) { } cc->SetInputStreamHandler("ImmediateInputStreamHandler"); cc->SetProcessTimestampBounds(true); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { +absl::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { channel_index_ = tool::GetChannelIndex(*cc, channel_index_); channel_tags_ = ChannelTags(cc->Outputs().TagMap()); @@ -145,10 +145,10 @@ mediapipe::Status SwitchDemuxCalculator::Open(CalculatorContext* cc) { } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) { +absl::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) { // Update the input channel index if specified. channel_index_ = tool::GetChannelIndex(*cc, channel_index_); @@ -164,7 +164,7 @@ mediapipe::Status SwitchDemuxCalculator::Process(CalculatorContext* cc) { } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/framework/tool/switch_mux_calculator.cc b/mediapipe/framework/tool/switch_mux_calculator.cc index aaa36ac04..dd120a2ed 100644 --- a/mediapipe/framework/tool/switch_mux_calculator.cc +++ b/mediapipe/framework/tool/switch_mux_calculator.cc @@ -60,10 +60,10 @@ class SwitchMuxCalculator : public CalculatorBase { static constexpr char kEnableTag[] = "ENABLE"; public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: int channel_index_; @@ -71,7 +71,7 @@ class SwitchMuxCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(SwitchMuxCalculator); -mediapipe::Status SwitchMuxCalculator::GetContract(CalculatorContract* cc) { +absl::Status SwitchMuxCalculator::GetContract(CalculatorContract* cc) { // Allow any one of kSelectTag, kEnableTag. if (cc->Inputs().HasTag(kSelectTag)) { cc->Inputs().Tag(kSelectTag).Set(); @@ -124,10 +124,10 @@ mediapipe::Status SwitchMuxCalculator::GetContract(CalculatorContract* cc) { } cc->SetInputStreamHandler("ImmediateInputStreamHandler"); cc->SetProcessTimestampBounds(true); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SwitchMuxCalculator::Open(CalculatorContext* cc) { +absl::Status SwitchMuxCalculator::Open(CalculatorContext* cc) { channel_index_ = tool::GetChannelIndex(*cc, channel_index_); channel_tags_ = ChannelTags(cc->Inputs().TagMap()); @@ -140,10 +140,10 @@ mediapipe::Status SwitchMuxCalculator::Open(CalculatorContext* cc) { cc->OutputSidePackets().Get(tag, index).Set(input); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SwitchMuxCalculator::Process(CalculatorContext* cc) { +absl::Status SwitchMuxCalculator::Process(CalculatorContext* cc) { // Update the input channel index if specified. channel_index_ = tool::GetChannelIndex(*cc, channel_index_); @@ -156,7 +156,7 @@ mediapipe::Status SwitchMuxCalculator::Process(CalculatorContext* cc) { tool::Relay(input, &output); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/framework/tool/tag_map.cc b/mediapipe/framework/tool/tag_map.cc index 5a39faf81..25fb07f0e 100644 --- a/mediapipe/framework/tool/tag_map.cc +++ b/mediapipe/framework/tool/tag_map.cc @@ -37,7 +37,7 @@ void TagMap::InitializeNames( } } -mediapipe::Status TagMap::Initialize( +absl::Status TagMap::Initialize( const proto_ns::RepeatedPtrField& tag_index_names) { std::map> tag_to_names; for (const auto& tag_index_name : tag_index_names) { @@ -100,10 +100,10 @@ mediapipe::Status TagMap::Initialize( num_entries_ = current_index; InitializeNames(tag_to_names); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TagMap::Initialize(const TagAndNameInfo& info) { +absl::Status TagMap::Initialize(const TagAndNameInfo& info) { if (info.tags.empty()) { if (!info.names.empty()) { mapping_.emplace( @@ -115,7 +115,7 @@ mediapipe::Status TagMap::Initialize(const TagAndNameInfo& info) { } else { std::map> tag_to_names; if (info.tags.size() != info.names.size()) { - return mediapipe::FailedPreconditionError( + return absl::FailedPreconditionError( "Expected info.tags.size() == info.names.size()"); } @@ -139,7 +139,7 @@ mediapipe::Status TagMap::Initialize(const TagAndNameInfo& info) { // Now create the names_ array in the correctly sorted order. InitializeNames(tag_to_names); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } proto_ns::RepeatedPtrField TagMap::CanonicalEntries() const { diff --git a/mediapipe/framework/tool/tag_map.h b/mediapipe/framework/tool/tag_map.h index a6b4a6b2a..ff9f64c12 100644 --- a/mediapipe/framework/tool/tag_map.h +++ b/mediapipe/framework/tool/tag_map.h @@ -53,7 +53,7 @@ class TagMap { // TAG::name. This is the most common usage: // ASSIGN_OR_RETURN(std::shared_ptr tag_map, // tool::TagMap::Create(node.input_streams())); - static mediapipe::StatusOr> Create( + static absl::StatusOr> Create( const proto_ns::RepeatedPtrField& tag_index_names) { std::shared_ptr output(new TagMap()); MP_RETURN_IF_ERROR(output->Initialize(tag_index_names)); @@ -64,7 +64,7 @@ class TagMap { // TODO: Migrate callers and delete this method. ABSL_DEPRECATED( "Use mediapipe::tool::TagMap::Create(tag_index_names) instead.") - static mediapipe::StatusOr> Create( + static absl::StatusOr> Create( const TagAndNameInfo& info) { std::shared_ptr output(new TagMap()); MP_RETURN_IF_ERROR(output->Initialize(info)); @@ -108,12 +108,12 @@ class TagMap { // Initialize the TagMap. Due to only having a factory function for // creation, there is no way for a user to have an uninitialized TagMap. - mediapipe::Status Initialize( + absl::Status Initialize( const proto_ns::RepeatedPtrField& tag_index_names); // Initialize from a TagAndNameInfo. ABSL_DEPRECATED("Use Initialize(tag_index_names) instead.") - mediapipe::Status Initialize(const TagAndNameInfo& info); + absl::Status Initialize(const TagAndNameInfo& info); // Initialize names_ using a map from tag to the names for that tag. void InitializeNames( diff --git a/mediapipe/framework/tool/tag_map_helper.cc b/mediapipe/framework/tool/tag_map_helper.cc index 8213a503f..e13bb0e48 100644 --- a/mediapipe/framework/tool/tag_map_helper.cc +++ b/mediapipe/framework/tool/tag_map_helper.cc @@ -31,7 +31,7 @@ namespace mediapipe { namespace tool { // Create using a vector of TAG::name. -mediapipe::StatusOr> CreateTagMap( +absl::StatusOr> CreateTagMap( const std::vector& tag_index_names) { proto_ns::RepeatedPtrField fields; for (const auto& tag_index_name : tag_index_names) { @@ -41,7 +41,7 @@ mediapipe::StatusOr> CreateTagMap( } // Create using an integer number of entries (for tag ""). -mediapipe::StatusOr> CreateTagMap(int num_entries) { +absl::StatusOr> CreateTagMap(int num_entries) { RET_CHECK_LE(0, num_entries); proto_ns::RepeatedPtrField fields; for (int i = 0; i < num_entries; ++i) { @@ -51,7 +51,7 @@ mediapipe::StatusOr> CreateTagMap(int num_entries) { } // Create using a vector of just tag names. -mediapipe::StatusOr> CreateTagMapFromTags( +absl::StatusOr> CreateTagMapFromTags( const std::vector& tags) { proto_ns::RepeatedPtrField fields; for (int i = 0; i < tags.size(); ++i) { diff --git a/mediapipe/framework/tool/tag_map_helper.h b/mediapipe/framework/tool/tag_map_helper.h index 6d4f67db5..48dd25249 100644 --- a/mediapipe/framework/tool/tag_map_helper.h +++ b/mediapipe/framework/tool/tag_map_helper.h @@ -23,14 +23,14 @@ namespace mediapipe { namespace tool { // Create a TagMap using a vector of TAG::name. -mediapipe::StatusOr> CreateTagMap( +absl::StatusOr> CreateTagMap( const std::vector& tag_index_names); // Create a TagMap using an integer number of entries (for tag ""). -mediapipe::StatusOr> CreateTagMap(int num_entries); +absl::StatusOr> CreateTagMap(int num_entries); // Create a TagMap using a vector of just tag names. -mediapipe::StatusOr> CreateTagMapFromTags( +absl::StatusOr> CreateTagMapFromTags( const std::vector& tags); } // namespace tool diff --git a/mediapipe/framework/tool/tag_map_test.cc b/mediapipe/framework/tool/tag_map_test.cc index 759f4fe70..39b2e1921 100644 --- a/mediapipe/framework/tool/tag_map_test.cc +++ b/mediapipe/framework/tool/tag_map_test.cc @@ -91,9 +91,9 @@ void TestSuccessTagMap(const std::vector& tag_index_names, const std::vector& names) { std::shared_ptr tag_map; if (create_from_tags) { - tag_map = tool::CreateTagMapFromTags(tag_index_names).ValueOrDie(); + tag_map = tool::CreateTagMapFromTags(tag_index_names).value(); } else { - tag_map = tool::CreateTagMap(tag_index_names).ValueOrDie(); + tag_map = tool::CreateTagMap(tag_index_names).value(); } EXPECT_EQ(num_entries, tag_map->NumEntries()) @@ -295,11 +295,11 @@ TEST(TagMapTest, SameAs) { auto statusor_tag_map = tool::CreateTagMapFromTags(std::get<2>(parameters)); MP_ASSERT_OK(statusor_tag_map); - tag_maps.push_back(std::move(statusor_tag_map.ValueOrDie())); + tag_maps.push_back(std::move(statusor_tag_map.value())); } else { auto statusor_tag_map = tool::CreateTagMap(std::get<2>(parameters)); MP_ASSERT_OK(statusor_tag_map); - tag_maps.push_back(std::move(statusor_tag_map.ValueOrDie())); + tag_maps.push_back(std::move(statusor_tag_map.value())); } } @@ -322,11 +322,11 @@ TEST(TagMapTest, SameAs) { // debug std::string each satisfy a matcher. template void TestDebugString( - const mediapipe::StatusOr>& statusor_tag_map, + const absl::StatusOr>& statusor_tag_map, const std::vector& canonical_entries, Matcher short_string_matcher) { MP_ASSERT_OK(statusor_tag_map); - tool::TagMap& tag_map = *statusor_tag_map.ValueOrDie(); + tool::TagMap& tag_map = *statusor_tag_map.value(); std::string debug_string = tag_map.DebugString(); std::string short_string = tag_map.ShortDebugString(); LOG(INFO) << "ShortDebugString:\n" << short_string << "\n"; diff --git a/mediapipe/framework/tool/template_expander.cc b/mediapipe/framework/tool/template_expander.cc index 4c0d8f13b..bd5cd97a0 100644 --- a/mediapipe/framework/tool/template_expander.cc +++ b/mediapipe/framework/tool/template_expander.cc @@ -86,8 +86,8 @@ std::unique_ptr CloneMessage(const MessageLite& message) { // Returns the (tag, index) pairs in a field path. // For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]". -mediapipe::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { - mediapipe::Status status; +absl::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { + absl::Status status; std::vector ids = absl::StrSplit(path, '/'); for (const std::string& id : ids) { if (id.length() > 0) { @@ -98,7 +98,7 @@ mediapipe::Status ProtoPathSplit(const std::string& path, ProtoPath* result) { bool ok = absl::SimpleAtoi(id_pair.first, &tag) && absl::SimpleAtoi(id_pair.second, &index); if (!ok) { - status.Update(mediapipe::InvalidArgumentError(path)); + status.Update(absl::InvalidArgumentError(path)); } result->push_back(std::make_pair(tag, index)); } @@ -146,7 +146,7 @@ int FieldCount(const FieldValue& base, ProtoPath field_path, // The default implementation for the mediapipe template rule interpreter. class TemplateExpanderImpl { public: - explicit TemplateExpanderImpl(std::vector* errors) + explicit TemplateExpanderImpl(std::vector* errors) : errors_(errors) {} // Applies the rules specified in a CalculatorGraphTemplate to a @@ -215,21 +215,21 @@ class TemplateExpanderImpl { } // Return the field values addressed by a template rule. - mediapipe::Status GetBaseValue(const std::string& base_path, - const TemplateExpression& rule, - const FieldValue& output, - std::vector* base) { + absl::Status GetBaseValue(const std::string& base_path, + const TemplateExpression& rule, + const FieldValue& output, + std::vector* base) { if (!rule.has_path()) { base->push_back(output); - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (rule.has_field_value()) { // For a non-repeated field, the field value is stored only in the rule. base->push_back(rule.field_value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } ProtoPath field_path; - mediapipe::Status status = + absl::Status status = ProtoPathSplit(ProtoPathRelative(rule.path(), base_path), &field_path); if (!status.ok()) return status; return ProtoUtilLite::GetFieldRange(output, field_path, 1, @@ -237,12 +237,13 @@ class TemplateExpanderImpl { } // Replace the field values addressed by a template rule. - mediapipe::Status ReplaceBaseValue( - const std::string& base_path, const TemplateExpression& rule, - const std::vector& field_values, FieldValue* output) { + absl::Status ReplaceBaseValue(const std::string& base_path, + const TemplateExpression& rule, + const std::vector& field_values, + FieldValue* output) { if (!rule.has_path()) { *output = field_values[0]; - return mediapipe::OkStatus(); + return absl::OkStatus(); } ProtoPath field_path; RET_CHECK_OK( @@ -252,7 +253,7 @@ class TemplateExpanderImpl { // For a non-repeated field, only one value can be specified. if (!field_values.empty() && FieldCount(*output, field_path, GetFieldType(rule)) > 0) { - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Multiple values specified for non-repeated field: ", rule.path())); } // For a non-repeated field, the field value is stored only in the rule. @@ -267,7 +268,7 @@ class TemplateExpanderImpl { bool ExpandNestedRules(int base_index, const std::string& base_path, const FieldValue& base_message, std::vector* result) { - mediapipe::Status status; + absl::Status status; FieldValue output = base_message; // Evaluate the rules nested below base_path in lexical order. @@ -280,7 +281,7 @@ class TemplateExpanderImpl { if (!status.ok()) break; std::vector values; if (!ExpandTemplateRule(rules[i], base[0], &values)) { - status = mediapipe::InternalError("ExpandTemplateRule failed"); + status = absl::InternalError("ExpandTemplateRule failed"); break; } edits.push_back(values); @@ -348,7 +349,7 @@ class TemplateExpanderImpl { // Retrieve the var param and the range expression. const TemplateExpression& rule = template_rules_.rule().Get(base_index); if (rule.arg().empty() || rule.arg().size() > 2) { - RecordError(mediapipe::InvalidArgumentError( + RecordError(absl::InvalidArgumentError( "Param declaration must specify a parameter name and " "may specify a single default value.")); } @@ -386,8 +387,8 @@ class TemplateExpanderImpl { const TemplateExpression& rule = template_rules_.rule().Get(base_index); TemplateArgument item = EvalExpression(rule); std::vector values; - mediapipe::Status status = AsFieldValues( - std::vector{item}, GetFieldType(rule), &values); + absl::Status status = AsFieldValues(std::vector{item}, + GetFieldType(rule), &values); if (!status.ok()) { RecordError(status); return false; @@ -400,8 +401,7 @@ class TemplateExpanderImpl { TemplateArgument EvalParam(const TemplateExpression& expr) { TemplateArgument* result = GetItem(&environment_, expr.param()); if (result == nullptr) { - RecordError( - mediapipe::NotFoundError(absl::StrCat("param: ", expr.param()))); + RecordError(absl::NotFoundError(absl::StrCat("param: ", expr.param()))); return AsArgument(0.0); } return *result; @@ -412,7 +412,7 @@ class TemplateExpanderImpl { TemplateArgument lhs = EvalExpression(expr.arg(0)); TemplateArgument* result = GetItem(lhs.mutable_dict(), expr.arg(1).param()); if (result == nullptr) { - RecordError(mediapipe::NotFoundError( + RecordError(absl::NotFoundError( absl::StrCat("param field: ", expr.arg(1).param()))); return AsArgument(0.0); } @@ -427,7 +427,7 @@ class TemplateExpanderImpl { } if (value.has_str()) { if (!absl::SimpleAtod(value.str(), &result)) { - RecordError(mediapipe::InvalidArgumentError(value.str())); + RecordError(absl::InvalidArgumentError(value.str())); } } return result; @@ -452,7 +452,7 @@ class TemplateExpanderImpl { return value.num() != 0; } else if (value.has_str()) { if (!absl::SimpleAtob(value.str(), &result)) { - RecordError(mediapipe::InvalidArgumentError(value.str())); + RecordError(absl::InvalidArgumentError(value.str())); } } return result; @@ -462,7 +462,7 @@ class TemplateExpanderImpl { TemplateArgument AsDict(const std::vector& args) { TemplateArgument result; if (args.size() % 2 != 0) { - RecordError(mediapipe::InvalidArgumentError(absl::StrCat( + RecordError(absl::InvalidArgumentError(absl::StrCat( "Dict requires an even number of arguments, got: ", args.size()))); return result; } @@ -595,9 +595,9 @@ class TemplateExpanderImpl { } // Convert between a proto feild value and a template argument. - mediapipe::Status AsFieldValues(const std::vector& args, - FieldType field_type, - std::vector* result) { + absl::Status AsFieldValues(const std::vector& args, + FieldType field_type, + std::vector* result) { for (int i = 0; i < args.size(); ++i) { if (args[i].has_dict()) { FieldValue dict_bytes; @@ -613,11 +613,11 @@ class TemplateExpanderImpl { result->push_back(r[0]); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Record a Status if it indicates an error. - void RecordError(const mediapipe::Status& status) { + void RecordError(const absl::Status& status) { if (!status.ok()) { errors_->push_back(status); } @@ -631,23 +631,23 @@ class TemplateExpanderImpl { TemplateDict environment_; // List of errors found in template parameters. - std::vector* errors_; + std::vector* errors_; }; TemplateExpander::TemplateExpander() {} // Expands template rules within a proto message. // Replaces template rules with expanded sub-messages. -mediapipe::Status TemplateExpander::ExpandTemplates( +absl::Status TemplateExpander::ExpandTemplates( const TemplateDict& args, const CalculatorGraphTemplate& templ, CalculatorGraphConfig* output) { errors_.clear(); TemplateExpanderImpl expander(&errors_); if (!expander.ExpandTemplates(args, templ, output)) { - errors_.push_back(mediapipe::InternalError("ExpandTemplates failed")); + errors_.push_back(absl::InternalError("ExpandTemplates failed")); } - mediapipe::Status status; - for (const mediapipe::Status& error : errors_) { + absl::Status status; + for (const absl::Status& error : errors_) { LOG(ERROR) << error; status.Update(error); } diff --git a/mediapipe/framework/tool/template_expander.h b/mediapipe/framework/tool/template_expander.h index f62e5b747..5e6696f62 100644 --- a/mediapipe/framework/tool/template_expander.h +++ b/mediapipe/framework/tool/template_expander.h @@ -33,13 +33,13 @@ class TemplateExpander { // Applies the rules specified in a CalculatorGraphTemplate to a // CalculatorGraphConfig. Each rule references a nested field-value or // message and defines zero or more replacement values for it. - mediapipe::Status ExpandTemplates(const TemplateDict& args, - const CalculatorGraphTemplate& templ, - CalculatorGraphConfig* output); + absl::Status ExpandTemplates(const TemplateDict& args, + const CalculatorGraphTemplate& templ, + CalculatorGraphConfig* output); private: // List of errors found in template parameters. - std::vector errors_; + std::vector errors_; }; } // namespace tool diff --git a/mediapipe/framework/tool/template_parser.cc b/mediapipe/framework/tool/template_parser.cc index b6d2fb371..1897dd292 100644 --- a/mediapipe/framework/tool/template_parser.cc +++ b/mediapipe/framework/tool/template_parser.cc @@ -1332,20 +1332,20 @@ bool IsFunctionOperator(const std::string& token) { // by the DynamicMessageFactory ("output"). These two Messages have // different Descriptors so Message::MergeFrom cannot be applied directly, // but they are expected to be equivalent. -mediapipe::Status MergeFields(const Message& source, Message* dest) { +absl::Status MergeFields(const Message& source, Message* dest) { std::unique_ptr temp(dest->New()); std::string temp_str; RET_CHECK(TextFormat::PrintToString(source, &temp_str)); RET_CHECK(TextFormat::ParseFromString(temp_str, temp.get())); dest->MergeFrom(*temp); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Returns the (tag, index) pairs in a field path. // For example, returns {{1, 1}, {2, 1}, {3, 1}} for path "/1[1]/2[1]/3[1]". -mediapipe::Status ProtoPathSplit(const std::string& path, - ProtoUtilLite::ProtoPath* result) { - mediapipe::Status status; +absl::Status ProtoPathSplit(const std::string& path, + ProtoUtilLite::ProtoPath* result) { + absl::Status status; std::vector ids = absl::StrSplit(path, '/'); for (const std::string& id : ids) { if (id.length() > 0) { @@ -1356,7 +1356,7 @@ mediapipe::Status ProtoPathSplit(const std::string& path, bool ok = absl::SimpleAtoi(id_pair.first, &tag) && absl::SimpleAtoi(id_pair.second, &index); if (!ok) { - status.Update(mediapipe::InvalidArgumentError(path)); + status.Update(absl::InvalidArgumentError(path)); } result->push_back(std::make_pair(tag, index)); } diff --git a/mediapipe/framework/tool/test_util.cc b/mediapipe/framework/tool/test_util.cc index e78c2cf85..15e100ee0 100644 --- a/mediapipe/framework/tool/test_util.cc +++ b/mediapipe/framework/tool/test_util.cc @@ -27,17 +27,10 @@ #include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/no_destructor.h" #include "mediapipe/framework/port/advanced_proto_inc.h" +#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/proto_ns.h" -#ifdef __APPLE__ -#include -#elif defined(__ANDROID__) -#include "mediapipe/util/android/file/base/helpers.h" -#else -#include "mediapipe/framework/port/file_helpers.h" -#endif - namespace mediapipe { namespace { @@ -235,28 +228,18 @@ bool CompareImageFrames(const ImageFrame& image1, const ImageFrame& image2, } std::string GetTestRootDir() { -#ifdef __APPLE__ - char path[1024]; - CFURLRef bundle_url = CFBundleCopyBundleURL(CFBundleGetMainBundle()); - Boolean success = CFURLGetFileSystemRepresentation( - bundle_url, true, reinterpret_cast(path), sizeof(path)); - CHECK(success); - CFRelease(bundle_url); - return path; -#elif defined(__ANDROID__) +#if defined(__ANDROID__) char path[1024]; char* ptr = getcwd(path, sizeof(path)); CHECK_EQ(ptr, path); return path; #else return ::mediapipe::file::JoinPath(std::getenv("TEST_SRCDIR"), "mediapipe"); -#endif // defined(__APPLE__) +#endif // defined(__ANDROID__) } std::string GetTestDataDir(const std::string& package_base_path) { -#ifdef __APPLE__ - return ::mediapipe::file::JoinPath(GetTestRootDir(), "testdata/"); -#elif defined(__ANDROID__) +#if defined(__ANDROID__) std::string data_dir = GetTestRootDir(); std::string binary_dir = GetBinaryDirectory(); // In Mobile Harness, the cwd is "/" and the run dir is "/data/local/tmp". diff --git a/mediapipe/framework/tool/text_to_binary_graph.cc b/mediapipe/framework/tool/text_to_binary_graph.cc index 1b4d53b01..4282f748e 100644 --- a/mediapipe/framework/tool/text_to_binary_graph.cc +++ b/mediapipe/framework/tool/text_to_binary_graph.cc @@ -41,9 +41,8 @@ DEFINE_string( namespace mediapipe { -mediapipe::Status ReadProto(proto_ns::io::ZeroCopyInputStream* in, - bool read_text, const std::string& source, - proto_ns::Message* result) { +absl::Status ReadProto(proto_ns::io::ZeroCopyInputStream* in, bool read_text, + const std::string& source, proto_ns::Message* result) { if (read_text) { RET_CHECK(proto_ns::TextFormat::Parse(in, result)) << "could not parse text proto: " << source; @@ -51,12 +50,12 @@ mediapipe::Status ReadProto(proto_ns::io::ZeroCopyInputStream* in, RET_CHECK(result->ParseFromZeroCopyStream(in)) << "could not parse binary proto: " << source; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status WriteProto(const proto_ns::Message& message, bool write_text, - const std::string& dest, - proto_ns::io::ZeroCopyOutputStream* out) { +absl::Status WriteProto(const proto_ns::Message& message, bool write_text, + const std::string& dest, + proto_ns::io::ZeroCopyOutputStream* out) { if (write_text) { RET_CHECK(proto_ns::TextFormat::Print(message, out)) << "could not write text proto to: " << dest; @@ -64,21 +63,21 @@ mediapipe::Status WriteProto(const proto_ns::Message& message, bool write_text, RET_CHECK(message.SerializeToZeroCopyStream(out)) << "could not write binary proto to: " << dest; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Read a proto from a text or a binary file. -mediapipe::Status ReadFile(const std::string& proto_source, bool read_text, - proto_ns::Message* result) { +absl::Status ReadFile(const std::string& proto_source, bool read_text, + proto_ns::Message* result) { std::ifstream ifs(proto_source); proto_ns::io::IstreamInputStream in(&ifs); MP_RETURN_IF_ERROR(ReadProto(&in, read_text, proto_source, result)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Write a proto to a text or a binary file. -mediapipe::Status WriteFile(const std::string& proto_output, bool write_text, - const proto_ns::Message& message) { +absl::Status WriteFile(const std::string& proto_output, bool write_text, + const proto_ns::Message& message) { std::ios_base::openmode mode = std::ios_base::out | std::ios_base::trunc; if (!write_text) { mode |= std::ios_base::binary; @@ -86,7 +85,7 @@ mediapipe::Status WriteFile(const std::string& proto_output, bool write_text, std::ofstream ofs(proto_output, mode); proto_ns::io::OstreamOutputStream out(&ofs); MP_RETURN_IF_ERROR(WriteProto(message, write_text, proto_output, &out)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe @@ -96,20 +95,22 @@ int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); // Validate command line options. - mediapipe::Status status; - if (FLAGS_proto_source.empty()) { + absl::Status status; + if (absl::GetFlag(FLAGS_proto_source).empty()) { status.Update( - mediapipe::InvalidArgumentError("--proto_source must be specified")); + absl::InvalidArgumentError("--proto_source must be specified")); } - if (FLAGS_proto_output.empty()) { + if (absl::GetFlag(FLAGS_proto_output).empty()) { status.Update( - mediapipe::InvalidArgumentError("--proto_output must be specified")); + absl::InvalidArgumentError("--proto_output must be specified")); } if (!status.ok()) { return EXIT_FAILURE; } mediapipe::CalculatorGraphConfig config; - EXIT_IF_ERROR(mediapipe::ReadFile(FLAGS_proto_source, true, &config)); - EXIT_IF_ERROR(mediapipe::WriteFile(FLAGS_proto_output, false, config)); + EXIT_IF_ERROR( + mediapipe::ReadFile(absl::GetFlag(FLAGS_proto_source), true, &config)); + EXIT_IF_ERROR( + mediapipe::WriteFile(absl::GetFlag(FLAGS_proto_output), false, config)); return EXIT_SUCCESS; } diff --git a/mediapipe/framework/tool/validate.cc b/mediapipe/framework/tool/validate.cc index c15a268eb..8db6d6278 100644 --- a/mediapipe/framework/tool/validate.cc +++ b/mediapipe/framework/tool/validate.cc @@ -26,7 +26,7 @@ namespace mediapipe { namespace tool { -mediapipe::Status ValidateInput(const InputCollection& input_collection) { +absl::Status ValidateInput(const InputCollection& input_collection) { if (!input_collection.name().empty()) { MP_RETURN_IF_ERROR(tool::ValidateName(input_collection.name())).SetPrepend() << "InputCollection " << input_collection.name() @@ -34,14 +34,14 @@ mediapipe::Status ValidateInput(const InputCollection& input_collection) { } if (input_collection.input_type() <= InputCollection::UNKNOWN || input_collection.input_type() >= InputCollection::INVALID_UPPER_BOUND) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "InputCollection must specify a valid input_type."); } if (input_collection.file_name().empty()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "InputCollection must specify a file_name."); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace tool diff --git a/mediapipe/framework/tool/validate.h b/mediapipe/framework/tool/validate.h index 545e7387c..2311e1ad9 100644 --- a/mediapipe/framework/tool/validate.h +++ b/mediapipe/framework/tool/validate.h @@ -24,12 +24,12 @@ namespace mediapipe { namespace tool { -// Returns mediapipe::OkStatus() if the InputCollection is valid. An input +// Returns absl::OkStatus() if the InputCollection is valid. An input // collection is invalid if it does not have the proper fields set // depending on what its input_type field is. Furthermore, if it uses // INLINE, then the number of value fields in each inputs must match // the number of input_side_packet_name fields. -mediapipe::Status ValidateInput(const InputCollection& input); +absl::Status ValidateInput(const InputCollection& input); } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/validate_name.cc b/mediapipe/framework/tool/validate_name.cc index e98fd3bf9..d87a3063e 100644 --- a/mediapipe/framework/tool/validate_name.cc +++ b/mediapipe/framework/tool/validate_name.cc @@ -41,7 +41,7 @@ namespace tool { #define MEDIAPIPE_TAG_INDEX_REGEX \ "(" MEDIAPIPE_TAG_REGEX ")?(:" MEDIAPIPE_NUMBER_REGEX ")?" -mediapipe::Status GetTagAndNameInfo( +absl::Status GetTagAndNameInfo( const proto_ns::RepeatedPtrField& tags_and_names, TagAndNameInfo* info) { RET_CHECK(info); @@ -59,15 +59,15 @@ mediapipe::Status GetTagAndNameInfo( if (info->tags.size() > 0 && info->names.size() != info->tags.size()) { info->tags.clear(); info->names.clear(); - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Each set of names must use exclusively either tags or indexes. " "Encountered: \"", absl::StrJoin(tags_and_names, "\", \""), "\"")); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status SetFromTagAndNameInfo( +absl::Status SetFromTagAndNameInfo( const TagAndNameInfo& info, proto_ns::RepeatedPtrField* tags_and_names) { tags_and_names->Clear(); @@ -88,52 +88,52 @@ mediapipe::Status SetFromTagAndNameInfo( *tags_and_names->Add() = absl::StrCat(info.tags[i], ":", info.names[i]); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidateName(const std::string& name) { +absl::Status ValidateName(const std::string& name) { return name.length() > 0 && (name[0] == '_' || islower(name[0])) && std::all_of(name.begin() + 1, name.end(), [](char c) { return c == '_' || isdigit(c) || islower(c); }) - ? mediapipe::OkStatus() - : mediapipe::InvalidArgumentError(absl::StrCat( + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::StrCat( "Name \"", absl::CEscape(name), "\" does not match \"" MEDIAPIPE_NAME_REGEX "\".")); } -mediapipe::Status ValidateNumber(const std::string& number) { +absl::Status ValidateNumber(const std::string& number) { return (number.length() == 1 && isdigit(number[0])) || (number.length() > 1 && isdigit(number[0]) && number[0] != '0' && std::all_of(number.begin() + 1, number.end(), [](char c) { return isdigit(c); })) - ? mediapipe::OkStatus() - : mediapipe::InvalidArgumentError(absl::StrCat( + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::StrCat( "Number \"", absl::CEscape(number), "\" does not match \"" MEDIAPIPE_NUMBER_REGEX "\".")); } -mediapipe::Status ValidateTag(const std::string& tag) { +absl::Status ValidateTag(const std::string& tag) { return tag.length() > 0 && (tag[0] == '_' || isupper(tag[0])) && std::all_of(tag.begin() + 1, tag.end(), [](char c) { return c == '_' || isdigit(c) || isupper(c); }) - ? mediapipe::OkStatus() - : mediapipe::InvalidArgumentError(absl::StrCat( + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::StrCat( "Tag \"", absl::CEscape(tag), "\" does not match \"" MEDIAPIPE_TAG_REGEX "\".")); } -mediapipe::Status ParseTagAndName(const std::string& tag_and_name, - std::string* tag, std::string* name) { +absl::Status ParseTagAndName(const std::string& tag_and_name, std::string* tag, + std::string* name) { // An optional tag and colon, followed by a name. RET_CHECK(tag); RET_CHECK(name); - mediapipe::Status tag_status = mediapipe::OkStatus(); - mediapipe::Status name_status = mediapipe::UnknownError(""); + absl::Status tag_status = absl::OkStatus(); + absl::Status name_status = absl::UnknownError(""); int name_index = 0; std::vector v = absl::StrSplit(tag_and_name, ':'); if (v.size() == 1) { @@ -144,11 +144,11 @@ mediapipe::Status ParseTagAndName(const std::string& tag_and_name, name_status = ValidateName(v[1]); name_index = 1; } - if (name_index == -1 || tag_status != mediapipe::OkStatus() || - name_status != mediapipe::OkStatus()) { + if (name_index == -1 || tag_status != absl::OkStatus() || + name_status != absl::OkStatus()) { tag->clear(); name->clear(); - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("\"tag and name\" is invalid, \"", tag_and_name, "\" does not match " "\"" MEDIAPIPE_TAG_AND_NAME_REGEX @@ -156,20 +156,20 @@ mediapipe::Status ParseTagAndName(const std::string& tag_and_name, } *tag = name_index == 1 ? v[0] : ""; *name = v[name_index]; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ParseTagIndexName(const std::string& tag_index_name, - std::string* tag, int* index, - std::string* name) { +absl::Status ParseTagIndexName(const std::string& tag_index_name, + std::string* tag, int* index, + std::string* name) { // An optional tag and colon, an optional index and color, followed by a name. RET_CHECK(tag); RET_CHECK(index); RET_CHECK(name); - mediapipe::Status tag_status = mediapipe::OkStatus(); - mediapipe::Status number_status = mediapipe::OkStatus(); - mediapipe::Status name_status = mediapipe::UnknownError(""); + absl::Status tag_status = absl::OkStatus(); + absl::Status number_status = absl::OkStatus(); + absl::Status name_status = absl::UnknownError(""); int name_index = -1; int the_index = 0; std::vector v = absl::StrSplit(tag_index_name, ':'); @@ -195,7 +195,7 @@ mediapipe::Status ParseTagIndexName(const std::string& tag_index_name, } // else omitted, name_index == -1, triggering error. if (name_index == -1 || !tag_status.ok() || !number_status.ok() || !name_status.ok()) { - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "TAG:index:name is invalid, \"", tag_index_name, "\" does not match " "\"" MEDIAPIPE_TAG_INDEX_NAME_REGEX @@ -204,16 +204,16 @@ mediapipe::Status ParseTagIndexName(const std::string& tag_index_name, *tag = name_index != 0 ? v[0] : ""; *index = the_index; *name = v[name_index]; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ParseTagIndex(const std::string& tag_index, std::string* tag, - int* index) { +absl::Status ParseTagIndex(const std::string& tag_index, std::string* tag, + int* index) { RET_CHECK(tag); RET_CHECK(index); - mediapipe::Status tag_status = mediapipe::OkStatus(); - mediapipe::Status number_status = mediapipe::OkStatus(); + absl::Status tag_status = absl::OkStatus(); + absl::Status number_status = absl::OkStatus(); int the_index = -1; std::vector v = absl::StrSplit(tag_index, ':'); if (v.size() == 1) { @@ -234,14 +234,14 @@ mediapipe::Status ParseTagIndex(const std::string& tag_index, std::string* tag, } } // else omitted, the_index == -1, triggering error. if (the_index == -1 || !tag_status.ok() || !number_status.ok()) { - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "TAG:index is invalid, \"", tag_index, "\" does not match " "\"" MEDIAPIPE_TAG_INDEX_REGEX "\" (examples: \"TAG\" \"VIDEO:2\").")); } *tag = v[0]; *index = the_index; - return mediapipe::OkStatus(); + return absl::OkStatus(); } #undef MEDIAPIPE_NAME_REGEX diff --git a/mediapipe/framework/tool/validate_name.h b/mediapipe/framework/tool/validate_name.h index 1e9299f75..8a21f1fbb 100644 --- a/mediapipe/framework/tool/validate_name.h +++ b/mediapipe/framework/tool/validate_name.h @@ -52,7 +52,7 @@ ABSL_DEPRECATED( "support the TAG:INDEX:name notation. You can use Create() to create the " "tag map, and then Names(), Mapping(), and other methods to access the " "tag, index and name information.") -mediapipe::Status GetTagAndNameInfo( +absl::Status GetTagAndNameInfo( const proto_ns::RepeatedPtrField& tags_and_names, TagAndNameInfo* info); @@ -62,7 +62,7 @@ ABSL_DEPRECATED( "Prefer using mediapipe::tool::TagMap instead, since this method does not " "support the TAG:INDEX:name notation. You can use CanonicalEntries() to " "translate a tag map to a RepeatedPtrField of tag and names.") -mediapipe::Status SetFromTagAndNameInfo( +absl::Status SetFromTagAndNameInfo( const TagAndNameInfo& info, proto_ns::RepeatedPtrField* tags_and_names); @@ -76,17 +76,17 @@ mediapipe::Status SetFromTagAndNameInfo( // trainer/calculator names. // (3) Because input side packet names end up in model directory names, // where lower case naming is the norm. -mediapipe::Status ValidateName(const std::string& name); +absl::Status ValidateName(const std::string& name); // The std::string is a valid tag name. Tags use only upper case letters, // numbers, and underscores. -mediapipe::Status ValidateTag(const std::string& tag); +absl::Status ValidateTag(const std::string& tag); // Parse a "Tag and Name" std::string into a tag and a name. // The format is an optional tag and colon, followed by a name. // Example 1: "VIDEO:frames2" -> tag: "VIDEO", name: "frames2" // Example 2: "video_frames_1" -> tag: "", name: "video_frames_1" -mediapipe::Status ParseTagAndName(const std::string& tag_and_name, - std::string* tag, std::string* name); +absl::Status ParseTagAndName(const std::string& tag_and_name, std::string* tag, + std::string* name); // Parse a generic TAG:index:name std::string. The format is a tag, then an // index, then a name. The tag and index are optional. If the index @@ -96,9 +96,8 @@ mediapipe::Status ParseTagAndName(const std::string& tag_and_name, // "VIDEO:frames2" -> tag: "VIDEO", index: 0, name: "frames2" // "VIDEO:1:frames" -> tag: "VIDEO", index: 1, name: "frames" // "raw_frames" -> tag: "", index: -1, name: "raw_frames" -mediapipe::Status ParseTagIndexName(const std::string& tag_and_name, - std::string* tag, int* index, - std::string* name); +absl::Status ParseTagIndexName(const std::string& tag_and_name, + std::string* tag, int* index, std::string* name); // Parse a generic TAG:index std::string. The format is a tag, then an index // with both being optional. If the tag is missing it is assumed to be @@ -109,8 +108,8 @@ mediapipe::Status ParseTagIndexName(const std::string& tag_and_name, // "VIDEO:1" -> tag: "VIDEO", index: 1 // ":2" -> tag: "", index: 2 // "" -> tag: "", index: 0 -mediapipe::Status ParseTagIndex(const std::string& tag_and_index, - std::string* tag, int* index); +absl::Status ParseTagIndex(const std::string& tag_and_index, std::string* tag, + int* index); } // namespace tool } // namespace mediapipe diff --git a/mediapipe/framework/tool/validate_type.cc b/mediapipe/framework/tool/validate_type.cc index 0e83e9caa..ed5b42e95 100644 --- a/mediapipe/framework/tool/validate_type.cc +++ b/mediapipe/framework/tool/validate_type.cc @@ -39,7 +39,7 @@ namespace mediapipe { namespace tool { -mediapipe::Status RunGeneratorFillExpectations( +absl::Status RunGeneratorFillExpectations( const PacketGeneratorConfig& input_config, const std::string& package) { // TODO Remove conversion after everyone uses input/output // side packet. @@ -65,7 +65,7 @@ mediapipe::Status RunGeneratorFillExpectations( } // Check that everything got initialized. - std::vector statuses; + std::vector statuses; statuses.push_back(ValidatePacketTypeSet(contract.InputSidePackets())); statuses.push_back(ValidatePacketTypeSet(contract.OutputSidePackets())); return tool::CombinedStatus( @@ -73,7 +73,7 @@ mediapipe::Status RunGeneratorFillExpectations( statuses); } -mediapipe::Status RunGenerateAndValidateTypes( +absl::Status RunGenerateAndValidateTypes( const std::string& packet_generator_name, const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets, @@ -96,7 +96,7 @@ mediapipe::Status RunGenerateAndValidateTypes( .SetPrepend() << packet_generator_name << "::FillExpectations failed: "; // Check that the types were filled well. - std::vector statuses; + std::vector statuses; statuses.push_back(ValidatePacketTypeSet(input_side_packet_types)); statuses.push_back(ValidatePacketTypeSet(output_side_packet_types)); MP_RETURN_IF_ERROR(tool::CombinedStatus( @@ -119,7 +119,7 @@ mediapipe::Status RunGenerateAndValidateTypes( << packet_generator_name << "::FillExpectations expected different " "output type than those produced: "; - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace tool diff --git a/mediapipe/framework/tool/validate_type.h b/mediapipe/framework/tool/validate_type.h index 5f87f40f5..a3d020e6f 100644 --- a/mediapipe/framework/tool/validate_type.h +++ b/mediapipe/framework/tool/validate_type.h @@ -26,14 +26,14 @@ namespace mediapipe { namespace tool { // Equivalent functions for PacketGenerators. -mediapipe::Status RunGeneratorFillExpectations( +absl::Status RunGeneratorFillExpectations( const PacketGeneratorConfig& config, const std::string& package = "mediapipe"); // Run PacketGenerator::Generate() on the given generator, options, // and inputs to produce outputs. Validate the types of the inputs and // outputs using PacketGenerator::FillExpectations. -mediapipe::Status RunGenerateAndValidateTypes( +absl::Status RunGenerateAndValidateTypes( const std::string& packet_generator_name, const PacketGeneratorOptions& extendable_options, const PacketSet& input_side_packets, PacketSet* output_side_packets, diff --git a/mediapipe/framework/type_map.h b/mediapipe/framework/type_map.h index b37522cb9..7aac86ac7 100644 --- a/mediapipe/framework/type_map.h +++ b/mediapipe/framework/type_map.h @@ -79,9 +79,9 @@ class HolderBase; // These functions use HolderBase to hide the type T from the function // definition. This allows these functions to be placed into an untyped // struct in the map of MediaPipeTypeData objects. -using SerializeFn = std::function; -using DeserializeFn = std::function* holder_base)>; @@ -294,18 +294,18 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string); // seperated by double colons. // // Example 1: register type with non-std::string proxy. -// mediapipe::Status ToProxyFn( +// absl::Status ToProxyFn( // const ClassType& obj, ProxyType* proxy) // { // ... -// return mediapipe::OkStatus(); +// return absl::OkStatus(); // } // -// mediapipe::Status FromProxyFn( +// absl::Status FromProxyFn( // const ProxyType& proxy, ClassType* obj) // { // ... -// return mediapipe::OkStatus(); +// return absl::OkStatus(); // } // // MEDIAPIPE_REGISTER_TYPE_WITH_PROXY( @@ -316,16 +316,16 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string); // ProxyType>, ToProxyFn, FromProxyFn); // // Example 2: register type with std::string proxy. -// mediapipe::Status ToProxyFn(const ClassType& obj, std::string* encoding) +// absl::Status ToProxyFn(const ClassType& obj, std::string* encoding) // { // ... -// return mediapipe::OkStatus(); +// return absl::OkStatus(); // } // -// mediapipe::Status FromProxyFn( +// absl::Status FromProxyFn( // const ProxyType& proxy, std::string* encoding) { // ... -// return mediapipe::OkStatus(); +// return absl::OkStatus(); // } // // MEDIAPIPE_REGISTER_TYPE_WITH_PROXY( diff --git a/mediapipe/framework/validated_graph_config.cc b/mediapipe/framework/validated_graph_config.cc index cd2608738..559a4a53c 100644 --- a/mediapipe/framework/validated_graph_config.cc +++ b/mediapipe/framework/validated_graph_config.cc @@ -114,13 +114,12 @@ std::string DebugName(const CalculatorGraphConfig& config, // // Converts the graph-level num_threads field to an ExecutorConfig for the // default executor with the executor type unspecified. -mediapipe::Status AddPredefinedExecutorConfigs( - CalculatorGraphConfig* graph_config) { +absl::Status AddPredefinedExecutorConfigs(CalculatorGraphConfig* graph_config) { bool has_default_executor_config = false; for (ExecutorConfig& executor_config : *graph_config->mutable_executor()) { if (executor_config.name().empty()) { if (graph_config->num_threads()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "ExecutorConfig for the default executor and the graph-level " "num_threads field should not both be specified."); } @@ -137,10 +136,10 @@ mediapipe::Status AddPredefinedExecutorConfigs( graph_config->clear_num_threads(); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status PerformBasicTransforms( +absl::Status PerformBasicTransforms( const CalculatorGraphConfig& input_graph_config, const GraphRegistry* graph_registry, CalculatorGraphConfig* output_graph_config) { @@ -164,7 +163,7 @@ mediapipe::Status PerformBasicTransforms( } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace @@ -187,7 +186,7 @@ std::string NodeTypeInfo::NodeTypeToString(NodeType node_type) { << static_cast(node_type); } -mediapipe::Status NodeTypeInfo::Initialize( +absl::Status NodeTypeInfo::Initialize( const ValidatedGraphConfig& validated_graph, const CalculatorGraphConfig::Node& node, int node_index) { node_.type = NodeType::CALCULATOR; @@ -245,8 +244,8 @@ mediapipe::Status NodeTypeInfo::Initialize( << node_class << ": "; // Validate result of FillExpectations or GetContract. - std::vector statuses; - mediapipe::Status status = ValidatePacketTypeSet(contract_.Inputs()); + std::vector statuses; + absl::Status status = ValidatePacketTypeSet(contract_.Inputs()); if (!status.ok()) { statuses.push_back( mediapipe::StatusBuilder(std::move(status), MEDIAPIPE_LOC).SetPrepend() @@ -270,10 +269,10 @@ mediapipe::Status NodeTypeInfo::Initialize( " failed to validate: "), statuses); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status NodeTypeInfo::Initialize( +absl::Status NodeTypeInfo::Initialize( const ValidatedGraphConfig& validated_graph, const PacketGeneratorConfig& node, int node_index) { node_.type = NodeType::PACKET_GENERATOR; @@ -297,9 +296,8 @@ mediapipe::Status NodeTypeInfo::Initialize( } // Validate result of FillExpectations. - std::vector statuses; - mediapipe::Status status = - ValidatePacketTypeSet(contract_.InputSidePackets()); + std::vector statuses; + absl::Status status = ValidatePacketTypeSet(contract_.InputSidePackets()); if (!status.ok()) { statuses.push_back(std::move(status)); } @@ -312,10 +310,10 @@ mediapipe::Status NodeTypeInfo::Initialize( absl::StrCat(node_class, "::FillExpectations failed to validate: "), statuses); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status NodeTypeInfo::Initialize( +absl::Status NodeTypeInfo::Initialize( const ValidatedGraphConfig& validated_graph, const StatusHandlerConfig& node, int node_index) { node_.type = NodeType::STATUS_HANDLER; @@ -341,10 +339,10 @@ mediapipe::Status NodeTypeInfo::Initialize( MP_RETURN_IF_ERROR(ValidatePacketTypeSet(contract_.InputSidePackets())) .SetPrepend() << node_class << "::FillExpectations failed to validate: "; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::Initialize( +absl::Status ValidatedGraphConfig::Initialize( const CalculatorGraphConfig& input_config, const GraphRegistry* graph_registry) { RET_CHECK(!initialized_) @@ -426,20 +424,20 @@ mediapipe::Status ValidatedGraphConfig::Initialize( << config_.DebugString(); #endif initialized_ = true; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::Initialize( +absl::Status ValidatedGraphConfig::Initialize( const std::string& graph_type, const Subgraph::SubgraphOptions* options, const GraphRegistry* graph_registry) { graph_registry = graph_registry ? graph_registry : &GraphRegistry::global_graph_registry; auto status_or_config = graph_registry->CreateByName("", graph_type, options); MP_RETURN_IF_ERROR(status_or_config.status()); - return Initialize(status_or_config.ValueOrDie(), graph_registry); + return Initialize(status_or_config.value(), graph_registry); } -mediapipe::Status ValidatedGraphConfig::Initialize( +absl::Status ValidatedGraphConfig::Initialize( const std::vector& input_configs, const std::vector& input_templates, const std::string& graph_type, const Subgraph::SubgraphOptions* options) { @@ -453,12 +451,12 @@ mediapipe::Status ValidatedGraphConfig::Initialize( return Initialize(graph_type, options, &graph_registry); } -mediapipe::Status ValidatedGraphConfig::InitializeCalculatorInfo() { - std::vector statuses; +absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() { + std::vector statuses; calculators_.reserve(config_.node_size()); for (const auto& node : config_.node()) { calculators_.emplace_back(); - mediapipe::Status status = + absl::Status status = calculators_.back().Initialize(*this, node, calculators_.size() - 1); if (!status.ok()) { statuses.push_back(status); @@ -468,12 +466,12 @@ mediapipe::Status ValidatedGraphConfig::InitializeCalculatorInfo() { statuses); } -mediapipe::Status ValidatedGraphConfig::InitializeGeneratorInfo() { - std::vector statuses; +absl::Status ValidatedGraphConfig::InitializeGeneratorInfo() { + std::vector statuses; generators_.reserve(config_.packet_generator_size()); for (const auto& node : config_.packet_generator()) { generators_.emplace_back(); - mediapipe::Status status = + absl::Status status = generators_.back().Initialize(*this, node, generators_.size() - 1); if (!status.ok()) { statuses.push_back(status); @@ -483,12 +481,12 @@ mediapipe::Status ValidatedGraphConfig::InitializeGeneratorInfo() { statuses); } -mediapipe::Status ValidatedGraphConfig::InitializeStatusHandlerInfo() { - std::vector statuses; +absl::Status ValidatedGraphConfig::InitializeStatusHandlerInfo() { + std::vector statuses; status_handlers_.reserve(config_.status_handler_size()); for (const auto& node : config_.status_handler()) { status_handlers_.emplace_back(); - mediapipe::Status status = status_handlers_.back().Initialize( + absl::Status status = status_handlers_.back().Initialize( *this, node, status_handlers_.size() - 1); if (!status.ok()) { statuses.push_back(status); @@ -498,7 +496,7 @@ mediapipe::Status ValidatedGraphConfig::InitializeStatusHandlerInfo() { statuses); } -mediapipe::Status ValidatedGraphConfig::InitializeSidePacketInfo( +absl::Status ValidatedGraphConfig::InitializeSidePacketInfo( bool* need_sorting_ptr) { for (NodeTypeInfo* node_type_info : sorted_nodes_) { MP_RETURN_IF_ERROR(AddInputSidePacketsForNode(node_type_info)); @@ -506,7 +504,7 @@ mediapipe::Status ValidatedGraphConfig::InitializeSidePacketInfo( AddOutputSidePacketsForNode(node_type_info, need_sorting_ptr)); } if (need_sorting_ptr && *need_sorting_ptr) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } for (int index = 0; index < config_.status_handler_size(); ++index) { NodeTypeInfo* node_type_info = &status_handlers_[index]; @@ -515,10 +513,10 @@ mediapipe::Status ValidatedGraphConfig::InitializeSidePacketInfo( RET_CHECK_EQ(node_type_info->Node().index, index); MP_RETURN_IF_ERROR(AddInputSidePacketsForNode(node_type_info)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::AddInputSidePacketsForNode( +absl::Status ValidatedGraphConfig::AddInputSidePacketsForNode( NodeTypeInfo* node_type_info) { node_type_info->SetInputSidePacketBaseIndex(input_side_packets_.size()); const tool::TagMap& tag_map = @@ -541,10 +539,10 @@ mediapipe::Status ValidatedGraphConfig::AddInputSidePacketsForNode( edge_info.name = name; edge_info.packet_type = &node_type_info->InputSidePacketTypes().Get(id); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::AddOutputSidePacketsForNode( +absl::Status ValidatedGraphConfig::AddOutputSidePacketsForNode( NodeTypeInfo* node_type_info, bool* need_sorting_ptr) { node_type_info->SetOutputSidePacketBaseIndex(output_side_packets_.size()); const tool::TagMap& tag_map = @@ -575,10 +573,10 @@ mediapipe::Status ValidatedGraphConfig::AddOutputSidePacketsForNode( } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::InitializeStreamInfo( +absl::Status ValidatedGraphConfig::InitializeStreamInfo( bool* need_sorting_ptr) { // Define output streams for graph input streams. ASSIGN_OR_RETURN(std::shared_ptr graph_input_streams, @@ -607,10 +605,10 @@ mediapipe::Status ValidatedGraphConfig::InitializeStreamInfo( // Validate tag-name-indexes for graph output streams. MP_RETURN_IF_ERROR(tool::TagMap::Create(config_.output_stream()).status()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::AddOutputStreamsForNode( +absl::Status ValidatedGraphConfig::AddOutputStreamsForNode( NodeTypeInfo* node_type_info) { // Define output streams connecting calculators. node_type_info->SetOutputStreamBaseIndex(output_streams_.size()); @@ -620,12 +618,12 @@ mediapipe::Status ValidatedGraphConfig::AddOutputStreamsForNode( AddOutputStream(node_type_info->Node(), tag_map.Names()[id.value()], &node_type_info->OutputStreamTypes().Get(id))); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::AddOutputStream( - NodeTypeInfo::NodeRef node, const std::string& name, - PacketType* packet_type) { +absl::Status ValidatedGraphConfig::AddOutputStream(NodeTypeInfo::NodeRef node, + const std::string& name, + PacketType* packet_type) { output_streams_.emplace_back(); auto& edge_info = output_streams_.back(); @@ -638,10 +636,10 @@ mediapipe::Status ValidatedGraphConfig::AddOutputStream( return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) << "Output Stream \"" << name << "\" defined twice."; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::AddInputStreamsForNode( +absl::Status ValidatedGraphConfig::AddInputStreamsForNode( NodeTypeInfo* node_type_info, bool* need_sorting_ptr) { node_type_info->SetInputStreamBaseIndex(input_streams_.size()); const int node_index = node_type_info->Node().index; @@ -704,7 +702,7 @@ mediapipe::Status ValidatedGraphConfig::AddInputStreamsForNode( edge_info.name = name; edge_info.packet_type = &node_type_info->InputStreamTypes().Get(id); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } int ValidatedGraphConfig::SorterIndexForNode(NodeTypeInfo::NodeRef node) const { @@ -728,7 +726,7 @@ NodeTypeInfo::NodeRef ValidatedGraphConfig::NodeForSorterIndex( } } -mediapipe::Status ValidatedGraphConfig::TopologicalSortNodes() { +absl::Status ValidatedGraphConfig::TopologicalSortNodes() { #if !(defined(MEDIAPIPE_LITE) || defined(MEDIAPIPE_MOBILE)) VLOG(2) << "BEFORE TOPOLOGICAL SORT:\n" << config_.DebugString(); #endif // !(MEDIAPIPE_LITE || MEDIAPIPE_MOBILE) @@ -836,10 +834,10 @@ mediapipe::Status ValidatedGraphConfig::TopologicalSortNodes() { #if !(defined(MEDIAPIPE_LITE) || defined(MEDIAPIPE_MOBILE)) VLOG(2) << "AFTER TOPOLOGICAL SORT:\n" << config_.DebugString(); #endif // !(MEDIAPIPE_LITE || MEDIAPIPE_MOBILE) - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::FillUpstreamFieldForBackEdges() { +absl::Status ValidatedGraphConfig::FillUpstreamFieldForBackEdges() { for (int index = 0; index < input_streams_.size(); ++index) { auto& input_stream = input_streams_[index]; if (input_stream.back_edge) { @@ -854,10 +852,10 @@ mediapipe::Status ValidatedGraphConfig::FillUpstreamFieldForBackEdges() { input_stream.upstream = iter->second; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::ValidateSidePacketTypes() { +absl::Status ValidatedGraphConfig::ValidateSidePacketTypes() { for (const auto& side_packet : input_side_packets_) { // TODO Add a check to ensure multiple input side packets // connected to a side packet that will be provided later all have @@ -865,7 +863,7 @@ mediapipe::Status ValidatedGraphConfig::ValidateSidePacketTypes() { if (side_packet.upstream != -1 && !side_packet.packet_type->IsConsistentWith( *output_side_packets_[side_packet.upstream].packet_type)) { - return mediapipe::UnknownError(absl::Substitute( + return absl::UnknownError(absl::Substitute( "Input side packet \"$0\" of $1 \"$2\" expected a packet of type " "\"$3\" but the connected output side packet will be of type \"$4\"", side_packet.name, @@ -877,10 +875,10 @@ mediapipe::Status ValidatedGraphConfig::ValidateSidePacketTypes() { .packet_type->DebugTypeName())); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::ResolveAnyTypes( +absl::Status ValidatedGraphConfig::ResolveAnyTypes( std::vector* input_edges, std::vector* output_edges) { for (EdgeInfo& input_edge : *input_edges) { if (input_edge.upstream == -1) { @@ -895,15 +893,15 @@ mediapipe::Status ValidatedGraphConfig::ResolveAnyTypes( output_root->SetSameAs(input_edge.packet_type); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::ValidateStreamTypes() { +absl::Status ValidatedGraphConfig::ValidateStreamTypes() { for (const EdgeInfo& stream : input_streams_) { RET_CHECK_NE(stream.upstream, -1); if (!stream.packet_type->IsConsistentWith( *output_streams_[stream.upstream].packet_type)) { - return mediapipe::UnknownError(absl::Substitute( + return absl::UnknownError(absl::Substitute( "Input stream \"$0\" of calculator \"$1\" expects packets of type " "\"$2\" but the connected output stream will contain packets of type " "\"$3\"", @@ -913,10 +911,10 @@ mediapipe::Status ValidatedGraphConfig::ValidateStreamTypes() { output_streams_[stream.upstream].packet_type->DebugTypeName())); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::ValidateExecutors() { +absl::Status ValidatedGraphConfig::ValidateExecutors() { absl::flat_hash_set declared_names; for (const ExecutorConfig& executor_config : config_.executor()) { if (IsReservedExecutorName(executor_config.name())) { @@ -926,7 +924,7 @@ mediapipe::Status ValidatedGraphConfig::ValidateExecutors() { } if (!declared_names.emplace(executor_config.name()).second) { if (executor_config.name().empty()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "ExecutorConfig for the default executor is duplicate."); } else { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) @@ -953,7 +951,7 @@ mediapipe::Status ValidatedGraphConfig::ValidateExecutors() { << "\" is not declared in an ExecutorConfig."; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // static @@ -961,19 +959,27 @@ bool ValidatedGraphConfig::IsReservedExecutorName(const std::string& name) { return name == "default" || name == "gpu" || absl::StartsWith(name, "__"); } -mediapipe::Status ValidatedGraphConfig::ValidateRequiredSidePackets( +absl::Status ValidatedGraphConfig::ValidateRequiredSidePackets( const std::map& side_packets) const { - std::vector statuses; + std::vector statuses; for (const auto& required_item : required_side_packets_) { auto iter = side_packets.find(required_item.first); if (iter == side_packets.end()) { + bool is_optional = true; + for (int index : required_item.second) { + is_optional &= input_side_packets_[index].packet_type->IsOptional(); + } + if (is_optional) { + // Side packets that are optional and not provided are ignored. + continue; + } statuses.push_back(mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Side packet \"" << required_item.first << "\" is required but was not provided."); continue; } for (int index : required_item.second) { - mediapipe::Status status = + absl::Status status = input_side_packets_[index].packet_type->Validate(iter->second); if (!status.ok()) { statuses.push_back( @@ -988,12 +994,12 @@ mediapipe::Status ValidatedGraphConfig::ValidateRequiredSidePackets( return tool::CombinedStatus( "ValidateRequiredSidePackets failed to validate: ", statuses); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::ValidateRequiredSidePacketTypes( +absl::Status ValidatedGraphConfig::ValidateRequiredSidePacketTypes( const std::map& side_packet_types) const { - std::vector statuses; + std::vector statuses; for (const auto& required_item : required_side_packets_) { auto iter = side_packet_types.find(required_item.first); if (iter == side_packet_types.end()) { @@ -1015,10 +1021,10 @@ mediapipe::Status ValidatedGraphConfig::ValidateRequiredSidePacketTypes( return tool::CombinedStatus( "ValidateRequiredSidePackets failed to validate: ", statuses); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidatedGraphConfig::ComputeSourceDependence() { +absl::Status ValidatedGraphConfig::ComputeSourceDependence() { for (int node_index = 0; node_index < calculators_.size(); ++node_index) { NodeTypeInfo& node_type_info = calculators_[node_index]; if (node_type_info.InputStreamTypes().NumEntries() == 0) { @@ -1059,11 +1065,11 @@ mediapipe::Status ValidatedGraphConfig::ComputeSourceDependence() { } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::StatusOr -ValidatedGraphConfig::RegisteredSidePacketTypeName(const std::string& name) { +absl::StatusOr ValidatedGraphConfig::RegisteredSidePacketTypeName( + const std::string& name) { auto iter = side_packet_to_producer_.find(name); bool defined = false; if (iter != side_packet_to_producer_.end()) { @@ -1101,7 +1107,7 @@ ValidatedGraphConfig::RegisteredSidePacketTypeName(const std::string& name) { "determinable, or the type may be defined but not registered."; } -mediapipe::StatusOr ValidatedGraphConfig::RegisteredStreamTypeName( +absl::StatusOr ValidatedGraphConfig::RegisteredStreamTypeName( const std::string& name) { auto iter = stream_to_producer_.find(name); if (iter == stream_to_producer_.end()) { diff --git a/mediapipe/framework/validated_graph_config.h b/mediapipe/framework/validated_graph_config.h index 0a01accee..f509707f5 100644 --- a/mediapipe/framework/validated_graph_config.h +++ b/mediapipe/framework/validated_graph_config.h @@ -65,14 +65,13 @@ class NodeTypeInfo { // node_index is the index of this node among the nodes of the same type // in the validated graph config. - mediapipe::Status Initialize(const ValidatedGraphConfig& validated_graph, - const CalculatorGraphConfig::Node& node, - int node_index); - mediapipe::Status Initialize(const ValidatedGraphConfig& validated_graph, - const PacketGeneratorConfig& node, - int node_index); - mediapipe::Status Initialize(const ValidatedGraphConfig& validated_graph, - const StatusHandlerConfig& node, int node_index); + absl::Status Initialize(const ValidatedGraphConfig& validated_graph, + const CalculatorGraphConfig::Node& node, + int node_index); + absl::Status Initialize(const ValidatedGraphConfig& validated_graph, + const PacketGeneratorConfig& node, int node_index); + absl::Status Initialize(const ValidatedGraphConfig& validated_graph, + const StatusHandlerConfig& node, int node_index); // TODO: many of these accessors can be replaced by Contract(). const PacketTypeSet& InputSidePacketTypes() const { @@ -195,17 +194,16 @@ class ValidatedGraphConfig { // Initializes the ValidatedGraphConfig. This function must be called // before any other functions. Subgraphs are specified through the // global graph registry or an optional local graph registry. - mediapipe::Status Initialize(const CalculatorGraphConfig& input_config, - const GraphRegistry* graph_registry = nullptr); + absl::Status Initialize(const CalculatorGraphConfig& input_config, + const GraphRegistry* graph_registry = nullptr); // Initializes the ValidatedGraphConfig from registered graph and subgraph // configs. Subgraphs are retrieved from the specified graph registry or from // the global graph registry. A subgraph can be instantiated directly by // specifying its type in |graph_type|. - mediapipe::Status Initialize( - const std::string& graph_type, - const Subgraph::SubgraphOptions* options = nullptr, - const GraphRegistry* graph_registry = nullptr); + absl::Status Initialize(const std::string& graph_type, + const Subgraph::SubgraphOptions* options = nullptr, + const GraphRegistry* graph_registry = nullptr); // Initializes the ValidatedGraphConfig from the specified graph and subgraph // configs. Template graph and subgraph configs can be specified through @@ -213,7 +211,7 @@ class ValidatedGraphConfig { // CalclatorGraphConfig.type. A subgraph can be instantiated directly by // specifying its type in |graph_type|. A template graph can be instantiated // directly by specifying its template arguments in |arguments|. - mediapipe::Status Initialize( + absl::Status Initialize( const std::vector& input_configs, const std::vector& input_templates, const std::string& graph_type = "", @@ -225,15 +223,15 @@ class ValidatedGraphConfig { // Returns an error if the provided side packets will be generated by // the PacketGenerators in this graph. template - mediapipe::Status CanAcceptSidePackets( + absl::Status CanAcceptSidePackets( const std::map& side_packets) const; // Validate that all the required side packets are provided, and the // packets have the required type. - mediapipe::Status ValidateRequiredSidePackets( + absl::Status ValidateRequiredSidePackets( const std::map& side_packets) const; // Same as ValidateRequiredSidePackets but only provide the type. - mediapipe::Status ValidateRequiredSidePacketTypes( + absl::Status ValidateRequiredSidePacketTypes( const std::map& side_packet_types) const; // The proto configuration (canonicalized). @@ -280,12 +278,11 @@ class ValidatedGraphConfig { // Returns the registered type name of the specified side packet if // it can be determined, otherwise an appropriate error is returned. - mediapipe::StatusOr RegisteredSidePacketTypeName( + absl::StatusOr RegisteredSidePacketTypeName( const std::string& name); // Returns the registered type name of the specified stream if it can // be determined, otherwise an appropriate error is returned. - mediapipe::StatusOr RegisteredStreamTypeName( - const std::string& name); + absl::StatusOr RegisteredStreamTypeName(const std::string& name); // The namespace used for class name lookup. std::string Package() const { return config_.package(); } @@ -293,13 +290,18 @@ class ValidatedGraphConfig { // Returns true if |name| is a reserved executor name. static bool IsReservedExecutorName(const std::string& name); + // Returns true if a side packet is provided as an input to the graph. + bool IsExternalSidePacket(const std::string& name) const { + return required_side_packets_.count(name) > 0; + } + private: // Initialize the PacketGenerator information. - mediapipe::Status InitializeGeneratorInfo(); + absl::Status InitializeGeneratorInfo(); // Initialize the Calculator information. - mediapipe::Status InitializeCalculatorInfo(); + absl::Status InitializeCalculatorInfo(); // Initialize the StatusHandler information. - mediapipe::Status InitializeStatusHandlerInfo(); + absl::Status InitializeStatusHandlerInfo(); // Initialize the EdgeInfo objects for side packets. // @@ -311,7 +313,7 @@ class ValidatedGraphConfig { // // If need_sorting_ptr is nullptr then an error will be returned if the // nodes in the side packet graph are not in topologically sorted order. - mediapipe::Status InitializeSidePacketInfo(bool* need_sorting_ptr); + absl::Status InitializeSidePacketInfo(bool* need_sorting_ptr); // Adds EdgeInfo objects to input_side_packets_ for all the input side // packets required by the node_type_info. If nodes are processed // with AddInputSidePacketsForNode and AddOutputSidePacketsForNode @@ -319,7 +321,7 @@ class ValidatedGraphConfig { // required_side_packets_ are used to ensure that the graph is // topologically sorted. node_type_info is updated with the proper // initial index for input side packets. - mediapipe::Status AddInputSidePacketsForNode(NodeTypeInfo* node_type_info); + absl::Status AddInputSidePacketsForNode(NodeTypeInfo* node_type_info); // Adds EdgeInfo objects to output_side_packets_ for all the output side // packets produced by the node_type_info. side_packet_to_producer_ is // updated. need_sorting_ptr will be set to true if the nodes in the @@ -327,21 +329,21 @@ class ValidatedGraphConfig { // is output after something that required it), otherwise need_sorting_ptr // is left as is. node_type_info is updated with the proper initial index // for output side packets. - mediapipe::Status AddOutputSidePacketsForNode(NodeTypeInfo* node_type_info, - bool* need_sorting_ptr); + absl::Status AddOutputSidePacketsForNode(NodeTypeInfo* node_type_info, + bool* need_sorting_ptr); // These functions are analogous to the same operations for side // packets, with the small difference that it is an error to use an // undefined stream (whereas it is allowed to use an undefined side // packet). - mediapipe::Status InitializeStreamInfo(bool* need_sorting_ptr); - mediapipe::Status AddOutputStreamsForNode(NodeTypeInfo* node_type_info); - mediapipe::Status AddInputStreamsForNode(NodeTypeInfo* node_type_info, - bool* need_sorting_ptr); + absl::Status InitializeStreamInfo(bool* need_sorting_ptr); + absl::Status AddOutputStreamsForNode(NodeTypeInfo* node_type_info); + absl::Status AddInputStreamsForNode(NodeTypeInfo* node_type_info, + bool* need_sorting_ptr); // A helper function for adding a single output stream EdgeInfo. - mediapipe::Status AddOutputStream(NodeTypeInfo::NodeRef node, - const std::string& name, - PacketType* packet_type); + absl::Status AddOutputStream(NodeTypeInfo::NodeRef node, + const std::string& name, + PacketType* packet_type); // Return the index of the node adjusted for the topological sorter. int SorterIndexForNode(NodeTypeInfo::NodeRef node) const; @@ -360,31 +362,31 @@ class ValidatedGraphConfig { // two node types, graph input streams and status handlers, can be safely // ignored in the analysis of output side packet generation or stream // header packet propagation. - mediapipe::Status TopologicalSortNodes(); + absl::Status TopologicalSortNodes(); // TODO Add InputStreamHandler. // TODO Add OutputStreamHandler. // Fill the "upstream" field for all back edges. - mediapipe::Status FillUpstreamFieldForBackEdges(); + absl::Status FillUpstreamFieldForBackEdges(); // Compute the dependence of nodes on sources. - mediapipe::Status ComputeSourceDependence(); + absl::Status ComputeSourceDependence(); // Infer the type of types set to "Any" by what they are connected to. - mediapipe::Status ResolveAnyTypes(std::vector* input_edges, - std::vector* output_edges); + absl::Status ResolveAnyTypes(std::vector* input_edges, + std::vector* output_edges); // Returns an error if the generator graph does not have consistent // type specifications for side packets. - mediapipe::Status ValidateSidePacketTypes(); + absl::Status ValidateSidePacketTypes(); // Returns an error if the graph of calculators does not have consistent // type specifications for streams. - mediapipe::Status ValidateStreamTypes(); + absl::Status ValidateStreamTypes(); // Returns an error if the graph does not have valid ExecutorConfigs, or // if the executor name in a node config is reserved or is not declared // in an ExecutorConfig. - mediapipe::Status ValidateExecutors(); + absl::Status ValidateExecutors(); bool initialized_ = false; @@ -422,7 +424,7 @@ class ValidatedGraphConfig { }; template -mediapipe::Status ValidatedGraphConfig::CanAcceptSidePackets( +absl::Status ValidatedGraphConfig::CanAcceptSidePackets( const std::map& side_packets) const { for (const auto& output_side_packet : output_side_packets_) { if (ContainsKey(side_packets, output_side_packet.name)) { @@ -431,7 +433,7 @@ mediapipe::Status ValidatedGraphConfig::CanAcceptSidePackets( << "\" is both provided and generated by a PacketGenerator."; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 42ee76a81..e09f85407 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -15,6 +15,7 @@ load("//mediapipe/gpu:metal.bzl", "metal_library") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") +load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") licenses(["notice"]) @@ -166,6 +167,7 @@ cc_library( deps = [ ":gl_base", ":gl_thread_collector", + ":gpu_buffer_format", "//mediapipe/framework:executor", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -364,13 +366,13 @@ cc_library( hdrs = [ "gpu_shared_data_internal.h", ], - defines = ["MEDIAPIPE_DISABLE_GPU"], visibility = ["//visibility:private"], deps = [ ":graph_support", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework:executor", + "//mediapipe/framework:port", "//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/port:ret_check", "//mediapipe/gpu:gl_context_options_cc_proto", @@ -560,6 +562,7 @@ cc_library( "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", @@ -597,6 +600,7 @@ objc_library( ":gpu_shared_data_internal", ":shader_util", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/formats:image", "//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:util", ], @@ -904,6 +908,23 @@ objc_library( alwayslink = 1, ) +objc_library( + name = "mps_threshold_calculator", + srcs = ["MPSThresholdCalculator.mm"], + copts = ["-std=c++17"], + sdk_frameworks = [ + "CoreVideo", + "Metal", + "MetalPerformanceShaders", + ], + visibility = ["//visibility:public"], + deps = [ + ":MPPMetalHelper", + "//mediapipe/objc:mediapipe_framework_ios", + ], + alwayslink = 1, +) + MIN_IOS_VERSION = "9.0" # For thread_local. test_suite( diff --git a/mediapipe/gpu/MPPMetalHelper.h b/mediapipe/gpu/MPPMetalHelper.h index 293e9acdc..f3662422e 100644 --- a/mediapipe/gpu/MPPMetalHelper.h +++ b/mediapipe/gpu/MPPMetalHelper.h @@ -41,7 +41,7 @@ NS_ASSUME_NONNULL_BEGIN /// Configures a calculator's contract for accessing GPU resources. /// Calculators should use this in GetContract. -+ (::mediapipe::Status)updateContract:(mediapipe::CalculatorContract*)cc; ++ (absl::Status)updateContract:(mediapipe::CalculatorContract*)cc; /// Deprecated initializer. - (instancetype)initWithSidePackets:(const mediapipe::PacketSet&)inputSidePackets; @@ -51,7 +51,7 @@ NS_ASSUME_NONNULL_BEGIN /// Configures a calculator's side packets for accessing GPU resources. /// Calculators should use this in FillExpectations. -+ (::mediapipe::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets; ++ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets; /// Get a metal command buffer. /// Calculators should use this method instead of getting a buffer from the diff --git a/mediapipe/gpu/MPPMetalHelper.mm b/mediapipe/gpu/MPPMetalHelper.mm index cc4fbd6e7..aeb6cd58c 100644 --- a/mediapipe/gpu/MPPMetalHelper.mm +++ b/mediapipe/gpu/MPPMetalHelper.mm @@ -54,7 +54,7 @@ class MetalHelperLegacySupport { return [self initWithGpuResources:&cc->Service(mediapipe::kGpuService).GetObject()]; } -+ (::mediapipe::Status)updateContract:(mediapipe::CalculatorContract*)cc { ++ (absl::Status)updateContract:(mediapipe::CalculatorContract*)cc { cc->UseService(mediapipe::kGpuService); // Allow the legacy side packet to be provided, too, for backwards // compatibility with existing graphs. It will just be ignored. @@ -63,7 +63,7 @@ class MetalHelperLegacySupport { if (id.IsValid()) { input_side_packets.Get(id).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Legacy support. @@ -85,7 +85,7 @@ class MetalHelperLegacySupport { } // Legacy support. -+ (::mediapipe::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets { ++ (absl::Status)setupInputSidePackets:(mediapipe::PacketTypeSet*)inputSidePackets { auto cc = mediapipe::MetalHelperLegacySupport::GetCalculatorContract(); if (cc) { CHECK_EQ(inputSidePackets, &cc->InputSidePackets()); @@ -101,7 +101,7 @@ class MetalHelperLegacySupport { << "A " << mediapipe::kGpuSharedTagName << " input side packet is required here."; inputSidePackets->Get(id).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } - (id)mtlDevice { diff --git a/mediapipe/gpu/gl_base.h b/mediapipe/gpu/gl_base.h index 9aa8b02dc..12a04e0bb 100644 --- a/mediapipe/gpu/gl_base.h +++ b/mediapipe/gpu/gl_base.h @@ -69,7 +69,7 @@ #include // When using the Linux EGL headers, we may end up pulling a -// "#define Status int" from Xlib.h, which interferes with mediapipe::Status. +// "#define Status int" from Xlib.h, which interferes with absl::Status. #undef Status // More crud from X diff --git a/mediapipe/gpu/gl_calculator_helper.cc b/mediapipe/gpu/gl_calculator_helper.cc index 8b6373506..aa708e731 100644 --- a/mediapipe/gpu/gl_calculator_helper.cc +++ b/mediapipe/gpu/gl_calculator_helper.cc @@ -14,6 +14,7 @@ #include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/legacy_calculator_support.h" #include "mediapipe/framework/port/canonical_errors.h" @@ -39,12 +40,12 @@ GlCalculatorHelper::GlCalculatorHelper() {} GlCalculatorHelper::~GlCalculatorHelper() {} -::mediapipe::Status GlCalculatorHelper::Open(CalculatorContext* cc) { +absl::Status GlCalculatorHelper::Open(CalculatorContext* cc) { CHECK(cc); // TODO return error from impl_ (needs two-stage init) impl_ = absl::make_unique( cc, &cc->Service(kGpuService).GetObject()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void GlCalculatorHelper::InitializeForTest(GpuSharedData* gpu_shared) { @@ -57,7 +58,7 @@ void GlCalculatorHelper::InitializeForTest(GpuResources* gpu_resources) { } // static -::mediapipe::Status GlCalculatorHelper::UpdateContract(CalculatorContract* cc) { +absl::Status GlCalculatorHelper::UpdateContract(CalculatorContract* cc) { cc->UseService(kGpuService); // Allow the legacy side packet to be provided, too, for backwards // compatibility with existing graphs. It will just be ignored. @@ -66,11 +67,11 @@ void GlCalculatorHelper::InitializeForTest(GpuResources* gpu_resources) { if (id.IsValid()) { input_side_packets.Get(id).Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // static -::mediapipe::Status GlCalculatorHelper::SetupInputSidePackets( +absl::Status GlCalculatorHelper::SetupInputSidePackets( PacketTypeSet* input_side_packets) { auto cc = LegacyCalculatorSupport::Scoped::current(); if (cc) { @@ -87,12 +88,12 @@ void GlCalculatorHelper::InitializeForTest(GpuResources* gpu_resources) { RET_CHECK(id.IsValid()) << "A " << mediapipe::kGpuSharedTagName << " input side packet is required here."; input_side_packets->Get(id).Set(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status GlCalculatorHelper::RunInGlContext( - std::function<::mediapipe::Status(void)> gl_func) { - if (!impl_) return ::mediapipe::InternalError("helper not initialized"); +absl::Status GlCalculatorHelper::RunInGlContext( + std::function gl_func) { + if (!impl_) return absl::InternalError("helper not initialized"); // TODO: Remove LegacyCalculatorSupport from MediaPipe OSS. auto calculator_context = LegacyCalculatorSupport::Scoped::current(); @@ -140,4 +141,21 @@ GlContext& GlCalculatorHelper::GetGlContext() const { return impl_->GetGlContext(); } +GlVersion GlCalculatorHelper::GetGlVersion() const { + return impl_->GetGlVersion(); +} + +GlTexture GlCalculatorHelper::CreateSourceTexture( + const mediapipe::Image& image) { + return impl_->CreateSourceTexture(image.GetGpuBuffer()); +} + +template <> +std::unique_ptr GlTexture::GetFrame() + const { + std::unique_ptr buf = GetFrame(); + auto output = absl::make_unique(*buf); + return output; +} + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_calculator_helper.h b/mediapipe/gpu/gl_calculator_helper.h index 53178da44..b5cc69990 100644 --- a/mediapipe/gpu/gl_calculator_helper.h +++ b/mediapipe/gpu/gl_calculator_helper.h @@ -18,6 +18,7 @@ #include "absl/memory/memory.h" #include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_contract.h" +#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet_set.h" @@ -68,7 +69,7 @@ class GlCalculatorHelper { ~GlCalculatorHelper(); // Call Open from the Open method of a calculator to initialize the helper. - ::mediapipe::Status Open(CalculatorContext* cc); + absl::Status Open(CalculatorContext* cc); // Can be used to initialize the helper outside of a calculator. Useful for // testing. @@ -77,33 +78,31 @@ class GlCalculatorHelper { // This method can be called from GetContract to set up the needed GPU // resources. - static ::mediapipe::Status UpdateContract(CalculatorContract* cc); + static absl::Status UpdateContract(CalculatorContract* cc); // This method can be called from FillExpectations to set the correct types // for the shared GL input side packet(s). - static ::mediapipe::Status SetupInputSidePackets( - PacketTypeSet* input_side_packets); + static absl::Status SetupInputSidePackets(PacketTypeSet* input_side_packets); // Execute the provided function within the helper's GL context. On some // platforms, this may be run on a different thread; however, this method // will still wait for the function to finish executing before returning. // The status result from the function is passed on to the caller. - ::mediapipe::Status RunInGlContext( - std::function<::mediapipe::Status(void)> gl_func); + absl::Status RunInGlContext(std::function gl_func); // Convenience version of RunInGlContext for arguments with a void result - // type. As with the ::mediapipe::Status version, this also waits for the + // type. As with the absl::Status version, this also waits for the // function to finish executing before returning. // // Implementation note: we cannot use a std::function argument // here, because that would break passing in a lambda that returns a status; // e.g.: - // RunInGlContext([]() -> ::mediapipe::Status { ... }); + // RunInGlContext([]() -> absl::Status { ... }); // // The reason is that std::function allows the implicit conversion // of a callable with any result type, as long as the argument types match. // As a result, the above lambda would be implicitly convertible to both - // std::function<::mediapipe::Status(void)> and std::function, and + // std::function and std::function, and // the invocation would be ambiguous. // // Therefore, instead of using std::function, we use a template @@ -113,7 +112,7 @@ class GlCalculatorHelper { void RunInGlContext(T f) { RunInGlContext([f] { f(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }).IgnoreError(); } @@ -125,6 +124,7 @@ class GlCalculatorHelper { // Creates a texture representing an input frame, and manages sync token. GlTexture CreateSourceTexture(const GpuBuffer& pixel_buffer); GlTexture CreateSourceTexture(const ImageFrame& image_frame); + GlTexture CreateSourceTexture(const mediapipe::Image& image); #ifdef __APPLE__ // Creates a texture from a plane of a planar buffer. @@ -153,6 +153,8 @@ class GlCalculatorHelper { GlContext& GetGlContext() const; + GlVersion GetGlVersion() const; + // Check if the calculator helper has been previously initialized. bool Initialized() { return impl_ != nullptr; } diff --git a/mediapipe/gpu/gl_calculator_helper_impl.h b/mediapipe/gpu/gl_calculator_helper_impl.h index 2a20b8438..438111183 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl.h +++ b/mediapipe/gpu/gl_calculator_helper_impl.h @@ -37,9 +37,8 @@ class GlCalculatorHelperImpl { GpuResources* gpu_resources); ~GlCalculatorHelperImpl(); - ::mediapipe::Status RunInGlContext( - std::function<::mediapipe::Status(void)> gl_func, - CalculatorContext* calculator_context); + absl::Status RunInGlContext(std::function gl_func, + CalculatorContext* calculator_context); GlTexture CreateSourceTexture(const ImageFrame& image_frame); GlTexture CreateSourceTexture(const GpuBuffer& gpu_buffer); @@ -54,9 +53,7 @@ class GlCalculatorHelperImpl { GLuint framebuffer() const { return framebuffer_; } void BindFramebuffer(const GlTexture& dst); -#ifdef __APPLE__ - GlVersion GetGlVersion(); -#endif + GlVersion GetGlVersion() const { return gl_context_->GetGlVersion(); } GlContext& GetGlContext() const; diff --git a/mediapipe/gpu/gl_calculator_helper_impl_android.cc b/mediapipe/gpu/gl_calculator_helper_impl_android.cc index 0fe5aa3cd..340734335 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl_android.cc +++ b/mediapipe/gpu/gl_calculator_helper_impl_android.cc @@ -36,6 +36,15 @@ std::unique_ptr GlTexture::GetFrame() const { template <> std::unique_ptr GlTexture::GetFrame() const { +#ifdef __EMSCRIPTEN__ + // When WebGL is used, the GL context may be spontaneously lost which can + // cause GpuBuffer allocations to fail. In that case, return a dummy buffer + // to allow processing of the current frame complete. + if (!gpu_buffer_) { + return std::make_unique(); + } +#endif // __EMSCRIPTEN__ + CHECK(gpu_buffer_); // Inform the GlTextureBuffer that we have produced new content, and create // a producer sync point. diff --git a/mediapipe/gpu/gl_calculator_helper_impl_common.cc b/mediapipe/gpu/gl_calculator_helper_impl_common.cc index 5a460426f..2e4ab10b6 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl_common.cc +++ b/mediapipe/gpu/gl_calculator_helper_impl_common.cc @@ -42,7 +42,7 @@ GlCalculatorHelperImpl::~GlCalculatorHelperImpl() { glDeleteFramebuffers(1, &framebuffer_); framebuffer_ = 0; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }, /*calculator_context=*/nullptr) .IgnoreError(); @@ -50,8 +50,8 @@ GlCalculatorHelperImpl::~GlCalculatorHelperImpl() { GlContext& GlCalculatorHelperImpl::GetGlContext() const { return *gl_context_; } -::mediapipe::Status GlCalculatorHelperImpl::RunInGlContext( - std::function<::mediapipe::Status(void)> gl_func, +absl::Status GlCalculatorHelperImpl::RunInGlContext( + std::function gl_func, CalculatorContext* calculator_context) { if (calculator_context) { return gl_context_->Run(std::move(gl_func), calculator_context->NodeId(), @@ -161,12 +161,14 @@ GlTexture GlCalculatorHelperImpl::MapGlTextureBuffer( texture.target_ = texture_buffer->target_; texture.name_ = texture_buffer->name_; - // TODO: do the params need to be reset here?? - glBindTexture(texture.target(), texture.name()); - GlTextureInfo info = - GlTextureInfoForGpuBufferFormat(texture_buffer->format(), texture.plane_); - SetStandardTextureParams(texture.target(), info.gl_internal_format); - glBindTexture(texture.target(), 0); + if (texture_buffer->format() != GpuBufferFormat::kUnknown) { + // TODO: do the params need to be reset here?? + glBindTexture(texture.target(), texture.name()); + GlTextureInfo info = GlTextureInfoForGpuBufferFormat( + texture_buffer->format(), texture.plane_, GetGlVersion()); + SetStandardTextureParams(texture.target(), info.gl_internal_format); + glBindTexture(texture.target(), 0); + } return texture; } @@ -178,11 +180,14 @@ GlTextureBufferSharedPtr GlCalculatorHelperImpl::MakeGlTextureBuffer( image_frame.Width(), image_frame.Height(), GpuBufferFormatForImageFormat(image_frame.Format()), image_frame.PixelData()); - glBindTexture(GL_TEXTURE_2D, buffer->name_); - GlTextureInfo info = - GlTextureInfoForGpuBufferFormat(buffer->format_, /*plane=*/0); - SetStandardTextureParams(buffer->target_, info.gl_internal_format); - glBindTexture(GL_TEXTURE_2D, 0); + + if (buffer->format_ != GpuBufferFormat::kUnknown) { + glBindTexture(GL_TEXTURE_2D, buffer->name_); + GlTextureInfo info = GlTextureInfoForGpuBufferFormat( + buffer->format_, /*plane=*/0, GetGlVersion()); + SetStandardTextureParams(buffer->target_, info.gl_internal_format); + glBindTexture(GL_TEXTURE_2D, 0); + } return buffer; } diff --git a/mediapipe/gpu/gl_calculator_helper_impl_ios.mm b/mediapipe/gpu/gl_calculator_helper_impl_ios.mm index 81cde6a9f..e91d36e2c 100644 --- a/mediapipe/gpu/gl_calculator_helper_impl_ios.mm +++ b/mediapipe/gpu/gl_calculator_helper_impl_ios.mm @@ -28,15 +28,6 @@ namespace mediapipe { -GlVersion GlCalculatorHelperImpl::GetGlVersion() { -#if TARGET_OS_OSX - return GlVersion::kGL; -#else - if (gl_context_->eagl_context().API == kEAGLRenderingAPIOpenGLES3) return GlVersion::kGLES3; - else return GlVersion::kGLES2; -#endif // TARGET_OS_OSX -} - #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER GlTexture GlCalculatorHelperImpl::CreateSourceTexture( const mediapipe::ImageFrame& image_frame) { @@ -122,9 +113,9 @@ std::unique_ptr GlTexture::GetFrame() const { ImageFormat::Format image_format = ImageFormatForGpuBufferFormat(gpu_buffer_.format()); - // TODO: handle gl version here. - GlTextureInfo info = GlTextureInfoForGpuBufferFormat( - gpu_buffer_.format(), plane_); + CHECK(helper_impl_); + GlTextureInfo info = + GlTextureInfoForGpuBufferFormat(gpu_buffer_.format(), plane_, helper_impl_->GetGlVersion()); auto output = absl::make_unique( image_format, width_, height_); diff --git a/mediapipe/gpu/gl_context.cc b/mediapipe/gpu/gl_context.cc index 58b6aab76..92c67035a 100644 --- a/mediapipe/gpu/gl_context.cc +++ b/mediapipe/gpu/gl_context.cc @@ -29,6 +29,7 @@ #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_builder.h" #include "mediapipe/gpu/gl_context_internal.h" +#include "mediapipe/gpu/gpu_buffer_format.h" #ifndef __EMSCRIPTEN__ #include "absl/debugging/leak_check.h" @@ -140,13 +141,13 @@ void GlContext::DedicatedThread::ThreadBody() { #endif } -::mediapipe::Status GlContext::DedicatedThread::Run(GlStatusFunction gl_func) { +absl::Status GlContext::DedicatedThread::Run(GlStatusFunction gl_func) { // Neither ENDO_SCOPE nor ENDO_TASK seem to work here. if (IsCurrentThread()) { return gl_func(); } bool done = false; // Guarded by mutex_ after initialization. - ::mediapipe::Status status; + absl::Status status; PutJob([this, gl_func, &done, &status]() { status = gl_func(); absl::MutexLock lock(&mutex_); @@ -204,6 +205,14 @@ bool GlContext::ParseGlVersion(absl::string_view version_string, GLint* major, return true; } +GlVersion GlContext::GetGlVersion() const { +#ifdef GL_ES_VERSION_2_0 // This actually means "is GLES available". + return gl_major_version() < 3 ? GlVersion::kGLES2 : GlVersion::kGLES3; +#else // This is the "desktop GL" case. + return GlVersion::kGL; +#endif +} + bool GlContext::HasGlExtension(absl::string_view extension) const { return gl_extensions_.find(extension) != gl_extensions_.end(); } @@ -212,7 +221,7 @@ bool GlContext::HasGlExtension(absl::string_view extension) const { // in an easily-accessible set. The glGetString call is actually *not* required // to work with GL_EXTENSIONS for newer GL versions, so we must maintain both // variations of this function. -::mediapipe::Status GlContext::GetGlExtensions() { +absl::Status GlContext::GetGlExtensions() { gl_extensions_.clear(); // glGetStringi only introduced in GL 3.0+; so we exit out this function if // we don't have that function defined, regardless of version number reported. @@ -226,54 +235,52 @@ bool GlContext::HasGlExtension(absl::string_view extension) const { LOG(ERROR) << "GL major version > 3.0 indicated, but glGetStringi not " << "defined. Falling back to deprecated GL extensions querying " << "method."; - return ::mediapipe::InternalError("glGetStringi not defined, but queried"); + return absl::InternalError("glGetStringi not defined, but queried"); } int num_extensions = 0; glGetIntegerv(GL_NUM_EXTENSIONS, &num_extensions); if (glGetError() != 0) { - return ::mediapipe::InternalError( - "Error querying for number of extensions"); + return absl::InternalError("Error querying for number of extensions"); } for (int i = 0; i < num_extensions; ++i) { const GLubyte* res = glGetStringi(GL_EXTENSIONS, i); if (glGetError() != 0 || res == nullptr) { - return ::mediapipe::InternalError( - "Error querying for an extension by index"); + return absl::InternalError("Error querying for an extension by index"); } const char* signed_res = reinterpret_cast(res); gl_extensions_.insert(signed_res); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); #else - return ::mediapipe::InternalError("GL version mismatch in GlGetExtensions"); + return absl::InternalError("GL version mismatch in GlGetExtensions"); #endif // (GL_VERSION_3_0 || GL_ES_VERSION_3_0) && !defined(__EMSCRIPTEN__) } // Same as GetGlExtensions() above, but for pre-GL3.0, where glGetStringi did // not exist. -::mediapipe::Status GlContext::GetGlExtensionsCompat() { +absl::Status GlContext::GetGlExtensionsCompat() { gl_extensions_.clear(); const GLubyte* res = glGetString(GL_EXTENSIONS); if (glGetError() != 0 || res == nullptr) { LOG(ERROR) << "Error querying for GL extensions"; - return ::mediapipe::InternalError("Error querying for GL extensions"); + return absl::InternalError("Error querying for GL extensions"); } const char* signed_res = reinterpret_cast(res); gl_extensions_ = absl::StrSplit(signed_res, ' '); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status GlContext::FinishInitialization(bool create_thread) { +absl::Status GlContext::FinishInitialization(bool create_thread) { if (create_thread) { thread_ = absl::make_unique(); MP_RETURN_IF_ERROR(thread_->Run([this] { return EnterContext(nullptr); })); } - return Run([this]() -> ::mediapipe::Status { + return Run([this]() -> absl::Status { // Clear any GL errors at this point: as this is a fresh context // there shouldn't be any, but if we adopted an existing context (e.g. in // some Emscripten cases), there might be some existing tripped error. @@ -326,7 +333,7 @@ bool GlContext::HasGlExtension(absl::string_view extension) const { if (gl_major_version_ >= 3) { auto status = GetGlExtensions(); if (status.ok()) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } return GetGlExtensionsCompat(); @@ -368,7 +375,7 @@ void GlContext::SetProfilingContext( } } -::mediapipe::Status GlContext::SwitchContextAndRun(GlStatusFunction gl_func) { +absl::Status GlContext::SwitchContextAndRun(GlStatusFunction gl_func) { ContextBinding saved_context; MP_RETURN_IF_ERROR(EnterContext(&saved_context)) << " (entering GL context)"; auto status = gl_func(); @@ -377,9 +384,9 @@ void GlContext::SetProfilingContext( return status; } -::mediapipe::Status GlContext::Run(GlStatusFunction gl_func, int node_id, - Timestamp input_timestamp) { - ::mediapipe::Status status; +absl::Status GlContext::Run(GlStatusFunction gl_func, int node_id, + Timestamp input_timestamp) { + absl::Status status; if (profiling_helper_) { gl_func = [=] { profiling_helper_->MarkTimestamp(node_id, input_timestamp, @@ -416,7 +423,7 @@ void GlContext::RunWithoutWaiting(GlVoidFunction gl_func) { // TODO: queue up task instead. auto status = SwitchContextAndRun([gl_func] { gl_func(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }); if (!status.ok()) { LOG(ERROR) << "Error in RunWithoutWaiting: " << status; @@ -433,8 +440,8 @@ std::weak_ptr& GlContext::CurrentContext() { return current_context; } -::mediapipe::Status GlContext::SwitchContext(ContextBinding* saved_context, - const ContextBinding& new_context) +absl::Status GlContext::SwitchContext(ContextBinding* saved_context, + const ContextBinding& new_context) ABSL_NO_THREAD_SAFETY_ANALYSIS { std::shared_ptr old_context_obj = CurrentContext().lock(); std::shared_ptr new_context_obj = @@ -452,7 +459,7 @@ std::weak_ptr& GlContext::CurrentContext() { } if (new_context_obj && (old_context_obj == new_context_obj)) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } if (old_context_obj) { @@ -479,13 +486,12 @@ std::weak_ptr& GlContext::CurrentContext() { } } -::mediapipe::Status GlContext::EnterContext(ContextBinding* saved_context) { +absl::Status GlContext::EnterContext(ContextBinding* saved_context) { DCHECK(HasContext()); return SwitchContext(saved_context, ThisContextBinding()); } -::mediapipe::Status GlContext::ExitContext( - const ContextBinding* saved_context) { +absl::Status GlContext::ExitContext(const ContextBinding* saved_context) { ContextBinding no_context; if (!saved_context) { saved_context = &no_context; @@ -822,4 +828,11 @@ void GlContext::LogUncheckedGlErrors(bool had_gl_errors) { } } +const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, + int plane) { + std::shared_ptr ctx = GlContext::GetCurrent(); + CHECK(ctx != nullptr); + return GlTextureInfoForGpuBufferFormat(format, plane, ctx->GetGlVersion()); +} + } // namespace mediapipe diff --git a/mediapipe/gpu/gl_context.h b/mediapipe/gpu/gl_context.h index 0e8990470..024ed9e5f 100644 --- a/mediapipe/gpu/gl_context.h +++ b/mediapipe/gpu/gl_context.h @@ -29,6 +29,7 @@ #include "mediapipe/framework/port/threadpool.h" #include "mediapipe/framework/timestamp.h" #include "mediapipe/gpu/gl_base.h" +#include "mediapipe/gpu/gpu_buffer_format.h" #ifdef __APPLE__ #include @@ -64,7 +65,7 @@ struct EAGLContext; namespace mediapipe { typedef std::function GlVoidFunction; -typedef std::function<::mediapipe::Status()> GlStatusFunction; +typedef std::function GlStatusFunction; class GlContext; @@ -141,10 +142,9 @@ constexpr PlatformGlContext kPlatformGlContextNone = nil; // - Managing the interaction between threads and GL contexts. // - Managing synchronization between different GL contexts. // -// See go/mediapipe-gl-context for details. class GlContext : public std::enable_shared_from_this { public: - using StatusOrGlContext = ::mediapipe::StatusOr>; + using StatusOrGlContext = absl::StatusOr>; // Creates a GlContext. // // The first argument (which can be a GlContext, or a platform-specific type) @@ -180,8 +180,8 @@ class GlContext : public std::enable_shared_from_this { // Executes a function in the GL context. Waits for the // function's execution to be complete before returning to the caller. - ::mediapipe::Status Run(GlStatusFunction gl_func, int node_id = -1, - Timestamp input_timestamp = Timestamp::Unset()); + absl::Status Run(GlStatusFunction gl_func, int node_id = -1, + Timestamp input_timestamp = Timestamp::Unset()); // Like Run, but does not wait. void RunWithoutWaiting(GlVoidFunction gl_func); @@ -237,6 +237,10 @@ class GlContext : public std::enable_shared_from_this { static bool ParseGlVersion(absl::string_view version_string, GLint* major, GLint* minor); + // Returns a GlVersion code used with GpuBufferFormat. + // TODO: make this more generally applicable. + GlVersion GetGlVersion() const; + // Simple query for GL extension support; only valid after GlContext has // finished its initialization successfully. bool HasGlExtension(absl::string_view extension) const; @@ -253,12 +257,12 @@ class GlContext : public std::enable_shared_from_this { // Implementation note: we cannot use a std::function argument // here, because that would break passing in a lambda that returns a status; // e.g.: - // RunInGlContext([]() -> ::mediapipe::Status { ... }); + // RunInGlContext([]() -> absl::Status { ... }); // // The reason is that std::function allows the implicit conversion // of a callable with any result type, as long as the argument types match. // As a result, the above lambda would be implicitly convertible to both - // std::function<::mediapipe::Status(void)> and std::function, and + // std::function and std::function, and // the invocation would be ambiguous. // // Therefore, instead of using std::function, we use a template @@ -268,7 +272,7 @@ class GlContext : public std::enable_shared_from_this { void Run(T f) { Run([f] { f(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }).IgnoreError(); } @@ -284,29 +288,27 @@ class GlContext : public std::enable_shared_from_this { GlContext(); #if defined(__EMSCRIPTEN__) - ::mediapipe::Status CreateContext( - EMSCRIPTEN_WEBGL_CONTEXT_HANDLE share_context); - ::mediapipe::Status CreateContextInternal( + absl::Status CreateContext(EMSCRIPTEN_WEBGL_CONTEXT_HANDLE share_context); + absl::Status CreateContextInternal( EMSCRIPTEN_WEBGL_CONTEXT_HANDLE share_context, int webgl_version); EMSCRIPTEN_WEBGL_CONTEXT_HANDLE context_ = 0; EmscriptenWebGLContextAttributes attrs_; #elif HAS_EGL - ::mediapipe::Status CreateContext(EGLContext share_context); - ::mediapipe::Status CreateContextInternal(EGLContext share_context, - int gl_version); + absl::Status CreateContext(EGLContext share_context); + absl::Status CreateContextInternal(EGLContext share_context, int gl_version); EGLDisplay display_ = EGL_NO_DISPLAY; EGLConfig config_; EGLSurface surface_ = EGL_NO_SURFACE; EGLContext context_ = EGL_NO_CONTEXT; #elif HAS_EAGL - ::mediapipe::Status CreateContext(EAGLSharegroup* sharegroup); + absl::Status CreateContext(EAGLSharegroup* sharegroup); EAGLContext* context_; CFHolder texture_cache_; #elif HAS_NSGL - ::mediapipe::Status CreateContext(NSOpenGLContext* share_context); + absl::Status CreateContext(NSOpenGLContext* share_context); NSOpenGLContext* context_; NSOpenGLPixelFormat* pixel_format_; @@ -335,16 +337,16 @@ class GlContext : public std::enable_shared_from_this { #endif // HAS_EGL }; - ::mediapipe::Status FinishInitialization(bool create_thread); + absl::Status FinishInitialization(bool create_thread); // This wraps a thread_local. static std::weak_ptr& CurrentContext(); - static ::mediapipe::Status SwitchContext(ContextBinding* old_context, - const ContextBinding& new_context); + static absl::Status SwitchContext(ContextBinding* old_context, + const ContextBinding& new_context); - ::mediapipe::Status EnterContext(ContextBinding* previous_context); - ::mediapipe::Status ExitContext(const ContextBinding* previous_context); + absl::Status EnterContext(ContextBinding* previous_context); + absl::Status ExitContext(const ContextBinding* previous_context); void DestroyContext(); bool HasContext() const; @@ -363,13 +365,13 @@ class GlContext : public std::enable_shared_from_this { bool CheckForGlErrors(bool force); void LogUncheckedGlErrors(bool had_gl_errors); - ::mediapipe::Status GetGlExtensions(); - ::mediapipe::Status GetGlExtensionsCompat(); + absl::Status GetGlExtensions(); + absl::Status GetGlExtensionsCompat(); // Make the context current, run gl_func, and restore the previous context. // Internal helper only; callers should use Run or RunWithoutWaiting instead, // which delegates to the dedicated thread if required. - ::mediapipe::Status SwitchContextAndRun(GlStatusFunction gl_func); + absl::Status SwitchContextAndRun(GlStatusFunction gl_func); // The following ContextBinding functions have platform-specific // implementations. @@ -380,7 +382,7 @@ class GlContext : public std::enable_shared_from_this { // context is current on this thread. static void GetCurrentContextBinding(ContextBinding* binding); // Makes the context described by new_context current on this thread. - static ::mediapipe::Status SetCurrentContextBinding( + static absl::Status SetCurrentContextBinding( const ContextBinding& new_context); // If not null, a dedicated thread used to execute tasks on this context. @@ -415,5 +417,12 @@ class GlContext : public std::enable_shared_from_this { std::unique_ptr profiling_helper_ = nullptr; }; +// For backward compatibility. TODO: migrate remaining callers. +ABSL_DEPRECATED( + "Prefer passing an explicit GlVersion argument (use " + "GlContext::GetGlVersion)") +const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, + int plane); + } // namespace mediapipe #endif // MEDIAPIPE_GPU_GL_CONTEXT_H_ diff --git a/mediapipe/gpu/gl_context_eagl.cc b/mediapipe/gpu/gl_context_eagl.cc index ed8ba11c1..2811ea0b4 100644 --- a/mediapipe/gpu/gl_context_eagl.cc +++ b/mediapipe/gpu/gl_context_eagl.cc @@ -53,7 +53,7 @@ GlContext::StatusOrGlContext GlContext::Create(EAGLSharegroup* sharegroup, return std::move(context); } -::mediapipe::Status GlContext::CreateContext(EAGLSharegroup* sharegroup) { +absl::Status GlContext::CreateContext(EAGLSharegroup* sharegroup) { context_ = [[EAGLContext alloc] initWithAPI:kEAGLRenderingAPIOpenGLES3 sharegroup:sharegroup]; if (context_) { @@ -72,7 +72,7 @@ GlContext::StatusOrGlContext GlContext::Create(EAGLSharegroup* sharegroup, << "Error at CVOpenGLESTextureCacheCreate"; texture_cache_.adopt(cache); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void GlContext::DestroyContext() { @@ -95,11 +95,11 @@ void GlContext::GetCurrentContextBinding(GlContext::ContextBinding* binding) { binding->context = [EAGLContext currentContext]; } -::mediapipe::Status GlContext::SetCurrentContextBinding( +absl::Status GlContext::SetCurrentContextBinding( const ContextBinding& new_binding) { BOOL success = [EAGLContext setCurrentContext:new_binding.context]; RET_CHECK(success) << "Cannot set OpenGL context"; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } bool GlContext::HasContext() const { return context_ != nil; } diff --git a/mediapipe/gpu/gl_context_egl.cc b/mediapipe/gpu/gl_context_egl.cc index 016afd995..26b165f8e 100644 --- a/mediapipe/gpu/gl_context_egl.cc +++ b/mediapipe/gpu/gl_context_egl.cc @@ -85,8 +85,8 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context, return std::move(context); } -::mediapipe::Status GlContext::CreateContextInternal( - EGLContext external_context, int gl_version) { +absl::Status GlContext::CreateContextInternal(EGLContext external_context, + int gl_version) { CHECK(gl_version == 2 || gl_version == 3); const EGLint config_attr[] = { @@ -146,10 +146,10 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context, // GLES 2 does not have them, so let's set the major version here at least. gl_major_version_ = gl_version; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status GlContext::CreateContext(EGLContext external_context) { +absl::Status GlContext::CreateContext(EGLContext external_context) { EGLint major = 0; EGLint minor = 0; @@ -178,7 +178,7 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context, << "eglCreatePbufferSurface() returned error " << std::showbase << std::hex << eglGetError(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void GlContext::DestroyContext() { @@ -212,7 +212,7 @@ void GlContext::DestroyContext() { thread_ ->Run([] { eglReleaseThread(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }) .IgnoreError(); } @@ -271,7 +271,7 @@ void GlContext::GetCurrentContextBinding(GlContext::ContextBinding* binding) { binding->context = eglGetCurrentContext(); } -::mediapipe::Status GlContext::SetCurrentContextBinding( +absl::Status GlContext::SetCurrentContextBinding( const ContextBinding& new_binding) { EnsureEglThreadRelease(); EGLDisplay display = new_binding.display; @@ -286,7 +286,7 @@ void GlContext::GetCurrentContextBinding(GlContext::ContextBinding* binding) { new_binding.read_surface, new_binding.context); RET_CHECK(success) << "eglMakeCurrent() returned error " << std::showbase << std::hex << eglGetError(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } bool GlContext::HasContext() const { return context_ != EGL_NO_CONTEXT; } diff --git a/mediapipe/gpu/gl_context_internal.h b/mediapipe/gpu/gl_context_internal.h index 1cd0a3c15..16b7bf9bf 100644 --- a/mediapipe/gpu/gl_context_internal.h +++ b/mediapipe/gpu/gl_context_internal.h @@ -35,7 +35,7 @@ class GlContext::DedicatedThread { DedicatedThread(const DedicatedThread&) = delete; DedicatedThread& operator=(DedicatedThread) = delete; - ::mediapipe::Status Run(GlStatusFunction gl_func); + absl::Status Run(GlStatusFunction gl_func); void RunWithoutWaiting(GlVoidFunction gl_fund); bool IsCurrentThread(); diff --git a/mediapipe/gpu/gl_context_nsgl.cc b/mediapipe/gpu/gl_context_nsgl.cc index e26d942bf..d9a261e5b 100644 --- a/mediapipe/gpu/gl_context_nsgl.cc +++ b/mediapipe/gpu/gl_context_nsgl.cc @@ -44,7 +44,7 @@ GlContext::StatusOrGlContext GlContext::Create(NSOpenGLContext* share_context, return std::move(context); } -::mediapipe::Status GlContext::CreateContext(NSOpenGLContext* share_context) { +absl::Status GlContext::CreateContext(NSOpenGLContext* share_context) { // TODO: choose a better list? NSOpenGLPixelFormatAttribute attrs[] = { // This is required to get any OpenGL version 3.2 or higher. Note that @@ -96,8 +96,7 @@ GlContext::StatusOrGlContext GlContext::Create(NSOpenGLContext* share_context, [[NSOpenGLPixelFormat alloc] initWithAttributes:attrs_no_accel]; } if (!pixel_format_) - return ::mediapipe::InternalError( - "Could not create an NSOpenGLPixelFormat"); + return absl::InternalError("Could not create an NSOpenGLPixelFormat"); context_ = [[NSOpenGLContext alloc] initWithFormat:pixel_format_ shareContext:share_context]; @@ -123,7 +122,7 @@ GlContext::StatusOrGlContext GlContext::Create(NSOpenGLContext* share_context, RET_CHECK_EQ(err, kCVReturnSuccess) << "Error at CVOpenGLTextureCacheCreate"; texture_cache_.adopt(cache); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void GlContext::DestroyContext() { @@ -146,14 +145,14 @@ void GlContext::GetCurrentContextBinding(GlContext::ContextBinding* binding) { binding->context = [NSOpenGLContext currentContext]; } -::mediapipe::Status GlContext::SetCurrentContextBinding( +absl::Status GlContext::SetCurrentContextBinding( const ContextBinding& new_binding) { if (new_binding.context) { [new_binding.context makeCurrentContext]; } else { [NSOpenGLContext clearCurrentContext]; } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } bool GlContext::HasContext() const { return context_ != nil; } diff --git a/mediapipe/gpu/gl_context_webgl.cc b/mediapipe/gpu/gl_context_webgl.cc index 435f16e59..51df1dab1 100644 --- a/mediapipe/gpu/gl_context_webgl.cc +++ b/mediapipe/gpu/gl_context_webgl.cc @@ -46,7 +46,7 @@ GlContext::StatusOrGlContext GlContext::Create( return std::move(context); } -::mediapipe::Status GlContext::CreateContextInternal( +absl::Status GlContext::CreateContextInternal( EMSCRIPTEN_WEBGL_CONTEXT_HANDLE external_context, int webgl_version) { CHECK(webgl_version == 1 || webgl_version == 2); @@ -58,9 +58,12 @@ GlContext::StatusOrGlContext GlContext::Create( attrs.majorVersion = webgl_version; attrs.minorVersion = 0; - attrs.premultipliedAlpha = 0; - // New one to try out... TODO: see if actually necessary for - // pushing resulting texture through MediaPipe pipeline. + // This flag tells the page compositor that the image written to the canvas + // uses premultiplied alpha, and so can be used directly for compositing. + // Without this, it needs to make an additional full-canvas rendering pass. + attrs.premultipliedAlpha = 1; + + // TODO: Investigate this option in more detail, esp. on Safari. attrs.preserveDrawingBuffer = 0; // Since the Emscripten canvas target finding function is visible from here, @@ -124,10 +127,10 @@ GlContext::StatusOrGlContext GlContext::Create( // GLES 2 does not have them, so let's set the major version here at least. // WebGL 1.0 maps to GLES 2.0 and WebGL 2.0 maps to GLES 3.0, so we add 1. gl_major_version_ = webgl_version + 1; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status GlContext::CreateContext( +absl::Status GlContext::CreateContext( EMSCRIPTEN_WEBGL_CONTEXT_HANDLE external_context) { // TODO: If we're given a non-0 external_context, could try to use // that directly, since we're assuming a single-threaded single-context @@ -144,7 +147,7 @@ GlContext::StatusOrGlContext GlContext::Create( LOG(INFO) << "Successfully created a WebGL context with major version " << gl_major_version_ << " and handle " << context_; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void GlContext::DestroyContext() { @@ -177,13 +180,13 @@ void GlContext::GetCurrentContextBinding(GlContext::ContextBinding* binding) { binding->context = emscripten_webgl_get_current_context(); } -::mediapipe::Status GlContext::SetCurrentContextBinding( +absl::Status GlContext::SetCurrentContextBinding( const ContextBinding& new_binding) { if (new_binding.context == 0) { // Calling emscripten_webgl_make_context_current(0) is resulting in an error // so don't remove context for now, only replace! In the future, can // perhaps create a separate "do-nothing" context for this. - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // TODO: See if setting the same context to current multiple times // comes with a performance cost, and fix if so. @@ -191,7 +194,7 @@ void GlContext::GetCurrentContextBinding(GlContext::ContextBinding* binding) { emscripten_webgl_make_context_current(new_binding.context); RET_CHECK(res == EMSCRIPTEN_RESULT_SUCCESS) << "emscripten_webgl_make_context_current() returned error " << res; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } bool GlContext::HasContext() const { return context_ != 0; } diff --git a/mediapipe/gpu/gl_ios_test.mm b/mediapipe/gpu/gl_ios_test.mm index 9e52052c6..05147566a 100644 --- a/mediapipe/gpu/gl_ios_test.mm +++ b/mediapipe/gpu/gl_ios_test.mm @@ -98,7 +98,7 @@ - (void)testGlConverters { CFHolder originalPixelBuffer; - ::mediapipe::Status status = + absl::Status status = CreateCVPixelBufferFromCGImage([_sourceImage CGImage], &originalPixelBuffer); XCTAssert(status.ok()); @@ -121,7 +121,7 @@ - (void)testGlConvertersNoOpInserted { CFHolder originalPixelBuffer; - ::mediapipe::Status status = + absl::Status status = CreateCVPixelBufferFromCGImage([_sourceImage CGImage], &originalPixelBuffer); XCTAssert(status.ok()); @@ -149,7 +149,7 @@ - (void)testGlConvertersWithOptionalSidePackets { CFHolder originalPixelBuffer; - ::mediapipe::Status status = + absl::Status status = CreateCVPixelBufferFromCGImage([_sourceImage CGImage], &originalPixelBuffer); XCTAssert(status.ok()); @@ -191,7 +191,7 @@ - (void)testSimpleConversionFromFormat:(OSType)cvPixelFormat { CFHolder originalPixelBuffer; - ::mediapipe::Status status = + absl::Status status = CreateCVPixelBufferFromCGImage([_sourceImage CGImage], &originalPixelBuffer); XCTAssert(status.ok()); CVPixelBufferRef convertedPixelBuffer = @@ -225,7 +225,7 @@ NSLog(@"Metal tests skipped on Simulator."); #else CFHolder originalPixelBuffer; - ::mediapipe::Status status = + absl::Status status = CreateCVPixelBufferFromCGImage([_sourceImage CGImage], &originalPixelBuffer); XCTAssert(status.ok()); CVPixelBufferRef redPixelBuffer = [self redPixelBuffer:*originalPixelBuffer]; diff --git a/mediapipe/gpu/gl_quad_renderer.cc b/mediapipe/gpu/gl_quad_renderer.cc index 38a1f35ee..c25a37e48 100644 --- a/mediapipe/gpu/gl_quad_renderer.cc +++ b/mediapipe/gpu/gl_quad_renderer.cc @@ -54,11 +54,11 @@ FrameRotation FrameRotationFromDegrees(int degrees_ccw) { } } -::mediapipe::Status QuadRenderer::GlSetup() { +absl::Status QuadRenderer::GlSetup() { return GlSetup(kBasicTexturedFragmentShader, {"video_frame"}); } -::mediapipe::Status QuadRenderer::GlSetup( +absl::Status QuadRenderer::GlSetup( const GLchar* custom_frag_shader, const std::vector& custom_frame_uniforms) { // Load vertex and fragment shaders @@ -87,7 +87,7 @@ FrameRotation FrameRotationFromDegrees(int degrees_ccw) { glGenVertexArrays(1, &vao_); glGenBuffers(2, vbo_); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void QuadRenderer::GlTeardown() { @@ -106,10 +106,12 @@ void QuadRenderer::GlTeardown() { } } -::mediapipe::Status QuadRenderer::GlRender( - float frame_width, float frame_height, float view_width, float view_height, - FrameScaleMode scale_mode, FrameRotation rotation, bool flip_horizontal, - bool flip_vertical, bool flip_texture) { +absl::Status QuadRenderer::GlRender(float frame_width, float frame_height, + float view_width, float view_height, + FrameScaleMode scale_mode, + FrameRotation rotation, + bool flip_horizontal, bool flip_vertical, + bool flip_texture) { RET_CHECK(program_) << "Must setup the program before rendering."; glUseProgram(program_); @@ -195,15 +197,14 @@ void QuadRenderer::GlTeardown() { glBindBuffer(GL_ARRAY_BUFFER, 0); glBindVertexArray(0); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status FrameRotationFromInt(FrameRotation* rotation, - int degrees_ccw) { +absl::Status FrameRotationFromInt(FrameRotation* rotation, int degrees_ccw) { RET_CHECK(degrees_ccw % 90 == 0) << "rotation must be a multiple of 90; " << degrees_ccw << " was provided"; *rotation = FrameRotationFromDegrees(degrees_ccw % 360); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/gpu/gl_quad_renderer.h b/mediapipe/gpu/gl_quad_renderer.h index e5fd06c70..7e2c44f1c 100644 --- a/mediapipe/gpu/gl_quad_renderer.h +++ b/mediapipe/gpu/gl_quad_renderer.h @@ -49,13 +49,12 @@ class QuadRenderer { QuadRenderer() {} // Creates the rendering program. Must be called within the GL context that // will be used for rendering. - ::mediapipe::Status GlSetup(); + absl::Status GlSetup(); // Creates the rendering program. Must be called within the GL context that // will be used for rendering. // This version allows you to customize the fragment shader. - ::mediapipe::Status GlSetup( - const GLchar* custom_frag_shader, - const std::vector& custom_frame_uniforms); + absl::Status GlSetup(const GLchar* custom_frag_shader, + const std::vector& custom_frame_uniforms); // Renders the texture bound to texture unit 1 onto the current viewport. // Note: mirroring and flipping are handled differently, by design. // - flip_texture is meant to be used when the texture image's rows are stored @@ -70,11 +69,10 @@ class QuadRenderer { // what's needed for the front-camera use case. // - flip_vertical is meant to be used to flip the output image vertically. // This flipping is applied AFTER rotation. - ::mediapipe::Status GlRender(float frame_width, float frame_height, - float view_width, float view_height, - FrameScaleMode scale_mode, - FrameRotation rotation, bool flip_horizontal, - bool flip_vertical, bool flip_texture); + absl::Status GlRender(float frame_width, float frame_height, float view_width, + float view_height, FrameScaleMode scale_mode, + FrameRotation rotation, bool flip_horizontal, + bool flip_vertical, bool flip_texture); // Deletes the rendering program. Must be called withn the GL context where // it was created. void GlTeardown(); @@ -87,8 +85,7 @@ class QuadRenderer { GLuint vbo_[2] = {0, 0}; // for vertex buffer storage }; -::mediapipe::Status FrameRotationFromInt(FrameRotation* rotation, - int degrees_ccw); +absl::Status FrameRotationFromInt(FrameRotation* rotation, int degrees_ccw); // Input degrees must be one of: [0, 90, 180, 270]. FrameRotation FrameRotationFromDegrees(int degrees_ccw); diff --git a/mediapipe/gpu/gl_scaler_calculator.cc b/mediapipe/gpu/gl_scaler_calculator.cc index 8806267be..e50ef2288 100644 --- a/mediapipe/gpu/gl_scaler_calculator.cc +++ b/mediapipe/gpu/gl_scaler_calculator.cc @@ -66,13 +66,13 @@ class GlScalerCalculator : public CalculatorBase { GlScalerCalculator() {} ~GlScalerCalculator(); - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status GlSetup(); - ::mediapipe::Status GlRender(const GlTexture& src, const GlTexture& dst); + absl::Status GlSetup(); + absl::Status GlRender(const GlTexture& src, const GlTexture& dst); void GetOutputDimensions(int src_width, int src_height, int* dst_width, int* dst_height); void GetOutputPadding(int src_width, int src_height, int dst_width, @@ -98,7 +98,7 @@ class GlScalerCalculator : public CalculatorBase { REGISTER_CALCULATOR(GlScalerCalculator); // static -::mediapipe::Status GlScalerCalculator::GetContract(CalculatorContract* cc) { +absl::Status GlScalerCalculator::GetContract(CalculatorContract* cc) { TagOrIndex(&cc->Inputs(), "VIDEO", 0).Set(); TagOrIndex(&cc->Outputs(), "VIDEO", 0).Set(); if (cc->Inputs().HasTag("ROTATION")) { @@ -126,10 +126,10 @@ REGISTER_CALCULATOR(GlScalerCalculator); cc->Outputs().Tag("TOP_BOTTOM_PADDING").Set(); cc->Outputs().Tag("LEFT_RIGHT_PADDING").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status GlScalerCalculator::Open(CalculatorContext* cc) { +absl::Status GlScalerCalculator::Open(CalculatorContext* cc) { // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(mediapipe::TimestampDiff(0)); @@ -181,14 +181,14 @@ REGISTER_CALCULATOR(GlScalerCalculator); MP_RETURN_IF_ERROR(FrameRotationFromInt(&rotation_, rotation_ccw)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status GlScalerCalculator::Process(CalculatorContext* cc) { +absl::Status GlScalerCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag("OUTPUT_DIMENSIONS")) { if (cc->Inputs().Tag("OUTPUT_DIMENSIONS").IsEmpty()) { // OUTPUT_DIMENSIONS input stream is specified, but value is missing. - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& dimensions = @@ -197,7 +197,7 @@ REGISTER_CALCULATOR(GlScalerCalculator); dst_height_ = dimensions[1]; } - return helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { + return helper_.RunInGlContext([this, cc]() -> absl::Status { const auto& input = TagOrIndex(cc->Inputs(), "VIDEO", 0).Get(); QuadRenderer* renderer = nullptr; GlTexture src1; @@ -294,7 +294,7 @@ REGISTER_CALCULATOR(GlScalerCalculator); TagOrIndex(&cc->Outputs(), "VIDEO", 0) .Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }); } diff --git a/mediapipe/gpu/gl_simple_calculator.cc b/mediapipe/gpu/gl_simple_calculator.cc index 3d4219b84..32c360390 100644 --- a/mediapipe/gpu/gl_simple_calculator.cc +++ b/mediapipe/gpu/gl_simple_calculator.cc @@ -17,7 +17,7 @@ namespace mediapipe { // static -::mediapipe::Status GlSimpleCalculator::GetContract(CalculatorContract* cc) { +absl::Status GlSimpleCalculator::GetContract(CalculatorContract* cc) { TagOrIndex(&cc->Inputs(), "VIDEO", 0).Set(); TagOrIndex(&cc->Outputs(), "VIDEO", 0).Set(); // Currently we pass GL context information and other stuff as external @@ -25,7 +25,7 @@ namespace mediapipe { return GlCalculatorHelper::UpdateContract(cc); } -::mediapipe::Status GlSimpleCalculator::Open(CalculatorContext* cc) { +absl::Status GlSimpleCalculator::Open(CalculatorContext* cc) { // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(mediapipe::TimestampDiff(0)); @@ -34,8 +34,8 @@ namespace mediapipe { return helper_.Open(cc); } -::mediapipe::Status GlSimpleCalculator::Process(CalculatorContext* cc) { - return RunInGlContext([this, cc]() -> ::mediapipe::Status { +absl::Status GlSimpleCalculator::Process(CalculatorContext* cc) { + return RunInGlContext([this, cc]() -> absl::Status { const auto& input = TagOrIndex(cc->Inputs(), "VIDEO", 0).Get(); if (!initialized_) { MP_RETURN_IF_ERROR(GlSetup()); @@ -69,13 +69,12 @@ namespace mediapipe { TagOrIndex(&cc->Outputs(), "VIDEO", 0) .Add(output.release(), cc->InputTimestamp()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }); } -::mediapipe::Status GlSimpleCalculator::Close(CalculatorContext* cc) { - return RunInGlContext( - [this]() -> ::mediapipe::Status { return GlTeardown(); }); +absl::Status GlSimpleCalculator::Close(CalculatorContext* cc) { + return RunInGlContext([this]() -> absl::Status { return GlTeardown(); }); } } // namespace mediapipe diff --git a/mediapipe/gpu/gl_simple_calculator.h b/mediapipe/gpu/gl_simple_calculator.h index b0b1ac292..bb56fedb1 100644 --- a/mediapipe/gpu/gl_simple_calculator.h +++ b/mediapipe/gpu/gl_simple_calculator.h @@ -53,30 +53,29 @@ class GlSimpleCalculator : public CalculatorBase { GlSimpleCalculator& operator=(const GlSimpleCalculator&) = delete; ~GlSimpleCalculator() override = default; - static ::mediapipe::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; - ::mediapipe::Status Close(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; // This method is called once on the first frame. Use it to setup any objects // that will be reused throughout the calculator's life. - virtual ::mediapipe::Status GlSetup() = 0; + virtual absl::Status GlSetup() = 0; // You can use this optional method to do any pre-rendering setup that needs // to be redone after the context has been used by another calculator. // If your context is not shared, it will only be called once. - virtual ::mediapipe::Status GlBind() { return ::mediapipe::OkStatus(); } + virtual absl::Status GlBind() { return absl::OkStatus(); } // Do your rendering here. The source and destination textures have already // been created and bound for you. // - src: source texture (contains input frame); already bound to GL_TEXTURE1. // - dst: destination texture (write output frame here); already bound to // GL_TEXTURE0 and attached to the framebuffer. - virtual ::mediapipe::Status GlRender(const GlTexture& src, - const GlTexture& dst) = 0; + virtual absl::Status GlRender(const GlTexture& src, const GlTexture& dst) = 0; // The method is called to delete all the programs. - virtual ::mediapipe::Status GlTeardown() = 0; + virtual absl::Status GlTeardown() = 0; // You can override this method to compute the size of the destination // texture. By default, it will take the same size as the source texture. diff --git a/mediapipe/gpu/gl_surface_sink_calculator.cc b/mediapipe/gpu/gl_surface_sink_calculator.cc index 6ab0b094f..ca5ce0cee 100644 --- a/mediapipe/gpu/gl_surface_sink_calculator.cc +++ b/mediapipe/gpu/gl_surface_sink_calculator.cc @@ -42,10 +42,10 @@ class GlSurfaceSinkCalculator : public CalculatorBase { GlSurfaceSinkCalculator() : initialized_(false) {} ~GlSurfaceSinkCalculator() override; - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: GlCalculatorHelper helper_; @@ -57,8 +57,7 @@ class GlSurfaceSinkCalculator : public CalculatorBase { REGISTER_CALCULATOR(GlSurfaceSinkCalculator); // static -::mediapipe::Status GlSurfaceSinkCalculator::GetContract( - CalculatorContract* cc) { +absl::Status GlSurfaceSinkCalculator::GetContract(CalculatorContract* cc) { TagOrIndex(&(cc->Inputs()), "VIDEO", 0).Set(); cc->InputSidePackets() .Tag("SURFACE") @@ -68,7 +67,7 @@ REGISTER_CALCULATOR(GlSurfaceSinkCalculator); return GlCalculatorHelper::UpdateContract(cc); } -::mediapipe::Status GlSurfaceSinkCalculator::Open(CalculatorContext* cc) { +absl::Status GlSurfaceSinkCalculator::Open(CalculatorContext* cc) { surface_holder_ = cc->InputSidePackets() .Tag("SURFACE") .Get>() @@ -82,13 +81,13 @@ REGISTER_CALCULATOR(GlSurfaceSinkCalculator); return helper_.Open(cc); } -::mediapipe::Status GlSurfaceSinkCalculator::Process(CalculatorContext* cc) { - return helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { +absl::Status GlSurfaceSinkCalculator::Process(CalculatorContext* cc) { + return helper_.RunInGlContext([this, &cc]() -> absl::Status { absl::MutexLock lock(&surface_holder_->mutex); EGLSurface surface = surface_holder_->surface; if (surface == EGL_NO_SURFACE) { LOG_EVERY_N(INFO, 300) << "GlSurfaceSinkCalculator: no surface"; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& input = TagOrIndex(cc->Inputs(), "VIDEO", 0).Get(); @@ -139,7 +138,7 @@ REGISTER_CALCULATOR(GlSurfaceSinkCalculator); RET_CHECK(success) << "failed to restore old surface"; src.Release(); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }); } diff --git a/mediapipe/gpu/gl_texture_buffer.cc b/mediapipe/gpu/gl_texture_buffer.cc index d9ec94f0e..079c821f0 100644 --- a/mediapipe/gpu/gl_texture_buffer.cc +++ b/mediapipe/gpu/gl_texture_buffer.cc @@ -63,7 +63,8 @@ bool GlTextureBuffer::CreateInternal(const void* data) { if (!name_) return false; glBindTexture(target_, name_); - GlTextureInfo info = GlTextureInfoForGpuBufferFormat(format_, 0); + GlTextureInfo info = + GlTextureInfoForGpuBufferFormat(format_, 0, context->GetGlVersion()); // See b/70294573 for details about this. if (info.gl_internal_format == GL_RGBA16F && diff --git a/mediapipe/gpu/gpu_buffer_format.cc b/mediapipe/gpu/gpu_buffer_format.cc index 9d87274b6..8c219258b 100644 --- a/mediapipe/gpu/gpu_buffer_format.cc +++ b/mediapipe/gpu/gpu_buffer_format.cc @@ -37,6 +37,7 @@ static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) { info->gl_internal_format = info->gl_format = GL_LUMINANCE; return; case GL_RG16F: + case GL_RG32F: // Should this be GL_RG_EXT instead? info->gl_internal_format = info->gl_format = GL_LUMINANCE_ALPHA; return; @@ -52,16 +53,6 @@ static void AdaptGlTextureInfoForGLES2(GlTextureInfo* info) { } #endif // GL_ES_VERSION_2_0 -const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, - int plane) { -#if defined(__APPLE__) && TARGET_OS_OSX - constexpr GlVersion default_version = GlVersion::kGL; -#else - constexpr GlVersion default_version = GlVersion::kGLES3; -#endif // defined(__APPLE__) && TARGET_OS_OSX - return GlTextureInfoForGpuBufferFormat(format, plane, default_version); -} - const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, int plane, GlVersion gl_version) { diff --git a/mediapipe/gpu/gpu_buffer_format.h b/mediapipe/gpu/gpu_buffer_format.h index fd008ba02..a92c5712c 100644 --- a/mediapipe/gpu/gpu_buffer_format.h +++ b/mediapipe/gpu/gpu_buffer_format.h @@ -63,8 +63,6 @@ struct GlTextureInfo { int downscale; }; -const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, - int plane); const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format, int plane, GlVersion gl_version); diff --git a/mediapipe/gpu/gpu_buffer_multi_pool.h b/mediapipe/gpu/gpu_buffer_multi_pool.h index 52f1deb1b..0a34a3017 100644 --- a/mediapipe/gpu/gpu_buffer_multi_pool.h +++ b/mediapipe/gpu/gpu_buffer_multi_pool.h @@ -82,7 +82,7 @@ class GpuBufferMultiPool { std::size_t operator()(const BufferSpec& spec) const { // Width and height are expected to be smaller than half the width of // size_t. We can combine them into a single integer, and then use - // std::hash, which is what go/hashing recommends for hashing numbers. + // std::hash. constexpr int kWidth = std::numeric_limits::digits; return std::hash{}( spec.width ^ RotateLeft(spec.height, kWidth / 2) ^ diff --git a/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc b/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc index aa40ae19a..8bca0a27d 100644 --- a/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc +++ b/mediapipe/gpu/gpu_buffer_to_image_frame_calculator.cc @@ -31,10 +31,10 @@ class GpuBufferToImageFrameCalculator : public CalculatorBase { public: GpuBufferToImageFrameCalculator() {} - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER @@ -44,7 +44,7 @@ class GpuBufferToImageFrameCalculator : public CalculatorBase { REGISTER_CALCULATOR(GpuBufferToImageFrameCalculator); // static -::mediapipe::Status GpuBufferToImageFrameCalculator::GetContract( +absl::Status GpuBufferToImageFrameCalculator::GetContract( CalculatorContract* cc) { cc->Inputs().Index(0).SetAny(); cc->Outputs().Index(0).Set(); @@ -52,25 +52,23 @@ REGISTER_CALCULATOR(GpuBufferToImageFrameCalculator); // to ensure the calculator's contract is the same. In particular, the helper // enables support for the legacy side packet, which several graphs still use. MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status GpuBufferToImageFrameCalculator::Open( - CalculatorContext* cc) { +absl::Status GpuBufferToImageFrameCalculator::Open(CalculatorContext* cc) { // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(TimestampDiff(0)); #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER MP_RETURN_IF_ERROR(helper_.Open(cc)); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status GpuBufferToImageFrameCalculator::Process( - CalculatorContext* cc) { +absl::Status GpuBufferToImageFrameCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().Index(0).Value().ValidateAsType().ok()) { cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value()); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } #ifdef HAVE_GPU_BUFFER @@ -87,7 +85,8 @@ REGISTER_CALCULATOR(GpuBufferToImageFrameCalculator); ImageFormatForGpuBufferFormat(input.format()), src.width(), src.height(), ImageFrame::kGlDefaultAlignmentBoundary); helper_.BindFramebuffer(src); - const auto info = GlTextureInfoForGpuBufferFormat(input.format(), 0); + const auto info = GlTextureInfoForGpuBufferFormat(input.format(), 0, + helper_.GetGlVersion()); glReadPixels(0, 0, src.width(), src.height(), info.gl_format, info.gl_type, frame->MutablePixelData()); glFlush(); @@ -95,12 +94,12 @@ REGISTER_CALCULATOR(GpuBufferToImageFrameCalculator); src.Release(); }); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } #endif // defined(HAVE_GPU_BUFFER) - return ::mediapipe::Status(::mediapipe::StatusCode::kInvalidArgument, - "Input packets must be ImageFrame or GpuBuffer."); + return absl::Status(absl::StatusCode::kInvalidArgument, + "Input packets must be ImageFrame or GpuBuffer."); } } // namespace mediapipe diff --git a/mediapipe/gpu/gpu_shared_data_internal.cc b/mediapipe/gpu/gpu_shared_data_internal.cc index 499894715..bd50bf5bf 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.cc +++ b/mediapipe/gpu/gpu_shared_data_internal.cc @@ -104,7 +104,7 @@ GpuResources::~GpuResources() { #endif } -mediapipe::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { +absl::Status GpuResources::PrepareGpuNode(CalculatorNode* node) { CHECK(node->UsesGpu()); std::string node_id = node->GetCalculatorState().NodeName(); std::string node_type = node->GetCalculatorState().CalculatorType(); diff --git a/mediapipe/gpu/gpu_shared_data_internal.h b/mediapipe/gpu/gpu_shared_data_internal.h index cbe77c709..62d6bb27e 100644 --- a/mediapipe/gpu/gpu_shared_data_internal.h +++ b/mediapipe/gpu/gpu_shared_data_internal.h @@ -42,8 +42,7 @@ namespace mediapipe { // TODO: rename to GpuService or GpuManager or something. class GpuResources { public: - using StatusOrGpuResources = - ::mediapipe::StatusOr>; + using StatusOrGpuResources = absl::StatusOr>; static StatusOrGpuResources Create(); static StatusOrGpuResources Create(PlatformGlContext external_context); @@ -69,7 +68,7 @@ class GpuResources { MPPGraphGPUData* ios_gpu_data(); #endif // defined(__APPLE__)§ - mediapipe::Status PrepareGpuNode(CalculatorNode* node); + absl::Status PrepareGpuNode(CalculatorNode* node); // If the node requires custom GPU executors in the current configuration, // returns the executor's names and the executors themselves. @@ -124,7 +123,7 @@ struct GpuSharedData { auto status_or_resources = GpuResources::Create(external_context); MEDIAPIPE_CHECK_OK(status_or_resources.status()) << ": could not create GpuResources"; - return std::move(status_or_resources).ValueOrDie(); + return std::move(status_or_resources).value(); } }; diff --git a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc index 8abb43d71..2a8331db8 100644 --- a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +++ b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc @@ -28,10 +28,10 @@ class ImageFrameToGpuBufferCalculator : public CalculatorBase { public: ImageFrameToGpuBufferCalculator() {} - static ::mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - ::mediapipe::Status Open(CalculatorContext* cc) override; - ::mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER @@ -41,7 +41,7 @@ class ImageFrameToGpuBufferCalculator : public CalculatorBase { REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator); // static -::mediapipe::Status ImageFrameToGpuBufferCalculator::GetContract( +absl::Status ImageFrameToGpuBufferCalculator::GetContract( CalculatorContract* cc) { cc->Inputs().Index(0).Set(); cc->Outputs().Index(0).Set(); @@ -49,22 +49,20 @@ REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator); // to ensure the calculator's contract is the same. In particular, the helper // enables support for the legacy side packet, which several graphs still use. MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageFrameToGpuBufferCalculator::Open( - CalculatorContext* cc) { +absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) { // Inform the framework that we always output at the same timestamp // as we receive a packet at. cc->SetOffset(TimestampDiff(0)); #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER MP_RETURN_IF_ERROR(helper_.Open(cc)); #endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ImageFrameToGpuBufferCalculator::Process( - CalculatorContext* cc) { +absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) { #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER CFHolder buffer; MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket( @@ -80,7 +78,7 @@ REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator); src.Release(); }); #endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/graphs/face_effect/BUILD b/mediapipe/graphs/face_effect/BUILD index 401ac043c..69d648e80 100644 --- a/mediapipe/graphs/face_effect/BUILD +++ b/mediapipe/graphs/face_effect/BUILD @@ -28,8 +28,9 @@ cc_library( "//mediapipe/calculators/core:gate_calculator", "//mediapipe/calculators/core:immediate_mux_calculator", "//mediapipe/calculators/image:image_properties_calculator", - "//mediapipe/graphs/face_effect/subgraphs:single_face_smooth_landmark_gpu", - "//mediapipe/modules/face_geometry", + "//mediapipe/framework/tool:switch_container", + "//mediapipe/graphs/face_effect/subgraphs:single_face_geometry_from_detection_gpu", + "//mediapipe/graphs/face_effect/subgraphs:single_face_geometry_from_landmarks_gpu", "//mediapipe/modules/face_geometry:effect_renderer_calculator", "//mediapipe/modules/face_geometry:env_generator_calculator", ], diff --git a/mediapipe/graphs/face_effect/data/BUILD b/mediapipe/graphs/face_effect/data/BUILD index 21368da60..999369970 100644 --- a/mediapipe/graphs/face_effect/data/BUILD +++ b/mediapipe/graphs/face_effect/data/BUILD @@ -18,6 +18,16 @@ licenses(["notice"]) package(default_visibility = ["//visibility:public"]) +encode_binary_proto( + name = "axis", + input = "axis.pbtxt", + message_type = "mediapipe.face_geometry.Mesh3d", + output = "axis.binarypb", + deps = [ + "//mediapipe/modules/face_geometry/protos:mesh_3d_proto", + ], +) + encode_binary_proto( name = "glasses", input = "glasses.pbtxt", @@ -31,6 +41,7 @@ encode_binary_proto( # `.pngblob` is used instead of `.png` to prevent iOS build from preprocessing the image. # OpenCV is unable to read a PNG file preprocessed by the iOS build. exports_files([ + "axis.pngblob", "facepaint.pngblob", "glasses.pngblob", ]) diff --git a/mediapipe/graphs/face_effect/data/axis.pbtxt b/mediapipe/graphs/face_effect/data/axis.pbtxt new file mode 100644 index 000000000..6a3fd52d6 --- /dev/null +++ b/mediapipe/graphs/face_effect/data/axis.pbtxt @@ -0,0 +1,320 @@ +vertex_type: VERTEX_PT +primitive_type: TRIANGLE +vertex_buffer: -0.100000 +vertex_buffer: -0.100000 +vertex_buffer: 11.500000 +vertex_buffer: 0.873006 +vertex_buffer: 1.000000 +vertex_buffer: 0.100000 +vertex_buffer: -0.100000 +vertex_buffer: 11.500000 +vertex_buffer: 0.928502 +vertex_buffer: 1.000000 +vertex_buffer: 0.100000 +vertex_buffer: 0.100000 +vertex_buffer: 11.500000 +vertex_buffer: 0.928502 +vertex_buffer: 0.750000 +vertex_buffer: -0.100000 +vertex_buffer: 0.100000 +vertex_buffer: 11.500000 +vertex_buffer: 0.873006 +vertex_buffer: 0.750000 +vertex_buffer: 0.100000 +vertex_buffer: 0.100000 +vertex_buffer: 8.500000 +vertex_buffer: 0.928502 +vertex_buffer: 0.500000 +vertex_buffer: -0.100000 +vertex_buffer: 0.100000 +vertex_buffer: 8.500000 +vertex_buffer: 0.873006 +vertex_buffer: 0.500000 +vertex_buffer: 0.100000 +vertex_buffer: -0.100000 +vertex_buffer: 8.500000 +vertex_buffer: 0.928502 +vertex_buffer: 0.250000 +vertex_buffer: -0.100000 +vertex_buffer: -0.100000 +vertex_buffer: 8.500000 +vertex_buffer: 0.873006 +vertex_buffer: 0.250000 +vertex_buffer: 0.100000 +vertex_buffer: -0.100000 +vertex_buffer: 11.500000 +vertex_buffer: 0.928502 +vertex_buffer: 0.000000 +vertex_buffer: -0.100000 +vertex_buffer: -0.100000 +vertex_buffer: 11.500000 +vertex_buffer: 0.873006 +vertex_buffer: 0.000000 +vertex_buffer: 0.100000 +vertex_buffer: -0.100000 +vertex_buffer: 8.500000 +vertex_buffer: 0.983999 +vertex_buffer: 1.000000 +vertex_buffer: 0.100000 +vertex_buffer: 0.100000 +vertex_buffer: 8.500000 +vertex_buffer: 0.983999 +vertex_buffer: 0.750000 +vertex_buffer: -0.100000 +vertex_buffer: -0.100000 +vertex_buffer: 8.500000 +vertex_buffer: 0.817509 +vertex_buffer: 1.000000 +vertex_buffer: -0.100000 +vertex_buffer: 0.100000 +vertex_buffer: 8.500000 +vertex_buffer: 0.817509 +vertex_buffer: 0.750000 +vertex_buffer: 3.000000 +vertex_buffer: -0.100000 +vertex_buffer: 8.600000 +vertex_buffer: 0.069341 +vertex_buffer: 1.000000 +vertex_buffer: 3.000000 +vertex_buffer: -0.100000 +vertex_buffer: 8.400000 +vertex_buffer: 0.123429 +vertex_buffer: 1.000000 +vertex_buffer: 3.000000 +vertex_buffer: 0.100000 +vertex_buffer: 8.400000 +vertex_buffer: 0.123429 +vertex_buffer: 0.750000 +vertex_buffer: 3.000000 +vertex_buffer: 0.100000 +vertex_buffer: 8.600000 +vertex_buffer: 0.069341 +vertex_buffer: 0.750000 +vertex_buffer: 0.000000 +vertex_buffer: 0.100000 +vertex_buffer: 8.400000 +vertex_buffer: 0.123419 +vertex_buffer: 0.499992 +vertex_buffer: 0.000000 +vertex_buffer: 0.100000 +vertex_buffer: 8.600000 +vertex_buffer: 0.069341 +vertex_buffer: 0.500000 +vertex_buffer: 0.000000 +vertex_buffer: -0.100000 +vertex_buffer: 8.400000 +vertex_buffer: 0.123429 +vertex_buffer: 0.250000 +vertex_buffer: 0.000000 +vertex_buffer: -0.100000 +vertex_buffer: 8.600000 +vertex_buffer: 0.069341 +vertex_buffer: 0.250000 +vertex_buffer: 3.000000 +vertex_buffer: -0.100000 +vertex_buffer: 8.400000 +vertex_buffer: 0.123429 +vertex_buffer: 0.000000 +vertex_buffer: 3.000000 +vertex_buffer: -0.100000 +vertex_buffer: 8.600000 +vertex_buffer: 0.069341 +vertex_buffer: 0.000000 +vertex_buffer: 0.000000 +vertex_buffer: -0.100000 +vertex_buffer: 8.400000 +vertex_buffer: 0.177516 +vertex_buffer: 1.000000 +vertex_buffer: 0.000000 +vertex_buffer: 0.100000 +vertex_buffer: 8.400000 +vertex_buffer: 0.177516 +vertex_buffer: 0.750000 +vertex_buffer: 0.000000 +vertex_buffer: -0.100000 +vertex_buffer: 8.600000 +vertex_buffer: 0.015254 +vertex_buffer: 1.000000 +vertex_buffer: 0.000000 +vertex_buffer: 0.100000 +vertex_buffer: 8.600000 +vertex_buffer: 0.015254 +vertex_buffer: 0.750000 +vertex_buffer: -0.100000 +vertex_buffer: 0.000000 +vertex_buffer: 8.400000 +vertex_buffer: 0.472252 +vertex_buffer: 1.000000 +vertex_buffer: 0.100000 +vertex_buffer: 0.000000 +vertex_buffer: 8.400000 +vertex_buffer: 0.527748 +vertex_buffer: 1.000000 +vertex_buffer: 0.100000 +vertex_buffer: 0.000000 +vertex_buffer: 8.600000 +vertex_buffer: 0.527748 +vertex_buffer: 0.750000 +vertex_buffer: -0.100000 +vertex_buffer: 0.000000 +vertex_buffer: 8.600000 +vertex_buffer: 0.472252 +vertex_buffer: 0.750000 +vertex_buffer: 0.100000 +vertex_buffer: 3.000000 +vertex_buffer: 8.600000 +vertex_buffer: 0.527748 +vertex_buffer: 0.500000 +vertex_buffer: -0.100000 +vertex_buffer: 3.000000 +vertex_buffer: 8.600000 +vertex_buffer: 0.472252 +vertex_buffer: 0.500000 +vertex_buffer: 0.100000 +vertex_buffer: 3.000000 +vertex_buffer: 8.400000 +vertex_buffer: 0.527748 +vertex_buffer: 0.250000 +vertex_buffer: -0.100000 +vertex_buffer: 3.000000 +vertex_buffer: 8.400000 +vertex_buffer: 0.472252 +vertex_buffer: 0.250000 +vertex_buffer: 0.100000 +vertex_buffer: 0.000000 +vertex_buffer: 8.400000 +vertex_buffer: 0.527748 +vertex_buffer: 0.000000 +vertex_buffer: -0.100000 +vertex_buffer: 0.000000 +vertex_buffer: 8.400000 +vertex_buffer: 0.472252 +vertex_buffer: 0.000000 +vertex_buffer: 0.100000 +vertex_buffer: 3.000000 +vertex_buffer: 8.400000 +vertex_buffer: 0.583245 +vertex_buffer: 1.000000 +vertex_buffer: 0.100000 +vertex_buffer: 3.000000 +vertex_buffer: 8.600000 +vertex_buffer: 0.583245 +vertex_buffer: 0.750000 +vertex_buffer: -0.100000 +vertex_buffer: 3.000000 +vertex_buffer: 8.400000 +vertex_buffer: 0.416755 +vertex_buffer: 1.000000 +vertex_buffer: -0.100000 +vertex_buffer: 3.000000 +vertex_buffer: 8.600000 +vertex_buffer: 0.416755 +vertex_buffer: 0.750000 +index_buffer: 0 +index_buffer: 1 +index_buffer: 2 +index_buffer: 0 +index_buffer: 2 +index_buffer: 3 +index_buffer: 3 +index_buffer: 2 +index_buffer: 4 +index_buffer: 3 +index_buffer: 4 +index_buffer: 5 +index_buffer: 5 +index_buffer: 4 +index_buffer: 6 +index_buffer: 5 +index_buffer: 6 +index_buffer: 7 +index_buffer: 7 +index_buffer: 6 +index_buffer: 8 +index_buffer: 7 +index_buffer: 8 +index_buffer: 9 +index_buffer: 1 +index_buffer: 10 +index_buffer: 11 +index_buffer: 1 +index_buffer: 11 +index_buffer: 2 +index_buffer: 12 +index_buffer: 0 +index_buffer: 3 +index_buffer: 12 +index_buffer: 3 +index_buffer: 13 +index_buffer: 14 +index_buffer: 15 +index_buffer: 16 +index_buffer: 14 +index_buffer: 16 +index_buffer: 17 +index_buffer: 17 +index_buffer: 16 +index_buffer: 18 +index_buffer: 17 +index_buffer: 18 +index_buffer: 19 +index_buffer: 19 +index_buffer: 18 +index_buffer: 20 +index_buffer: 19 +index_buffer: 20 +index_buffer: 21 +index_buffer: 21 +index_buffer: 20 +index_buffer: 22 +index_buffer: 21 +index_buffer: 22 +index_buffer: 23 +index_buffer: 15 +index_buffer: 24 +index_buffer: 25 +index_buffer: 15 +index_buffer: 25 +index_buffer: 16 +index_buffer: 26 +index_buffer: 14 +index_buffer: 17 +index_buffer: 26 +index_buffer: 17 +index_buffer: 27 +index_buffer: 28 +index_buffer: 29 +index_buffer: 30 +index_buffer: 28 +index_buffer: 30 +index_buffer: 31 +index_buffer: 31 +index_buffer: 30 +index_buffer: 32 +index_buffer: 31 +index_buffer: 32 +index_buffer: 33 +index_buffer: 33 +index_buffer: 32 +index_buffer: 34 +index_buffer: 33 +index_buffer: 34 +index_buffer: 35 +index_buffer: 35 +index_buffer: 34 +index_buffer: 36 +index_buffer: 35 +index_buffer: 36 +index_buffer: 37 +index_buffer: 29 +index_buffer: 38 +index_buffer: 39 +index_buffer: 29 +index_buffer: 39 +index_buffer: 30 +index_buffer: 40 +index_buffer: 28 +index_buffer: 31 +index_buffer: 40 +index_buffer: 31 +index_buffer: 41 diff --git a/mediapipe/graphs/face_effect/data/axis.pngblob b/mediapipe/graphs/face_effect/data/axis.pngblob new file mode 100644 index 000000000..3c36c7895 Binary files /dev/null and b/mediapipe/graphs/face_effect/data/axis.pngblob differ diff --git a/mediapipe/graphs/face_effect/face_effect_gpu.pbtxt b/mediapipe/graphs/face_effect/face_effect_gpu.pbtxt index a7af2d571..40888d0f4 100644 --- a/mediapipe/graphs/face_effect/face_effect_gpu.pbtxt +++ b/mediapipe/graphs/face_effect/face_effect_gpu.pbtxt @@ -3,11 +3,20 @@ # GPU buffer. (GpuBuffer) input_stream: "input_video" -# Boolean flag, which indicates whether the Facepaint effect is selected. (bool) +# An integer, which indicate which effect is selected. (int) # -# If `true`, the Facepaint effect will be rendered. -# If `false`, the Glasses effect will be rendered. -input_stream: "is_facepaint_effect_selected" +# If `selected_effect_id` is `0`, the Axis effect is selected. +# If `selected_effect_id` is `1`, the Facepaint effect is selected. +# If `selected_effect_id` is `2`, the Glasses effect is selected. +# +# No other values are allowed for `selected_effect_id`. +input_stream: "selected_effect_id" + +# Indicates whether to use the face detection as the input source. (bool) +# +# If `true`, the face detection pipeline will be used to produce landmarks. +# If `false`, the face landmark pipeline will be used to produce landmarks. +input_side_packet: "use_face_detection_input_source" # Output image with rendered results. (GpuBuffer) output_stream: "output_video" @@ -59,87 +68,63 @@ node { } } -# Subgraph that detects a single face and corresponding landmarks. The landmarks -# are also "smoothed" to achieve better visual results. +# Computes the face geometry for a single face. The input source is defined +# through `use_face_detection_input_source`. node { - calculator: "SingleFaceSmoothLandmarkGpu" + calculator: "SwitchContainer" input_stream: "IMAGE:throttled_input_video" - output_stream: "LANDMARKS:multi_face_landmarks" -} - -# Extracts the throttled input video frame dimensions as a separate packet. -node { - calculator: "ImagePropertiesCalculator" - input_stream: "IMAGE_GPU:throttled_input_video" - output_stream: "SIZE:input_video_size" -} - -# Subgraph that computes face geometry from landmarks for a single face. -node { - calculator: "FaceGeometry" - input_stream: "MULTI_FACE_LANDMARKS:multi_face_landmarks" - input_stream: "IMAGE_SIZE:input_video_size" + input_side_packet: "ENABLE:use_face_detection_input_source" input_side_packet: "ENVIRONMENT:environment" output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" -} - -# Decides whether to render the Facepaint effect based on the -# `is_facepaint_effect_selected` flag value. -node { - calculator: "GateCalculator" - input_stream: "throttled_input_video" - input_stream: "multi_face_geometry" - input_stream: "ALLOW:is_facepaint_effect_selected" - output_stream: "facepaint_effect_throttled_input_video" - output_stream: "facepaint_effect_multi_face_geometry" -} - -# Renders the Facepaint effect. -node { - calculator: "FaceGeometryEffectRendererCalculator" - input_side_packet: "ENVIRONMENT:environment" - input_stream: "IMAGE_GPU:facepaint_effect_throttled_input_video" - input_stream: "MULTI_FACE_GEOMETRY:facepaint_effect_multi_face_geometry" - output_stream: "IMAGE_GPU:facepaint_effect_output_video" node_options: { - [type.googleapis.com/mediapipe.FaceGeometryEffectRendererCalculatorOptions] { - effect_texture_path: "mediapipe/graphs/face_effect/data/facepaint.pngblob" + [type.googleapis.com/mediapipe.SwitchContainerOptions] { + contained_node: { + calculator: "SingleFaceGeometryFromLandmarksGpu" + } + contained_node: { + calculator: "SingleFaceGeometryFromDetectionGpu" + } } } } -# Decides whether to render the Glasses effect based on the -# `is_facepaint_effect_selected` flag value. +# Renders the selected effect based on `selected_effect_id`. node { - calculator: "GateCalculator" - input_stream: "throttled_input_video" - input_stream: "multi_face_geometry" - input_stream: "DISALLOW:is_facepaint_effect_selected" - output_stream: "glasses_effect_throttled_input_video" - output_stream: "glasses_effect_multi_face_geometry" -} - -# Renders the Glasses effect. -node { - calculator: "FaceGeometryEffectRendererCalculator" + calculator: "SwitchContainer" + input_stream: "SELECT:selected_effect_id" + input_stream: "IMAGE_GPU:throttled_input_video" + input_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" input_side_packet: "ENVIRONMENT:environment" - input_stream: "IMAGE_GPU:glasses_effect_throttled_input_video" - input_stream: "MULTI_FACE_GEOMETRY:glasses_effect_multi_face_geometry" - output_stream: "IMAGE_GPU:glasses_effect_output_video" + output_stream: "IMAGE_GPU:output_video" node_options: { - [type.googleapis.com/mediapipe.FaceGeometryEffectRendererCalculatorOptions] { - effect_texture_path: "mediapipe/graphs/face_effect/data/glasses.pngblob" - effect_mesh_3d_path: "mediapipe/graphs/face_effect/data/glasses.binarypb" + [type.googleapis.com/mediapipe.SwitchContainerOptions] { + contained_node: { + calculator: "FaceGeometryEffectRendererCalculator" + node_options: { + [type.googleapis.com/mediapipe.FaceGeometryEffectRendererCalculatorOptions] { + effect_texture_path: "mediapipe/graphs/face_effect/data/axis.pngblob" + effect_mesh_3d_path: "mediapipe/graphs/face_effect/data/axis.binarypb" + } + } + } + contained_node: { + calculator: "FaceGeometryEffectRendererCalculator" + node_options: { + [type.googleapis.com/mediapipe.FaceGeometryEffectRendererCalculatorOptions] { + effect_texture_path: "mediapipe/graphs/face_effect/data/facepaint.pngblob" + } + } + } + contained_node: { + calculator: "FaceGeometryEffectRendererCalculator" + node_options: { + [type.googleapis.com/mediapipe.FaceGeometryEffectRendererCalculatorOptions] { + effect_texture_path: "mediapipe/graphs/face_effect/data/glasses.pngblob" + effect_mesh_3d_path: "mediapipe/graphs/face_effect/data/glasses.binarypb" + } + } + } } } } -# Decides which of the Facepaint or the Glasses rendered results should be sent -# as the output GPU frame. -node { - calculator: "ImmediateMuxCalculator" - input_stream: "facepaint_effect_output_video" - input_stream: "glasses_effect_output_video" - output_stream: "output_video" -} - diff --git a/mediapipe/graphs/face_effect/subgraphs/BUILD b/mediapipe/graphs/face_effect/subgraphs/BUILD index f7df8fb87..c38008e8f 100644 --- a/mediapipe/graphs/face_effect/subgraphs/BUILD +++ b/mediapipe/graphs/face_effect/subgraphs/BUILD @@ -22,15 +22,40 @@ licenses(["notice"]) package(default_visibility = ["//visibility:public"]) mediapipe_simple_subgraph( - name = "single_face_smooth_landmark_gpu", - graph = "single_face_smooth_landmark_gpu.pbtxt", - register_as = "SingleFaceSmoothLandmarkGpu", + name = "face_landmarks_smoothing", + graph = "face_landmarks_smoothing.pbtxt", + register_as = "FaceLandmarksSmoothing", deps = [ + "//mediapipe/calculators/util:landmarks_smoothing_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "single_face_geometry_from_detection_gpu", + graph = "single_face_geometry_from_detection_gpu.pbtxt", + register_as = "SingleFaceGeometryFromDetectionGpu", + deps = [ + ":face_landmarks_smoothing", + "//mediapipe/calculators/core:concatenate_detection_vector_calculator", + "//mediapipe/calculators/core:split_vector_calculator", + "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/modules/face_detection:face_detection_front_gpu", + "//mediapipe/modules/face_geometry:face_geometry_from_detection", + ], +) + +mediapipe_simple_subgraph( + name = "single_face_geometry_from_landmarks_gpu", + graph = "single_face_geometry_from_landmarks_gpu.pbtxt", + register_as = "SingleFaceGeometryFromLandmarksGpu", + deps = [ + ":face_landmarks_smoothing", "//mediapipe/calculators/core:concatenate_vector_calculator", "//mediapipe/calculators/core:constant_side_packet_calculator", "//mediapipe/calculators/core:split_vector_calculator", "//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/util:landmarks_smoothing_calculator", + "//mediapipe/modules/face_geometry:face_geometry_from_landmarks", "//mediapipe/modules/face_landmark:face_landmark_front_gpu", ], ) diff --git a/mediapipe/graphs/face_effect/subgraphs/face_landmarks_smoothing.pbtxt b/mediapipe/graphs/face_effect/subgraphs/face_landmarks_smoothing.pbtxt new file mode 100644 index 000000000..3f565f5c0 --- /dev/null +++ b/mediapipe/graphs/face_effect/subgraphs/face_landmarks_smoothing.pbtxt @@ -0,0 +1,24 @@ +# MediaPipe subgraph that smoothes face landmarks. + +type: "FaceLandmarksSmoothing" + +input_stream: "NORM_LANDMARKS:landmarks" +input_stream: "IMAGE_SIZE:input_image_size" +output_stream: "NORM_FILTERED_LANDMARKS:filtered_landmarks" + +# Applies smoothing to a face landmark list. The filter options were handpicked +# to achieve better visual results. +node { + calculator: "LandmarksSmoothingCalculator" + input_stream: "NORM_LANDMARKS:landmarks" + input_stream: "IMAGE_SIZE:input_image_size" + output_stream: "NORM_FILTERED_LANDMARKS:filtered_landmarks" + node_options: { + [type.googleapis.com/mediapipe.LandmarksSmoothingCalculatorOptions] { + velocity_filter: { + window_size: 5 + velocity_scale: 20.0 + } + } + } +} diff --git a/mediapipe/graphs/face_effect/subgraphs/single_face_geometry_from_detection_gpu.pbtxt b/mediapipe/graphs/face_effect/subgraphs/single_face_geometry_from_detection_gpu.pbtxt new file mode 100644 index 000000000..582107584 --- /dev/null +++ b/mediapipe/graphs/face_effect/subgraphs/single_face_geometry_from_detection_gpu.pbtxt @@ -0,0 +1,91 @@ +# MediaPipe subgraph that extracts geometry from a single face using the face +# landmark pipeline on an input GPU image. The face landmarks are also +# "smoothed" to achieve better visual results. + +type: "SingleFaceGeometryFromDetectionGpu" + +# GPU image. (GpuBuffer) +input_stream: "IMAGE:input_image" + +# Environment that describes the current virtual scene. +# (face_geometry::Environment) +input_side_packet: "ENVIRONMENT:environment" + +# A list of geometry data for a single detected face. The size of this +# collection is at most 1 because of the single-face use in this graph. +# (std::vector) +# +# NOTE: if no face is detected at a particular timestamp, there will not be an +# output packet in the `MULTI_FACE_GEOMETRY` stream for this timestamp. However, +# the MediaPipe framework will internally inform the downstream calculators of +# the absence of this packet so that they don't wait for it unnecessarily. +output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" + +# Subgraph that detects faces and corresponding landmarks using the face +# detection pipeline. +node { + calculator: "FaceDetectionFrontGpu" + input_stream: "IMAGE:input_image" + output_stream: "DETECTIONS:multi_face_detection" +} + +# Extracts the first face detection associated with the most prominent face from +# a collection. +node { + calculator: "SplitDetectionVectorCalculator" + input_stream: "multi_face_detection" + output_stream: "face_detection" + node_options: { + [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { + ranges: { begin: 0 end: 1 } + element_only: true + } + } +} + +# Extracts face detection keypoints as a normalized landmarks. +node { + calculator: "DetectionToLandmarksCalculator" + input_stream: "DETECTION:face_detection" + output_stream: "LANDMARKS:face_landmarks" +} + +# Extracts the input image frame dimensions as a separate packet. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_GPU:input_image" + output_stream: "SIZE:input_image_size" +} + +# Applies smoothing to the face landmarks previously extracted from the face +# detection keypoints. +node { + calculator: "FaceLandmarksSmoothing" + input_stream: "NORM_LANDMARKS:face_landmarks" + input_stream: "IMAGE_SIZE:input_image_size" + output_stream: "NORM_FILTERED_LANDMARKS:smoothed_face_landmarks" +} + +# Converts smoothed face landmarks back into the detection format. +node { + calculator: "LandmarksToDetectionCalculator" + input_stream: "NORM_LANDMARKS:smoothed_face_landmarks" + output_stream: "DETECTION:smoothed_face_detection" +} + +# Puts the smoothed single face detection back into a collection to simplify +# passing the result into the `FaceGeometryFromDetection` subgraph. +node { + calculator: "ConcatenateDetectionVectorCalculator" + input_stream: "smoothed_face_detection" + output_stream: "multi_smoothed_face_detection" +} + +# Computes face geometry from the single face detection. +node { + calculator: "FaceGeometryFromDetection" + input_stream: "MULTI_FACE_DETECTION:multi_smoothed_face_detection" + input_stream: "IMAGE_SIZE:input_image_size" + input_side_packet: "ENVIRONMENT:environment" + output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" +} diff --git a/mediapipe/graphs/face_effect/subgraphs/single_face_geometry_from_landmarks_gpu.pbtxt b/mediapipe/graphs/face_effect/subgraphs/single_face_geometry_from_landmarks_gpu.pbtxt new file mode 100644 index 000000000..364e38654 --- /dev/null +++ b/mediapipe/graphs/face_effect/subgraphs/single_face_geometry_from_landmarks_gpu.pbtxt @@ -0,0 +1,89 @@ +# MediaPipe subgraph that extracts geometry from a single face using the face +# landmark pipeline on an input GPU image. The face landmarks are also +# "smoothed" to achieve better visual results. + +type: "SingleFaceGeometryFromLandmarksGpu" + +# GPU image. (GpuBuffer) +input_stream: "IMAGE:input_image" + +# Environment that describes the current virtual scene. +# (face_geometry::Environment) +input_side_packet: "ENVIRONMENT:environment" + +# A list of geometry data for a single detected face. The size of this +# collection is at most 1 because of the single-face use in this graph. +# (std::vector) +# +# NOTE: if no face is detected at a particular timestamp, there will not be an +# output packet in the `MULTI_FACE_GEOMETRY` stream for this timestamp. However, +# the MediaPipe framework will internally inform the downstream calculators of +# the absence of this packet so that they don't wait for it unnecessarily. +output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" + +# Creates a packet to inform the `FaceLandmarkFrontGpu` subgraph to detect at +# most 1 face. +node { + calculator: "ConstantSidePacketCalculator" + output_side_packet: "PACKET:num_faces" + node_options: { + [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { + packet { int_value: 1 } + } + } +} + +# Subgraph that detects faces and corresponding landmarks using the face +# landmark pipeline. +node { + calculator: "FaceLandmarkFrontGpu" + input_stream: "IMAGE:input_image" + input_side_packet: "NUM_FACES:num_faces" + output_stream: "LANDMARKS:multi_face_landmarks" +} + +# Extracts a single set of face landmarks associated with the most prominent +# face detected from a collection. +node { + calculator: "SplitNormalizedLandmarkListVectorCalculator" + input_stream: "multi_face_landmarks" + output_stream: "face_landmarks" + node_options: { + [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { + ranges: { begin: 0 end: 1 } + element_only: true + } + } +} + +# Extracts the input image frame dimensions as a separate packet. +node { + calculator: "ImagePropertiesCalculator" + input_stream: "IMAGE_GPU:input_image" + output_stream: "SIZE:input_image_size" +} + +# Applies smoothing to the single set of face landmarks. +node { + calculator: "FaceLandmarksSmoothing" + input_stream: "NORM_LANDMARKS:face_landmarks" + input_stream: "IMAGE_SIZE:input_image_size" + output_stream: "NORM_FILTERED_LANDMARKS:smoothed_face_landmarks" +} + +# Puts the single set of smoothed landmarks back into a collection to simplify +# passing the result into the `FaceGeometryFromLandmarks` subgraph. +node { + calculator: "ConcatenateLandmarListVectorCalculator" + input_stream: "smoothed_face_landmarks" + output_stream: "multi_smoothed_face_landmarks" +} + +# Computes face geometry from face landmarks for a single face. +node { + calculator: "FaceGeometryFromLandmarks" + input_stream: "MULTI_FACE_LANDMARKS:multi_smoothed_face_landmarks" + input_stream: "IMAGE_SIZE:input_image_size" + input_side_packet: "ENVIRONMENT:environment" + output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" +} diff --git a/mediapipe/graphs/face_effect/subgraphs/single_face_smooth_landmark_gpu.pbtxt b/mediapipe/graphs/face_effect/subgraphs/single_face_smooth_landmark_gpu.pbtxt deleted file mode 100644 index e2e347103..000000000 --- a/mediapipe/graphs/face_effect/subgraphs/single_face_smooth_landmark_gpu.pbtxt +++ /dev/null @@ -1,84 +0,0 @@ -# MediaPipe subgraph that detects a single face and corresponding landmarks on -# a input GPU image. The landmarks are also "smoothed" to achieve better visual -# results. - -type: "SingleFaceSmoothLandmarkGpu" - -# GPU image. (GpuBuffer) -input_stream: "IMAGE:input_image" - -# Collection of detected/predicted faces, each represented as a list of face -# landmarks. However, the size of this collection is always 1 because of the -# single-face use in this graph. The decision to wrap the landmark list into a -# collection was made to simplify passing the result into the `FaceGeometry` -# subgraph. (std::vector) -# -# NOTE: there will not be an output packet in the LANDMARKS stream for this -# particular timestamp if none of faces detected. However, the MediaPipe -# framework will internally inform the downstream calculators of the absence of -# this packet so that they don't wait for it unnecessarily. -output_stream: "LANDMARKS:multi_face_smooth_landmarks" - -# Creates a packet to inform the `FaceLandmarkFrontGpu` subgraph to detect at -# most 1 face. -node { - calculator: "ConstantSidePacketCalculator" - output_side_packet: "PACKET:num_faces" - node_options: { - [type.googleapis.com/mediapipe.ConstantSidePacketCalculatorOptions]: { - packet { int_value: 1 } - } - } -} - -# Subgraph that detects faces and corresponding landmarks. -node { - calculator: "FaceLandmarkFrontGpu" - input_stream: "IMAGE:input_image" - input_side_packet: "NUM_FACES:num_faces" - output_stream: "LANDMARKS:multi_face_landmarks" -} - -# Extracts the detected face landmark list from a collection. -node { - calculator: "SplitNormalizedLandmarkListVectorCalculator" - input_stream: "multi_face_landmarks" - output_stream: "face_landmarks" - node_options: { - [type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] { - ranges: { begin: 0 end: 1 } - element_only: true - } - } -} - -# Extracts the input image frame dimensions as a separate packet. -node { - calculator: "ImagePropertiesCalculator" - input_stream: "IMAGE_GPU:input_image" - output_stream: "SIZE:input_image_size" -} - -# Applies smoothing to the single face landmarks. -node { - calculator: "LandmarksSmoothingCalculator" - input_stream: "NORM_LANDMARKS:face_landmarks" - input_stream: "IMAGE_SIZE:input_image_size" - output_stream: "NORM_FILTERED_LANDMARKS:face_smooth_landmarks" - node_options: { - [type.googleapis.com/mediapipe.LandmarksSmoothingCalculatorOptions] { - velocity_filter: { - window_size: 5 - velocity_scale: 20.0 - } - } - } -} - -# Puts the single face smooth landmarks back into a collection to simplify -# passing the result into the `FaceGeometry` subgraph. -node { - calculator: "ConcatenateLandmarListVectorCalculator" - input_stream: "face_smooth_landmarks" - output_stream: "multi_face_smooth_landmarks" -} diff --git a/mediapipe/graphs/face_mesh/calculators/face_landmarks_to_render_data_calculator.cc b/mediapipe/graphs/face_mesh/calculators/face_landmarks_to_render_data_calculator.cc index 86ad06ec5..db1bc3422 100644 --- a/mediapipe/graphs/face_mesh/calculators/face_landmarks_to_render_data_calculator.cc +++ b/mediapipe/graphs/face_mesh/calculators/face_landmarks_to_render_data_calculator.cc @@ -81,12 +81,11 @@ constexpr int kFaceLandmarkConnections[] = { class FaceLandmarksToRenderDataCalculator : public LandmarksToRenderDataCalculator { public: - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(FaceLandmarksToRenderDataCalculator); -mediapipe::Status FaceLandmarksToRenderDataCalculator::Open( - CalculatorContext* cc) { +absl::Status FaceLandmarksToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); @@ -95,7 +94,7 @@ mediapipe::Status FaceLandmarksToRenderDataCalculator::Open( landmark_connections_.push_back(kFaceLandmarkConnections[i * 2 + 1]); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/graphs/hand_tracking/BUILD b/mediapipe/graphs/hand_tracking/BUILD index ed922ba8e..71525bb52 100644 --- a/mediapipe/graphs/hand_tracking/BUILD +++ b/mediapipe/graphs/hand_tracking/BUILD @@ -46,7 +46,6 @@ cc_library( "//mediapipe/calculators/core:merge_calculator", "//mediapipe/graphs/hand_tracking/subgraphs:hand_renderer_cpu", "//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu", - "//mediapipe/modules/palm_detection:palm_detection_cpu", ], ) diff --git a/mediapipe/graphs/instant_motion_tracking/calculators/matrices_manager_calculator.cc b/mediapipe/graphs/instant_motion_tracking/calculators/matrices_manager_calculator.cc index d3abb6540..f8190c506 100644 --- a/mediapipe/graphs/instant_motion_tracking/calculators/matrices_manager_calculator.cc +++ b/mediapipe/graphs/instant_motion_tracking/calculators/matrices_manager_calculator.cc @@ -87,9 +87,9 @@ constexpr float kInitialZ = -10.0f; class MatricesManagerCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: // Device properties that will be preset by side packets @@ -137,8 +137,7 @@ class MatricesManagerCalculator : public CalculatorBase { REGISTER_CALCULATOR(MatricesManagerCalculator); -mediapipe::Status MatricesManagerCalculator::GetContract( - CalculatorContract* cc) { +absl::Status MatricesManagerCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kAnchorsTag) && cc->Inputs().HasTag(kIMUMatrixTag) && cc->Inputs().HasTag(kUserRotationsTag) && @@ -162,20 +161,20 @@ mediapipe::Status MatricesManagerCalculator::GetContract( cc->InputSidePackets().Tag(kFOVSidePacketTag).Set(); cc->InputSidePackets().Tag(kAspectRatioSidePacketTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MatricesManagerCalculator::Open(CalculatorContext* cc) { +absl::Status MatricesManagerCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); // Set device properties from side packets vertical_fov_radians_ = cc->InputSidePackets().Tag(kFOVSidePacketTag).Get(); aspect_ratio_ = cc->InputSidePackets().Tag(kAspectRatioSidePacketTag).Get(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status MatricesManagerCalculator::Process(CalculatorContext* cc) { +absl::Status MatricesManagerCalculator::Process(CalculatorContext* cc) { // Define each object's model matrices auto asset_matrices_gif = std::make_unique(); @@ -276,7 +275,7 @@ mediapipe::Status MatricesManagerCalculator::Process(CalculatorContext* cc) { .Get(cc->Outputs().GetId("MATRICES", 1)) .Add(asset_matrices_1.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Using a specified rotation value in radians, generate a rotation matrix for diff --git a/mediapipe/graphs/instant_motion_tracking/calculators/sticker_manager_calculator.cc b/mediapipe/graphs/instant_motion_tracking/calculators/sticker_manager_calculator.cc index 5f0ee94ac..40210c27a 100644 --- a/mediapipe/graphs/instant_motion_tracking/calculators/sticker_manager_calculator.cc +++ b/mediapipe/graphs/instant_motion_tracking/calculators/sticker_manager_calculator.cc @@ -53,7 +53,7 @@ constexpr char kRenderDescriptorsTag[] = "RENDER_DATA"; class StickerManagerCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kProtoDataString)); RET_CHECK(cc->Outputs().HasTag(kAnchorsTag) && cc->Outputs().HasTag(kUserRotationsTag) && @@ -66,15 +66,15 @@ class StickerManagerCalculator : public CalculatorBase { cc->Outputs().Tag(kUserScalingsTag).Set>(); cc->Outputs().Tag(kRenderDescriptorsTag).Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { std::string sticker_proto_string = cc->Inputs().Tag(kProtoDataString).Get(); @@ -138,11 +138,11 @@ class StickerManagerCalculator : public CalculatorBase { .At(cc->InputTimestamp())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Close(CalculatorContext* cc) override { + return absl::OkStatus(); } }; diff --git a/mediapipe/graphs/instant_motion_tracking/calculators/tracked_anchor_manager_calculator.cc b/mediapipe/graphs/instant_motion_tracking/calculators/tracked_anchor_manager_calculator.cc index 96ffa7aa8..446aee781 100644 --- a/mediapipe/graphs/instant_motion_tracking/calculators/tracked_anchor_manager_calculator.cc +++ b/mediapipe/graphs/instant_motion_tracking/calculators/tracked_anchor_manager_calculator.cc @@ -71,7 +71,7 @@ class TrackedAnchorManagerCalculator : public CalculatorBase { std::vector previous_anchor_data_; public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kAnchorsTag) && cc->Inputs().HasTag(kSentinelTag)); RET_CHECK(cc->Outputs().HasTag(kAnchorsTag) && @@ -91,19 +91,16 @@ class TrackedAnchorManagerCalculator : public CalculatorBase { cc->Outputs().Tag(kCancelTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { - return mediapipe::OkStatus(); - } + absl::Status Open(CalculatorContext* cc) override { return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(TrackedAnchorManagerCalculator); -mediapipe::Status TrackedAnchorManagerCalculator::Process( - CalculatorContext* cc) { +absl::Status TrackedAnchorManagerCalculator::Process(CalculatorContext* cc) { mediapipe::Timestamp timestamp = cc->InputTimestamp(); const int sticker_sentinel = cc->Inputs().Tag(kSentinelTag).Get(); std::vector current_anchor_data = @@ -208,6 +205,6 @@ mediapipe::Status TrackedAnchorManagerCalculator::Process( .Tag(kBoxesOutputTag) .Add(pos_boxes.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/graphs/iris_tracking/calculators/iris_to_depth_calculator.cc b/mediapipe/graphs/iris_tracking/calculators/iris_to_depth_calculator.cc index caac042e6..3522274de 100644 --- a/mediapipe/graphs/iris_tracking/calculators/iris_to_depth_calculator.cc +++ b/mediapipe/graphs/iris_tracking/calculators/iris_to_depth_calculator.cc @@ -89,7 +89,7 @@ float CalculateDepth(const NormalizedLandmark& center, float focal_length, // } class IrisToDepthCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kIrisTag).Set(); cc->Inputs().Tag(kImageSizeTag).Set>(); @@ -111,12 +111,12 @@ class IrisToDepthCalculator : public CalculatorBase { if (cc->Outputs().HasTag(kRightIrisDepthTag)) { cc->Outputs().Tag(kRightIrisDepthTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: float focal_length_pixels_ = -1.f; @@ -134,7 +134,7 @@ class IrisToDepthCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(IrisToDepthCalculator); -mediapipe::Status IrisToDepthCalculator::Open(CalculatorContext* cc) { +absl::Status IrisToDepthCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); if (cc->InputSidePackets().HasTag(kFocalLengthPixelTag)) { #if defined(__APPLE__) @@ -155,13 +155,13 @@ mediapipe::Status IrisToDepthCalculator::Open(CalculatorContext* cc) { } options_ = cc->Options<::mediapipe::IrisToDepthCalculatorOptions>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status IrisToDepthCalculator::Process(CalculatorContext* cc) { +absl::Status IrisToDepthCalculator::Process(CalculatorContext* cc) { // Only process if there's input landmarks. if (cc->Inputs().Tag(kIrisTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& iris_landmarks = @@ -220,7 +220,7 @@ mediapipe::Status IrisToDepthCalculator::Process(CalculatorContext* cc) { .At(cc->InputTimestamp())); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void IrisToDepthCalculator::GetLeftIris(const NormalizedLandmarkList& lds, diff --git a/mediapipe/graphs/iris_tracking/calculators/iris_to_render_data_calculator.cc b/mediapipe/graphs/iris_tracking/calculators/iris_to_render_data_calculator.cc index 5bf86b170..c19db2a39 100644 --- a/mediapipe/graphs/iris_tracking/calculators/iris_to_render_data_calculator.cc +++ b/mediapipe/graphs/iris_tracking/calculators/iris_to_render_data_calculator.cc @@ -108,7 +108,7 @@ float CalculateDepth(const NormalizedLandmark& center, float focal_length, // } class IrisToRenderDataCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kIrisTag).Set(); cc->Outputs().Tag(kRenderDataTag).Set(); cc->Inputs().Tag(kImageSizeTag).Set>(); @@ -119,12 +119,12 @@ class IrisToRenderDataCalculator : public CalculatorBase { if (cc->Inputs().HasTag(kRightIrisDepthTag)) { cc->Inputs().Tag(kRightIrisDepthTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: void RenderIris(const NormalizedLandmarkList& iris_landmarks, @@ -150,15 +150,15 @@ class IrisToRenderDataCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(IrisToRenderDataCalculator); -mediapipe::Status IrisToRenderDataCalculator::Open(CalculatorContext* cc) { +absl::Status IrisToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status IrisToRenderDataCalculator::Process(CalculatorContext* cc) { +absl::Status IrisToRenderDataCalculator::Process(CalculatorContext* cc) { // Only process if there's input landmarks. if (cc->Inputs().Tag(kIrisTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& options = cc->Options<::mediapipe::IrisToRenderDataCalculatorOptions>(); @@ -212,7 +212,7 @@ mediapipe::Status IrisToRenderDataCalculator::Process(CalculatorContext* cc) { cc->Outputs() .Tag(kRenderDataTag) .Add(render_data.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void IrisToRenderDataCalculator::AddTextRenderData( diff --git a/mediapipe/graphs/iris_tracking/calculators/update_face_landmarks_calculator.cc b/mediapipe/graphs/iris_tracking/calculators/update_face_landmarks_calculator.cc index 3616eba33..de9549a3d 100644 --- a/mediapipe/graphs/iris_tracking/calculators/update_face_landmarks_calculator.cc +++ b/mediapipe/graphs/iris_tracking/calculators/update_face_landmarks_calculator.cc @@ -215,28 +215,27 @@ constexpr int kEyeLandmarkIndicesInFaceLandmarks[] = { // class UpdateFaceLandmarksCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kFaceLandmarksTag).Set(); cc->Inputs().Tag(kNewEyeLandmarksTag).Set(); cc->Outputs().Tag(kUpdatedFaceLandmarksTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) { + absl::Status Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(UpdateFaceLandmarksCalculator); -mediapipe::Status UpdateFaceLandmarksCalculator::Process( - CalculatorContext* cc) { +absl::Status UpdateFaceLandmarksCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().Tag(kFaceLandmarksTag).IsEmpty() || cc->Inputs().Tag(kNewEyeLandmarksTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& face_landmarks = cc->Inputs().Tag(kFaceLandmarksTag).Get(); @@ -263,7 +262,7 @@ mediapipe::Status UpdateFaceLandmarksCalculator::Process( .Tag(kUpdatedFaceLandmarksTag) .Add(refined_face_landmarks.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/graphs/object_detection_3d/BUILD b/mediapipe/graphs/object_detection_3d/BUILD index 467347f65..7ba00c0eb 100644 --- a/mediapipe/graphs/object_detection_3d/BUILD +++ b/mediapipe/graphs/object_detection_3d/BUILD @@ -56,8 +56,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//mediapipe/calculators/core:constant_side_packet_calculator", - "//mediapipe/calculators/tflite:tflite_model_calculator", - "//mediapipe/calculators/util:local_file_contents_calculator", "//mediapipe/calculators/video:opencv_video_decoder_calculator", "//mediapipe/calculators/video:opencv_video_encoder_calculator", "//mediapipe/graphs/object_detection_3d/subgraphs:renderer_cpu", diff --git a/mediapipe/graphs/object_detection_3d/calculators/annotations_to_model_matrices_calculator.cc b/mediapipe/graphs/object_detection_3d/calculators/annotations_to_model_matrices_calculator.cc index 73098e840..c2166c648 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/annotations_to_model_matrices_calculator.cc +++ b/mediapipe/graphs/object_detection_3d/calculators/annotations_to_model_matrices_calculator.cc @@ -37,6 +37,7 @@ namespace { constexpr char kAnnotationTag[] = "ANNOTATIONS"; constexpr char kModelMatricesTag[] = "MODEL_MATRICES"; +using Matrix3fRM = Eigen::Matrix; using Matrix4fRM = Eigen::Matrix; } // namespace @@ -66,14 +67,14 @@ class AnnotationsToModelMatricesCalculator : public CalculatorBase { AnnotationsToModelMatricesCalculator& operator=( const AnnotationsToModelMatricesCalculator&) = delete; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: - mediapipe::Status GetModelMatricesForAnnotations( + absl::Status GetModelMatricesForAnnotations( const FrameAnnotation& annotations, TimedModelMatrixProtoList* model_matrix_list); @@ -83,7 +84,7 @@ class AnnotationsToModelMatricesCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(AnnotationsToModelMatricesCalculator); -mediapipe::Status AnnotationsToModelMatricesCalculator::GetContract( +absl::Status AnnotationsToModelMatricesCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kAnnotationTag)) << "No input stream found."; if (cc->Inputs().HasTag(kAnnotationTag)) { @@ -101,11 +102,10 @@ mediapipe::Status AnnotationsToModelMatricesCalculator::GetContract( if (cc->InputSidePackets().HasTag("MODEL_TRANSFORMATION")) { cc->InputSidePackets().Tag("MODEL_TRANSFORMATION").Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AnnotationsToModelMatricesCalculator::Open( - CalculatorContext* cc) { +absl::Status AnnotationsToModelMatricesCalculator::Open(CalculatorContext* cc) { RET_CHECK(cc->Inputs().HasTag(kAnnotationTag)); cc->SetOffset(TimestampDiff(0)); @@ -131,10 +131,10 @@ mediapipe::Status AnnotationsToModelMatricesCalculator::Open( model_transformation_.setIdentity(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AnnotationsToModelMatricesCalculator::Process( +absl::Status AnnotationsToModelMatricesCalculator::Process( CalculatorContext* cc) { auto model_matrices = std::make_unique(); @@ -142,73 +142,66 @@ mediapipe::Status AnnotationsToModelMatricesCalculator::Process( cc->Inputs().Tag(kAnnotationTag).Get(); if (!GetModelMatricesForAnnotations(annotations, model_matrices.get()).ok()) { - return mediapipe::InvalidArgumentError("Error in GetModelMatricesForBoxes"); + return absl::InvalidArgumentError("Error in GetModelMatricesForBoxes"); } cc->Outputs() .Tag(kModelMatricesTag) .Add(model_matrices.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status +absl::Status AnnotationsToModelMatricesCalculator::GetModelMatricesForAnnotations( const FrameAnnotation& annotations, TimedModelMatrixProtoList* model_matrix_list) { if (model_matrix_list == nullptr) { - return mediapipe::InvalidArgumentError("model_matrix_list is nullptr"); + return absl::InvalidArgumentError("model_matrix_list is nullptr"); } model_matrix_list->clear_model_matrix(); - Box box("category"); for (const auto& object : annotations.annotations()) { TimedModelMatrixProto* model_matrix = model_matrix_list->add_model_matrix(); model_matrix->set_id(object.object_id()); - // Fit a box to the original vertices to estimate the scale of the box - std::vector vertices; - for (const auto& keypoint : object.keypoints()) { - const auto& point = keypoint.point_3d(); - Eigen::Vector3f p(point.x(), point.y(), point.z()); - vertices.emplace_back(p); - } - box.Fit(vertices); + // Get object rotation, translation and scale. + const auto object_rotation = + Eigen::Map(object.rotation().data()); + const auto object_translation = + Eigen::Map(object.translation().data()); + const auto object_scale = + Eigen::Map(object.scale().data()); - // Re-scale the box if necessary - Eigen::Vector3f estimated_scale = box.GetScale(); - vertices.clear(); - for (const auto& keypoint : object.keypoints()) { - const auto& point = keypoint.point_3d(); - Eigen::Vector3f p(point.x(), point.y(), point.z()); - vertices.emplace_back(p); - } - box.Fit(vertices); + // Compose object transformation matrix. + Matrix4fRM object_transformation; + object_transformation.setIdentity(); + object_transformation.topLeftCorner<3, 3>() = object_rotation; + object_transformation.topRightCorner<3, 1>() = object_translation; - Matrix4fRM object_transformation = box.GetTransformation(); Matrix4fRM model_view; - Matrix4fRM pursuit_model; + Matrix4fRM objectron_model; // The reference view is // // ref << 0., 0., 1., 0., // -1., 0., 0., 0., // 0., -1., 0., 0., // 0., 0., 0., 1.; - // We have pursuit_model * model = model_view, to get pursuit_model: - // pursuit_model = model_view * model^-1 + // We have objectron_model * model = model_view, to get objectron_model: + // objectron_model = model_view * model^-1 // clang-format off - pursuit_model << 0.0, 1.0, 0.0, 0.0, - 1.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 1.0, 0.0, - 0.0, 0.0, 0.0, 1.0; + objectron_model << 1.0, 0.0, 0.0, 0.0, + 0.0, -1., 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0; // clang-format on // Re-scale the CAD model to the scale of the estimated bounding box. - const Eigen::Vector3f scale = model_scale_.cwiseProduct(estimated_scale); + const Eigen::Vector3f scale = model_scale_.cwiseProduct(object_scale); const Matrix4fRM model = model_transformation_.array().colwise() * scale.homogeneous().array(); // Finally compute the model_view matrix. - model_view = pursuit_model * object_transformation * model; + model_view = objectron_model * object_transformation * model; for (int i = 0; i < model_view.rows(); ++i) { for (int j = 0; j < model_view.cols(); ++j) { @@ -216,7 +209,7 @@ AnnotationsToModelMatricesCalculator::GetModelMatricesForAnnotations( } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/graphs/object_detection_3d/calculators/annotations_to_render_data_calculator.cc b/mediapipe/graphs/object_detection_3d/calculators/annotations_to_render_data_calculator.cc index fc8287d25..65bff7768 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/annotations_to_render_data_calculator.cc +++ b/mediapipe/graphs/object_detection_3d/calculators/annotations_to_render_data_calculator.cc @@ -98,11 +98,11 @@ class AnnotationsToRenderDataCalculator : public CalculatorBase { AnnotationsToRenderDataCalculator& operator=( const AnnotationsToRenderDataCalculator&) = delete; - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: static void SetRenderAnnotationColorThickness( @@ -134,7 +134,7 @@ class AnnotationsToRenderDataCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(AnnotationsToRenderDataCalculator); -mediapipe::Status AnnotationsToRenderDataCalculator::GetContract( +absl::Status AnnotationsToRenderDataCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kAnnotationTag)) << "No input stream found."; if (cc->Inputs().HasTag(kAnnotationTag)) { @@ -142,19 +142,17 @@ mediapipe::Status AnnotationsToRenderDataCalculator::GetContract( } cc->Outputs().Tag(kRenderDataTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AnnotationsToRenderDataCalculator::Open( - CalculatorContext* cc) { +absl::Status AnnotationsToRenderDataCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AnnotationsToRenderDataCalculator::Process( - CalculatorContext* cc) { +absl::Status AnnotationsToRenderDataCalculator::Process(CalculatorContext* cc) { auto render_data = absl::make_unique(); bool visualize_depth = options_.visualize_landmark_depth(); float z_min = 0.f; @@ -215,7 +213,7 @@ mediapipe::Status AnnotationsToRenderDataCalculator::Process( .Tag(kRenderDataTag) .Add(render_data.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void AnnotationsToRenderDataCalculator::AddConnectionToRenderData( diff --git a/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc b/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc index b9f1346a9..9bc43ba03 100644 --- a/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc +++ b/mediapipe/graphs/object_detection_3d/calculators/gl_animation_overlay_calculator.cc @@ -128,10 +128,10 @@ class GlAnimationOverlayCalculator : public CalculatorBase { GlAnimationOverlayCalculator() {} ~GlAnimationOverlayCalculator(); - static ::mediapipe::Status GetContract(CalculatorContract *cc); + static absl::Status GetContract(CalculatorContract *cc); - ::mediapipe::Status Open(CalculatorContext *cc) override; - ::mediapipe::Status Process(CalculatorContext *cc) override; + absl::Status Open(CalculatorContext *cc) override; + absl::Status Process(CalculatorContext *cc) override; private: bool has_video_stream_ = false; @@ -171,11 +171,11 @@ class GlAnimationOverlayCalculator : public CalculatorBase { float *vertical_fov_degrees); int GetAnimationFrameIndex(Timestamp timestamp); - ::mediapipe::Status GlSetup(); - ::mediapipe::Status GlBind(const TriangleMesh &triangle_mesh, - const GlTexture &texture); - ::mediapipe::Status GlRender(const TriangleMesh &triangle_mesh, - const float *model_matrix); + absl::Status GlSetup(); + absl::Status GlBind(const TriangleMesh &triangle_mesh, + const GlTexture &texture); + absl::Status GlRender(const TriangleMesh &triangle_mesh, + const float *model_matrix); void InitializePerspectiveMatrix(float aspect_ratio, float vertical_fov_degrees, float z_near, float z_far); @@ -198,8 +198,7 @@ class GlAnimationOverlayCalculator : public CalculatorBase { REGISTER_CALCULATOR(GlAnimationOverlayCalculator); // static -::mediapipe::Status GlAnimationOverlayCalculator::GetContract( - CalculatorContract *cc) { +absl::Status GlAnimationOverlayCalculator::GetContract(CalculatorContract *cc) { MP_RETURN_IF_ERROR( GlCalculatorHelper::SetupInputSidePackets(&(cc->InputSidePackets()))); if (cc->Inputs().HasTag("VIDEO")) { @@ -236,7 +235,7 @@ REGISTER_CALCULATOR(GlAnimationOverlayCalculator); cc->InputSidePackets().Tag("MASK_ASSET").Set(); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } void GlAnimationOverlayCalculator::CalculateTriangleMeshNormals( @@ -515,7 +514,7 @@ void GlAnimationOverlayCalculator::ComputeAspectRatioAndFovFromCameraParameters( std::atan(camera_parameters.portrait_height() * 0.5f) * 2 * 180 / M_PI; } -::mediapipe::Status GlAnimationOverlayCalculator::Open(CalculatorContext *cc) { +absl::Status GlAnimationOverlayCalculator::Open(CalculatorContext *cc) { cc->SetOffset(TimestampDiff(0)); MP_RETURN_IF_ERROR(helper_.Open(cc)); @@ -562,7 +561,7 @@ void GlAnimationOverlayCalculator::ComputeAspectRatioAndFovFromCameraParameters( loaded_animation = LoadAnimationAndroid(mask_asset_name, &mask_meshes_); if (!loaded_animation) { LOG(ERROR) << "Failed to load mask asset."; - return ::mediapipe::UnknownError("Failed to load mask asset."); + return absl::UnknownError("Failed to load mask asset."); } } loaded_animation = LoadAnimationAndroid(asset_name, &triangle_meshes_); @@ -571,10 +570,10 @@ void GlAnimationOverlayCalculator::ComputeAspectRatioAndFovFromCameraParameters( #endif if (!loaded_animation) { LOG(ERROR) << "Failed to load animation asset."; - return ::mediapipe::UnknownError("Failed to load animation asset."); + return absl::UnknownError("Failed to load animation asset."); } - return helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { + return helper_.RunInGlContext([this, &cc]() -> absl::Status { if (cc->InputSidePackets().HasTag("MASK_TEXTURE")) { const auto &mask_texture = cc->InputSidePackets().Tag("MASK_TEXTURE").Get(); @@ -591,7 +590,7 @@ void GlAnimationOverlayCalculator::ComputeAspectRatioAndFovFromCameraParameters( VLOG(2) << "Input texture size: " << texture_.width() << ", " << texture_.height() << std::endl; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }); } @@ -624,9 +623,8 @@ void GlAnimationOverlayCalculator::LoadModelMatrices( } } -::mediapipe::Status GlAnimationOverlayCalculator::Process( - CalculatorContext *cc) { - return helper_.RunInGlContext([this, &cc]() -> mediapipe::Status { +absl::Status GlAnimationOverlayCalculator::Process(CalculatorContext *cc) { + return helper_.RunInGlContext([this, &cc]() -> absl::Status { if (!initialized_) { MP_RETURN_IF_ERROR(GlSetup()); initialized_ = true; @@ -663,7 +661,7 @@ void GlAnimationOverlayCalculator::LoadModelMatrices( if (has_video_stream_ && !(cc->Inputs().Tag("VIDEO").IsEmpty())) { auto result = cc->Inputs().Tag("VIDEO").Value().Consume(); if (result.ok()) { - input_frame = std::move(result).ValueOrDie(); + input_frame = std::move(result).value(); #if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER input_frame->GetGlTextureBufferSharedPtr()->Reuse(); #endif @@ -679,7 +677,7 @@ void GlAnimationOverlayCalculator::LoadModelMatrices( dst = helper_.CreateDestinationTexture(width, height); } else { // We have an input video stream, but not for this frame. Don't render! - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } helper_.BindFramebuffer(dst); @@ -759,11 +757,11 @@ void GlAnimationOverlayCalculator::LoadModelMatrices( TagOrIndex(&(cc->Outputs()), "OUTPUT", 0) .Add(output.release(), cc->InputTimestamp()); GLCHECK(glFrontFace(GL_CCW)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }); } -::mediapipe::Status GlAnimationOverlayCalculator::GlSetup() { +absl::Status GlAnimationOverlayCalculator::GlSetup() { // Load vertex and fragment shaders const GLint attr_location[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, @@ -881,10 +879,10 @@ void GlAnimationOverlayCalculator::LoadModelMatrices( GLCHECK(glGetUniformLocation(program_, "perspectiveMatrix")); model_matrix_uniform_ = GLCHECK(glGetUniformLocation(program_, "modelMatrix")); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status GlAnimationOverlayCalculator::GlBind( +absl::Status GlAnimationOverlayCalculator::GlBind( const TriangleMesh &triangle_mesh, const GlTexture &texture) { GLCHECK(glUseProgram(program_)); @@ -915,16 +913,16 @@ void GlAnimationOverlayCalculator::LoadModelMatrices( GLCHECK(glUniformMatrix4fv(perspective_matrix_uniform_, 1, GL_FALSE, perspective_matrix_)); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status GlAnimationOverlayCalculator::GlRender( +absl::Status GlAnimationOverlayCalculator::GlRender( const TriangleMesh &triangle_mesh, const float *model_matrix) { GLCHECK(glUniformMatrix4fv(model_matrix_uniform_, 1, GL_FALSE, model_matrix)); GLCHECK(glDrawElements(GL_TRIANGLES, triangle_mesh.index_count, GL_UNSIGNED_SHORT, triangle_mesh.triangle_indices.get())); - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } GlAnimationOverlayCalculator::~GlAnimationOverlayCalculator() { diff --git a/mediapipe/graphs/object_detection_3d/obj_parser/BUILD b/mediapipe/graphs/object_detection_3d/obj_parser/BUILD new file mode 100644 index 000000000..3b84cc84d --- /dev/null +++ b/mediapipe/graphs/object_detection_3d/obj_parser/BUILD @@ -0,0 +1,33 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +java_library( + name = "obj_parser_lib", + srcs = [ + "ObjParserMain.java", + "SimpleObjParser.java", + ], + javacopts = ["-Xep:DefaultPackage:OFF"], +) + +java_binary( + name = "ObjParser", + javacopts = ["-Xep:DefaultPackage:OFF"], + main_class = "ObjParserMain", + runtime_deps = [ + ":obj_parser_lib", + ], +) diff --git a/mediapipe/graphs/object_detection_3d/obj_parser/ObjParserMain.java b/mediapipe/graphs/object_detection_3d/obj_parser/ObjParserMain.java new file mode 100644 index 000000000..80e639d96 --- /dev/null +++ b/mediapipe/graphs/object_detection_3d/obj_parser/ObjParserMain.java @@ -0,0 +1,205 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileFilter; +import java.io.FileOutputStream; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; + +/** + * Class for running desktop-side parsing/packing routines on .obj AR assets. Usage: ObjParser + * --input_dir=[INPUT_DIRECTORY] --output_dir=[OUTPUT_DIRECTORY] where INPUT_DIRECTORY is the folder + * with asset obj files to process, and OUTPUT_DIRECTORY is the folder where processed asset uuu + * file should be placed. + * + *

NOTE: Directories are assumed to be absolute paths. + */ +public final class ObjParserMain { + // Simple FileFilter implementation to let us walk over only our .obj files in a particular + // directory. + private static final class ObjFileFilter implements FileFilter { + ObjFileFilter() { + // Nothing to do here. + } + + @Override + public boolean accept(File file) { + return file.getName().endsWith(".obj"); + } + } + + // File extension for binary output files; tagged onto end of initial file extension. + private static final String BINARY_FILE_EXT = ".uuu"; + private static final String INPUT_DIR_FLAG = "--input_dir="; + private static final String OUTPUT_DIR_FLAG = "--output_dir="; + private static final float DEFAULT_VERTEX_SCALE_FACTOR = 30.0f; + private static final double NS_TO_SECONDS = 1e9; + + public final PrintWriter writer; + + public ObjParserMain() { + super(); + this.writer = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out, UTF_8))); + } + + // Simple overridable logging function. + protected void logString(String infoLog) { + writer.println(infoLog); + } + + /* + * Main program logic: parse command-line arguments and perform actions. + */ + public void run(String inDirectory, String outDirectory) { + if (inDirectory.isEmpty()) { + logString("Error: Must provide input directory with " + INPUT_DIR_FLAG); + return; + } + if (outDirectory.isEmpty()) { + logString("Error: Must provide output directory with " + OUTPUT_DIR_FLAG); + return; + } + + File dirAsFile = new File(inDirectory); + ObjFileFilter objFileFilter = new ObjFileFilter(); + File[] objFiles = dirAsFile.listFiles(objFileFilter); + + FileOutputStream outputStream = null; + logString("Parsing directory: " + inDirectory); + // We need frames processed in correct order. + Arrays.sort(objFiles); + + for (File objFile : objFiles) { + String fileName = objFile.getAbsolutePath(); + + // Just take the file name of the first processed frame. + if (outputStream == null) { + String outputFileName = outDirectory + objFile.getName() + BINARY_FILE_EXT; + try { + // Create new file here, if we can. + outputStream = new FileOutputStream(outputFileName); + logString("Created outfile: " + outputFileName); + } catch (Exception e) { + logString("Error creating outfile: " + e.toString()); + e.printStackTrace(writer); + return; + } + } + + // Process each file into the stream. + logString("Processing file: " + fileName); + processFile(fileName, outputStream); + } + + // Finally close the stream out. + try { + if (outputStream != null) { + outputStream.close(); + } + } catch (Exception e) { + logString("Error trying to close output stream: " + e.toString()); + e.printStackTrace(writer); + } + } + + /* + * Entrypoint for command-line executable. + */ + public static void main(String[] args) { + // Parse flags + String inDirectory = ""; + String outDirectory = ""; + for (int i = 0; i < args.length; i++) { + if (args[i].startsWith(INPUT_DIR_FLAG)) { + inDirectory = args[i].substring(INPUT_DIR_FLAG.length()); + // Make sure this will be treated as a directory + if (!inDirectory.endsWith("/")) { + inDirectory += "/"; + } + } + if (args[i].startsWith(OUTPUT_DIR_FLAG)) { + outDirectory = args[i].substring(OUTPUT_DIR_FLAG.length()); + // Make sure this will be treated as a directory + if (!outDirectory.endsWith("/")) { + outDirectory += "/"; + } + } + } + ObjParserMain parser = new ObjParserMain(); + parser.run(inDirectory, outDirectory); + parser.writer.flush(); + } + + /* + * Internal helper function to parse a .obj from an infile name and stream the resulting data + * directly out in binary-dump format to outputStream. + */ + private void processFile(String infileName, OutputStream outputStream) { + long start = System.nanoTime(); + + // First we parse the obj. + SimpleObjParser objParser = new SimpleObjParser(infileName, DEFAULT_VERTEX_SCALE_FACTOR); + if (!objParser.parse()) { + logString("Error parsing .obj file before processing"); + return; + } + + final float[] vertices = objParser.getVertices(); + final float[] textureCoords = objParser.getTextureCoords(); + final ArrayList triangleList = objParser.getTriangles(); + + // Overall byte count to stream: 12 for the 3 list-length ints, and then 4 for each vertex and + // texCoord int, and finally 2 for each triangle index short. + final int bbSize = + 12 + 4 * vertices.length + 4 * textureCoords.length + 2 * triangleList.size(); + + // Ensure ByteBuffer is native order, just like we want to read it in, but is NOT direct, so + // we can call .array() on it. + ByteBuffer bb = ByteBuffer.allocate(bbSize); + bb.order(ByteOrder.nativeOrder()); + + bb.putInt(vertices.length); + bb.putInt(textureCoords.length); + bb.putInt(triangleList.size()); + logString(String.format("Writing... Vertices: %d, TextureCoords: %d, Indices: %d.%n", + vertices.length, textureCoords.length, triangleList.size())); + for (float vertex : vertices) { + bb.putFloat(vertex); + } + for (float textureCoord : textureCoords) { + bb.putFloat(textureCoord); + } + for (Short vertexIndex : triangleList) { + bb.putShort(vertexIndex.shortValue()); + } + bb.position(0); + try { + outputStream.write(bb.array(), 0, bbSize); + logString(String.format("Processing successful! Took %.4f seconds.%n", + (System.nanoTime() - start) / NS_TO_SECONDS)); + } catch (Exception e) { + logString("Error writing during processing: " + e.toString()); + e.printStackTrace(writer); + } + } +} diff --git a/mediapipe/graphs/object_detection_3d/obj_parser/SimpleObjParser.java b/mediapipe/graphs/object_detection_3d/obj_parser/SimpleObjParser.java new file mode 100644 index 000000000..937fdff89 --- /dev/null +++ b/mediapipe/graphs/object_detection_3d/obj_parser/SimpleObjParser.java @@ -0,0 +1,386 @@ +// Copyright 2021 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.io.BufferedReader; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +/** + * Class for parsing a single .obj file into openGL-usable pieces. + * + *

Usage: + * + *

SimpleObjParser objParser = new SimpleObjParser("animations/cow/cow320.obj", .015f); + * + *

if (objParser.parse()) { ... } + */ +public class SimpleObjParser { + private static class ShortPair { + private final Short first; + private final Short second; + + public ShortPair(Short newFirst, Short newSecond) { + first = newFirst; + second = newSecond; + } + + public Short getFirst() { + return first; + } + + public Short getSecond() { + return second; + } + } + + private static final String TAG = SimpleObjParser.class.getSimpleName(); + private static final boolean DEBUG = false; + private static final int INVALID_INDEX = -1; + private static final int POSITIONS_COORDS_PER_VERTEX = 3; + private static final int TEXTURE_COORDS_PER_VERTEX = 2; + private final String fileName; + + // Since .obj doesn't tie together texture coordinates and vertex + // coordinates, but OpenGL does, we need to keep a map of all such pairings that occur in + // our face list. + private final HashMap vertexTexCoordMap; + + // Internal (de-coupled) unique vertices and texture coordinates + private ArrayList vertices; + private ArrayList textureCoords; + + // Data we expose to openGL for rendering + private float[] finalizedVertices; + private float[] finalizedTextureCoords; + private ArrayList finalizedTriangles; + + // So we only display warnings about dropped w-coordinates once + private boolean vertexCoordIgnoredWarning; + private boolean textureCoordIgnoredWarning; + private boolean startedProcessingFaces; + + private int numPrimitiveVertices; + private int numPrimitiveTextureCoords; + private int numPrimitiveFaces; + + // For scratchwork, so we don't have to keep reallocating + private float[] tempCoords; + + // We scale all our position coordinates uniformly by this factor + private float objectUniformScaleFactor; + + public SimpleObjParser(String objFile, float scaleFactor) { + objectUniformScaleFactor = scaleFactor; + + fileName = objFile; + vertices = new ArrayList(); + textureCoords = new ArrayList(); + + vertexTexCoordMap = new HashMap(); + finalizedTriangles = new ArrayList(); + + tempCoords = new float[Math.max(POSITIONS_COORDS_PER_VERTEX, TEXTURE_COORDS_PER_VERTEX)]; + numPrimitiveFaces = 0; + + vertexCoordIgnoredWarning = false; + textureCoordIgnoredWarning = false; + startedProcessingFaces = false; + } + + // Simple helper wrapper function + private void debugLogString(String message) { + if (DEBUG) { + System.out.println(message); + } + } + + private void parseVertex(String[] linePieces) { + // Note: Traditionally xyzw is acceptable as a format, with w defaulting to 1.0, but for now + // we only parse xyz. + if (linePieces.length < POSITIONS_COORDS_PER_VERTEX + 1 + || linePieces.length > POSITIONS_COORDS_PER_VERTEX + 2) { + System.out.println("Malformed vertex coordinate specification, assuming xyz format only."); + return; + } else if (linePieces.length == POSITIONS_COORDS_PER_VERTEX + 2 && !vertexCoordIgnoredWarning) { + System.out.println( + "Only x, y, and z parsed for vertex coordinates; w coordinates will be ignored."); + vertexCoordIgnoredWarning = true; + } + + boolean success = true; + try { + for (int i = 1; i < POSITIONS_COORDS_PER_VERTEX + 1; i++) { + tempCoords[i - 1] = Float.parseFloat(linePieces[i]); + } + } catch (NumberFormatException e) { + success = false; + System.out.println("Malformed vertex coordinate error: " + e.toString()); + } + + if (success) { + for (int i = 0; i < POSITIONS_COORDS_PER_VERTEX; i++) { + vertices.add(Float.valueOf(tempCoords[i] * objectUniformScaleFactor)); + } + } + } + + private void parseTextureCoordinate(String[] linePieces) { + // Similar to vertices, uvw is acceptable as a format, with w defaulting to 0.0, but for now we + // only parse uv. + if (linePieces.length < TEXTURE_COORDS_PER_VERTEX + 1 + || linePieces.length > TEXTURE_COORDS_PER_VERTEX + 2) { + System.out.println("Malformed texture coordinate specification, assuming uv format only."); + return; + } else if (linePieces.length == (TEXTURE_COORDS_PER_VERTEX + 2) + && !textureCoordIgnoredWarning) { + debugLogString("Only u and v parsed for texture coordinates; w coordinates will be ignored."); + textureCoordIgnoredWarning = true; + } + + boolean success = true; + try { + for (int i = 1; i < TEXTURE_COORDS_PER_VERTEX + 1; i++) { + tempCoords[i - 1] = Float.parseFloat(linePieces[i]); + } + } catch (NumberFormatException e) { + success = false; + System.out.println("Malformed texture coordinate error: " + e.toString()); + } + + if (success) { + // .obj files treat (0,0) as top-left, compared to bottom-left for openGL. So invert "v" + // texture coordinate only here. + textureCoords.add(Float.valueOf(tempCoords[0])); + textureCoords.add(Float.valueOf(1.0f - tempCoords[1])); + } + } + + // Will return INVALID_INDEX if error occurs, and otherwise will return finalized (combined) + // index, adding and hashing new combinations as it sees them. + private short parseAndProcessCombinedVertexCoord(String coordString) { + String[] coords = coordString.split("/"); + try { + // Parse vertex and texture indices; 1-indexed from front if positive and from end of list if + // negative. + short vertexIndex = Short.parseShort(coords[0]); + short textureIndex = Short.parseShort(coords[1]); + if (vertexIndex > 0) { + vertexIndex--; + } else { + vertexIndex = (short) (vertexIndex + numPrimitiveVertices); + } + if (textureIndex > 0) { + textureIndex--; + } else { + textureIndex = (short) (textureIndex + numPrimitiveTextureCoords); + } + + // Combine indices and look up in pair map. + ShortPair indexPair = new ShortPair(Short.valueOf(vertexIndex), Short.valueOf(textureIndex)); + Short combinedIndex = vertexTexCoordMap.get(indexPair); + if (combinedIndex == null) { + short numIndexPairs = (short) vertexTexCoordMap.size(); + vertexTexCoordMap.put(indexPair, numIndexPairs); + return numIndexPairs; + } else { + return combinedIndex.shortValue(); + } + } catch (NumberFormatException e) { + // Failure to parse coordinates as shorts + return INVALID_INDEX; + } + } + + // Note: it is assumed that face list occurs AFTER vertex and texture coordinate lists finish in + // the obj file format. + private void parseFace(String[] linePieces) { + if (linePieces.length < 4) { + System.out.println("Malformed face index list: there must be at least 3 indices per face"); + return; + } + + short[] faceIndices = new short[linePieces.length - 1]; + boolean success = true; + for (int i = 1; i < linePieces.length; i++) { + short faceIndex = parseAndProcessCombinedVertexCoord(linePieces[i]); + + if (faceIndex < 0) { + System.out.println(faceIndex); + System.out.println("Malformed face index: " + linePieces[i]); + success = false; + break; + } + faceIndices[i - 1] = faceIndex; + } + + if (success) { + numPrimitiveFaces++; + // Manually triangulate the face under the assumption that the points are coplanar, the poly + // is convex, and the points are listed in either clockwise or anti-clockwise orientation. + for (int i = 1; i < faceIndices.length - 1; i++) { + // We use a triangle fan here, so first point is part of all triangles + finalizedTriangles.add(faceIndices[0]); + finalizedTriangles.add(faceIndices[i]); + finalizedTriangles.add(faceIndices[i + 1]); + } + } + } + + // Iterate over map and reconstruct proper vertex/texture coordinate pairings. + private boolean constructFinalCoordinatesFromMap() { + final int numIndexPairs = vertexTexCoordMap.size(); + // XYZ vertices and UV texture coordinates + finalizedVertices = new float[POSITIONS_COORDS_PER_VERTEX * numIndexPairs]; + finalizedTextureCoords = new float[TEXTURE_COORDS_PER_VERTEX * numIndexPairs]; + try { + for (Map.Entry entry : vertexTexCoordMap.entrySet()) { + ShortPair indexPair = entry.getKey(); + short rawVertexIndex = indexPair.getFirst().shortValue(); + short rawTexCoordIndex = indexPair.getSecond().shortValue(); + short finalIndex = entry.getValue().shortValue(); + for (int i = 0; i < POSITIONS_COORDS_PER_VERTEX; i++) { + finalizedVertices[POSITIONS_COORDS_PER_VERTEX * finalIndex + i] + = vertices.get(rawVertexIndex * POSITIONS_COORDS_PER_VERTEX + i); + } + for (int i = 0; i < TEXTURE_COORDS_PER_VERTEX; i++) { + finalizedTextureCoords[TEXTURE_COORDS_PER_VERTEX * finalIndex + i] + = textureCoords.get(rawTexCoordIndex * TEXTURE_COORDS_PER_VERTEX + i); + } + } + } catch (NumberFormatException e) { + System.out.println("Malformed index in vertex/texture coordinate mapping."); + return false; + } + return true; + } + + /** + * Returns the vertex position coordinate list (x1, y1, z1, x2, y2, z2, ...) after a successful + * call to parse(). + */ + public float[] getVertices() { + return finalizedVertices; + } + + /** + * Returns the vertex texture coordinate list (u1, v1, u2, v2, ...) after a successful call to + * parse(). + */ + public float[] getTextureCoords() { + return finalizedTextureCoords; + } + + /** + * Returns the list of indices (a1, b1, c1, a2, b2, c2, ...) after a successful call to parse(). + * Each (a, b, c) triplet specifies a triangle to be rendered, with a, b, and c Short objects used + * to index into the coordinates returned by getVertices() and getTextureCoords().

+ * For example, a Short index representing 5 should be used to index into vertices[15], + * vertices[16], and vertices[17], as well as textureCoords[10] and textureCoords[11]. + */ + public ArrayList getTriangles() { + return finalizedTriangles; + } + + /** + * Attempts to locate and read the specified .obj file, and parse it accordingly. None of the + * getter functions in this class will return valid results until a value of true is returned + * from this function. + * @return true on success. + */ + public boolean parse() { + boolean success = true; + BufferedReader reader = null; + try { + reader = Files.newBufferedReader(Paths.get(fileName), UTF_8); + String line; + while ((line = reader.readLine()) != null) { + // Skip over lines with no characters + if (line.length() < 1) { + continue; + } + + // Ignore comment lines entirely + if (line.charAt(0) == '#') { + continue; + } + + // Split into pieces based on whitespace, and process according to first command piece + String[] linePieces = line.split(" +"); + switch (linePieces[0]) { + case "v": + // Add vertex + if (startedProcessingFaces) { + throw new IOException("Vertices must all be declared before faces in obj files."); + } + parseVertex(linePieces); + break; + case "vt": + // Add texture coordinate + if (startedProcessingFaces) { + throw new IOException( + "Texture coordinates must all be declared before faces in obj files."); + } + parseTextureCoordinate(linePieces); + break; + case "f": + // Vertex and texture coordinate lists should be locked into place by now + if (!startedProcessingFaces) { + startedProcessingFaces = true; + numPrimitiveVertices = vertices.size() / POSITIONS_COORDS_PER_VERTEX; + numPrimitiveTextureCoords = textureCoords.size() / TEXTURE_COORDS_PER_VERTEX; + } + // Add face + parseFace(linePieces); + break; + default: + // Unknown or unused directive: ignoring + // Note: We do not yet process vertex normals or curves, so we ignore {vp, vn, s} + // Note: We assume only a single object, so we ignore {g, o} + // Note: We also assume a single texture, which we process independently, so we ignore + // {mtllib, usemtl} + break; + } + } + + // If we made it all the way through, then we have a vertex-to-tex-coord pair mapping, so + // construct our final vertex and texture coordinate lists now. + success = constructFinalCoordinatesFromMap(); + + } catch (IOException e) { + success = false; + System.out.println("Failure to parse obj file: " + e.toString()); + } finally { + try { + if (reader != null) { + reader.close(); + } + } catch (IOException e) { + System.out.println("Couldn't close reader"); + } + } + if (success) { + debugLogString("Successfully parsed " + numPrimitiveVertices + " vertices and " + + numPrimitiveTextureCoords + " texture coordinates into " + vertexTexCoordMap.size() + + " combined vertices and " + numPrimitiveFaces + " faces, represented as a mesh of " + + finalizedTriangles.size() / 3 + " triangles."); + } + return success; + } +} diff --git a/mediapipe/graphs/object_detection_3d/obj_parser/obj_cleanup.sh b/mediapipe/graphs/object_detection_3d/obj_parser/obj_cleanup.sh new file mode 100755 index 000000000..1573387a4 --- /dev/null +++ b/mediapipe/graphs/object_detection_3d/obj_parser/obj_cleanup.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# The SimpleObjParser expects the obj commands to follow v/vt/f order. This +# little script will read all the obj files in a directory and sort the +# existing obj commands inside them to also follow this order (so all v lines +# will appear before all vt lines, which will appear before all f lines). + +# Usage: ./obj_cleanup.sh input_folder output_folder +# input_folder and output_folder paths can be absolute or relative. + +input_folder=$1 +output_folder=$2 +if [[ "${input_folder}" == "" ]]; then + echo "input_folder must be defined. Usage: ./obj_cleanup.sh input_folder output_folder" + exit 1 +fi +if [[ "${output_folder}" == "" ]]; then + echo "output_folder must be defined. Usage: ./obj_cleanup.sh input_folder output_folder" + exit 1 +fi + +# Find all the obj files and remove the directory name +# Interestingly, piping | sed 's!.obj!! also removed the extension obj too. +find "${input_folder}" -name "*.obj" | sed 's!.*/!!' | sort | +while IFS= read -r filename; do + echo "Clean up ${filename}" + cat "${input_folder}/${filename}" | grep 'v ' > "${output_folder}/${filename}" + cat "${input_folder}/${filename}" | grep 'vt ' >> "${output_folder}/${filename}" + cat "${input_folder}/${filename}" | grep 'f ' >> "${output_folder}/${filename}" +done diff --git a/mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt b/mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt index 31efc2da6..bc8b78b34 100644 --- a/mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt +++ b/mediapipe/graphs/object_detection_3d/objectron_desktop_cpu.pbtxt @@ -23,22 +23,11 @@ node { output_stream: "VIDEO_PRESTREAM:input_video_header" } -node { - calculator: "LocalFileContentsCalculator" - input_side_packet: "FILE_PATH:0:box_landmark_model_path" - output_side_packet: "CONTENTS:0:box_landmark_model_blob" -} - -node { - calculator: "TfLiteModelCalculator" - input_side_packet: "MODEL_BLOB:box_landmark_model_blob" - output_side_packet: "MODEL:box_landmark_model" -} - +# Run Objectron subgraph. node { calculator: "ObjectronCpuSubgraph" input_stream: "IMAGE:input_video" - input_side_packet: "MODEL:box_landmark_model" + input_side_packet: "MODEL_PATH:box_landmark_model_path" input_side_packet: "LABELS_CSV:allowed_labels" input_side_packet: "MAX_NUM_OBJECTS:max_num_objects" output_stream: "MULTI_LANDMARKS:box_landmarks" diff --git a/mediapipe/java/com/google/mediapipe/framework/BUILD b/mediapipe/java/com/google/mediapipe/framework/BUILD index dd5ae2e2a..ed1a42c20 100644 --- a/mediapipe/java/com/google/mediapipe/framework/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/BUILD @@ -18,7 +18,6 @@ licenses(["notice"]) exports_files([ "proguard.pgcfg", - "proguard_allowobfuscation.pgcfg", ]) android_library( @@ -93,6 +92,7 @@ android_library( deps = [ ":framework_proto_lite", ":mediapipe_exception_android", + "@com_google_protobuf//:protobuf_javalite", "@maven//:com_google_code_findbugs_jsr305", "@maven//:com_google_flogger_flogger", "@maven//:com_google_flogger_flogger_system_backend", @@ -119,6 +119,7 @@ android_library( "//mediapipe/framework:calculator_profile_java_proto_lite", "//mediapipe/framework:stream_handler_java_proto_lite", "//mediapipe/framework/tool:calculator_graph_template_java_proto_lite", + "@com_google_protobuf//:protobuf_javalite", "@maven//:com_google_guava_guava", ], ) diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc index 955947559..dde43f567 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.cc @@ -39,9 +39,9 @@ #else #include "mediapipe/framework/port/file_helpers.h" #endif // __ANDROID__ -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/egl_surface_holder.h" -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU namespace mediapipe { namespace android { @@ -172,10 +172,10 @@ bool Graph::RemovePacket(int64_t packet_handle) { void Graph::EnsureMinimumExecutorStackSizeForJava() {} -mediapipe::Status Graph::AddCallbackHandler(std::string output_stream_name, - jobject java_callback) { +absl::Status Graph::AddCallbackHandler(std::string output_stream_name, + jobject java_callback) { if (!graph_config()) { - return mediapipe::InternalError("Graph is not loaded!"); + return absl::InternalError("Graph is not loaded!"); } std::unique_ptr handler( new internal::CallbackHandler(this, java_callback)); @@ -188,7 +188,7 @@ mediapipe::Status Graph::AddCallbackHandler(std::string output_stream_name, side_packet_name, MakePacket>( handler->CreateCallback())); callback_handlers_.emplace_back(std::move(handler)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { @@ -197,7 +197,7 @@ int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { return 0; } -#ifdef MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_DISABLE_GPU LOG(FATAL) << "GPU support has been disabled in this build!"; #else CalculatorGraphConfig::Node* sink_node = graph_config()->add_node(); @@ -219,12 +219,12 @@ int64_t Graph::AddSurfaceOutput(const std::string& output_stream_name) { AdoptAsUniquePtr(new mediapipe::EglSurfaceHolder())); return WrapPacketIntoContext(it_inserted.first->second); -#endif // defined(MEDIAPIPE_DISABLE_GPU) +#endif // MEDIAPIPE_DISABLE_GPU } -mediapipe::Status Graph::LoadBinaryGraph(std::string path_to_graph) { +absl::Status Graph::LoadBinaryGraph(std::string path_to_graph) { std::string graph_config_string; - mediapipe::Status status = + absl::Status status = mediapipe::file::GetContents(path_to_graph, &graph_config_string); if (!status.ok()) { return status; @@ -233,39 +233,39 @@ mediapipe::Status Graph::LoadBinaryGraph(std::string path_to_graph) { graph_config_string.length()); } -mediapipe::Status Graph::LoadBinaryGraph(const char* data, int size) { +absl::Status Graph::LoadBinaryGraph(const char* data, int size) { CalculatorGraphConfig graph_config; if (!graph_config.ParseFromArray(data, size)) { - return mediapipe::InvalidArgumentError("Failed to parse the graph"); + return absl::InvalidArgumentError("Failed to parse the graph"); } graph_configs_.push_back(graph_config); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Graph::LoadBinaryGraphTemplate(const char* data, int size) { +absl::Status Graph::LoadBinaryGraphTemplate(const char* data, int size) { CalculatorGraphTemplate graph_template; if (!graph_template.ParseFromArray(data, size)) { - return mediapipe::InvalidArgumentError("Failed to parse the graph"); + return absl::InvalidArgumentError("Failed to parse the graph"); } graph_templates_.push_back(graph_template); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Graph::SetGraphType(std::string graph_type) { +absl::Status Graph::SetGraphType(std::string graph_type) { graph_type_ = graph_type; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Graph::SetGraphOptions(const char* data, int size) { +absl::Status Graph::SetGraphOptions(const char* data, int size) { if (!graph_options_.ParseFromArray(data, size)) { - return mediapipe::InvalidArgumentError("Failed to parse the graph"); + return absl::InvalidArgumentError("Failed to parse the graph"); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } CalculatorGraphConfig Graph::GetCalculatorGraphConfig() { CalculatorGraph temp_graph; - mediapipe::Status status = InitializeGraph(&temp_graph); + absl::Status status = InitializeGraph(&temp_graph); if (!status.ok()) { LOG(ERROR) << "GetCalculatorGraphConfig failed:\n" << status.message(); } @@ -344,14 +344,14 @@ void Graph::SetPacketJavaClass(JNIEnv* env) { } } -mediapipe::Status Graph::RunGraphUntilClose(JNIEnv* env) { +absl::Status Graph::RunGraphUntilClose(JNIEnv* env) { // Get a global reference to the packet class, so it can be used in other // native thread for call back. SetPacketJavaClass(env); // Running as a synchronized mode, the same Java thread is available through // out the run. CalculatorGraph calculator_graph; - mediapipe::Status status = InitializeGraph(&calculator_graph); + absl::Status status = InitializeGraph(&calculator_graph); if (!status.ok()) { LOG(ERROR) << status.message(); running_graph_.reset(nullptr); @@ -364,9 +364,9 @@ mediapipe::Status Graph::RunGraphUntilClose(JNIEnv* env) { return status; } -mediapipe::Status Graph::StartRunningGraph(JNIEnv* env) { +absl::Status Graph::StartRunningGraph(JNIEnv* env) { if (running_graph_) { - return mediapipe::InternalError("Graph is already running."); + return absl::InternalError("Graph is already running."); } // Get a global reference to the packet class, so it can be used in other // native thread for call back. @@ -382,15 +382,15 @@ mediapipe::Status Graph::StartRunningGraph(JNIEnv* env) { LOG(INFO) << name; } } - mediapipe::Status status; -#ifndef MEDIAPIPE_DISABLE_GPU + absl::Status status; +#if !MEDIAPIPE_DISABLE_GPU status = running_graph_->SetGpuResources(gpu_resources_); if (!status.ok()) { LOG(ERROR) << status.message(); running_graph_.reset(nullptr); return status; } -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU for (const auto& service_packet : service_packets_) { status = running_graph_->SetServicePacket(*service_packet.first, @@ -416,10 +416,10 @@ mediapipe::Status Graph::StartRunningGraph(JNIEnv* env) { running_graph_.reset(nullptr); return status; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Graph::SetTimestampAndMovePacketToInputStream( +absl::Status Graph::SetTimestampAndMovePacketToInputStream( const std::string& stream_name, int64_t packet_handle, int64_t timestamp) { internal::PacketWithContext* packet_with_context = reinterpret_cast(packet_handle); @@ -433,60 +433,60 @@ mediapipe::Status Graph::SetTimestampAndMovePacketToInputStream( return AddPacketToInputStream(stream_name, std::move(packet)); } -mediapipe::Status Graph::AddPacketToInputStream(const std::string& stream_name, - const Packet& packet) { +absl::Status Graph::AddPacketToInputStream(const std::string& stream_name, + const Packet& packet) { if (!running_graph_) { - return mediapipe::FailedPreconditionError("Graph must be running."); + return absl::FailedPreconditionError("Graph must be running."); } return running_graph_->AddPacketToInputStream(stream_name, packet); } -mediapipe::Status Graph::AddPacketToInputStream(const std::string& stream_name, - Packet&& packet) { +absl::Status Graph::AddPacketToInputStream(const std::string& stream_name, + Packet&& packet) { if (!running_graph_) { - return mediapipe::FailedPreconditionError("Graph must be running."); + return absl::FailedPreconditionError("Graph must be running."); } return running_graph_->AddPacketToInputStream(stream_name, std::move(packet)); } -mediapipe::Status Graph::CloseInputStream(std::string stream_name) { +absl::Status Graph::CloseInputStream(std::string stream_name) { if (!running_graph_) { - return mediapipe::FailedPreconditionError("Graph must be running."); + return absl::FailedPreconditionError("Graph must be running."); } LOG(INFO) << "Close input stream: " << stream_name; return running_graph_->CloseInputStream(stream_name); } -mediapipe::Status Graph::CloseAllInputStreams() { +absl::Status Graph::CloseAllInputStreams() { LOG(INFO) << "Close all input streams."; if (!running_graph_) { - return mediapipe::FailedPreconditionError("Graph must be running."); + return absl::FailedPreconditionError("Graph must be running."); } return running_graph_->CloseAllInputStreams(); } -mediapipe::Status Graph::CloseAllPacketSources() { +absl::Status Graph::CloseAllPacketSources() { LOG(INFO) << "Close all input streams."; if (!running_graph_) { - return mediapipe::FailedPreconditionError("Graph must be running."); + return absl::FailedPreconditionError("Graph must be running."); } return running_graph_->CloseAllPacketSources(); } -mediapipe::Status Graph::WaitUntilDone(JNIEnv* env) { +absl::Status Graph::WaitUntilDone(JNIEnv* env) { if (!running_graph_) { - return mediapipe::FailedPreconditionError("Graph must be running."); + return absl::FailedPreconditionError("Graph must be running."); } - mediapipe::Status status = running_graph_->WaitUntilDone(); + absl::Status status = running_graph_->WaitUntilDone(); running_graph_.reset(nullptr); return status; } -mediapipe::Status Graph::WaitUntilIdle(JNIEnv* env) { +absl::Status Graph::WaitUntilIdle(JNIEnv* env) { if (!running_graph_) { - return mediapipe::FailedPreconditionError("Graph must be running."); + return absl::FailedPreconditionError("Graph must be running."); } return running_graph_->WaitUntilIdle(); } @@ -511,20 +511,20 @@ mediapipe::GpuResources* Graph::GetGpuResources() const { return gpu_resources_.get(); } -mediapipe::Status Graph::SetParentGlContext(int64 java_gl_context) { +absl::Status Graph::SetParentGlContext(int64 java_gl_context) { if (gpu_resources_) { - return mediapipe::AlreadyExistsError( + return absl::AlreadyExistsError( "trying to set the parent GL context, but the gpu shared " "data has already been set up."); } -#ifdef MEDIAPIPE_DISABLE_GPU +#if MEDIAPIPE_DISABLE_GPU LOG(FATAL) << "GPU support has been disabled in this build!"; #else gpu_resources_ = mediapipe::GpuResources::Create( reinterpret_cast(java_gl_context)) - .ValueOrDie(); -#endif // defined(MEDIAPIPE_DISABLE_GPU) - return mediapipe::OkStatus(); + .value(); +#endif // MEDIAPIPE_DISABLE_GPU + return absl::OkStatus(); } void Graph::SetServicePacket(const GraphServiceBase& service, Packet packet) { @@ -583,7 +583,7 @@ std::string Graph::graph_type() { return ""; } -mediapipe::Status Graph::InitializeGraph(CalculatorGraph* graph) { +absl::Status Graph::InitializeGraph(CalculatorGraph* graph) { if (graph_configs_.size() == 1 && graph_templates_.empty()) { return graph->Initialize(*graph_config()); } else { diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h index 43da503ff..488920f8e 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph.h @@ -25,9 +25,9 @@ #include #include "mediapipe/framework/calculator_framework.h" -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU #include "absl/synchronization/mutex.h" #include "mediapipe/gpu/gpu_shared_data_internal.h" @@ -49,26 +49,26 @@ class Graph { ~Graph(); // Adds a callback for a given stream name. - mediapipe::Status AddCallbackHandler(std::string output_stream_name, - jobject java_callback); + absl::Status AddCallbackHandler(std::string output_stream_name, + jobject java_callback); // Loads a binary graph from a file. - mediapipe::Status LoadBinaryGraph(std::string path_to_graph); + absl::Status LoadBinaryGraph(std::string path_to_graph); // Loads a binary graph from a buffer. - mediapipe::Status LoadBinaryGraph(const char* data, int size); + absl::Status LoadBinaryGraph(const char* data, int size); // Loads a binary graph template from a buffer. - mediapipe::Status LoadBinaryGraphTemplate(const char* data, int size); + absl::Status LoadBinaryGraphTemplate(const char* data, int size); // Specifies the CalculatorGraphConfig::type of the top level graph. - mediapipe::Status SetGraphType(std::string graph_type); + absl::Status SetGraphType(std::string graph_type); // Specifies options such as template arguments for the graph. - mediapipe::Status SetGraphOptions(const char* data, int size); + absl::Status SetGraphOptions(const char* data, int size); // Returns the expanded calculator graph config. CalculatorGraphConfig GetCalculatorGraphConfig(); // Runs the graph until it closes. // Mainly is used for writing tests. - mediapipe::Status RunGraphUntilClose(JNIEnv* env); + absl::Status RunGraphUntilClose(JNIEnv* env); // The following 4 functions are used to run the graph in // step by step mode, the usual call sequence is like this: @@ -81,26 +81,26 @@ class Graph { // wait until nothing is running and nothing can be scheduled. // // Starts running the graph. - mediapipe::Status StartRunningGraph(JNIEnv* env); + absl::Status StartRunningGraph(JNIEnv* env); // Closes one input stream. - mediapipe::Status CloseInputStream(std::string stream_name); + absl::Status CloseInputStream(std::string stream_name); // Closes all the graph input streams. - mediapipe::Status CloseAllInputStreams(); + absl::Status CloseAllInputStreams(); // Closes all the graph packet sources. - mediapipe::Status CloseAllPacketSources(); + absl::Status CloseAllPacketSources(); // Waits util graph is done. - mediapipe::Status WaitUntilDone(JNIEnv* env); + absl::Status WaitUntilDone(JNIEnv* env); // Waits util graph is idle. - mediapipe::Status WaitUntilIdle(JNIEnv* env); + absl::Status WaitUntilIdle(JNIEnv* env); // Adds a packet to an input stream. - mediapipe::Status AddPacketToInputStream(const std::string& stream_name, - const Packet& packet); + absl::Status AddPacketToInputStream(const std::string& stream_name, + const Packet& packet); // Moves a packet into an input stream. - mediapipe::Status AddPacketToInputStream(const std::string& stream_name, - Packet&& packet); + absl::Status AddPacketToInputStream(const std::string& stream_name, + Packet&& packet); // Takes the MediaPipe Packet referenced by the handle, sets its timestamp, // and then tries to move the Packet into the given input stream. - mediapipe::Status SetTimestampAndMovePacketToInputStream( + absl::Status SetTimestampAndMovePacketToInputStream( const std::string& stream_name, int64_t packet_handle, int64_t timestamp); // Sets the mode for adding packets to a graph input stream. @@ -127,7 +127,7 @@ class Graph { int64_t AddSurfaceOutput(const std::string& stream_name); // Sets a parent GL context to use for texture sharing. - mediapipe::Status SetParentGlContext(int64 java_gl_context); + absl::Status SetParentGlContext(int64 java_gl_context); // Sets the object for a service. template @@ -176,7 +176,7 @@ class Graph { // CalculatorGraphConfig::type is not yet defined. std::string graph_type(); // Initializes CalculatorGraph |graph| using the loaded graph-configs. - mediapipe::Status InitializeGraph(CalculatorGraph* graph); + absl::Status InitializeGraph(CalculatorGraph* graph); // CalculatorGraphConfigs for the calculator graph and subgraphs. std::vector graph_configs_; diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc index c3e981570..a9ed0ccc8 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_jni.cc @@ -27,12 +27,12 @@ using mediapipe::android::JStringToStdString; using mediapipe::android::ThrowIfError; namespace { -mediapipe::Status AddSidePacketsIntoGraph( - mediapipe::android::Graph* mediapipe_graph, JNIEnv* env, - jobjectArray stream_names, jlongArray packets) { +absl::Status AddSidePacketsIntoGraph(mediapipe::android::Graph* mediapipe_graph, + JNIEnv* env, jobjectArray stream_names, + jlongArray packets) { jsize num_side_packets = env->GetArrayLength(stream_names); if (num_side_packets != env->GetArrayLength(packets)) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "Number of streams and packets doesn't match!"); } // Note, packets_array_ref is really a const jlong* but this clashes with the @@ -47,16 +47,16 @@ mediapipe::Status AddSidePacketsIntoGraph( env->DeleteLocalRef(name); } env->ReleaseLongArrayElements(packets, packets_array_ref, JNI_ABORT); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AddStreamHeadersIntoGraph( +absl::Status AddStreamHeadersIntoGraph( mediapipe::android::Graph* mediapipe_graph, JNIEnv* env, jobjectArray stream_names, jlongArray packets) { jsize num_headers = env->GetArrayLength(stream_names); if (num_headers != env->GetArrayLength(packets)) { - return mediapipe::Status(mediapipe::StatusCode::kFailedPrecondition, - "Number of streams and packets doesn't match!"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Number of streams and packets doesn't match!"); } jlong* packets_array_ref = env->GetLongArrayElements(packets, nullptr); for (jsize i = 0; i < num_headers; ++i) { @@ -68,7 +68,7 @@ mediapipe::Status AddStreamHeadersIntoGraph( env->DeleteLocalRef(name); } env->ReleaseLongArrayElements(packets, packets_array_ref, JNI_ABORT); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace @@ -106,7 +106,7 @@ JNIEXPORT void JNICALL GRAPH_METHOD(nativeLoadBinaryGraphBytes)( reinterpret_cast(context); jbyte* data_ptr = env->GetByteArrayElements(data, nullptr); int size = env->GetArrayLength(data); - mediapipe::Status status = + absl::Status status = mediapipe_graph->LoadBinaryGraph(reinterpret_cast(data_ptr), size); env->ReleaseByteArrayElements(data, data_ptr, JNI_ABORT); ThrowIfError(env, status); @@ -118,7 +118,7 @@ JNIEXPORT void JNICALL GRAPH_METHOD(nativeLoadBinaryGraphTemplate)( reinterpret_cast(context); jbyte* data_ptr = env->GetByteArrayElements(data, nullptr); int size = env->GetArrayLength(data); - mediapipe::Status status = mediapipe_graph->LoadBinaryGraphTemplate( + absl::Status status = mediapipe_graph->LoadBinaryGraphTemplate( reinterpret_cast(data_ptr), size); env->ReleaseByteArrayElements(data, data_ptr, JNI_ABORT); ThrowIfError(env, status); @@ -145,7 +145,7 @@ JNIEXPORT void JNICALL GRAPH_METHOD(nativeSetGraphOptions)(JNIEnv* env, reinterpret_cast(context); jbyte* data_ptr = env->GetByteArrayElements(data, nullptr); int size = env->GetArrayLength(data); - mediapipe::Status status = + absl::Status status = mediapipe_graph->SetGraphOptions(reinterpret_cast(data_ptr), size); env->ReleaseByteArrayElements(data, data_ptr, JNI_ABORT); ThrowIfError(env, status); @@ -179,8 +179,8 @@ GRAPH_METHOD(nativeAddPacketCallback)(JNIEnv* env, jobject thiz, jlong context, // be accessed later. jobject global_callback_ref = env->NewGlobalRef(callback); if (!global_callback_ref) { - ThrowIfError( - env, mediapipe::InternalError("Failed to allocate packet callback")); + ThrowIfError(env, + absl::InternalError("Failed to allocate packet callback")); return; } ThrowIfError(env, mediapipe_graph->AddCallbackHandler(output_stream_name, diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.cc index 494609e29..6f9df3bee 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/graph_profiler_jni.cc @@ -45,7 +45,7 @@ JNIEXPORT jobjectArray JNICALL GRAPH_METHOD(nativeGetCalculatorProfiles)( std::vector profiles_vec; if (profiling_context->GetCalculatorProfiles(&profiles_vec) != - mediapipe::OkStatus()) { + absl::OkStatus()) { return nullptr; } int num_profiles = profiles_vec.size(); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc index 10d6852d9..08b340495 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.cc @@ -111,7 +111,7 @@ std::string JStringToStdString(JNIEnv* env, jstring jstr) { return str; } -jthrowable CreateMediaPipeException(JNIEnv* env, mediapipe::Status status) { +jthrowable CreateMediaPipeException(JNIEnv* env, absl::Status status) { auto& class_registry = mediapipe::android::ClassRegistry::GetInstance(); std::string mpe_class_name = class_registry.GetClassName( mediapipe::android::ClassRegistry::kMediaPipeExceptionClassName); @@ -131,7 +131,7 @@ jthrowable CreateMediaPipeException(JNIEnv* env, mediapipe::Status status) { env->NewObject(status_cls, status_ctr, status.code(), message_bytes)); } -bool ThrowIfError(JNIEnv* env, mediapipe::Status status) { +bool ThrowIfError(JNIEnv* env, absl::Status status) { if (!status.ok()) { env->Throw(mediapipe::android::CreateMediaPipeException(env, status)); return true; diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h index f52e142ee..2524467ff 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h +++ b/mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h @@ -27,12 +27,12 @@ namespace android { std::string JStringToStdString(JNIEnv* env, jstring jstr); -// Creates a java MediaPipeException object for a mediapipe::Status. -jthrowable CreateMediaPipeException(JNIEnv* env, mediapipe::Status status); +// Creates a java MediaPipeException object for a absl::Status. +jthrowable CreateMediaPipeException(JNIEnv* env, absl::Status status); -// Throws a MediaPipeException for any non-ok mediapipe::Status. +// Throws a MediaPipeException for any non-ok absl::Status. // Note that the exception is thrown after execution returns to Java. -bool ThrowIfError(JNIEnv* env, mediapipe::Status status); +bool ThrowIfError(JNIEnv* env, absl::Status status); // The Jni ids for Java class SerializedMessage. class SerializedMessageIds { diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc index ecf388000..c1bfd09ac 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_creator_jni.cc @@ -29,9 +29,9 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/graph.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU namespace { using mediapipe::android::SerializedMessageIds; @@ -300,7 +300,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateMatrix)( return CreatePacketWithContext(context, packet); } -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuBuffer)( JNIEnv* env, jobject thiz, jlong context, jint name, jint width, @@ -351,7 +351,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateGpuBuffer)( return CreatePacketWithContext(context, packet); } -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU // TODO: Add vector creators. @@ -450,7 +450,7 @@ JNIEXPORT jlong JNICALL PACKET_CREATOR_METHOD(nativeCreateProto)(JNIEnv* env, auto packet_or = mediapipe::packet_internal::PacketFromDynamicProto( type_name, std::string((char*)value_ref, value_len)); if (!ThrowIfError(env, packet_or.status())) { - packet = packet_or.ValueOrDie(); + packet = packet_or.value(); } env->ReleaseByteArrayElements(value_array, value_ref, JNI_ABORT); return CreatePacketWithContext(context, packet); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc index 9391bce2f..0b40dd642 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/packet_getter_jni.cc @@ -24,9 +24,9 @@ #include "mediapipe/java/com/google/mediapipe/framework/jni/colorspace.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/graph.h" #include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h" -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU #include "mediapipe/gpu/gl_calculator_helper.h" -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU namespace { using mediapipe::android::SerializedMessageIds; @@ -152,7 +152,7 @@ JNIEXPORT void JNICALL PACKET_GETTER_METHOD(nativeGetProto)(JNIEnv* env, jobject result) { mediapipe::Packet mediapipe_packet = mediapipe::android::Graph::GetPacketFromHandle(packet); - mediapipe::Status status = mediapipe_packet.ValidateAsProtoMessageLite(); + absl::Status status = mediapipe_packet.ValidateAsProtoMessageLite(); if (!ThrowIfError(env, status)) { // Convert type_name and value to Java data. const auto& proto_message = mediapipe_packet.GetProtoMessageLite(); @@ -182,7 +182,7 @@ JNIEXPORT jobjectArray JNICALL PACKET_GETTER_METHOD(nativeGetProtoVector)( env, get_proto_vector.status())); } const std::vector& proto_vector = - get_proto_vector.ValueOrDie(); + get_proto_vector.value(); jobjectArray proto_array = env->NewObjectArray(proto_vector.size(), env->FindClass("[B"), nullptr); for (int i = 0; i < proto_vector.size(); ++i) { @@ -403,7 +403,7 @@ JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetMatrixCols)(JNIEnv* env, return GetFromNativeHandle(packet).cols(); } -#ifndef MEDIAPIPE_DISABLE_GPU +#if !MEDIAPIPE_DISABLE_GPU JNIEXPORT jint JNICALL PACKET_GETTER_METHOD(nativeGetGpuBufferName)( JNIEnv* env, jobject thiz, jlong packet) { @@ -427,4 +427,4 @@ JNIEXPORT jlong JNICALL PACKET_GETTER_METHOD(nativeGetGpuBuffer)(JNIEnv* env, new mediapipe::GlTextureBufferSharedPtr(ptr)); } -#endif // !defined(MEDIAPIPE_DISABLE_GPU) +#endif // !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc index 38148be89..d687fbecd 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/register_natives.cc @@ -184,6 +184,12 @@ void RegisterPacketCreatorNatives(JNIEnv *env) { AddJNINativeMethod(&packet_creator_methods, packet_creator, "nativeCreateInt32", "(JI)J", (void *)&PACKET_CREATOR_METHOD(nativeCreateInt32)); + AddJNINativeMethod(&packet_creator_methods, packet_creator, + "nativeCreateFloat32", "(JF)J", + (void *)&PACKET_CREATOR_METHOD(nativeCreateFloat32)); + AddJNINativeMethod(&packet_creator_methods, packet_creator, + "nativeCreateBool", "(JZ)J", + (void *)&PACKET_CREATOR_METHOD(nativeCreateBool)); RegisterNativesVector(env, packet_creator_class, packet_creator_methods); } @@ -203,6 +209,12 @@ void RegisterPacketGetterNatives(JNIEnv *env) { AddJNINativeMethod(&packet_getter_methods, packet_getter, "nativeGetImageData", "(JLjava/nio/ByteBuffer;)Z", (void *)&PACKET_GETTER_METHOD(nativeGetImageData)); + AddJNINativeMethod(&packet_getter_methods, packet_getter, + "nativeGetImageWidth", "(J)I", + (void *)&PACKET_GETTER_METHOD(nativeGetImageWidth)); + AddJNINativeMethod(&packet_getter_methods, packet_getter, + "nativeGetImageHeight", "(J)I", + (void *)&PACKET_GETTER_METHOD(nativeGetImageHeight)); AddJNINativeMethod(&packet_getter_methods, packet_getter, "nativeGetFloat32Vector", "(J)[F", (void *)&PACKET_GETTER_METHOD(nativeGetFloat32Vector)); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc b/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc index b64788dd3..51d693b20 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc +++ b/mediapipe/java/com/google/mediapipe/framework/jni/surface_output_jni.cc @@ -62,7 +62,7 @@ JNIEXPORT void JNICALL MEDIAPIPE_SURFACE_OUTPUT_METHOD(nativeSetSurface)( } auto status = gl_context->Run( - [gl_context, surface_holder, surface, window]() -> mediapipe::Status { + [gl_context, surface_holder, surface, window]() -> absl::Status { absl::MutexLock lock(&surface_holder->mutex); // Must destroy old surface first in case we are assigning the same // surface. @@ -90,7 +90,7 @@ JNIEXPORT void JNICALL MEDIAPIPE_SURFACE_OUTPUT_METHOD(nativeSetSurface)( } surface_holder->surface = egl_surface; surface_holder->owned = egl_surface != EGL_NO_SURFACE; - return mediapipe::OkStatus(); + return absl::OkStatus(); }); MEDIAPIPE_CHECK_OK(status); @@ -122,10 +122,10 @@ JNIEXPORT void JNICALL MEDIAPIPE_SURFACE_OUTPUT_METHOD(nativeSetEglSurface)( if (old_surface != EGL_NO_SURFACE) { MEDIAPIPE_CHECK_OK( - gl_context->Run([gl_context, old_surface]() -> mediapipe::Status { + gl_context->Run([gl_context, old_surface]() -> absl::Status { RET_CHECK(eglDestroySurface(gl_context->egl_display(), old_surface)) << "eglDestroySurface failed:" << eglGetError(); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); } } diff --git a/mediapipe/java/com/google/mediapipe/framework/proguard_allowobfuscation.pgcfg b/mediapipe/java/com/google/mediapipe/framework/proguard_allowobfuscation.pgcfg deleted file mode 100644 index 1de918205..000000000 --- a/mediapipe/java/com/google/mediapipe/framework/proguard_allowobfuscation.pgcfg +++ /dev/null @@ -1,29 +0,0 @@ -# Additional flags to pass to Proguard when processing a binary that uses -# MediaPipe wherein classes and methods invoked from native are kept but -# they are allowed to be obfuscated. In such cases, you must use the -# go/proguard-jni system to feed the proguard obfuscation map using the -# generate_obfuscation_mapping genrule. - -# Keep public members of our public interfaces. This also prevents the -# obfuscation of the corresponding methods in classes implementing them, -# such as implementations of PacketCallback#process. --keep,allowobfuscation public interface com.google.mediapipe.framework.* { - public *; -} - -# This method is invoked by native code. --keep,allowobfuscation public class com.google.mediapipe.framework.Packet { - public static *** create(***); - public long getNativeHandle(); - public void release(); -} - -# This method is invoked by native code. --keep,allowobfuscation public class com.google.mediapipe.framework.PacketCreator { - *** releaseWithSyncToken(...); -} - -# This method is invoked by native code. --keep,allowobfuscation public class com.google.mediapipe.framework.MediaPipeException { - (int, byte[]); -} diff --git a/mediapipe/models/object_detection_oidv4_labelmap.pbtxt b/mediapipe/models/object_detection_oidv4_labelmap.pbtxt deleted file mode 100644 index 7d93e6402..000000000 --- a/mediapipe/models/object_detection_oidv4_labelmap.pbtxt +++ /dev/null @@ -1,195 +0,0 @@ -??? -Container -Ambulance -Ladder -Toothbrush -Sink -Cassette deck -Beer -Parking meter -Traffic light -Washing machine -Sunglasses -Ball -Backpack -Bicycle -Home appliance -Boat -Boot -Headphones -Bus -Screwdriver -Laptop -Teapot -Person -Swimwear -Balloon -Wrench -Vehicle registration plate -Lantern -Toaster -Flashlight -Billboard -Limousine -Necklace -Scissors -Stairs -Computer keyboard -Printer -Traffic sign -Chair -Poster -Fire hydrant -Land vehicle -Cabinetry -Suitcase -Snowmobile -Clock -Cattle -Cello -Desk -Cat -Computer mouse -Calculator -Computer monitor -Box -Stapler -Studio couch -Drum -Dice -Oven -Couch -Whiteboard -Door -Hat -Eraser -Tin can -Mug -Can opener -Goggles -Roller skates -Coffee cup -Cutting board -Blender -Stop sign -Volleyball -Vase -Slow cooker -Wardrobe -Paper towel -Sun hat -Tree house -Gas stove -Salt and pepper shakers -Mechanical fan -Fax -Nightstand -Barrel -Guitar -Pillow -Stationary bicycle -Hammer -Ceiling fan -Sofa bed -Sandal -Bicycle helmet -Bed -Kettle -Hair dryer -Kitchenware -Bookcase -Refrigerator -Alarm clock -Filing cabinet -Table -Knife -Bottle -Dumbbell -Bowl -Billiard table -Motorcycle -Frying pan -Bathroom cabinet -Plate -Mobile phone -Table tennis racket -Musical keyboard -Scoreboard -Briefcase -Kitchen knife -Piano -Pumpkin -Infant bed -Mixer -Cupboard -Digital clock -Rifle -Skateboard -High heels -Snowboard -Sword -Training bench -Coffee table -Television -Trombone -Tank -Telephone -Trumpet -Train -Picnic basket -Football helmet -Truck -Measuring cup -Coffeemaker -Violin -Vehicle -Wine -Wheel -Jug -Toilet -Clothing -Footwear -Tablet computer -Dog -Book -Candle -Hand dryer -Soap dispenser -Furniture -Airplane -Spoon -Bench -Window -Closet -Fork -Lamp -Camera -Racket -Human face -Unicycle -Flowerpot -Drawer -Stool -Microwave oven -Shelf -Handgun -Van -Corded phone -Tennis racket -Wall clock -Kitchen & dining room table -Pressure cooker -Kitchen appliance -Tire -Luggage and bags -Microphone -Glasses -Pen -Car -Aircraft -Dishwasher -Binoculars -Rays and skates -Remote control -Wheelchair -Helmet diff --git a/mediapipe/models/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite b/mediapipe/models/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite deleted file mode 100644 index fa6ad878d..000000000 Binary files a/mediapipe/models/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite and /dev/null differ diff --git a/mediapipe/modules/face_detection/BUILD b/mediapipe/modules/face_detection/BUILD index 25349e181..374cfeb58 100644 --- a/mediapipe/modules/face_detection/BUILD +++ b/mediapipe/modules/face_detection/BUILD @@ -26,12 +26,10 @@ mediapipe_simple_subgraph( graph = "face_detection_front_by_roi_cpu.pbtxt", register_as = "FaceDetectionFrontByRoiCpu", deps = [ + ":face_detection_front_common", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/calculators/tensor:tensors_to_detections_calculator", - "//mediapipe/calculators/tflite:ssd_anchors_calculator", - "//mediapipe/calculators/util:detection_projection_calculator", - "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:to_image_calculator", ], ) @@ -40,12 +38,10 @@ mediapipe_simple_subgraph( graph = "face_detection_front_by_roi_gpu.pbtxt", register_as = "FaceDetectionFrontByRoiGpu", deps = [ + ":face_detection_front_common", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/calculators/tensor:tensors_to_detections_calculator", - "//mediapipe/calculators/tflite:ssd_anchors_calculator", - "//mediapipe/calculators/util:detection_projection_calculator", - "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:to_image_calculator", ], ) @@ -54,12 +50,10 @@ mediapipe_simple_subgraph( graph = "face_detection_front_cpu.pbtxt", register_as = "FaceDetectionFrontCpu", deps = [ + ":face_detection_front_common", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:inference_calculator", - "//mediapipe/calculators/tensor:tensors_to_detections_calculator", - "//mediapipe/calculators/tflite:ssd_anchors_calculator", - "//mediapipe/calculators/util:detection_letterbox_removal_calculator", - "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:to_image_calculator", ], ) @@ -68,11 +62,21 @@ mediapipe_simple_subgraph( graph = "face_detection_front_gpu.pbtxt", register_as = "FaceDetectionFrontGpu", deps = [ + ":face_detection_front_common", "//mediapipe/calculators/tensor:image_to_tensor_calculator", "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/util:to_image_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "face_detection_front_common", + graph = "face_detection_front_common.pbtxt", + register_as = "FaceDetectionFrontCommon", + deps = [ "//mediapipe/calculators/tensor:tensors_to_detections_calculator", "//mediapipe/calculators/tflite:ssd_anchors_calculator", - "//mediapipe/calculators/util:detection_letterbox_removal_calculator", + "//mediapipe/calculators/util:detection_projection_calculator", "//mediapipe/calculators/util:non_max_suppression_calculator", ], ) diff --git a/mediapipe/modules/face_detection/face_detection_front_by_roi_cpu.pbtxt b/mediapipe/modules/face_detection/face_detection_front_by_roi_cpu.pbtxt index 956dd727c..65d9b482b 100644 --- a/mediapipe/modules/face_detection/face_detection_front_by_roi_cpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_front_by_roi_cpu.pbtxt @@ -13,7 +13,7 @@ # output_stream: "DETECTIONS:face_detections" # } -type: "FaceDetectionFrontCpu" +type: "FaceDetectionFrontByRoiCpu" # CPU image. (ImageFrame) input_stream: "IMAGE:image" @@ -29,6 +29,14 @@ input_stream: "ROI:roi" # this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" +# Converts the input CPU image (ImageFrame) to the multi-backend image type +# (Image). +node: { + calculator: "ToImageCalculator" + input_stream: "IMAGE_CPU:image" + output_stream: "IMAGE:multi_backend_image" +} + # Transforms specified region of image into 128x128 tensor keeping aspect ratio # (padding tensor if needed). node { @@ -66,78 +74,10 @@ node { } } -# Generates a single side packet containing a vector of SSD anchors based on -# the specification in the options. +# Performs tensor post processing to generate face detections. node { - calculator: "SsdAnchorsCalculator" - output_side_packet: "anchors" - options: { - [mediapipe.SsdAnchorsCalculatorOptions.ext] { - num_layers: 4 - min_scale: 0.1484375 - max_scale: 0.75 - input_size_height: 128 - input_size_width: 128 - anchor_offset_x: 0.5 - anchor_offset_y: 0.5 - strides: 8 - strides: 16 - strides: 16 - strides: 16 - aspect_ratios: 1.0 - fixed_anchor_size: true - } - } -} - -# Decodes the detection tensors generated by the TensorFlow Lite model, based on -# the SSD anchors and the specification in the options, into a vector of -# detections. Each detection describes a detected object. -node { - calculator: "TensorsToDetectionsCalculator" + calculator: "FaceDetectionFrontCommon" input_stream: "TENSORS:detection_tensors" - input_side_packet: "ANCHORS:anchors" - output_stream: "DETECTIONS:unfiltered_detections" - options: { - [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { - num_classes: 1 - num_boxes: 896 - num_coords: 16 - box_coord_offset: 0 - keypoint_coord_offset: 4 - num_keypoints: 6 - num_values_per_keypoint: 2 - sigmoid_score: true - score_clipping_thresh: 100.0 - reverse_output_order: true - x_scale: 128.0 - y_scale: 128.0 - h_scale: 128.0 - w_scale: 128.0 - min_score_thresh: 0.5 - } - } -} - -# Performs non-max suppression to remove excessive detections. -node { - calculator: "NonMaxSuppressionCalculator" - input_stream: "unfiltered_detections" - output_stream: "filtered_detections" - options: { - [mediapipe.NonMaxSuppressionCalculatorOptions.ext] { - min_suppression_threshold: 0.3 - overlap_type: INTERSECTION_OVER_UNION - algorithm: WEIGHTED - } - } -} - -# Projects the detections from input tensor to the corresponding locations on -# the original image (input to the graph). -node { - calculator: "DetectionProjectionCalculator" - input_stream: "DETECTIONS:filtered_detections" - input_stream: "PROJECTION_MATRIX:transform_matrix" + input_stream: "MATRIX:transform_matrix" output_stream: "DETECTIONS:detections" } diff --git a/mediapipe/modules/face_detection/face_detection_front_by_roi_gpu.pbtxt b/mediapipe/modules/face_detection/face_detection_front_by_roi_gpu.pbtxt index 331b1129f..179ace814 100644 --- a/mediapipe/modules/face_detection/face_detection_front_by_roi_gpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_front_by_roi_gpu.pbtxt @@ -13,7 +13,7 @@ # output_stream: "DETECTIONS:face_detections" # } -type: "FaceDetectionFrontGpu" +type: "FaceDetectionFrontByRoiGpu" # GPU image. (GpuBuffer) input_stream: "IMAGE:image" @@ -29,11 +29,19 @@ input_stream: "ROI:roi" # this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" +# Converts the input GPU image (GpuBuffer) to the multi-backend image type +# (Image). +node: { + calculator: "ToImageCalculator" + input_stream: "IMAGE_GPU:image" + output_stream: "IMAGE:multi_backend_image" +} + # Transforms specified region of image into 128x128 tensor keeping aspect ratio # (padding tensor if needed). node { calculator: "ImageToTensorCalculator" - input_stream: "IMAGE_GPU:image" + input_stream: "IMAGE:multi_backend_image" input_stream: "NORM_RECT:roi" output_stream: "TENSORS:input_tensors" output_stream: "MATRIX:transform_matrix" @@ -66,78 +74,10 @@ node { } } -# Generates a single side packet containing a vector of SSD anchors based on -# the specification in the options. +# Performs tensor post processing to generate face detections. node { - calculator: "SsdAnchorsCalculator" - output_side_packet: "anchors" - options: { - [mediapipe.SsdAnchorsCalculatorOptions.ext] { - num_layers: 4 - min_scale: 0.1484375 - max_scale: 0.75 - input_size_height: 128 - input_size_width: 128 - anchor_offset_x: 0.5 - anchor_offset_y: 0.5 - strides: 8 - strides: 16 - strides: 16 - strides: 16 - aspect_ratios: 1.0 - fixed_anchor_size: true - } - } -} - -# Decodes the detection tensors generated by the TensorFlow Lite model, based on -# the SSD anchors and the specification in the options, into a vector of -# detections. Each detection describes a detected object. -node { - calculator: "TensorsToDetectionsCalculator" + calculator: "FaceDetectionFrontCommon" input_stream: "TENSORS:detection_tensors" - input_side_packet: "ANCHORS:anchors" - output_stream: "DETECTIONS:unfiltered_detections" - options: { - [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { - num_classes: 1 - num_boxes: 896 - num_coords: 16 - box_coord_offset: 0 - keypoint_coord_offset: 4 - num_keypoints: 6 - num_values_per_keypoint: 2 - sigmoid_score: true - score_clipping_thresh: 100.0 - reverse_output_order: true - x_scale: 128.0 - y_scale: 128.0 - h_scale: 128.0 - w_scale: 128.0 - min_score_thresh: 0.5 - } - } -} - -# Performs non-max suppression to remove excessive detections. -node { - calculator: "NonMaxSuppressionCalculator" - input_stream: "unfiltered_detections" - output_stream: "filtered_detections" - options: { - [mediapipe.NonMaxSuppressionCalculatorOptions.ext] { - min_suppression_threshold: 0.3 - overlap_type: INTERSECTION_OVER_UNION - algorithm: WEIGHTED - } - } -} - -# Projects the detections from input tensor to the corresponding locations on -# the original image (input to the graph). -node { - calculator: "DetectionProjectionCalculator" - input_stream: "DETECTIONS:filtered_detections" - input_stream: "PROJECTION_MATRIX:transform_matrix" + input_stream: "MATRIX:transform_matrix" output_stream: "DETECTIONS:detections" } diff --git a/mediapipe/modules/face_detection/face_detection_front_common.pbtxt b/mediapipe/modules/face_detection/face_detection_front_common.pbtxt new file mode 100644 index 000000000..3c9200596 --- /dev/null +++ b/mediapipe/modules/face_detection/face_detection_front_common.pbtxt @@ -0,0 +1,103 @@ +# MediaPipe graph performing common processing to detect faces, currently +# consisting of tensor post processing. +# +# EXAMPLE: +# node { +# calculator: "FaceDetectionFrontCommon" +# input_stream: "TENSORS:detection_tensors" +# input_stream: "MATRIX:transform_matrix" +# output_stream: "DETECTIONS:detections" +# } + +type: "FaceDetectionFrontCommon" + +# Detection tensors. (std::vector) +input_stream: "TENSORS:detection_tensors" + +# A 4x4 row-major-order matrix that maps a point represented in the detection +# tensors to a desired coordinate system, e.g., in the original input image +# before scaling/cropping. (std::array) +input_stream: "MATRIX:transform_matrix" + +# Detected faces. (std::vector) +# NOTE: there will not be an output packet in the DETECTIONS stream for this +# particular timestamp if none of faces detected. However, the MediaPipe +# framework will internally inform the downstream calculators of the absence of +# this packet so that they don't wait for it unnecessarily. +output_stream: "DETECTIONS:detections" + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + options: { + [mediapipe.SsdAnchorsCalculatorOptions.ext] { + num_layers: 4 + min_scale: 0.1484375 + max_scale: 0.75 + input_size_height: 128 + input_size_width: 128 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 8 + strides: 16 + strides: 16 + strides: 16 + aspect_ratios: 1.0 + fixed_anchor_size: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "TensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:unfiltered_detections" + options: { + [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { + num_classes: 1 + num_boxes: 896 + num_coords: 16 + box_coord_offset: 0 + keypoint_coord_offset: 4 + num_keypoints: 6 + num_values_per_keypoint: 2 + sigmoid_score: true + score_clipping_thresh: 100.0 + reverse_output_order: true + x_scale: 128.0 + y_scale: 128.0 + h_scale: 128.0 + w_scale: 128.0 + min_score_thresh: 0.5 + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "unfiltered_detections" + output_stream: "filtered_detections" + options: { + [mediapipe.NonMaxSuppressionCalculatorOptions.ext] { + min_suppression_threshold: 0.3 + overlap_type: INTERSECTION_OVER_UNION + algorithm: WEIGHTED + } + } +} + +# Projects the detections from input tensor to the corresponding locations on +# the original image (input to the graph). +node { + calculator: "DetectionProjectionCalculator" + input_stream: "DETECTIONS:filtered_detections" + input_stream: "PROJECTION_MATRIX:transform_matrix" + output_stream: "DETECTIONS:detections" +} diff --git a/mediapipe/modules/face_detection/face_detection_front_cpu.pbtxt b/mediapipe/modules/face_detection/face_detection_front_cpu.pbtxt index a7d8dbcc1..bfae6162c 100644 --- a/mediapipe/modules/face_detection/face_detection_front_cpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_front_cpu.pbtxt @@ -24,14 +24,22 @@ input_stream: "IMAGE:image" # this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" +# Converts the input CPU image (ImageFrame) to the multi-backend image type +# (Image). +node: { + calculator: "ToImageCalculator" + input_stream: "IMAGE_CPU:image" + output_stream: "IMAGE:multi_backend_image" +} + # Transforms the input image into a 128x128 tensor while keeping the aspect # ratio (what is expected by the corresponding face detection model), resulting # in potential letterboxing in the transformed image. node: { calculator: "ImageToTensorCalculator" - input_stream: "IMAGE:image" + input_stream: "IMAGE:multi_backend_image" output_stream: "TENSORS:input_tensors" - output_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "MATRIX:transform_matrix" options: { [mediapipe.ImageToTensorCalculatorOptions.ext] { output_tensor_width: 128 @@ -61,80 +69,10 @@ node { } } -# Generates a single side packet containing a vector of SSD anchors based on -# the specification in the options. +# Performs tensor post processing to generate face detections. node { - calculator: "SsdAnchorsCalculator" - output_side_packet: "anchors" - options: { - [mediapipe.SsdAnchorsCalculatorOptions.ext] { - num_layers: 4 - min_scale: 0.1484375 - max_scale: 0.75 - input_size_height: 128 - input_size_width: 128 - anchor_offset_x: 0.5 - anchor_offset_y: 0.5 - strides: 8 - strides: 16 - strides: 16 - strides: 16 - aspect_ratios: 1.0 - fixed_anchor_size: true - } - } -} - -# Decodes the detection tensors generated by the TensorFlow Lite model, based on -# the SSD anchors and the specification in the options, into a vector of -# detections. Each detection describes a detected object. -node { - calculator: "TensorsToDetectionsCalculator" + calculator: "FaceDetectionFrontCommon" input_stream: "TENSORS:detection_tensors" - input_side_packet: "ANCHORS:anchors" - output_stream: "DETECTIONS:unfiltered_detections" - options: { - [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { - num_classes: 1 - num_boxes: 896 - num_coords: 16 - box_coord_offset: 0 - keypoint_coord_offset: 4 - num_keypoints: 6 - num_values_per_keypoint: 2 - sigmoid_score: true - score_clipping_thresh: 100.0 - reverse_output_order: true - x_scale: 128.0 - y_scale: 128.0 - h_scale: 128.0 - w_scale: 128.0 - min_score_thresh: 0.5 - } - } -} - -# Performs non-max suppression to remove excessive detections. -node { - calculator: "NonMaxSuppressionCalculator" - input_stream: "unfiltered_detections" - output_stream: "filtered_detections" - options: { - [mediapipe.NonMaxSuppressionCalculatorOptions.ext] { - min_suppression_threshold: 0.3 - overlap_type: INTERSECTION_OVER_UNION - algorithm: WEIGHTED - } - } -} - -# Adjusts detection locations (already normalized to [0.f, 1.f]) on the -# letterboxed image (after image transformation with the FIT scale mode) to the -# corresponding locations on the same image with the letterbox removed (the -# input image to the graph before image transformation). -node { - calculator: "DetectionLetterboxRemovalCalculator" - input_stream: "DETECTIONS:filtered_detections" - input_stream: "LETTERBOX_PADDING:letterbox_padding" + input_stream: "MATRIX:transform_matrix" output_stream: "DETECTIONS:detections" } diff --git a/mediapipe/modules/face_detection/face_detection_front_gpu.pbtxt b/mediapipe/modules/face_detection/face_detection_front_gpu.pbtxt index 4fd7dc5fc..21a9158ba 100644 --- a/mediapipe/modules/face_detection/face_detection_front_gpu.pbtxt +++ b/mediapipe/modules/face_detection/face_detection_front_gpu.pbtxt @@ -24,14 +24,22 @@ input_stream: "IMAGE:image" # this packet so that they don't wait for it unnecessarily. output_stream: "DETECTIONS:detections" +# Converts the input GPU image (GpuBuffer) to the multi-backend image type +# (Image). +node: { + calculator: "ToImageCalculator" + input_stream: "IMAGE_GPU:image" + output_stream: "IMAGE:multi_backend_image" +} + # Transforms the input image into a 128x128 tensor while keeping the aspect # ratio (what is expected by the corresponding face detection model), resulting # in potential letterboxing in the transformed image. node: { calculator: "ImageToTensorCalculator" - input_stream: "IMAGE_GPU:image" + input_stream: "IMAGE:multi_backend_image" output_stream: "TENSORS:input_tensors" - output_stream: "LETTERBOX_PADDING:letterbox_padding" + output_stream: "MATRIX:transform_matrix" options: { [mediapipe.ImageToTensorCalculatorOptions.ext] { output_tensor_width: 128 @@ -61,80 +69,10 @@ node { } } -# Generates a single side packet containing a vector of SSD anchors based on -# the specification in the options. +# Performs tensor post processing to generate face detections. node { - calculator: "SsdAnchorsCalculator" - output_side_packet: "anchors" - options: { - [mediapipe.SsdAnchorsCalculatorOptions.ext] { - num_layers: 4 - min_scale: 0.1484375 - max_scale: 0.75 - input_size_height: 128 - input_size_width: 128 - anchor_offset_x: 0.5 - anchor_offset_y: 0.5 - strides: 8 - strides: 16 - strides: 16 - strides: 16 - aspect_ratios: 1.0 - fixed_anchor_size: true - } - } -} - -# Decodes the detection tensors generated by the TensorFlow Lite model, based on -# the SSD anchors and the specification in the options, into a vector of -# detections. Each detection describes a detected object. -node { - calculator: "TensorsToDetectionsCalculator" + calculator: "FaceDetectionFrontCommon" input_stream: "TENSORS:detection_tensors" - input_side_packet: "ANCHORS:anchors" - output_stream: "DETECTIONS:unfiltered_detections" - options: { - [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { - num_classes: 1 - num_boxes: 896 - num_coords: 16 - box_coord_offset: 0 - keypoint_coord_offset: 4 - num_keypoints: 6 - num_values_per_keypoint: 2 - sigmoid_score: true - score_clipping_thresh: 100.0 - reverse_output_order: true - x_scale: 128.0 - y_scale: 128.0 - h_scale: 128.0 - w_scale: 128.0 - min_score_thresh: 0.5 - } - } -} - -# Performs non-max suppression to remove excessive detections. -node { - calculator: "NonMaxSuppressionCalculator" - input_stream: "unfiltered_detections" - output_stream: "filtered_detections" - options: { - [mediapipe.NonMaxSuppressionCalculatorOptions.ext] { - min_suppression_threshold: 0.3 - overlap_type: INTERSECTION_OVER_UNION - algorithm: WEIGHTED - } - } -} - -# Adjusts detection locations (already normalized to [0.f, 1.f]) on the -# letterboxed image (after image transformation with the FIT scale mode) to the -# corresponding locations on the same image with the letterbox removed (the -# input image to the graph before image transformation). -node { - calculator: "DetectionLetterboxRemovalCalculator" - input_stream: "DETECTIONS:filtered_detections" - input_stream: "LETTERBOX_PADDING:letterbox_padding" + input_stream: "MATRIX:transform_matrix" output_stream: "DETECTIONS:detections" } diff --git a/mediapipe/modules/face_geometry/BUILD b/mediapipe/modules/face_geometry/BUILD index ce869b0f5..c1f996755 100644 --- a/mediapipe/modules/face_geometry/BUILD +++ b/mediapipe/modules/face_geometry/BUILD @@ -28,6 +28,27 @@ mediapipe_simple_subgraph( ], ) +mediapipe_simple_subgraph( + name = "face_geometry_from_detection", + graph = "face_geometry_from_detection.pbtxt", + register_as = "FaceGeometryFromDetection", + deps = [ + ":geometry_pipeline_calculator", + "//mediapipe/calculators/core:begin_loop_calculator", + "//mediapipe/calculators/core:end_loop_calculator", + "//mediapipe/calculators/util:detection_to_landmarks_calculator", + ], +) + +mediapipe_simple_subgraph( + name = "face_geometry_from_landmarks", + graph = "face_geometry_from_landmarks.pbtxt", + register_as = "FaceGeometryFromLandmarks", + deps = [ + ":geometry_pipeline_calculator", + ], +) + mediapipe_proto_library( name = "effect_renderer_calculator_proto", srcs = ["effect_renderer_calculator.proto"], diff --git a/mediapipe/modules/face_geometry/README.md b/mediapipe/modules/face_geometry/README.md index 662cdd1cf..8427ea63c 100644 --- a/mediapipe/modules/face_geometry/README.md +++ b/mediapipe/modules/face_geometry/README.md @@ -15,4 +15,6 @@ Calculators|Details Subgraphs|Details :--- | :--- -[`FaceGeometry`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry.pbtxt)| Extracts face geometry from landmarks for multiple faces. +[`FaceGeometryFromDetection`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry_from_detection.pbtxt)| Extracts geometry from face detection for multiple faces. +[`FaceGeometryFromLandmarks`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt)| Extracts geometry from face landmarks for multiple faces. +[`FaceGeometry`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/face_geometry/face_geometry.pbtxt)| Extracts geometry from face landmarks for multiple faces. Deprecated, please use `FaceGeometryFromLandmarks` in the new code. diff --git a/mediapipe/modules/face_geometry/data/BUILD b/mediapipe/modules/face_geometry/data/BUILD index 5682e4e6f..1661a2283 100644 --- a/mediapipe/modules/face_geometry/data/BUILD +++ b/mediapipe/modules/face_geometry/data/BUILD @@ -18,9 +18,31 @@ licenses(["notice"]) package(default_visibility = ["//visibility:public"]) +encode_binary_proto( + name = "geometry_pipeline_metadata_detection", + input = "geometry_pipeline_metadata_detection.pbtxt", + message_type = "mediapipe.face_geometry.GeometryPipelineMetadata", + output = "geometry_pipeline_metadata_detection.binarypb", + deps = [ + "//mediapipe/modules/face_geometry/protos:geometry_pipeline_metadata_proto", + ], +) + +encode_binary_proto( + name = "geometry_pipeline_metadata_landmarks", + input = "geometry_pipeline_metadata_landmarks.pbtxt", + message_type = "mediapipe.face_geometry.GeometryPipelineMetadata", + output = "geometry_pipeline_metadata_landmarks.binarypb", + deps = [ + "//mediapipe/modules/face_geometry/protos:geometry_pipeline_metadata_proto", + ], +) + +# For backward-compatibility reasons, generate `geometry_pipeline_metadata.binarypb` from +# the `geometry_pipeline_metadata_landmarks.pbtxt` definition. encode_binary_proto( name = "geometry_pipeline_metadata", - input = "geometry_pipeline_metadata.pbtxt", + input = "geometry_pipeline_metadata_landmarks.pbtxt", message_type = "mediapipe.face_geometry.GeometryPipelineMetadata", output = "geometry_pipeline_metadata.binarypb", deps = [ diff --git a/mediapipe/modules/face_geometry/data/geometry_pipeline_metadata_detection.pbtxt b/mediapipe/modules/face_geometry/data/geometry_pipeline_metadata_detection.pbtxt new file mode 100644 index 000000000..c4389a624 --- /dev/null +++ b/mediapipe/modules/face_geometry/data/geometry_pipeline_metadata_detection.pbtxt @@ -0,0 +1,78 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +input_source: FACE_DETECTION_PIPELINE +procrustes_landmark_basis { landmark_id: 0 weight: 1.0 } +procrustes_landmark_basis { landmark_id: 1 weight: 1.0 } +procrustes_landmark_basis { landmark_id: 2 weight: 1.0 } +procrustes_landmark_basis { landmark_id: 3 weight: 1.0 } +procrustes_landmark_basis { landmark_id: 4 weight: 1.0 } +procrustes_landmark_basis { landmark_id: 5 weight: 1.0 } +# NOTE: the triangular topology of the face meshes is only useful when derived +# from the 468 face landmarks, not from the 6 face detection landmarks +# (keypoints). The former don't cover the entire face and this mesh is +# defined here only to comply with the API. It should be considered as +# a placeholder and/or for debugging purposes. +# +# Use the face geometry derived from the face detection landmarks +# (keypoints) for the face pose transformation matrix, not the mesh. +canonical_mesh: { + vertex_type: VERTEX_PT + primitive_type: TRIANGLE + vertex_buffer: -3.1511454582214355 + vertex_buffer: 2.6246179342269897 + vertex_buffer: 3.4656630754470825 + vertex_buffer: 0.349575996398926 + vertex_buffer: 0.38137748837470997 + vertex_buffer: 3.1511454582214355 + vertex_buffer: 2.6246179342269897 + vertex_buffer: 3.4656630754470825 + vertex_buffer: 0.650443494319916 + vertex_buffer: 0.38137999176979054 + vertex_buffer: 0.0 + vertex_buffer: -1.126865029335022 + vertex_buffer: 7.475604057312012 + vertex_buffer: 0.500025987625122 + vertex_buffer: 0.547487020492554 + vertex_buffer: 0.0 + vertex_buffer: -4.304508209228516 + vertex_buffer: 4.162498950958252 + vertex_buffer: 0.499989986419678 + vertex_buffer: 0.694203019142151 + vertex_buffer: -7.664182186126709 + vertex_buffer: 0.673132002353668 + vertex_buffer: -2.435867071151733 + vertex_buffer: 0.007561000064015 + vertex_buffer: 0.480777025222778 + vertex_buffer: 7.664182186126709 + vertex_buffer: 0.673132002353668 + vertex_buffer: -2.435867071151733 + vertex_buffer: 0.992439985275269 + vertex_buffer: 0.480777025222778 + index_buffer: 0 + index_buffer: 1 + index_buffer: 2 + index_buffer: 1 + index_buffer: 5 + index_buffer: 2 + index_buffer: 4 + index_buffer: 0 + index_buffer: 2 + index_buffer: 4 + index_buffer: 2 + index_buffer: 3 + index_buffer: 2 + index_buffer: 5 + index_buffer: 3 +} diff --git a/mediapipe/modules/face_geometry/data/geometry_pipeline_metadata.pbtxt b/mediapipe/modules/face_geometry/data/geometry_pipeline_metadata_landmarks.pbtxt similarity index 99% rename from mediapipe/modules/face_geometry/data/geometry_pipeline_metadata.pbtxt rename to mediapipe/modules/face_geometry/data/geometry_pipeline_metadata_landmarks.pbtxt index ee06f2779..8dfb46394 100644 --- a/mediapipe/modules/face_geometry/data/geometry_pipeline_metadata.pbtxt +++ b/mediapipe/modules/face_geometry/data/geometry_pipeline_metadata_landmarks.pbtxt @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +input_source: FACE_LANDMARK_PIPELINE procrustes_landmark_basis { landmark_id: 4 weight: 0.070909939706326 } procrustes_landmark_basis { landmark_id: 6 weight: 0.032100144773722 } procrustes_landmark_basis { landmark_id: 10 weight: 0.008446550928056 } diff --git a/mediapipe/modules/face_geometry/effect_renderer_calculator.cc b/mediapipe/modules/face_geometry/effect_renderer_calculator.cc index fba714905..f353b8f96 100644 --- a/mediapipe/modules/face_geometry/effect_renderer_calculator.cc +++ b/mediapipe/modules/face_geometry/effect_renderer_calculator.cc @@ -85,7 +85,7 @@ static constexpr char kMultiFaceGeometryTag[] = "MULTI_FACE_GEOMETRY"; // class EffectRendererCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)) << "Failed to update contract for the GPU helper!"; @@ -101,12 +101,12 @@ class EffectRendererCalculator : public CalculatorBase { return mediapipe::GlCalculatorHelper::UpdateContract(cc); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(mediapipe::TimestampDiff(0)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)) << "Failed to open the GPU helper!"; - return gpu_helper_.RunInGlContext([&]() -> mediapipe::Status { + return gpu_helper_.RunInGlContext([&]() -> absl::Status { const auto& options = cc->Options(); @@ -136,19 +136,19 @@ class EffectRendererCalculator : public CalculatorBase { std::move(effect_texture)), _ << "Failed to create the effect renderer!"); - return mediapipe::OkStatus(); + return absl::OkStatus(); }); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { // The `IMAGE_GPU` stream is required to have a non-empty packet. In case // this requirement is not met, there's nothing to be processed at the // current timestamp. if (cc->Inputs().Tag(kImageGpuTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } - return gpu_helper_.RunInGlContext([this, cc]() -> mediapipe::Status { + return gpu_helper_.RunInGlContext([this, cc]() -> absl::Status { const auto& input_gpu_buffer = cc->Inputs().Tag(kImageGpuTag).Get(); @@ -191,7 +191,7 @@ class EffectRendererCalculator : public CalculatorBase { output_gl_texture.Release(); input_gl_texture.Release(); - return mediapipe::OkStatus(); + return absl::OkStatus(); }); } @@ -200,7 +200,7 @@ class EffectRendererCalculator : public CalculatorBase { } private: - static mediapipe::StatusOr ReadTextureFromFile( + static absl::StatusOr ReadTextureFromFile( const std::string& texture_path) { ASSIGN_OR_RETURN(std::string texture_blob, ReadContentBlobFromFile(texture_path), @@ -244,7 +244,7 @@ class EffectRendererCalculator : public CalculatorBase { return output_image_frame; } - static mediapipe::StatusOr ReadMesh3dFromFile( + static absl::StatusOr ReadMesh3dFromFile( const std::string& mesh_3d_path) { ASSIGN_OR_RETURN(std::string mesh_3d_blob, ReadContentBlobFromFile(mesh_3d_path), @@ -257,7 +257,7 @@ class EffectRendererCalculator : public CalculatorBase { return mesh_3d; } - static mediapipe::StatusOr ReadContentBlobFromFile( + static absl::StatusOr ReadContentBlobFromFile( const std::string& unresolved_path) { ASSIGN_OR_RETURN(std::string resolved_path, mediapipe::PathToResourceAsFile(unresolved_path), diff --git a/mediapipe/modules/face_geometry/env_generator_calculator.cc b/mediapipe/modules/face_geometry/env_generator_calculator.cc index 2464998de..2e95a66e6 100644 --- a/mediapipe/modules/face_geometry/env_generator_calculator.cc +++ b/mediapipe/modules/face_geometry/env_generator_calculator.cc @@ -40,14 +40,14 @@ static constexpr char kEnvironmentTag[] = "ENVIRONMENT"; // class EnvGeneratorCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->OutputSidePackets() .Tag(kEnvironmentTag) .Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(mediapipe::TimestampDiff(0)); const face_geometry::Environment& environment = @@ -60,15 +60,15 @@ class EnvGeneratorCalculator : public CalculatorBase { .Tag(kEnvironmentTag) .Set(mediapipe::MakePacket(environment)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Process(CalculatorContext* cc) override { + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Close(CalculatorContext* cc) override { + return absl::OkStatus(); } }; diff --git a/mediapipe/modules/face_geometry/face_geometry.pbtxt b/mediapipe/modules/face_geometry/face_geometry.pbtxt index 41ad9810b..76228d4b1 100644 --- a/mediapipe/modules/face_geometry/face_geometry.pbtxt +++ b/mediapipe/modules/face_geometry/face_geometry.pbtxt @@ -1,17 +1,12 @@ -# MediaPipe graph to extract face geometry from landmarks for multiple faces. +# MediaPipe graph to extract geometry from face landmarks for multiple faces. # # It is required that "geometry_pipeline_metadata.binarypb" is available at # "mediapipe/modules/face_geometry/data/geometry_pipeline_metadata.binarypb" # path during execution. # -# EXAMPLE: -# node { -# calculator: "FaceGeometry" -# input_stream: "IMAGE_SIZE:image_size" -# input_stream: "MULTI_FACE_LANDMARKS:multi_face_landmarks" -# input_side_packet: "ENVIRONMENT:environment" -# output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" -# } +# This is a deprecated subgraph kept for backward-compatibility reasons. Please, +# be explicit and use the `FaceGeometryFromLandmarks` subgraph in the new code +# to enable the same runtime behaviour. type: "FaceGeometry" @@ -37,7 +32,8 @@ input_side_packet: "ENVIRONMENT:environment" # (std::vector) output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" -# Extracts face geometry for multiple faces from a vector of landmark lists. +# Extracts face geometry for multiple faces from a vector of face landmark +# lists. node { calculator: "FaceGeometryPipelineCalculator" input_side_packet: "ENVIRONMENT:environment" diff --git a/mediapipe/modules/face_geometry/face_geometry_from_detection.pbtxt b/mediapipe/modules/face_geometry/face_geometry_from_detection.pbtxt new file mode 100644 index 000000000..f570286aa --- /dev/null +++ b/mediapipe/modules/face_geometry/face_geometry_from_detection.pbtxt @@ -0,0 +1,87 @@ +# MediaPipe graph to extract geometry from face detection for multiple faces. +# +# It is required that "geometry_pipeline_metadata_detection.binarypb" is +# available at +# "mediapipe/modules/face_geometry/data/geometry_pipeline_metadata_detection.binarypb" +# path during execution. +# +# EXAMPLE: +# node { +# calculator: "FaceGeometryFromDetection" +# input_stream: "IMAGE_SIZE:image_size" +# input_stream: "MULTI_FACE_DETECTION:multi_face_detection" +# input_side_packet: "ENVIRONMENT:environment" +# output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" +# } + +type: "FaceGeometryFromDetection" + +# The size of the input frame. The first element of the pair is the frame width; +# the other one is the frame height. +# +# The face landmarks should have been detected on a frame with the same +# ratio. If used as-is, the resulting face geometry visualization should be +# happening on a frame with the same ratio as well. +# +# (std::pair) +input_stream: "IMAGE_SIZE:image_size" + +# Collection of detected/predicted faces, each represented as a detection. +# (std::vector) +input_stream: "MULTI_FACE_DETECTION:multi_face_detection" + +# Environment that describes the current virtual scene. +# (face_geometry::Environment) +input_side_packet: "ENVIRONMENT:environment" + +# A list of geometry data for each detected face. +# (std::vector) +# +# NOTE: the triangular topology of the face meshes is only useful when derived +# from the 468 face landmarks, not from the 6 face detection landmarks +# (keypoints). The former don't cover the entire face and this mesh is +# defined here only to comply with the API. It should be considered as +# a placeholder and/or for debugging purposes. +# +# Use the face geometry derived from the face detection landmarks +# (keypoints) for the face pose transformation matrix, not the mesh. +output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" + +# Begin iterating over a vector of the face detections. +node { + calculator: "BeginLoopDetectionCalculator" + input_stream: "ITERABLE:multi_face_detection" + output_stream: "ITEM:face_detection" + output_stream: "BATCH_END:detection_timestamp" +} + +# Extracts face detection keypoints as a normalized landmarks. +node { + calculator: "DetectionToLandmarksCalculator" + input_stream: "DETECTION:face_detection" + output_stream: "LANDMARKS:face_landmarks" +} + +# End iterating over a vector of the face detections and receive a vector of +# face landmark lists as a result. +node { + calculator: "EndLoopNormalizedLandmarkListVectorCalculator" + input_stream: "ITEM:face_landmarks" + input_stream: "BATCH_END:detection_timestamp" + output_stream: "ITERABLE:multi_face_landmarks" +} + +# Extracts face geometry for multiple faces from a vector of face detection +# landmark lists. +node { + calculator: "FaceGeometryPipelineCalculator" + input_side_packet: "ENVIRONMENT:environment" + input_stream: "IMAGE_SIZE:image_size" + input_stream: "MULTI_FACE_LANDMARKS:multi_face_landmarks" + output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" + options: { + [mediapipe.FaceGeometryPipelineCalculatorOptions.ext] { + metadata_path: "mediapipe/modules/face_geometry/data/geometry_pipeline_metadata_detection.binarypb" + } + } +} diff --git a/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt b/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt new file mode 100644 index 000000000..329147663 --- /dev/null +++ b/mediapipe/modules/face_geometry/face_geometry_from_landmarks.pbtxt @@ -0,0 +1,54 @@ +# MediaPipe graph to extract geometry from face landmarks for multiple faces. +# +# It is required that "geometry_pipeline_metadata_from_landmark.binarypb" is +# available at +# "mediapipe/modules/face_geometry/data/geometry_pipeline_metadata_from_landmarks.binarypb" +# path during execution. +# +# EXAMPLE: +# node { +# calculator: "FaceGeometryFromLandmarks" +# input_stream: "IMAGE_SIZE:image_size" +# input_stream: "MULTI_FACE_LANDMARKS:multi_face_landmarks" +# input_side_packet: "ENVIRONMENT:environment" +# output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" +# } + +type: "FaceGeometryFromLandmarks" + +# The size of the input frame. The first element of the pair is the frame width; +# the other one is the frame height. +# +# The face landmarks should have been detected on a frame with the same +# ratio. If used as-is, the resulting face geometry visualization should be +# happening on a frame with the same ratio as well. +# +# (std::pair) +input_stream: "IMAGE_SIZE:image_size" + +# Collection of detected/predicted faces, each represented as a list of face +# landmarks. (std::vector) +input_stream: "MULTI_FACE_LANDMARKS:multi_face_landmarks" + +# Environment that describes the current virtual scene. +# (face_geometry::Environment) +input_side_packet: "ENVIRONMENT:environment" + +# A list of geometry data for each detected face. +# (std::vector) +output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" + +# Extracts face geometry for multiple faces from a vector of face landmark +# lists. +node { + calculator: "FaceGeometryPipelineCalculator" + input_side_packet: "ENVIRONMENT:environment" + input_stream: "IMAGE_SIZE:image_size" + input_stream: "MULTI_FACE_LANDMARKS:multi_face_landmarks" + output_stream: "MULTI_FACE_GEOMETRY:multi_face_geometry" + options: { + [mediapipe.FaceGeometryPipelineCalculatorOptions.ext] { + metadata_path: "mediapipe/modules/face_geometry/data/geometry_pipeline_metadata_landmarks.binarypb" + } + } +} diff --git a/mediapipe/modules/face_geometry/geometry_pipeline_calculator.cc b/mediapipe/modules/face_geometry/geometry_pipeline_calculator.cc index 600c9d428..87e710e42 100644 --- a/mediapipe/modules/face_geometry/geometry_pipeline_calculator.cc +++ b/mediapipe/modules/face_geometry/geometry_pipeline_calculator.cc @@ -72,7 +72,7 @@ static constexpr char kMultiFaceLandmarksTag[] = "MULTI_FACE_LANDMARKS"; // class GeometryPipelineCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets() .Tag(kEnvironmentTag) .Set(); @@ -84,10 +84,10 @@ class GeometryPipelineCalculator : public CalculatorBase { .Tag(kMultiFaceGeometryTag) .Set>(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(mediapipe::TimestampDiff(0)); const auto& options = cc->Options(); @@ -114,16 +114,16 @@ class GeometryPipelineCalculator : public CalculatorBase { face_geometry::CreateGeometryPipeline(environment, metadata), _ << "Failed to create a geometry pipeline!"); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { // Both the `IMAGE_SIZE` and the `MULTI_FACE_LANDMARKS` streams are required // to have a non-empty packet. In case this requirement is not met, there's // nothing to be processed at the current timestamp. if (cc->Inputs().Tag(kImageSizeTag).IsEmpty() || cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } const auto& image_size = @@ -150,15 +150,15 @@ class GeometryPipelineCalculator : public CalculatorBase { multi_face_geometry.release()) .At(cc->InputTimestamp())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Close(CalculatorContext* cc) override { - return mediapipe::OkStatus(); + absl::Status Close(CalculatorContext* cc) override { + return absl::OkStatus(); } private: - static mediapipe::StatusOr + static absl::StatusOr ReadMetadataFromFile(const std::string& metadata_path) { ASSIGN_OR_RETURN(std::string metadata_blob, ReadContentBlobFromFile(metadata_path), @@ -171,7 +171,7 @@ class GeometryPipelineCalculator : public CalculatorBase { return metadata; } - static mediapipe::StatusOr ReadContentBlobFromFile( + static absl::StatusOr ReadContentBlobFromFile( const std::string& unresolved_path) { ASSIGN_OR_RETURN(std::string resolved_path, mediapipe::PathToResourceAsFile(unresolved_path), diff --git a/mediapipe/modules/face_geometry/libs/effect_renderer.cc b/mediapipe/modules/face_geometry/libs/effect_renderer.cc index 31ee47ef7..27a54e011 100644 --- a/mediapipe/modules/face_geometry/libs/effect_renderer.cc +++ b/mediapipe/modules/face_geometry/libs/effect_renderer.cc @@ -42,7 +42,7 @@ namespace mediapipe::face_geometry { namespace { struct RenderableMesh3d { - static mediapipe::StatusOr CreateFromProtoMesh3d( + static absl::StatusOr CreateFromProtoMesh3d( const Mesh3d& proto_mesh_3d) { Mesh3d::VertexType vertex_type = proto_mesh_3d.vertex_type(); @@ -106,14 +106,14 @@ struct RenderableMesh3d { class Texture { public: - static mediapipe::StatusOr> WrapExternalTexture( + static absl::StatusOr> WrapExternalTexture( GLuint handle, GLenum target, int width, int height) { RET_CHECK(handle) << "External texture must have a non-null handle!"; return absl::WrapUnique(new Texture(handle, target, width, height, /*is_owned*/ false)); } - static mediapipe::StatusOr> CreateFromImageFrame( + static absl::StatusOr> CreateFromImageFrame( const ImageFrame& image_frame) { RET_CHECK(image_frame.IsAligned(ImageFrame::kGlDefaultAlignmentBoundary)) << "Image frame memory must be aligned for GL usage!"; @@ -187,7 +187,7 @@ class Texture { class RenderTarget { public: - static mediapipe::StatusOr> Create() { + static absl::StatusOr> Create() { GLuint framebuffer_handle; glGenFramebuffers(1, &framebuffer_handle); RET_CHECK(framebuffer_handle) @@ -205,7 +205,7 @@ class RenderTarget { } } - mediapipe::Status SetColorbuffer(const Texture& colorbuffer_texture) { + absl::Status SetColorbuffer(const Texture& colorbuffer_texture) { glBindFramebuffer(GL_FRAMEBUFFER, framebuffer_handle_); glViewport(0, 0, colorbuffer_texture.width(), colorbuffer_texture.height()); @@ -245,7 +245,7 @@ class RenderTarget { glBindFramebuffer(GL_FRAMEBUFFER, 0); glFlush(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void Bind() const { @@ -288,7 +288,7 @@ class Renderer { public: enum class RenderMode { OPAQUE, OVERDRAW, OCCLUSION }; - static mediapipe::StatusOr> Create() { + static absl::StatusOr> Create() { static const GLint kAttrLocation[NUM_ATTRIBUTES] = { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, @@ -346,12 +346,11 @@ class Renderer { ~Renderer() { glDeleteProgram(program_handle_); } - mediapipe::Status Render(const RenderTarget& render_target, - const Texture& texture, - const RenderableMesh3d& mesh_3d, - const std::array& projection_mat, - const std::array& model_mat, - RenderMode render_mode) const { + absl::Status Render(const RenderTarget& render_target, const Texture& texture, + const RenderableMesh3d& mesh_3d, + const std::array& projection_mat, + const std::array& model_mat, + RenderMode render_mode) const { glUseProgram(program_handle_); // Set up the GL state. glEnable(GL_BLEND); @@ -413,7 +412,7 @@ class Renderer { glUseProgram(0); glFlush(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -451,7 +450,7 @@ class EffectRendererImpl : public EffectRenderer { effect_texture_(std::move(effect_texture)), identity_matrix_(Create4x4IdentityMatrix()) {} - mediapipe::Status RenderEffect( + absl::Status RenderEffect( const std::vector& multi_face_geometry, int frame_width, // int frame_height, // @@ -567,7 +566,7 @@ class EffectRendererImpl : public EffectRenderer { // At this point in the code, the destination texture must contain the // correctly renderer effect, so we should just return. - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -606,7 +605,7 @@ class EffectRendererImpl : public EffectRenderer { 0.f, 0.f, 0.f, 1.f}; } - static mediapipe::StatusOr> + static absl::StatusOr> Convert4x4MatrixDataToArrayFormat(const MatrixData& matrix_data) { RET_CHECK(matrix_data.rows() == 4 && // matrix_data.cols() == 4 && // @@ -689,7 +688,7 @@ ImageFrame CreateEmptyColorTexture() { } // namespace -mediapipe::StatusOr> CreateEffectRenderer( +absl::StatusOr> CreateEffectRenderer( const Environment& environment, // const absl::optional& effect_mesh_3d, // ImageFrame&& effect_texture) { diff --git a/mediapipe/modules/face_geometry/libs/effect_renderer.h b/mediapipe/modules/face_geometry/libs/effect_renderer.h index ea989d449..71330e742 100644 --- a/mediapipe/modules/face_geometry/libs/effect_renderer.h +++ b/mediapipe/modules/face_geometry/libs/effect_renderer.h @@ -49,7 +49,7 @@ class EffectRenderer { // reference existing OpenGL textures in the current context. They should also // reference different textures as the in-place effect rendering is not yet // supported. - virtual mediapipe::Status RenderEffect( + virtual absl::Status RenderEffect( const std::vector& multi_face_geometry, int frame_width, // int frame_height, // @@ -82,7 +82,7 @@ class EffectRenderer { // // `effect_texture` must have positive dimensions. Its format must be either // `SRGB` or `SRGBA`. Its memory must be aligned for GL usage. -mediapipe::StatusOr> CreateEffectRenderer( +absl::StatusOr> CreateEffectRenderer( const Environment& environment, // const absl::optional& effect_mesh_3d, // ImageFrame&& effect_texture); diff --git a/mediapipe/modules/face_geometry/libs/geometry_pipeline.cc b/mediapipe/modules/face_geometry/libs/geometry_pipeline.cc index 309e31903..bcfce7cff 100644 --- a/mediapipe/modules/face_geometry/libs/geometry_pipeline.cc +++ b/mediapipe/modules/face_geometry/libs/geometry_pipeline.cc @@ -73,10 +73,12 @@ class ScreenToMetricSpaceConverter { public: ScreenToMetricSpaceConverter( OriginPointLocation origin_point_location, // + InputSource input_source, // Eigen::Matrix3Xf&& canonical_metric_landmarks, // Eigen::VectorXf&& landmark_weights, // std::unique_ptr procrustes_solver) : origin_point_location_(origin_point_location), + input_source_(input_source), canonical_metric_landmarks_(std::move(canonical_metric_landmarks)), landmark_weights_(std::move(landmark_weights)), procrustes_solver_(std::move(procrustes_solver)) {} @@ -118,11 +120,10 @@ class ScreenToMetricSpaceConverter { // // To keep the logic correct, the landmark set handedness is changed any // time the screen-to-metric semantic barrier is passed. - mediapipe::Status Convert( - const NormalizedLandmarkList& screen_landmark_list, // - const PerspectiveCameraFrustum& pcf, // - LandmarkList& metric_landmark_list, // - Eigen::Matrix4f& pose_transform_mat) const { + absl::Status Convert(const NormalizedLandmarkList& screen_landmark_list, // + const PerspectiveCameraFrustum& pcf, // + LandmarkList& metric_landmark_list, // + Eigen::Matrix4f& pose_transform_mat) const { RET_CHECK_EQ(screen_landmark_list.landmark_size(), canonical_metric_landmarks_.cols()) << "The number of landmarks doesn't match the number passed upon " @@ -151,12 +152,27 @@ class ScreenToMetricSpaceConverter { intermediate_landmarks); UnprojectXY(pcf, intermediate_landmarks); ChangeHandedness(intermediate_landmarks); + + // For face detection input landmarks, re-write Z-coord from the canonical + // landmarks. + if (input_source_ == InputSource::FACE_DETECTION_PIPELINE) { + Eigen::Matrix4f intermediate_pose_transform_mat; + MP_RETURN_IF_ERROR(procrustes_solver_->SolveWeightedOrthogonalProblem( + canonical_metric_landmarks_, intermediate_landmarks, + landmark_weights_, intermediate_pose_transform_mat)) + << "Failed to estimate pose transform matrix!"; + + intermediate_landmarks.row(2) = + (intermediate_pose_transform_mat * + canonical_metric_landmarks_.colwise().homogeneous()) + .row(2); + } ASSIGN_OR_RETURN(const float second_iteration_scale, EstimateScale(intermediate_landmarks), _ << "Failed to estimate second iteration scale!"); // Use the total scale to unproject the screen landmarks. - float total_scale = first_iteration_scale * second_iteration_scale; + const float total_scale = first_iteration_scale * second_iteration_scale; MoveAndRescaleZ(pcf, depth_offset, total_scale, screen_landmarks); UnprojectXY(pcf, screen_landmarks); ChangeHandedness(screen_landmarks); @@ -169,18 +185,30 @@ class ScreenToMetricSpaceConverter { pose_transform_mat)) << "Failed to estimate pose transform matrix!"; - // Multiply each of the metric landmarks by the inverse pose transformation - // matrix to align the runtime metric face landmarks with the canonical - // metric face landmarks. - Eigen::Matrix4f inv_pose_transform_mat = pose_transform_mat.inverse(); - auto inv_pose_rotation = inv_pose_transform_mat.leftCols(3).topRows(3); - auto inv_pose_translation = inv_pose_transform_mat.col(3).topRows(3); - metric_landmarks = - (inv_pose_rotation * metric_landmarks).colwise() + inv_pose_translation; + // For face detection input landmarks, re-write Z-coord from the canonical + // landmarks and run the pose transform estimation again. + if (input_source_ == InputSource::FACE_DETECTION_PIPELINE) { + metric_landmarks.row(2) = + (pose_transform_mat * + canonical_metric_landmarks_.colwise().homogeneous()) + .row(2); + + MP_RETURN_IF_ERROR(procrustes_solver_->SolveWeightedOrthogonalProblem( + canonical_metric_landmarks_, metric_landmarks, landmark_weights_, + pose_transform_mat)) + << "Failed to estimate pose transform matrix!"; + } + + // Multiply each of the metric landmarks by the inverse pose + // transformation matrix to align the runtime metric face landmarks with + // the canonical metric face landmarks. + metric_landmarks = (pose_transform_mat.inverse() * + metric_landmarks.colwise().homogeneous()) + .topRows(3); ConvertEigenMatrixToLandmarkList(metric_landmarks, metric_landmark_list); - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: @@ -200,7 +228,7 @@ class ScreenToMetricSpaceConverter { landmarks.colwise() += Eigen::Vector3f(x_translation, y_translation, 0.f); } - mediapipe::StatusOr EstimateScale(Eigen::Matrix3Xf& landmarks) const { + absl::StatusOr EstimateScale(Eigen::Matrix3Xf& landmarks) const { Eigen::Matrix4f transform_mat; MP_RETURN_IF_ERROR(procrustes_solver_->SolveWeightedOrthogonalProblem( canonical_metric_landmarks_, landmarks, landmark_weights_, @@ -253,7 +281,8 @@ class ScreenToMetricSpaceConverter { } } - OriginPointLocation origin_point_location_; + const OriginPointLocation origin_point_location_; + const InputSource input_source_; Eigen::Matrix3Xf canonical_metric_landmarks_; Eigen::VectorXf landmark_weights_; @@ -277,7 +306,7 @@ class GeometryPipelineImpl : public GeometryPipeline { canonical_mesh_vertex_position_offset), space_converter_(std::move(space_converter)) {} - mediapipe::StatusOr> EstimateFaceGeometry( + absl::StatusOr> EstimateFaceGeometry( const std::vector& multi_face_landmarks, int frame_width, int frame_height) const override { MP_RETURN_IF_ERROR(ValidateFrameDimensions(frame_width, frame_height)) @@ -301,8 +330,8 @@ class GeometryPipelineImpl : public GeometryPipeline { continue; } - // Convert the screen landmarks into the metric landmarks and - // get the pose transformation matrix. + // Convert the screen landmarks into the metric landmarks and get the pose + // transformation matrix. LandmarkList metric_face_landmarks; Eigen::Matrix4f pose_transform_mat; MP_RETURN_IF_ERROR(space_converter_->Convert(screen_face_landmarks, pcf, @@ -370,7 +399,7 @@ class GeometryPipelineImpl : public GeometryPipeline { } // namespace -mediapipe::StatusOr> CreateGeometryPipeline( +absl::StatusOr> CreateGeometryPipeline( const Environment& environment, const GeometryPipelineMetadata& metadata) { MP_RETURN_IF_ERROR(ValidateEnvironment(environment)) << "Invalid environment!"; @@ -392,7 +421,7 @@ mediapipe::StatusOr> CreateGeometryPipeline( uint32_t canonical_mesh_vertex_position_offset = GetVertexComponentOffset(canonical_mesh.vertex_type(), VertexComponent::POSITION) - .ValueOrDie(); + .value(); // Put the Procrustes landmark basis into Eigen matrices for an easier access. Eigen::Matrix3Xf canonical_metric_landmarks = @@ -424,6 +453,9 @@ mediapipe::StatusOr> CreateGeometryPipeline( canonical_mesh_vertex_position_offset, absl::make_unique( environment.origin_point_location(), + metadata.input_source() == InputSource::DEFAULT + ? InputSource::FACE_LANDMARK_PIPELINE + : metadata.input_source(), std::move(canonical_metric_landmarks), std::move(landmark_weights), CreateFloatPrecisionProcrustesSolver())); diff --git a/mediapipe/modules/face_geometry/libs/geometry_pipeline.h b/mediapipe/modules/face_geometry/libs/geometry_pipeline.h index da530966b..ffa779c5d 100644 --- a/mediapipe/modules/face_geometry/libs/geometry_pipeline.h +++ b/mediapipe/modules/face_geometry/libs/geometry_pipeline.h @@ -47,7 +47,7 @@ class GeometryPipeline { // geometry pipeline metadata). // // Both `frame_width` and `frame_height` must be positive. - virtual mediapipe::StatusOr> EstimateFaceGeometry( + virtual absl::StatusOr> EstimateFaceGeometry( const std::vector& multi_face_landmarks, int frame_width, int frame_height) const = 0; }; @@ -59,7 +59,7 @@ class GeometryPipeline { // // Canonical face mesh (defined as a part of `metadata`) must have the // `POSITION` and the `TEX_COORD` vertex components. -mediapipe::StatusOr> CreateGeometryPipeline( +absl::StatusOr> CreateGeometryPipeline( const Environment& environment, const GeometryPipelineMetadata& metadata); } // namespace mediapipe::face_geometry diff --git a/mediapipe/modules/face_geometry/libs/mesh_3d_utils.cc b/mediapipe/modules/face_geometry/libs/mesh_3d_utils.cc index 80510818d..2078ec6f8 100644 --- a/mediapipe/modules/face_geometry/libs/mesh_3d_utils.cc +++ b/mediapipe/modules/face_geometry/libs/mesh_3d_utils.cc @@ -78,7 +78,7 @@ bool HasVertexComponent(Mesh3d::VertexType vertex_type, } } -mediapipe::StatusOr GetVertexComponentOffset( +absl::StatusOr GetVertexComponentOffset( Mesh3d::VertexType vertex_type, VertexComponent vertex_component) { RET_CHECK(HasVertexComponentVertexPT(vertex_component)) << "A given vertex type doesn't have the requested component!"; @@ -89,7 +89,7 @@ mediapipe::StatusOr GetVertexComponentOffset( } } -mediapipe::StatusOr GetVertexComponentSize( +absl::StatusOr GetVertexComponentSize( Mesh3d::VertexType vertex_type, VertexComponent vertex_component) { RET_CHECK(HasVertexComponentVertexPT(vertex_component)) << "A given vertex type doesn't have the requested component!"; diff --git a/mediapipe/modules/face_geometry/libs/mesh_3d_utils.h b/mediapipe/modules/face_geometry/libs/mesh_3d_utils.h index e6e76d3a6..a320aae91 100644 --- a/mediapipe/modules/face_geometry/libs/mesh_3d_utils.h +++ b/mediapipe/modules/face_geometry/libs/mesh_3d_utils.h @@ -36,14 +36,14 @@ bool HasVertexComponent(Mesh3d::VertexType vertex_type, // // Returns an error status if a given vertex type doesn't have the requested // component. -mediapipe::StatusOr GetVertexComponentOffset( +absl::StatusOr GetVertexComponentOffset( Mesh3d::VertexType vertex_type, VertexComponent vertex_component); // Computes the vertex component size. // // Returns an error status if a given vertex type doesn't have the requested // component. -mediapipe::StatusOr GetVertexComponentSize( +absl::StatusOr GetVertexComponentSize( Mesh3d::VertexType vertex_type, VertexComponent vertex_component); } // namespace mediapipe::face_geometry diff --git a/mediapipe/modules/face_geometry/libs/procrustes_solver.cc b/mediapipe/modules/face_geometry/libs/procrustes_solver.cc index 0004ac229..2ffae0e6f 100644 --- a/mediapipe/modules/face_geometry/libs/procrustes_solver.cc +++ b/mediapipe/modules/face_geometry/libs/procrustes_solver.cc @@ -32,7 +32,7 @@ class FloatPrecisionProcrustesSolver : public ProcrustesSolver { public: FloatPrecisionProcrustesSolver() = default; - mediapipe::Status SolveWeightedOrthogonalProblem( + absl::Status SolveWeightedOrthogonalProblem( const Eigen::Matrix3Xf& source_points, // const Eigen::Matrix3Xf& target_points, // const Eigen::VectorXf& point_weights, @@ -52,13 +52,13 @@ class FloatPrecisionProcrustesSolver : public ProcrustesSolver { source_points, target_points, sqrt_weights, transform_mat)) << "Failed to solve the WEOP problem!"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } private: static constexpr float kAbsoluteErrorEps = 1e-9f; - static mediapipe::Status ValidateInputPoints( + static absl::Status ValidateInputPoints( const Eigen::Matrix3Xf& source_points, const Eigen::Matrix3Xf& target_points) { RET_CHECK_GT(source_points.cols(), 0) @@ -67,10 +67,10 @@ class FloatPrecisionProcrustesSolver : public ProcrustesSolver { RET_CHECK_EQ(source_points.cols(), target_points.cols()) << "The number of source and target points must be equal!"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::Status ValidatePointWeights( + static absl::Status ValidatePointWeights( int num_points, const Eigen::VectorXf& point_weights) { RET_CHECK_GT(point_weights.size(), 0) << "The number of point weights must be positive!"; @@ -89,7 +89,7 @@ class FloatPrecisionProcrustesSolver : public ProcrustesSolver { RET_CHECK_GT(total_weight, kAbsoluteErrorEps) << "The total point weight is too small!"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } static Eigen::VectorXf ExtractSquareRoot( @@ -139,7 +139,7 @@ class FloatPrecisionProcrustesSolver : public ProcrustesSolver { // Note: the output `transform_mat` argument is used instead of `StatusOr<>` // return type in order to avoid Eigen memory alignment issues. Details: // https://eigen.tuxfamily.org/dox/group__TopicStructHavingEigenMembers.html - static mediapipe::Status InternalSolveWeightedOrthogonalProblem( + static absl::Status InternalSolveWeightedOrthogonalProblem( const Eigen::Matrix3Xf& sources, const Eigen::Matrix3Xf& targets, const Eigen::VectorXf& sqrt_weights, Eigen::Matrix4f& transform_mat) { // tranposed(A_w). @@ -195,7 +195,7 @@ class FloatPrecisionProcrustesSolver : public ProcrustesSolver { transform_mat = CombineTransformMatrix(rotation_and_scale, translation); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // `design_matrix` is a transposed LHS of (51) in the paper. @@ -203,7 +203,7 @@ class FloatPrecisionProcrustesSolver : public ProcrustesSolver { // Note: the output `rotation` argument is used instead of `StatusOr<>` // return type in order to avoid Eigen memory alignment issues. Details: // https://eigen.tuxfamily.org/dox/group__TopicStructHavingEigenMembers.html - static mediapipe::Status ComputeOptimalRotation( + static absl::Status ComputeOptimalRotation( const Eigen::Matrix3f& design_matrix, Eigen::Matrix3f& rotation) { RET_CHECK_GT(design_matrix.norm(), kAbsoluteErrorEps) << "Design matrix norm is too small!"; @@ -228,10 +228,10 @@ class FloatPrecisionProcrustesSolver : public ProcrustesSolver { // Transposed (52) from the paper. rotation = postrotation * prerotation; - return mediapipe::OkStatus(); + return absl::OkStatus(); } - static mediapipe::StatusOr ComputeOptimalScale( + static absl::StatusOr ComputeOptimalScale( const Eigen::Matrix3Xf& centered_weighted_sources, const Eigen::Matrix3Xf& weighted_sources, const Eigen::Matrix3Xf& weighted_targets, diff --git a/mediapipe/modules/face_geometry/libs/procrustes_solver.h b/mediapipe/modules/face_geometry/libs/procrustes_solver.h index 4147a21c7..c34b8f60b 100644 --- a/mediapipe/modules/face_geometry/libs/procrustes_solver.h +++ b/mediapipe/modules/face_geometry/libs/procrustes_solver.h @@ -56,7 +56,7 @@ class ProcrustesSolver { // Note: the output `transform_mat` argument is used instead of `StatusOr<>` // return type in order to avoid Eigen memory alignment issues. Details: // https://eigen.tuxfamily.org/dox/group__TopicStructHavingEigenMembers.html - virtual mediapipe::Status SolveWeightedOrthogonalProblem( + virtual absl::Status SolveWeightedOrthogonalProblem( const Eigen::Matrix3Xf& source_points, // const Eigen::Matrix3Xf& target_points, // const Eigen::VectorXf& point_weights, // diff --git a/mediapipe/modules/face_geometry/libs/validation_utils.cc b/mediapipe/modules/face_geometry/libs/validation_utils.cc index eceaebce1..eb4fd08f6 100644 --- a/mediapipe/modules/face_geometry/libs/validation_utils.cc +++ b/mediapipe/modules/face_geometry/libs/validation_utils.cc @@ -28,7 +28,7 @@ namespace mediapipe::face_geometry { -mediapipe::Status ValidatePerspectiveCamera( +absl::Status ValidatePerspectiveCamera( const PerspectiveCamera& perspective_camera) { static constexpr float kAbsoluteErrorEps = 1e-9f; @@ -46,18 +46,18 @@ mediapipe::Status ValidatePerspectiveCamera( 180.f) << "Vertical FOV must be less than 180 degrees with a margin of 10^{-9}"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidateEnvironment(const Environment& environment) { +absl::Status ValidateEnvironment(const Environment& environment) { MP_RETURN_IF_ERROR( ValidatePerspectiveCamera(environment.perspective_camera())) << "Invalid perspective camera!"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidateMesh3d(const Mesh3d& mesh_3d) { +absl::Status ValidateMesh3d(const Mesh3d& mesh_3d) { const std::size_t vertex_size = GetVertexSize(mesh_3d.vertex_type()); const std::size_t primitive_type = GetPrimitiveSize(mesh_3d.primitive_type()); @@ -73,10 +73,10 @@ mediapipe::Status ValidateMesh3d(const Mesh3d& mesh_3d) { << "All mesh indices must refer to an existing vertex!"; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidateFaceGeometry(const FaceGeometry& face_geometry) { +absl::Status ValidateFaceGeometry(const FaceGeometry& face_geometry) { MP_RETURN_IF_ERROR(ValidateMesh3d(face_geometry.mesh())) << "Invalid mesh!"; static constexpr char kInvalid4x4MatrixMessage[] = @@ -89,10 +89,10 @@ mediapipe::Status ValidateFaceGeometry(const FaceGeometry& face_geometry) { RET_CHECK_EQ(pose_transform_matrix.packed_data_size(), 16) << kInvalid4x4MatrixMessage; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidateGeometryPipelineMetadata( +absl::Status ValidateGeometryPipelineMetadata( const GeometryPipelineMetadata& metadata) { MP_RETURN_IF_ERROR(ValidateMesh3d(metadata.canonical_mesh())) << "Invalid canonical mesh!"; @@ -113,14 +113,14 @@ mediapipe::Status ValidateGeometryPipelineMetadata( << "All Procrustes basis landmarks must have a non-negative weight!"; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status ValidateFrameDimensions(int frame_width, int frame_height) { +absl::Status ValidateFrameDimensions(int frame_width, int frame_height) { RET_CHECK_GT(frame_width, 0) << "Frame width must be positive!"; RET_CHECK_GT(frame_height, 0) << "Frame height must be positive!"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe::face_geometry diff --git a/mediapipe/modules/face_geometry/libs/validation_utils.h b/mediapipe/modules/face_geometry/libs/validation_utils.h index b633d38b0..c0a7e08a6 100644 --- a/mediapipe/modules/face_geometry/libs/validation_utils.h +++ b/mediapipe/modules/face_geometry/libs/validation_utils.h @@ -29,26 +29,26 @@ namespace mediapipe::face_geometry { // Far Z must be greater than Near Z with a margin of `1e-9`. // Vertical FOV must be in range (0, 180) with a margin of `1e-9` on the range // edges. -mediapipe::Status ValidatePerspectiveCamera( +absl::Status ValidatePerspectiveCamera( const PerspectiveCamera& perspective_camera); // Validates `environment`. // // Environment's perspective camera must be valid. -mediapipe::Status ValidateEnvironment(const Environment& environment); +absl::Status ValidateEnvironment(const Environment& environment); // Validates `mesh_3d`. // // Mesh vertex buffer size must a multiple of the vertex size. // Mesh index buffer size must a multiple of the primitive size. // All mesh indices must reference an existing mesh vertex. -mediapipe::Status ValidateMesh3d(const Mesh3d& mesh_3d); +absl::Status ValidateMesh3d(const Mesh3d& mesh_3d); // Validates `face_geometry`. // // Face mesh must be valid. // Face pose transformation matrix must be a 4x4 matrix. -mediapipe::Status ValidateFaceGeometry(const FaceGeometry& face_geometry); +absl::Status ValidateFaceGeometry(const FaceGeometry& face_geometry); // Validates `metadata`. // @@ -57,13 +57,13 @@ mediapipe::Status ValidateFaceGeometry(const FaceGeometry& face_geometry); // All Procrustes basis indices must reference an existing canonical mesh // vertex. // All Procrustes basis landmarks must have a non-negative weight. -mediapipe::Status ValidateGeometryPipelineMetadata( +absl::Status ValidateGeometryPipelineMetadata( const GeometryPipelineMetadata& metadata); // Validates frame dimensions. // // Both frame width and frame height must be positive. -mediapipe::Status ValidateFrameDimensions(int frame_width, int frame_height); +absl::Status ValidateFrameDimensions(int frame_width, int frame_height); } // namespace mediapipe::face_geometry diff --git a/mediapipe/modules/face_geometry/protos/face_geometry.proto b/mediapipe/modules/face_geometry/protos/face_geometry.proto index 459861b29..b91a7d7d5 100644 --- a/mediapipe/modules/face_geometry/protos/face_geometry.proto +++ b/mediapipe/modules/face_geometry/protos/face_geometry.proto @@ -34,6 +34,15 @@ message FaceGeometry { // the input face landmarks after (1) being multiplied by the face pose // transformation matrix and then (2) being projected with a perspective // camera matrix of the same environment. + // + // NOTE: the triangular topology of the face mesh is only useful when derived + // from the 468 face landmarks, not from the 6 face detection landmarks + // (keypoints). The former don't cover the entire face and this mesh is + // defined here only to comply with the API. It should be considered as + // a placeholder and/or for debugging purposes. + // + // Use the face geometry derived from the face detection landmarks + // (keypoints) for the face pose transformation matrix, not the mesh. optional Mesh3d mesh = 1; // Defines a face pose transformation matrix, which provides mapping from diff --git a/mediapipe/modules/face_geometry/protos/geometry_pipeline_metadata.proto b/mediapipe/modules/face_geometry/protos/geometry_pipeline_metadata.proto index ab9906c0b..dac0e25e6 100644 --- a/mediapipe/modules/face_geometry/protos/geometry_pipeline_metadata.proto +++ b/mediapipe/modules/face_geometry/protos/geometry_pipeline_metadata.proto @@ -21,6 +21,12 @@ import "mediapipe/modules/face_geometry/protos/mesh_3d.proto"; option java_package = "com.google.mediapipe.modules.facegeometry"; option java_outer_classname = "GeometryPipelineMetadataProto"; +enum InputSource { + DEFAULT = 0; // FACE_LANDMARK_PIPELINE + FACE_LANDMARK_PIPELINE = 1; + FACE_DETECTION_PIPELINE = 2; +} + message WeightedLandmarkRef { // Defines the landmark ID. References an existing face landmark ID. optional uint32 landmark_id = 1; @@ -31,7 +37,18 @@ message WeightedLandmarkRef { optional float weight = 2; } +// Next field ID: 4 message GeometryPipelineMetadata { + // Defines the source of the input landmarks to let the underlying geometry + // pipeline to adjust in order to produce the best results. + // + // Face landmark pipeline is expected to produce 3D landmarks with relative Z + // coordinate, which is scaled as the X coordinate assuming the weak + // perspective projection camera model. + // + // Face landmark pipeline is expected to produce 2D landmarks with Z + // coordinate being equal to 0. + optional InputSource input_source = 3; // Defines a mesh surface for a canonical face. The canonical face mesh vertex // IDs are the same as the face landmark IDs. // diff --git a/mediapipe/modules/hand_landmark/BUILD b/mediapipe/modules/hand_landmark/BUILD index 9c0ac6dba..4dd45e130 100644 --- a/mediapipe/modules/hand_landmark/BUILD +++ b/mediapipe/modules/hand_landmark/BUILD @@ -23,6 +23,7 @@ package(default_visibility = ["//visibility:public"]) exports_files([ "hand_landmark.tflite", + "hand_landmark_sparse.tflite", "handedness.txt", ]) @@ -104,7 +105,7 @@ mediapipe_simple_subgraph( "//mediapipe/calculators/util:association_norm_rect_calculator", "//mediapipe/calculators/util:collection_has_min_size_calculator", "//mediapipe/calculators/util:filter_collection_calculator", - "//mediapipe/modules/palm_detection:palm_detection_gpu", + "//mediapipe/modules/palm_detection:palm_detection_cpu", ], ) diff --git a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc index f95d18722..3e3f5c8fa 100644 --- a/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc +++ b/mediapipe/modules/hand_landmark/calculators/hand_landmarks_to_rect_calculator.cc @@ -58,7 +58,7 @@ float ComputeRotation(const NormalizedLandmarkList& landmarks, return rotation; } -mediapipe::Status NormalizedLandmarkListToRect( +absl::Status NormalizedLandmarkListToRect( const NormalizedLandmarkList& landmarks, const std::pair& image_size, NormalizedRect* rect) { const float rotation = ComputeRotation(landmarks, image_size); @@ -117,7 +117,7 @@ mediapipe::Status NormalizedLandmarkListToRect( rect->set_height(height); rect->set_rotation(rotation); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace @@ -130,21 +130,21 @@ mediapipe::Status NormalizedLandmarkListToRect( // mean of PIP joints at the top. class HandLandmarksToRectCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc) { + static absl::Status GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kNormalizedLandmarksTag).Set(); cc->Inputs().Tag(kImageSizeTag).Set>(); cc->Outputs().Tag(kNormRectTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Open(CalculatorContext* cc) override { + absl::Status Open(CalculatorContext* cc) override { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } - mediapipe::Status Process(CalculatorContext* cc) override { + absl::Status Process(CalculatorContext* cc) override { if (cc->Inputs().Tag(kNormalizedLandmarksTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } RET_CHECK(!cc->Inputs().Tag(kImageSizeTag).IsEmpty()); @@ -159,7 +159,7 @@ class HandLandmarksToRectCalculator : public CalculatorBase { .Tag(kNormRectTag) .Add(output_rect.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(HandLandmarksToRectCalculator); diff --git a/mediapipe/modules/hand_landmark/hand_landmark_cpu.pbtxt b/mediapipe/modules/hand_landmark/hand_landmark_cpu.pbtxt index 246a1fc2e..c46b243ad 100644 --- a/mediapipe/modules/hand_landmark/hand_landmark_cpu.pbtxt +++ b/mediapipe/modules/hand_landmark/hand_landmark_cpu.pbtxt @@ -52,6 +52,7 @@ node { model_path: "mediapipe/modules/hand_landmark/hand_landmark.tflite" delegate { xnnpack {} } } + # } } diff --git a/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite b/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite new file mode 100755 index 000000000..28d2730a3 Binary files /dev/null and b/mediapipe/modules/hand_landmark/hand_landmark_sparse.tflite differ diff --git a/mediapipe/modules/holistic_landmark/README.md b/mediapipe/modules/holistic_landmark/README.md index 126518a51..d285f155a 100644 --- a/mediapipe/modules/holistic_landmark/README.md +++ b/mediapipe/modules/holistic_landmark/README.md @@ -3,4 +3,4 @@ Subgraphs|Details :--- | :--- [`HolisticLandmarkCpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/holistic_landmark/holistic_landmark_cpu.pbtxt)| Predicts pose + left/right hand + face landmarks. (CPU input) -[`HolisticLandmarkCpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt)| Predicts pose + left/right hand + face landmarks. (GPU input.) +[`HolisticLandmarkGpu`](https://github.com/google/mediapipe/tree/master/mediapipe/modules/holistic_landmark/holistic_landmark_gpu.pbtxt)| Predicts pose + left/right hand + face landmarks. (GPU input.) diff --git a/mediapipe/modules/holistic_landmark/calculators/hand_detections_from_pose_to_rects_calculator.cc b/mediapipe/modules/holistic_landmark/calculators/hand_detections_from_pose_to_rects_calculator.cc index 67faf60a4..5afdb8a2c 100644 --- a/mediapipe/modules/holistic_landmark/calculators/hand_detections_from_pose_to_rects_calculator.cc +++ b/mediapipe/modules/holistic_landmark/calculators/hand_detections_from_pose_to_rects_calculator.cc @@ -39,15 +39,15 @@ namespace {} // namespace class HandDetectionsFromPoseToRectsCalculator : public DetectionsToRectsCalculator { public: - ::mediapipe::Status Open(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; private: - ::mediapipe::Status DetectionToNormalizedRect( - const Detection& detection, const DetectionSpec& detection_spec, - NormalizedRect* rect) override; - ::mediapipe::Status ComputeRotation(const Detection& detection, - const DetectionSpec& detection_spec, - float* rotation) override; + ::absl::Status DetectionToNormalizedRect(const Detection& detection, + const DetectionSpec& detection_spec, + NormalizedRect* rect) override; + absl::Status ComputeRotation(const Detection& detection, + const DetectionSpec& detection_spec, + float* rotation) override; }; REGISTER_CALCULATOR(HandDetectionsFromPoseToRectsCalculator); @@ -61,7 +61,7 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE"; } // namespace -::mediapipe::Status HandDetectionsFromPoseToRectsCalculator::Open( +::absl::Status HandDetectionsFromPoseToRectsCalculator::Open( CalculatorContext* cc) { RET_CHECK(cc->Inputs().HasTag(kImageSizeTag)) << "Image size is required to calculate rotated rect."; @@ -72,10 +72,10 @@ constexpr char kImageSizeTag[] = "IMAGE_SIZE"; output_zero_rect_for_empty_detections_ = options_.output_zero_rect_for_empty_detections(); - return ::mediapipe::OkStatus(); + return ::absl::OkStatus(); } -::mediapipe::Status +::absl::Status HandDetectionsFromPoseToRectsCalculator ::DetectionToNormalizedRect( const Detection& detection, const DetectionSpec& detection_spec, NormalizedRect* rect) { @@ -118,10 +118,10 @@ HandDetectionsFromPoseToRectsCalculator ::DetectionToNormalizedRect( rect->set_width(box_size / image_size->first); rect->set_height(box_size / image_size->second); - return ::mediapipe::OkStatus(); + return ::absl::OkStatus(); } -::mediapipe::Status HandDetectionsFromPoseToRectsCalculator::ComputeRotation( +absl::Status HandDetectionsFromPoseToRectsCalculator::ComputeRotation( const Detection& detection, const DetectionSpec& detection_spec, float* rotation) { const auto& location_data = detection.location_data(); @@ -150,7 +150,7 @@ HandDetectionsFromPoseToRectsCalculator ::DetectionToNormalizedRect( *rotation = NormalizeRadians( target_angle_ - std::atan2(-(y_middle - y_wrist), x_middle - x_wrist)); - return ::mediapipe::OkStatus(); + return ::absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc index 80560694c..0da6cd7f7 100644 --- a/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc +++ b/mediapipe/modules/holistic_landmark/calculators/roi_tracking_calculator.cc @@ -253,36 +253,36 @@ bool LandmarksRequirementsSatisfied(const NormalizedLandmarkList& landmarks, // } class RoiTrackingCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: RoiTrackingCalculatorOptions options_; }; REGISTER_CALCULATOR(RoiTrackingCalculator); -mediapipe::Status RoiTrackingCalculator::GetContract(CalculatorContract* cc) { +absl::Status RoiTrackingCalculator::GetContract(CalculatorContract* cc) { cc->Inputs().Tag(kPrevLandmarksTag).Set(); cc->Inputs().Tag(kPrevLandmarksRectTag).Set(); cc->Inputs().Tag(kRecropRectTag).Set(); cc->Inputs().Tag(kImageSizeTag).Set>(); cc->Outputs().Tag(kTrackingRectTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RoiTrackingCalculator::Open(CalculatorContext* cc) { +absl::Status RoiTrackingCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status RoiTrackingCalculator::Process(CalculatorContext* cc) { +absl::Status RoiTrackingCalculator::Process(CalculatorContext* cc) { // If there is no current frame re-crop rect (i.e. object is not present on // the current frame) - return empty packet. if (cc->Inputs().Tag(kRecropRectTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } // If there is no previous rect, but there is current re-crop rect - return @@ -291,7 +291,7 @@ mediapipe::Status RoiTrackingCalculator::Process(CalculatorContext* cc) { cc->Outputs() .Tag(kTrackingRectTag) .AddPacket(cc->Inputs().Tag(kRecropRectTag).Value()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } // At this point we have both previous rect (which also means we have previous @@ -352,7 +352,7 @@ mediapipe::Status RoiTrackingCalculator::Process(CalculatorContext* cc) { VLOG(1) << "Lost tracking: check messages above for details"; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/modules/objectron/BUILD b/mediapipe/modules/objectron/BUILD index 832d363f0..cee576879 100644 --- a/mediapipe/modules/objectron/BUILD +++ b/mediapipe/modules/objectron/BUILD @@ -21,6 +21,17 @@ licenses(["notice"]) package(default_visibility = ["//visibility:public"]) +exports_files([ + "object_detection_3d_camera.tflite", + "object_detection_3d_chair.tflite", + "object_detection_3d_chair_1stage.tflite", + "object_detection_3d_cup.tflite", + "object_detection_3d_sneakers.tflite", + "object_detection_3d_sneakers_1stage.tflite", + "object_detection_oidv4_labelmap.txt", + "object_detection_ssd_mobilenetv2_oidv4_fp16.tflite", +]) + mediapipe_simple_subgraph( name = "objectron_detection_1stage_gpu", graph = "objectron_detection_1stage_gpu.pbtxt", @@ -98,11 +109,10 @@ mediapipe_simple_subgraph( graph = "object_detection_oid_v4_gpu.pbtxt", register_as = "ObjectDetectionOidV4Subgraph", deps = [ - "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", "//mediapipe/calculators/tflite:ssd_anchors_calculator", - "//mediapipe/calculators/tflite:tflite_converter_calculator", - "//mediapipe/calculators/tflite:tflite_inference_calculator", - "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator", "//mediapipe/calculators/util:detection_label_id_to_text_calculator", "//mediapipe/calculators/util:non_max_suppression_calculator", "//mediapipe/modules/objectron/calculators:filter_detection_calculator", @@ -114,11 +124,10 @@ mediapipe_simple_subgraph( graph = "object_detection_oid_v4_cpu.pbtxt", register_as = "ObjectDetectionOidV4Subgraph", deps = [ - "//mediapipe/calculators/image:image_transformation_calculator", + "//mediapipe/calculators/tensor:image_to_tensor_calculator", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", "//mediapipe/calculators/tflite:ssd_anchors_calculator", - "//mediapipe/calculators/tflite:tflite_converter_calculator", - "//mediapipe/calculators/tflite:tflite_inference_calculator", - "//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator", "//mediapipe/calculators/util:detection_label_id_to_text_calculator", "//mediapipe/calculators/util:non_max_suppression_calculator", "//mediapipe/modules/objectron/calculators:filter_detection_calculator", @@ -139,9 +148,11 @@ mediapipe_simple_subgraph( "//mediapipe/calculators/core:gate_calculator", "//mediapipe/calculators/core:previous_loopback_calculator", "//mediapipe/calculators/image:image_properties_calculator", + "//mediapipe/calculators/tflite:tflite_model_calculator", "//mediapipe/calculators/util:association_norm_rect_calculator", "//mediapipe/calculators/util:collection_has_min_size_calculator", "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:local_file_contents_calculator", "//mediapipe/modules/objectron/calculators:frame_annotation_to_rect_calculator", "//mediapipe/modules/objectron/calculators:landmarks_to_frame_annotation_calculator", "//mediapipe/modules/objectron/calculators:lift_2d_frame_annotation_to_3d_calculator", diff --git a/mediapipe/modules/objectron/calculators/BUILD b/mediapipe/modules/objectron/calculators/BUILD index 0a8b326a6..eb985d04d 100644 --- a/mediapipe/modules/objectron/calculators/BUILD +++ b/mediapipe/modules/objectron/calculators/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") +load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type") licenses(["notice"]) @@ -40,6 +41,15 @@ mediapipe_proto_library( ], ) +mediapipe_register_type( + base_name = "annotation", + include_headers = ["mediapipe/modules/objectron/calculators/annotation_data.pb.h"], + types = [ + "::mediapipe::FrameAnnotation", + ], + deps = [":annotation_cc_proto"], +) + mediapipe_proto_library( name = "camera_parameters_proto", srcs = ["camera_parameters.proto"], @@ -153,6 +163,7 @@ cc_library( deps = [ ":annotation_cc_proto", ":belief_decoder_config_cc_proto", + ":box", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_imgproc", @@ -291,7 +302,6 @@ cc_library( ":decoder", ":lift_2d_frame_annotation_to_3d_calculator_cc_proto", ":tensor_util", - ":tflite_tensors_to_objects_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/deps:file_path", "//mediapipe/framework/formats:detection_cc_proto", @@ -311,7 +321,6 @@ cc_library( srcs = ["frame_annotation_to_rect_calculator.cc"], deps = [ ":annotation_cc_proto", - ":box", ":frame_annotation_to_rect_calculator_cc_proto", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:rect_cc_proto", diff --git a/mediapipe/modules/objectron/calculators/annotation_data.proto b/mediapipe/modules/objectron/calculators/annotation_data.proto index f1a9600eb..897e7001f 100644 --- a/mediapipe/modules/objectron/calculators/annotation_data.proto +++ b/mediapipe/modules/objectron/calculators/annotation_data.proto @@ -56,6 +56,18 @@ message ObjectAnnotation { // Visibiity of this annotation in a frame. float visibility = 3; + + // 3x3 row-major rotation matrix describing the orientation of the rigid + // object's frame of reference in the camera-coordinate system. + repeated float rotation = 4; + + // 3x1 vector describing the translation of the rigid object's frame of + // reference in the camera-coordinate system in meters. + repeated float translation = 5; + + // 3x1 vector describing the scale of the rigid object's frame of reference in + // the camera-coordinate system. + repeated float scale = 6; } message FrameAnnotation { diff --git a/mediapipe/modules/objectron/calculators/decoder.cc b/mediapipe/modules/objectron/calculators/decoder.cc index 0f66c3a79..6eff637ca 100644 --- a/mediapipe/modules/objectron/calculators/decoder.cc +++ b/mediapipe/modules/objectron/calculators/decoder.cc @@ -22,6 +22,7 @@ #include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/modules/objectron/calculators/annotation_data.pb.h" +#include "mediapipe/modules/objectron/calculators/box.h" namespace mediapipe { constexpr int Decoder::kNumOffsetmaps = 16; @@ -203,8 +204,9 @@ absl::Status Decoder::Lift2DTo3D( float u, v; for (int i = 0; i < 8; ++i) { const auto& keypoint2d = annotation.keypoints(i + 1).point_2d(); + // Convert 2d point from screen coordinates to NDC coordinates([-1, 1]). if (portrait) { - // swap x and y given that our image is in portrait orientation + // Swap x and y given that our image is in portrait orientation u = keypoint2d.y() * 2 - 1; v = keypoint2d.x() * 2 - 1; } else { @@ -237,6 +239,7 @@ absl::Status Decoder::Lift2DTo3D( Eigen::VectorXf eigen_vec = eigen_solver.eigenvectors().col(0); Eigen::Map> control_matrix( eigen_vec.data()); + // All 3d points should be in front of camera (z < 0). if (control_matrix(0, 2) > 0) { control_matrix = -control_matrix; } @@ -246,10 +249,36 @@ absl::Status Decoder::Lift2DTo3D( // Then set the 8 vertices. Eigen::Matrix vertices = epnp_alpha_ * control_matrix; + + std::vector vertices_vec; + vertices_vec.emplace_back(Eigen::Vector3f( + control_matrix(0, 0), control_matrix(0, 1), control_matrix(0, 2))); for (int i = 0; i < 8; ++i) { SetPoint3d(vertices(i, 0), vertices(i, 1), vertices(i, 2), annotation.mutable_keypoints(i + 1)->mutable_point_3d()); + vertices_vec.emplace_back( + Eigen::Vector3f(vertices(i, 0), vertices(i, 1), vertices(i, 2))); } + + // Fit a box to the vertices to get box scale, rotation, translation. + Box box("category"); + box.Fit(vertices_vec); + const Eigen::Matrix rotation = + box.GetRotation(); + const Eigen::Vector3f translation = box.GetTranslation(); + const Eigen::Vector3f scale = box.GetScale(); + // Fill box rotation. + std::vector rotation_vec(rotation.data(), + rotation.data() + rotation.size()); + *annotation.mutable_rotation() = {rotation_vec.begin(), rotation_vec.end()}; + // Fill box translation. + std::vector translation_vec(translation.data(), + translation.data() + translation.size()); + *annotation.mutable_translation() = {translation_vec.begin(), + translation_vec.end()}; + // Fill box scale. + std::vector scale_vec(scale.data(), scale.data() + scale.size()); + *annotation.mutable_scale() = {scale_vec.begin(), scale_vec.end()}; } return absl::OkStatus(); } diff --git a/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc b/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc index 37eb74a81..ea238a86b 100644 --- a/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc +++ b/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc @@ -77,8 +77,8 @@ struct FirstGreaterComparator { } }; -mediapipe::Status SortLabelsByDecreasingScore(const Detection& detection, - Detection* sorted_detection) { +absl::Status SortLabelsByDecreasingScore(const Detection& detection, + Detection* sorted_detection) { RET_CHECK(sorted_detection); RET_CHECK_EQ(detection.score_size(), detection.label_size()); if (!detection.label_id().empty()) { @@ -110,14 +110,14 @@ mediapipe::Status SortLabelsByDecreasingScore(const Detection& detection, sorted_detection->set_label_id(i, detection.label_id(index)); } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } class FilterDetectionCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: bool IsValidLabel(const std::string& label); @@ -134,8 +134,7 @@ class FilterDetectionCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(FilterDetectionCalculator); -mediapipe::Status FilterDetectionCalculator::GetContract( - CalculatorContract* cc) { +absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -153,10 +152,10 @@ mediapipe::Status FilterDetectionCalculator::GetContract( if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) { cc->InputSidePackets().Tag(kLabelsCsvTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FilterDetectionCalculator::Open(CalculatorContext* cc) { +absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) || @@ -187,12 +186,12 @@ mediapipe::Status FilterDetectionCalculator::Open(CalculatorContext* cc) { limit_labels_ = false; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FilterDetectionCalculator::Process(CalculatorContext* cc) { +absl::Status FilterDetectionCalculator::Process(CalculatorContext* cc) { if (limit_labels_ && allowed_labels_.empty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } Detections detections; if (cc->Inputs().HasTag(kDetectionsTag)) { @@ -234,7 +233,7 @@ mediapipe::Status FilterDetectionCalculator::Process(CalculatorContext* cc) { .Tag(kDetectionsTag) .Add(new Detection((*outputs)[0]), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } bool FilterDetectionCalculator::IsValidLabel(const std::string& label) { @@ -258,7 +257,6 @@ bool FilterDetectionCalculator::IsValidScore(float score) { LOG(ERROR) << "Filter out detection with high score " << score; return false; } - LOG(ERROR) << "Pass detection with score " << score; return true; } diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc index 741efc777..476f8cb54 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.cc @@ -21,7 +21,6 @@ #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/modules/objectron/calculators/annotation_data.pb.h" -#include "mediapipe/modules/objectron/calculators/box.h" #include "mediapipe/modules/objectron/calculators/frame_annotation_to_rect_calculator.pb.h" namespace mediapipe { @@ -48,9 +47,9 @@ class FrameAnnotationToRectCalculator : public CalculatorBase { TOP_VIEW_OFF, }; - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: void AddAnnotationToRect(const ObjectAnnotation& annotation, @@ -65,7 +64,7 @@ class FrameAnnotationToRectCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(FrameAnnotationToRectCalculator); -mediapipe::Status FrameAnnotationToRectCalculator::GetContract( +absl::Status FrameAnnotationToRectCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -77,23 +76,22 @@ mediapipe::Status FrameAnnotationToRectCalculator::GetContract( if (cc->Outputs().HasTag(kOutputNormRectsTag)) { cc->Outputs().Tag(kOutputNormRectsTag).Set>(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FrameAnnotationToRectCalculator::Open(CalculatorContext* cc) { +absl::Status FrameAnnotationToRectCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); status_ = TOP_VIEW_OFF; const auto& options = cc->Options(); off_threshold_ = options.off_threshold(); on_threshold_ = options.on_threshold(); RET_CHECK(off_threshold_ <= on_threshold_); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FrameAnnotationToRectCalculator::Process( - CalculatorContext* cc) { +absl::Status FrameAnnotationToRectCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().Tag(kInputFrameAnnotationTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } auto output_rects = absl::make_unique>(); const auto& frame_annotation = @@ -106,7 +104,7 @@ mediapipe::Status FrameAnnotationToRectCalculator::Process( cc->Outputs() .Tag(kOutputNormRectsTag) .Add(output_rects.release(), cc->InputTimestamp()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void FrameAnnotationToRectCalculator::AddAnnotationToRect( @@ -133,20 +131,11 @@ void FrameAnnotationToRectCalculator::AddAnnotationToRect( float FrameAnnotationToRectCalculator::RotationAngleFromAnnotation( const ObjectAnnotation& annotation) { - Box box("category"); - std::vector vertices_3d; - std::vector vertices_2d; - for (const auto& keypoint : annotation.keypoints()) { - const auto& point_3d = keypoint.point_3d(); - const auto& point_2d = keypoint.point_2d(); - vertices_3d.emplace_back( - Vector3f(point_3d.x(), point_3d.y(), point_3d.z())); - vertices_2d.emplace_back(Vector2f(point_2d.x(), point_2d.y())); - } - box.Fit(vertices_3d); - Vector3f scale = box.GetScale(); - Matrix3fRM box_rotation = box.GetRotation(); - Vector3f box_translation = box.GetTranslation(); + // Get box rotation and translation from annotation. + const auto box_rotation = + Eigen::Map(annotation.rotation().data()); + const auto box_translation = + Eigen::Map(annotation.translation().data()); // Rotation angle to use when top-view is on(top-view on), // Which will make z-axis upright after the rotation. @@ -180,9 +169,9 @@ float FrameAnnotationToRectCalculator::RotationAngleFromPose( const Vector3f& vec) { auto p1 = rotation * vec + translation; auto p2 = -rotation * vec + translation; - const float dy = p2[2] * p1[1] - p1[2] * p2[1]; - const float dx = p2[2] * p1[0] - p1[2] * p2[0]; - return std::atan2(-dy, dx); + const float dy = p2[2] * p2[1] - p1[2] * p1[1]; + const float dx = p2[2] * p2[0] - p1[2] * p1[0]; + return M_PI / 2 - std::atan2(dy, dx); } } // namespace mediapipe diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc b/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc index 55b1acac6..74678804f 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_to_timed_box_list_calculator.cc @@ -47,15 +47,15 @@ namespace mediapipe { // } class FrameAnnotationToTimedBoxListCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; }; REGISTER_CALCULATOR(FrameAnnotationToTimedBoxListCalculator); -mediapipe::Status FrameAnnotationToTimedBoxListCalculator::GetContract( +absl::Status FrameAnnotationToTimedBoxListCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -67,15 +67,15 @@ mediapipe::Status FrameAnnotationToTimedBoxListCalculator::GetContract( if (cc->Outputs().HasTag(kOutputStreamTag)) { cc->Outputs().Tag(kOutputStreamTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FrameAnnotationToTimedBoxListCalculator::Open( +absl::Status FrameAnnotationToTimedBoxListCalculator::Open( CalculatorContext* cc) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FrameAnnotationToTimedBoxListCalculator::Process( +absl::Status FrameAnnotationToTimedBoxListCalculator::Process( CalculatorContext* cc) { if (cc->Inputs().HasTag(kInputStreamTag) && !cc->Inputs().Tag(kInputStreamTag).IsEmpty()) { @@ -104,12 +104,12 @@ mediapipe::Status FrameAnnotationToTimedBoxListCalculator::Process( } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FrameAnnotationToTimedBoxListCalculator::Close( +absl::Status FrameAnnotationToTimedBoxListCalculator::Close( CalculatorContext* cc) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/modules/objectron/calculators/frame_annotation_tracker_calculator.cc b/mediapipe/modules/objectron/calculators/frame_annotation_tracker_calculator.cc index dfdc581a2..9079b9af0 100644 --- a/mediapipe/modules/objectron/calculators/frame_annotation_tracker_calculator.cc +++ b/mediapipe/modules/objectron/calculators/frame_annotation_tracker_calculator.cc @@ -52,18 +52,18 @@ namespace mediapipe { // } class FrameAnnotationTrackerCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: std::unique_ptr frame_annotation_tracker_; }; REGISTER_CALCULATOR(FrameAnnotationTrackerCalculator); -mediapipe::Status FrameAnnotationTrackerCalculator::GetContract( +absl::Status FrameAnnotationTrackerCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -80,19 +80,17 @@ mediapipe::Status FrameAnnotationTrackerCalculator::GetContract( if (cc->Outputs().HasTag(kOutputCancelObjectIdTag)) { cc->Outputs().Tag(kOutputCancelObjectIdTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FrameAnnotationTrackerCalculator::Open( - CalculatorContext* cc) { +absl::Status FrameAnnotationTrackerCalculator::Open(CalculatorContext* cc) { const auto& options = cc->Options(); frame_annotation_tracker_ = absl::make_unique( options.iou_threshold(), options.img_width(), options.img_height()); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FrameAnnotationTrackerCalculator::Process( - CalculatorContext* cc) { +absl::Status FrameAnnotationTrackerCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().HasTag(kInputFrameAnnotationTag) && !cc->Inputs().Tag(kInputFrameAnnotationTag).IsEmpty()) { frame_annotation_tracker_->AddDetectionResult( @@ -126,12 +124,11 @@ mediapipe::Status FrameAnnotationTrackerCalculator::Process( } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status FrameAnnotationTrackerCalculator::Close( - CalculatorContext* cc) { - return mediapipe::OkStatus(); +absl::Status FrameAnnotationTrackerCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/modules/objectron/calculators/landmarks_to_frame_annotation_calculator.cc b/mediapipe/modules/objectron/calculators/landmarks_to_frame_annotation_calculator.cc index 0ca16aff6..60c48766c 100644 --- a/mediapipe/modules/objectron/calculators/landmarks_to_frame_annotation_calculator.cc +++ b/mediapipe/modules/objectron/calculators/landmarks_to_frame_annotation_calculator.cc @@ -31,9 +31,9 @@ constexpr char kOutputFrameAnnotationTag[] = "FRAME_ANNOTATION"; // A calculator that converts NormalizedLandmarkList to FrameAnnotation proto. class LandmarksToFrameAnnotationCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; + static absl::Status GetContract(CalculatorContract* cc); + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; private: void AddLandmarksToFrameAnnotation(const NormalizedLandmarkList& landmarks, @@ -41,7 +41,7 @@ class LandmarksToFrameAnnotationCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(LandmarksToFrameAnnotationCalculator); -mediapipe::Status LandmarksToFrameAnnotationCalculator::GetContract( +absl::Status LandmarksToFrameAnnotationCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -57,16 +57,15 @@ mediapipe::Status LandmarksToFrameAnnotationCalculator::GetContract( if (cc->Outputs().HasTag(kOutputFrameAnnotationTag)) { cc->Outputs().Tag(kOutputFrameAnnotationTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LandmarksToFrameAnnotationCalculator::Open( - CalculatorContext* cc) { +absl::Status LandmarksToFrameAnnotationCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LandmarksToFrameAnnotationCalculator::Process( +absl::Status LandmarksToFrameAnnotationCalculator::Process( CalculatorContext* cc) { auto frame_annotation = absl::make_unique(); @@ -96,7 +95,7 @@ mediapipe::Status LandmarksToFrameAnnotationCalculator::Process( .Tag(kOutputFrameAnnotationTag) .Add(frame_annotation.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } void LandmarksToFrameAnnotationCalculator::AddLandmarksToFrameAnnotation( diff --git a/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.cc b/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.cc index 2400cd813..1405e5ac0 100644 --- a/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.cc +++ b/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.cc @@ -55,16 +55,16 @@ namespace mediapipe { // } class Lift2DFrameAnnotationTo3DCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status ProcessCPU(CalculatorContext* cc, - FrameAnnotation* output_objects); - mediapipe::Status LoadOptions(CalculatorContext* cc); + absl::Status ProcessCPU(CalculatorContext* cc, + FrameAnnotation* output_objects); + absl::Status LoadOptions(CalculatorContext* cc); // Increment and assign object ID for each detected object. // In a single MediaPipe session, the IDs are unique. @@ -78,37 +78,39 @@ class Lift2DFrameAnnotationTo3DCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(Lift2DFrameAnnotationTo3DCalculator); -mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::GetContract( +absl::Status Lift2DFrameAnnotationTo3DCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(cc->Inputs().HasTag(kInputStreamTag)); RET_CHECK(cc->Outputs().HasTag(kOutputStreamTag)); cc->Inputs().Tag(kInputStreamTag).Set(); cc->Outputs().Tag(kOutputStreamTag).Set(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Open( - CalculatorContext* cc) { +absl::Status Lift2DFrameAnnotationTo3DCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); MP_RETURN_IF_ERROR(LoadOptions(cc)); + // Load camera intrinsic matrix. + const float fx = options_.normalized_focal_x(); + const float fy = options_.normalized_focal_y(); + const float px = options_.normalized_principal_point_x(); + const float py = options_.normalized_principal_point_y(); // clang-format off - projection_matrix_ << - 1.5731, 0, 0, 0, - 0, 2.0975, 0, 0, - 0, 0, -1.0002, -0.2, - 0, 0, -1, 0; + projection_matrix_ << fx, 0., px, 0., + 0., fy, py, 0., + 0., 0., -1., 0., + 0., 0., -1., 0.; // clang-format on - decoder_ = absl::make_unique( BeliefDecoderConfig(options_.decoder_config())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Process( +absl::Status Lift2DFrameAnnotationTo3DCalculator::Process( CalculatorContext* cc) { if (cc->Inputs().Tag(kInputStreamTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } auto output_objects = absl::make_unique(); @@ -122,17 +124,17 @@ mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Process( .Add(output_objects.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::ProcessCPU( +absl::Status Lift2DFrameAnnotationTo3DCalculator::ProcessCPU( CalculatorContext* cc, FrameAnnotation* output_objects) { const auto& input_frame_annotations = cc->Inputs().Tag(kInputStreamTag).Get(); // Copy the input frame annotation to the output *output_objects = input_frame_annotations; - auto status = decoder_->Lift2DTo3D(projection_matrix_, /*portrait*/ true, + auto status = decoder_->Lift2DTo3D(projection_matrix_, /*portrait*/ false, output_objects); if (!status.ok()) { LOG(ERROR) << status; @@ -141,20 +143,19 @@ mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::ProcessCPU( AssignObjectIdAndTimestamp(cc->InputTimestamp().Microseconds(), output_objects); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::Close( - CalculatorContext* cc) { - return mediapipe::OkStatus(); +absl::Status Lift2DFrameAnnotationTo3DCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); } -mediapipe::Status Lift2DFrameAnnotationTo3DCalculator::LoadOptions( +absl::Status Lift2DFrameAnnotationTo3DCalculator::LoadOptions( CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = cc->Options(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void Lift2DFrameAnnotationTo3DCalculator::AssignObjectIdAndTimestamp( diff --git a/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.proto b/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.proto index 5e73ea600..a3005c1f9 100644 --- a/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.proto +++ b/mediapipe/modules/objectron/calculators/lift_2d_frame_annotation_to_3d_calculator.proto @@ -27,4 +27,16 @@ message Lift2DFrameAnnotationTo3DCalculatorOptions { } optional BeliefDecoderConfig decoder_config = 1; + + // Camera focal length along x, normalized by width/2. + optional float normalized_focal_x = 2 [default = 1.0]; + + // Camera focal length along y, normalized by height/2. + optional float normalized_focal_y = 3 [default = 1.0]; + + // Camera principle point x, normalized by width/2, origin is image center. + optional float normalized_principal_point_x = 4 [default = 0.0]; + + // Camera principle point y, normalized by height/2, origin is image center. + optional float normalized_principal_point_y = 5 [default = 0.0]; } diff --git a/mediapipe/modules/objectron/calculators/tensors_to_objects_calculator.cc b/mediapipe/modules/objectron/calculators/tensors_to_objects_calculator.cc index 2147ea5ce..6989c34ce 100644 --- a/mediapipe/modules/objectron/calculators/tensors_to_objects_calculator.cc +++ b/mediapipe/modules/objectron/calculators/tensors_to_objects_calculator.cc @@ -58,16 +58,16 @@ namespace mediapipe { // } class TensorsToObjectsCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status ProcessCPU(CalculatorContext* cc, - FrameAnnotation* output_objects); - mediapipe::Status LoadOptions(CalculatorContext* cc); + absl::Status ProcessCPU(CalculatorContext* cc, + FrameAnnotation* output_objects); + absl::Status LoadOptions(CalculatorContext* cc); // Takes point_3d in FrameAnnotation, projects to 2D, and overwrite the // point_2d field with the projection. void Project3DTo2D(bool portrait, FrameAnnotation* annotation) const; @@ -87,8 +87,7 @@ class TensorsToObjectsCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TensorsToObjectsCalculator); -mediapipe::Status TensorsToObjectsCalculator::GetContract( - CalculatorContract* cc) { +absl::Status TensorsToObjectsCalculator::GetContract(CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -99,10 +98,10 @@ mediapipe::Status TensorsToObjectsCalculator::GetContract( if (cc->Outputs().HasTag(kOutputStreamTag)) { cc->Outputs().Tag(kOutputStreamTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToObjectsCalculator::Open(CalculatorContext* cc) { +absl::Status TensorsToObjectsCalculator::Open(CalculatorContext* cc) { MP_RETURN_IF_ERROR(LoadOptions(cc)); // clang-format off projection_matrix_ << @@ -114,12 +113,12 @@ mediapipe::Status TensorsToObjectsCalculator::Open(CalculatorContext* cc) { decoder_ = absl::make_unique( BeliefDecoderConfig(options_.decoder_config())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToObjectsCalculator::Process(CalculatorContext* cc) { +absl::Status TensorsToObjectsCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().Tag(kInputStreamTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } auto output_objects = absl::make_unique(); @@ -133,10 +132,10 @@ mediapipe::Status TensorsToObjectsCalculator::Process(CalculatorContext* cc) { .Add(output_objects.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToObjectsCalculator::ProcessCPU( +absl::Status TensorsToObjectsCalculator::ProcessCPU( CalculatorContext* cc, FrameAnnotation* output_objects) { const auto& input_tensors = cc->Inputs().Tag(kInputStreamTag).Get>(); @@ -156,15 +155,14 @@ mediapipe::Status TensorsToObjectsCalculator::ProcessCPU( AssignObjectIdAndTimestamp(cc->InputTimestamp().Microseconds(), output_objects); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TensorsToObjectsCalculator::Close(CalculatorContext* cc) { - return mediapipe::OkStatus(); +absl::Status TensorsToObjectsCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); } -mediapipe::Status TensorsToObjectsCalculator::LoadOptions( - CalculatorContext* cc) { +absl::Status TensorsToObjectsCalculator::LoadOptions(CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = cc->Options<::mediapipe::TensorsToObjectsCalculatorOptions>(); @@ -174,7 +172,7 @@ mediapipe::Status TensorsToObjectsCalculator::LoadOptions( // Currently only support 2D when num_values_per_keypoint equals to 2. CHECK_EQ(options_.num_values_per_keypoint(), 2); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void TensorsToObjectsCalculator::Project3DTo2D( diff --git a/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc b/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc index 55d7104bd..e3686f65e 100644 --- a/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc +++ b/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.cc @@ -59,16 +59,16 @@ namespace mediapipe { // } class TfLiteTensorsToObjectsCalculator : public CalculatorBase { public: - static mediapipe::Status GetContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); - mediapipe::Status Open(CalculatorContext* cc) override; - mediapipe::Status Process(CalculatorContext* cc) override; - mediapipe::Status Close(CalculatorContext* cc) override; + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + absl::Status Close(CalculatorContext* cc) override; private: - mediapipe::Status ProcessCPU(CalculatorContext* cc, - FrameAnnotation* output_objects); - mediapipe::Status LoadOptions(CalculatorContext* cc); + absl::Status ProcessCPU(CalculatorContext* cc, + FrameAnnotation* output_objects); + absl::Status LoadOptions(CalculatorContext* cc); // Takes point_3d in FrameAnnotation, projects to 2D, and overwrite the // point_2d field with the projection. void Project3DTo2D(bool portrait, FrameAnnotation* annotation) const; @@ -88,7 +88,7 @@ class TfLiteTensorsToObjectsCalculator : public CalculatorBase { }; REGISTER_CALCULATOR(TfLiteTensorsToObjectsCalculator); -mediapipe::Status TfLiteTensorsToObjectsCalculator::GetContract( +absl::Status TfLiteTensorsToObjectsCalculator::GetContract( CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); @@ -100,29 +100,31 @@ mediapipe::Status TfLiteTensorsToObjectsCalculator::GetContract( if (cc->Outputs().HasTag(kOutputStreamTag)) { cc->Outputs().Tag(kOutputStreamTag).Set(); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToObjectsCalculator::Open( - CalculatorContext* cc) { +absl::Status TfLiteTensorsToObjectsCalculator::Open(CalculatorContext* cc) { MP_RETURN_IF_ERROR(LoadOptions(cc)); + // Load camera intrinsic matrix. + const float fx = options_.normalized_focal_x(); + const float fy = options_.normalized_focal_y(); + const float px = options_.normalized_principal_point_x(); + const float py = options_.normalized_principal_point_y(); // clang-format off - projection_matrix_ << - 1.5731, 0, 0, 0, - 0, 2.0975, 0, 0, - 0, 0, -1.0002, -0.2, - 0, 0, -1, 0; + projection_matrix_ << fx, 0., px, 0., + 0., fy, py, 0., + 0., 0., -1., 0., + 0., 0., -1., 0.; // clang-format on decoder_ = absl::make_unique( BeliefDecoderConfig(options_.decoder_config())); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToObjectsCalculator::Process( - CalculatorContext* cc) { +absl::Status TfLiteTensorsToObjectsCalculator::Process(CalculatorContext* cc) { if (cc->Inputs().Tag(kInputStreamTag).IsEmpty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } auto output_objects = absl::make_unique(); @@ -136,10 +138,10 @@ mediapipe::Status TfLiteTensorsToObjectsCalculator::Process( .Add(output_objects.release(), cc->InputTimestamp()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToObjectsCalculator::ProcessCPU( +absl::Status TfLiteTensorsToObjectsCalculator::ProcessCPU( CalculatorContext* cc, FrameAnnotation* output_objects) { const auto& input_tensors = cc->Inputs().Tag(kInputStreamTag).Get>(); @@ -159,15 +161,14 @@ mediapipe::Status TfLiteTensorsToObjectsCalculator::ProcessCPU( AssignObjectIdAndTimestamp(cc->InputTimestamp().Microseconds(), output_objects); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToObjectsCalculator::Close( - CalculatorContext* cc) { - return mediapipe::OkStatus(); +absl::Status TfLiteTensorsToObjectsCalculator::Close(CalculatorContext* cc) { + return absl::OkStatus(); } -mediapipe::Status TfLiteTensorsToObjectsCalculator::LoadOptions( +absl::Status TfLiteTensorsToObjectsCalculator::LoadOptions( CalculatorContext* cc) { // Get calculator options specified in the graph. options_ = @@ -179,7 +180,7 @@ mediapipe::Status TfLiteTensorsToObjectsCalculator::LoadOptions( // Currently only support 2D when num_values_per_keypoint equals to 2. CHECK_EQ(options_.num_values_per_keypoint(), 2); - return mediapipe::OkStatus(); + return absl::OkStatus(); } void TfLiteTensorsToObjectsCalculator::Project3DTo2D( diff --git a/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.proto b/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.proto index 7237ee559..32520d98b 100644 --- a/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.proto +++ b/mediapipe/modules/objectron/calculators/tflite_tensors_to_objects_calculator.proto @@ -36,4 +36,16 @@ message TfLiteTensorsToObjectsCalculatorOptions { optional int32 num_values_per_keypoint = 3 [default = 2]; optional BeliefDecoderConfig decoder_config = 4; + + // Camera focal length along x, normalized by width/2. + optional float normalized_focal_x = 5 [default = 1.0]; + + // Camera focal length along y, normalized by height/2. + optional float normalized_focal_y = 6 [default = 1.0]; + + // Camera principle point x, normalized by width/2, origin is image center. + optional float normalized_principal_point_x = 7 [default = 0.0]; + + // Camera principle point y, normalized by height/2, origin is image center. + optional float normalized_principal_point_y = 8 [default = 0.0]; } diff --git a/mediapipe/models/object_detection_3d_camera.tflite b/mediapipe/modules/objectron/object_detection_3d_camera.tflite similarity index 100% rename from mediapipe/models/object_detection_3d_camera.tflite rename to mediapipe/modules/objectron/object_detection_3d_camera.tflite diff --git a/mediapipe/models/object_detection_3d_chair.tflite b/mediapipe/modules/objectron/object_detection_3d_chair.tflite similarity index 100% rename from mediapipe/models/object_detection_3d_chair.tflite rename to mediapipe/modules/objectron/object_detection_3d_chair.tflite diff --git a/mediapipe/models/object_detection_3d_chair_1stage.tflite b/mediapipe/modules/objectron/object_detection_3d_chair_1stage.tflite similarity index 100% rename from mediapipe/models/object_detection_3d_chair_1stage.tflite rename to mediapipe/modules/objectron/object_detection_3d_chair_1stage.tflite diff --git a/mediapipe/models/object_detection_3d_cup.tflite b/mediapipe/modules/objectron/object_detection_3d_cup.tflite similarity index 100% rename from mediapipe/models/object_detection_3d_cup.tflite rename to mediapipe/modules/objectron/object_detection_3d_cup.tflite diff --git a/mediapipe/models/object_detection_3d_sneakers.tflite b/mediapipe/modules/objectron/object_detection_3d_sneakers.tflite similarity index 100% rename from mediapipe/models/object_detection_3d_sneakers.tflite rename to mediapipe/modules/objectron/object_detection_3d_sneakers.tflite diff --git a/mediapipe/models/object_detection_3d_sneakers_1stage.tflite b/mediapipe/modules/objectron/object_detection_3d_sneakers_1stage.tflite similarity index 100% rename from mediapipe/models/object_detection_3d_sneakers_1stage.tflite rename to mediapipe/modules/objectron/object_detection_3d_sneakers_1stage.tflite diff --git a/mediapipe/modules/objectron/object_detection_oid_v4_cpu.pbtxt b/mediapipe/modules/objectron/object_detection_oid_v4_cpu.pbtxt index bfd488dc5..f7a09fc19 100644 --- a/mediapipe/modules/objectron/object_detection_oid_v4_cpu.pbtxt +++ b/mediapipe/modules/objectron/object_detection_oid_v4_cpu.pbtxt @@ -6,41 +6,37 @@ input_stream: "IMAGE:input_video" input_side_packet: "LABELS_CSV:allowed_labels" output_stream: "DETECTIONS:detections" -# Transforms the input image on CPU to a 300x300 image. To scale the image, by -# default it uses the STRETCH scale mode that maps the entire input image to the -# entire transformed image. As a result, image aspect ratio may be changed and -# objects in the image may be deformed (stretched or squeezed), but the object -# detection model used in this graph is agnostic to that deformation. -node: { - calculator: "ImageTransformationCalculator" +# Crops, resizes, and converts the input video into tensor. +# Preserves aspect ratio of the images. +node { + calculator: "ImageToTensorCalculator" input_stream: "IMAGE:input_video" - output_stream: "IMAGE:transformed_input_video" - options: { - [mediapipe.ImageTransformationCalculatorOptions.ext] { - output_width: 300 - output_height: 300 + output_stream: "TENSORS:image_tensor" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + options { + [mediapipe.ImageToTensorCalculatorOptions.ext] { + output_tensor_width: 300 + output_tensor_height: 300 + keep_aspect_ratio: false + output_tensor_float_range { + min: -1.0 + max: 1.0 + } } } } -# Converts the transformed input image on CPU into an image tensor stored as a -# TfLiteTensor. -node { - calculator: "TfLiteConverterCalculator" - input_stream: "IMAGE:transformed_input_video" - output_stream: "TENSORS:image_tensor" -} - # Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a # vector of tensors representing, for instance, detection boxes/keypoints and # scores. node { - calculator: "TfLiteInferenceCalculator" + calculator: "InferenceCalculator" input_stream: "TENSORS:image_tensor" output_stream: "TENSORS:detection_tensors" options: { - [mediapipe.TfLiteInferenceCalculatorOptions.ext] { - model_path: "mediapipe/models/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite" + [mediapipe.InferenceCalculatorOptions.ext] { + model_path: "mediapipe/modules/objectron/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite" + delegate { xnnpack {} } } } } @@ -79,13 +75,13 @@ node { # the SSD anchors and the specification in the options, into a vector of # detections. Each detection describes a detected object. node { - calculator: "TfLiteTensorsToDetectionsCalculator" + calculator: "TensorsToDetectionsCalculator" input_stream: "TENSORS:detection_tensors" input_side_packet: "ANCHORS:anchors" output_stream: "DETECTIONS:all_detections" options: { - [mediapipe.TfLiteTensorsToDetectionsCalculatorOptions.ext] { - num_classes: 195 + [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { + num_classes: 24 num_boxes: 1917 num_coords: 4 ignore_classes: 0 @@ -108,7 +104,7 @@ node { output_stream: "labeled_detections" options: { [mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] { - label_map_path: "mediapipe/models/object_detection_oidv4_labelmap.pbtxt" + label_map_path: "mediapipe/modules/objectron/object_detection_oidv4_labelmap.txt" } } } diff --git a/mediapipe/modules/objectron/object_detection_oid_v4_gpu.pbtxt b/mediapipe/modules/objectron/object_detection_oid_v4_gpu.pbtxt index 76adbdd46..7873e8081 100644 --- a/mediapipe/modules/objectron/object_detection_oid_v4_gpu.pbtxt +++ b/mediapipe/modules/objectron/object_detection_oid_v4_gpu.pbtxt @@ -6,41 +6,39 @@ input_stream: "IMAGE_GPU:input_video" input_side_packet: "LABELS_CSV:allowed_labels" output_stream: "DETECTIONS:detections" -# Transforms the input image on GPU to a 300x300 image. To scale the image, by -# default it uses the STRETCH scale mode that maps the entire input image to the -# entire transformed image. As a result, image aspect ratio may be changed and -# objects in the image may be deformed (stretched or squeezed), but the object -# detection model used in this graph is agnostic to that deformation. -node: { - calculator: "ImageTransformationCalculator" +# Crops, resizes, and converts the input video into tensor. +# Preserves aspect ratio of the images. +node { + calculator: "ImageToTensorCalculator" input_stream: "IMAGE_GPU:input_video" - output_stream: "IMAGE_GPU:transformed_input_video" - options: { - [mediapipe.ImageTransformationCalculatorOptions.ext] { - output_width: 300 - output_height: 300 + output_stream: "TENSORS:image_tensor" + output_stream: "LETTERBOX_PADDING:letterbox_padding" + options { + [mediapipe.ImageToTensorCalculatorOptions.ext] { + output_tensor_width: 300 + output_tensor_height: 300 + keep_aspect_ratio: false + output_tensor_float_range { + min: -1.0 + max: 1.0 + } + gpu_origin: TOP_LEFT } } } -# Converts the transformed input image on GPU into an image tensor stored as a -# TfLiteTensor. -node { - calculator: "TfLiteConverterCalculator" - input_stream: "IMAGE_GPU:transformed_input_video" - output_stream: "TENSORS_GPU:image_tensor" -} # Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a # vector of tensors representing, for instance, detection boxes/keypoints and # scores. node { - calculator: "TfLiteInferenceCalculator" - input_stream: "TENSORS_GPU:image_tensor" - output_stream: "TENSORS_GPU:detection_tensors" + calculator: "InferenceCalculator" + input_stream: "TENSORS:image_tensor" + output_stream: "TENSORS:detection_tensors" options: { - [mediapipe.TfLiteInferenceCalculatorOptions.ext] { - model_path: "object_detection_ssd_mobilenetv2_oidv4_fp16.tflite" + [mediapipe.InferenceCalculatorOptions.ext] { + model_path: "mediapipe/modules/objectron/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite" + delegate { gpu {} } } } } @@ -79,13 +77,13 @@ node { # the SSD anchors and the specification in the options, into a vector of # detections. Each detection describes a detected object. node { - calculator: "TfLiteTensorsToDetectionsCalculator" - input_stream: "TENSORS_GPU:detection_tensors" + calculator: "TensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" input_side_packet: "ANCHORS:anchors" output_stream: "DETECTIONS:all_detections" options: { - [mediapipe.TfLiteTensorsToDetectionsCalculatorOptions.ext] { - num_classes: 195 + [mediapipe.TensorsToDetectionsCalculatorOptions.ext] { + num_classes: 24 num_boxes: 1917 num_coords: 4 ignore_classes: 0 @@ -108,7 +106,7 @@ node { output_stream: "labeled_detections" options: { [mediapipe.DetectionLabelIdToTextCalculatorOptions.ext] { - label_map_path: "object_detection_oidv4_labelmap.pbtxt" + label_map_path: "object_detection_oidv4_labelmap.txt" } } } diff --git a/mediapipe/modules/objectron/object_detection_oidv4_labelmap.txt b/mediapipe/modules/objectron/object_detection_oidv4_labelmap.txt new file mode 100644 index 000000000..ef9032cee --- /dev/null +++ b/mediapipe/modules/objectron/object_detection_oidv4_labelmap.txt @@ -0,0 +1,24 @@ +??? +Bicycle +Boot +Laptop +Person +Chair +Cattle +Desk +Cat +Computer mouse +Computer monitor +Box +Mug +Coffee cup +Stationary bicycle +Table +Bottle +High heels +Vehicle +Footwear +Dog +Book +Camera +Car diff --git a/mediapipe/modules/objectron/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite b/mediapipe/modules/objectron/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite new file mode 100644 index 000000000..3cb7291d9 Binary files /dev/null and b/mediapipe/modules/objectron/object_detection_ssd_mobilenetv2_oidv4_fp16.tflite differ diff --git a/mediapipe/modules/objectron/objectron_cpu.pbtxt b/mediapipe/modules/objectron/objectron_cpu.pbtxt index 6b6ef5f6b..6f8fbade5 100644 --- a/mediapipe/modules/objectron/objectron_cpu.pbtxt +++ b/mediapipe/modules/objectron/objectron_cpu.pbtxt @@ -1,7 +1,10 @@ # MediaPipe Objectron on CPU that produces 3D bounding boxes for objects. -input_stream: "IMAGE:input_video" -# TfLite model for 3D bounding box landmark prediction -input_side_packet: "MODEL:box_landmark_model" +type: "ObjectronCpuSubgraph" + +# Input/Output streams and input side packets. +input_stream: "IMAGE:image" +# Path to TfLite model for 3D bounding box landmark prediction +input_side_packet: "MODEL_PATH:box_landmark_model_path" # Allowed category labels, e.g. Footwear, Coffee cup, Mug, Chair, Camera input_side_packet: "LABELS_CSV:allowed_labels" # Max number of objects to detect/track. (int) @@ -23,11 +26,28 @@ input_side_packet: "MAX_NUM_OBJECTS:max_num_objects" # \ + \ + \ # \+ \+ # 2 + + + + + + + + 6 -# + +# Collection of detected 3D objects, represented as a FrameAnnotation. +output_stream: "FRAME_ANNOTATION:detected_objects" +# Collection of box landmarks. (NormalizedLandmarkList) output_stream: "MULTI_LANDMARKS:multi_box_landmarks" # Crop rectangles derived from bounding box landmarks. output_stream: "NORM_RECTS:multi_box_rects" +# Loads the file in the specified path into a blob. +node { + calculator: "LocalFileContentsCalculator" + input_side_packet: "FILE_PATH:0:box_landmark_model_path" + output_side_packet: "CONTENTS:0:box_landmark_model_blob" +} + +# Converts the input blob into a TF Lite model. +node { + calculator: "TfLiteModelCalculator" + input_side_packet: "MODEL_BLOB:box_landmark_model_blob" + output_side_packet: "MODEL:box_landmark_model" +} + # Defines whether landmarks from the previous video frame should be used to help # predict landmarks on the current video frame. node { @@ -62,9 +82,9 @@ node { # to trigger a new round of box detection in ObjectDetectionOidV4Subgraph. node { calculator: "GateCalculator" - input_stream: "input_video" + input_stream: "image" input_stream: "DISALLOW:prev_has_enough_objects" - output_stream: "detection_input_video" + output_stream: "detection_image" options: { [mediapipe.GateCalculatorOptions.ext] { @@ -76,7 +96,7 @@ node { # Subgraph that performs 2D object detection. node { calculator: "ObjectDetectionOidV4Subgraph" - input_stream: "IMAGE:detection_input_video" + input_stream: "IMAGE:detection_image" input_side_packet: "LABELS_CSV:allowed_labels" output_stream: "DETECTIONS:raw_detections" } @@ -93,7 +113,7 @@ node { # Extracts image size from the input images. node { calculator: "ImagePropertiesCalculator" - input_stream: "IMAGE:input_video" + input_stream: "IMAGE:image" output_stream: "SIZE:image_size" } @@ -135,16 +155,16 @@ node { node { calculator: "BeginLoopNormalizedRectCalculator" input_stream: "ITERABLE:multi_box_rects" - input_stream: "CLONE:input_video" + input_stream: "CLONE:image" output_stream: "ITEM:single_box_rect" - output_stream: "CLONE:landmarks_input_video" + output_stream: "CLONE:landmarks_image" output_stream: "BATCH_END:box_rects_timestamp" } # Subgraph that localizes box landmarks. node { calculator: "BoxLandmarkSubgraph" - input_stream: "IMAGE:landmarks_input_video" + input_stream: "IMAGE:landmarks_image" input_side_packet: "MODEL:box_landmark_model" input_stream: "NORM_RECT:single_box_rect" output_stream: "NORM_LANDMARKS:single_box_landmarks" @@ -169,15 +189,22 @@ node { # Lift the 2D landmarks to 3D using EPnP algorithm. node { + name: "Lift2DFrameAnnotationTo3DCalculator" calculator: "Lift2DFrameAnnotationTo3DCalculator" input_stream: "FRAME_ANNOTATION:box_annotations" - output_stream: "LIFTED_FRAME_ANNOTATION:lifted_objects" + output_stream: "LIFTED_FRAME_ANNOTATION:detected_objects" + options: { + [mediapipe.Lift2DFrameAnnotationTo3DCalculatorOptions.ext] { + normalized_focal_x: 1.0 + normalized_focal_y: 1.0 + } + } } -# Get rotated rectangle from lifted box. +# Get rotated rectangle from detected box. node { calculator: "FrameAnnotationToRectCalculator" - input_stream: "FRAME_ANNOTATION:lifted_objects" + input_stream: "FRAME_ANNOTATION:detected_objects" output_stream: "NORM_RECTS:box_rects_from_landmarks" } @@ -189,7 +216,7 @@ node { # feedback loop. node { calculator: "PreviousLoopbackCalculator" - input_stream: "MAIN:input_video" + input_stream: "MAIN:image" input_stream: "LOOP:box_rects_from_landmarks" input_stream_info: { tag_index: "LOOP" diff --git a/mediapipe/modules/objectron/objectron_detection_1stage_gpu.pbtxt b/mediapipe/modules/objectron/objectron_detection_1stage_gpu.pbtxt index 96d4eec7a..290b120b1 100644 --- a/mediapipe/modules/objectron/objectron_detection_1stage_gpu.pbtxt +++ b/mediapipe/modules/objectron/objectron_detection_1stage_gpu.pbtxt @@ -76,6 +76,8 @@ node { voting_allowance: 1 voting_threshold: 0.2 } + normalized_focal_x: 2.0975 + normalized_focal_y: 1.5731 } } } diff --git a/mediapipe/modules/objectron/objectron_gpu.pbtxt b/mediapipe/modules/objectron/objectron_gpu.pbtxt index 185752dc9..16187deae 100644 --- a/mediapipe/modules/objectron/objectron_gpu.pbtxt +++ b/mediapipe/modules/objectron/objectron_gpu.pbtxt @@ -1,15 +1,16 @@ # MediaPipe Objectron on GPU that produces 3D bounding boxes for objects. +type: "ObjectronGpuSubgraph" # Input/Output streams and input side packets. -# Note that the input video is assumed to have aspect ratio 3:4 (width:height). -input_stream: "IMAGE_GPU:input_video" +# Note that the input image is assumed to have aspect ratio 3:4 (width:height). +input_stream: "IMAGE_GPU:image" # Allowed category labels, e.g. Footwear, Coffee cup, Mug, Chair, Camera input_side_packet: "LABELS_CSV:allowed_labels" # Max number of objects to detect/track. (int) input_side_packet: "MAX_NUM_OBJECTS:max_num_objects" # Collection of detected 3D objects, represented as a FrameAnnotation. -output_stream: "FRAME_ANNOTATION:lifted_objects" +output_stream: "FRAME_ANNOTATION:detected_objects" # Defines whether landmarks from the previous video frame should be used to help # predict landmarks on the current video frame. @@ -45,9 +46,9 @@ node { # to trigger a new round of box detection in ObjectDetectionOidV4Subgraph. node { calculator: "GateCalculator" - input_stream: "input_video" + input_stream: "image" input_stream: "DISALLOW:prev_has_enough_objects" - output_stream: "detection_input_video" + output_stream: "detection_image" options: { [mediapipe.GateCalculatorOptions.ext] { @@ -59,7 +60,7 @@ node { # Subgraph that performs 2D object detection. node { calculator: "ObjectDetectionOidV4Subgraph" - input_stream: "IMAGE_GPU:detection_input_video" + input_stream: "IMAGE_GPU:detection_image" input_side_packet: "LABELS_CSV:allowed_labels" output_stream: "DETECTIONS:raw_detections" } @@ -76,7 +77,7 @@ node { # Extracts image size from the input images. node { calculator: "ImagePropertiesCalculator" - input_stream: "IMAGE_GPU:input_video" + input_stream: "IMAGE_GPU:image" output_stream: "SIZE:image_size" } @@ -118,16 +119,16 @@ node { node { calculator: "BeginLoopNormalizedRectCalculator" input_stream: "ITERABLE:box_rects" - input_stream: "CLONE:input_video" + input_stream: "CLONE:image" output_stream: "ITEM:single_box_rect" - output_stream: "CLONE:landmarks_input_video" + output_stream: "CLONE:landmarks_image" output_stream: "BATCH_END:box_rects_timestamp" } # Subgraph that localizes box landmarks. node { calculator: "BoxLandmarkSubgraph" - input_stream: "IMAGE:landmarks_input_video" + input_stream: "IMAGE:landmarks_image" input_stream: "NORM_RECT:single_box_rect" output_stream: "NORM_LANDMARKS:single_box_landmarks" } @@ -153,13 +154,19 @@ node { node { calculator: "Lift2DFrameAnnotationTo3DCalculator" input_stream: "FRAME_ANNOTATION:box_annotations" - output_stream: "LIFTED_FRAME_ANNOTATION:lifted_objects" + output_stream: "LIFTED_FRAME_ANNOTATION:detected_objects" + options: { + [mediapipe.Lift2DFrameAnnotationTo3DCalculatorOptions.ext] { + normalized_focal_x: 2.0975 + normalized_focal_y: 1.5731 + } + } } -# Get rotated rectangle from lifted box. +# Get rotated rectangle from detected box. node { calculator: "FrameAnnotationToRectCalculator" - input_stream: "FRAME_ANNOTATION:lifted_objects" + input_stream: "FRAME_ANNOTATION:detected_objects" output_stream: "NORM_RECTS:box_rects_from_landmarks" } @@ -171,7 +178,7 @@ node { # feedback loop. node { calculator: "PreviousLoopbackCalculator" - input_stream: "MAIN:input_video" + input_stream: "MAIN:image" input_stream: "LOOP:box_rects_from_landmarks" input_stream_info: { tag_index: "LOOP" diff --git a/mediapipe/modules/objectron/objectron_tracking_1stage_gpu.pbtxt b/mediapipe/modules/objectron/objectron_tracking_1stage_gpu.pbtxt index 5b33f6055..eb19a446b 100644 --- a/mediapipe/modules/objectron/objectron_tracking_1stage_gpu.pbtxt +++ b/mediapipe/modules/objectron/objectron_tracking_1stage_gpu.pbtxt @@ -167,4 +167,10 @@ node { calculator: "Lift2DFrameAnnotationTo3DCalculator" input_stream: "FRAME_ANNOTATION:tracked_objects" output_stream: "LIFTED_FRAME_ANNOTATION:lifted_tracked_objects" + options: { + [mediapipe.Lift2DFrameAnnotationTo3DCalculatorOptions.ext] { + normalized_focal_x: 2.0975 + normalized_focal_y: 1.5731 + } + } } diff --git a/mediapipe/modules/palm_detection/README b/mediapipe/modules/palm_detection/README.md similarity index 100% rename from mediapipe/modules/palm_detection/README rename to mediapipe/modules/palm_detection/README.md diff --git a/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt b/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt index 81079fc27..7f10b59a1 100644 --- a/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt +++ b/mediapipe/modules/pose_detection/pose_detection_cpu.pbtxt @@ -54,6 +54,10 @@ node: { max: 1.0 } border_mode: BORDER_ZERO + # If this calculator truly operates in the CPU, then gpu_origin is + # ignored, but if some build switch insists on GPU inference, then we will + # still need to set this. + gpu_origin: TOP_LEFT } } } diff --git a/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite b/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite index 186885611..713130c2e 100755 Binary files a/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite and b/mediapipe/modules/pose_landmark/pose_landmark_full_body.tflite differ diff --git a/mediapipe/modules/pose_landmark/pose_landmark_model_loader.pbtxt b/mediapipe/modules/pose_landmark/pose_landmark_model_loader.pbtxt index 7365bcbcd..d40cf8416 100644 --- a/mediapipe/modules/pose_landmark/pose_landmark_model_loader.pbtxt +++ b/mediapipe/modules/pose_landmark/pose_landmark_model_loader.pbtxt @@ -48,7 +48,7 @@ node { output_side_packet: "CONTENTS:model_blob" options: { [mediapipe.LocalFileContentsCalculatorOptions.ext]: { - read_as_binary: true + text_mode: false } } } diff --git a/mediapipe/objc/BUILD b/mediapipe/objc/BUILD index 266c52466..a58a210f9 100644 --- a/mediapipe/objc/BUILD +++ b/mediapipe/objc/BUILD @@ -261,9 +261,9 @@ exports_files( [ "testdata/googlelogo_color_272x92dp.png", "testdata/googlelogo_color_272x92dp_luminance.png", + "testdata/sergey.png", ], visibility = [ - "//mediapipe/feature_extraction/video/video_effects:__pkg__", - "//mediapipe/gpu:__pkg__", + "//mediapipe:__subpackages__", ], ) diff --git a/mediapipe/objc/CGImageRefUtils.mm b/mediapipe/objc/CGImageRefUtils.mm index 4d7f47325..4d82cac64 100644 --- a/mediapipe/objc/CGImageRefUtils.mm +++ b/mediapipe/objc/CGImageRefUtils.mm @@ -23,7 +23,7 @@ CGImageRef CreateCGImageFromCVPixelBuffer(CVPixelBufferRef imageBuffer, NSError **error) { CFHolder cg_image_holder; - ::mediapipe::Status status = CreateCGImageFromCVPixelBuffer(imageBuffer, &cg_image_holder); + absl::Status status = CreateCGImageFromCVPixelBuffer(imageBuffer, &cg_image_holder); if (!status.ok()) { *error = [NSError gus_errorWithStatus:status]; return nil; @@ -35,7 +35,7 @@ CGImageRef CreateCGImageFromCVPixelBuffer(CVPixelBufferRef imageBuffer, NSError CVPixelBufferRef CreateCVPixelBufferFromCGImage(CGImageRef image, NSError **error) { CFHolder pixel_buffer_holder; - ::mediapipe::Status status = CreateCVPixelBufferFromCGImage(image, &pixel_buffer_holder); + absl::Status status = CreateCVPixelBufferFromCGImage(image, &pixel_buffer_holder); if (!status.ok()) { *error = [NSError gus_errorWithStatus:status]; return nil; diff --git a/mediapipe/objc/MPPDisplayLinkWeakTarget.m b/mediapipe/objc/MPPDisplayLinkWeakTarget.m index c5922a473..80c0a5e8e 100644 --- a/mediapipe/objc/MPPDisplayLinkWeakTarget.m +++ b/mediapipe/objc/MPPDisplayLinkWeakTarget.m @@ -33,8 +33,9 @@ #pragma mark - Public - (void)displayLinkCallback:(CADisplayLink *)sender { - void (*display)(id, SEL, CADisplayLink *) = (void *)[_target methodForSelector:_selector]; - display(_target, _selector, sender); + __strong id target = _target; + void (*display)(id, SEL, CADisplayLink *) = (void *)[target methodForSelector:_selector]; + display(target, _selector, sender); } @end diff --git a/mediapipe/objc/MPPGraph.mm b/mediapipe/objc/MPPGraph.mm index 0aa1590d2..04a96bb1b 100644 --- a/mediapipe/objc/MPPGraph.mm +++ b/mediapipe/objc/MPPGraph.mm @@ -213,7 +213,7 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, } - (BOOL)startWithError:(NSError**)error { - ::mediapipe::Status status = [self performStart]; + absl::Status status = [self performStart]; if (!status.ok()) { if (error) { *error = [NSError gus_errorWithStatus:status]; @@ -224,8 +224,8 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, return YES; } -- (::mediapipe::Status)performStart { - ::mediapipe::Status status = _graph->Initialize(_config); +- (absl::Status)performStart { + absl::Status status = _graph->Initialize(_config); if (!status.ok()) { return status; } @@ -251,13 +251,13 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, } - (BOOL)closeInputStream:(const std::string&)inputName error:(NSError**)error { - ::mediapipe::Status status = _graph->CloseInputStream(inputName); + absl::Status status = _graph->CloseInputStream(inputName); if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; return status.ok(); } - (BOOL)closeAllInputStreamsWithError:(NSError**)error { - ::mediapipe::Status status = _graph->CloseAllInputStreams(); + absl::Status status = _graph->CloseAllInputStreams(); if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; return status.ok(); } @@ -268,14 +268,14 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, // TODO: is this too heavy-handed? Maybe a warning would be fine. _GTMDevAssert(![NSThread isMainThread] || (NSClassFromString(@"XCTest")), @"waitUntilDoneWithError: should not be called on the main thread"); - ::mediapipe::Status status = _graph->WaitUntilDone(); + absl::Status status = _graph->WaitUntilDone(); _started = NO; if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; return status.ok(); } - (BOOL)waitUntilIdleWithError:(NSError**)error { - ::mediapipe::Status status = _graph->WaitUntilIdle(); + absl::Status status = _graph->WaitUntilIdle(); if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; return status.ok(); } @@ -283,7 +283,7 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, - (BOOL)movePacket:(mediapipe::Packet&&)packet intoStream:(const std::string&)streamName error:(NSError**)error { - ::mediapipe::Status status = _graph->AddPacketToInputStream(streamName, std::move(packet)); + absl::Status status = _graph->AddPacketToInputStream(streamName, std::move(packet)); if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; return status.ok(); } @@ -291,7 +291,7 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, - (BOOL)sendPacket:(const mediapipe::Packet&)packet intoStream:(const std::string&)streamName error:(NSError**)error { - ::mediapipe::Status status = _graph->AddPacketToInputStream(streamName, packet); + absl::Status status = _graph->AddPacketToInputStream(streamName, packet); if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; return status.ok(); } @@ -299,7 +299,7 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, - (BOOL)setMaxQueueSize:(int)maxQueueSize forStream:(const std::string&)streamName error:(NSError**)error { - ::mediapipe::Status status = _graph->SetInputStreamMaxQueueSize(streamName, maxQueueSize); + absl::Status status = _graph->SetInputStreamMaxQueueSize(streamName, maxQueueSize); if (!status.ok() && error) *error = [NSError gus_errorWithStatus:status]; return status.ok(); } @@ -396,7 +396,7 @@ void CallFrameDelegate(void* wrapperVoid, const std::string& streamName, NSString* extensionString; (void)gpu_resources->gl_context()->Run([&extensionString]{ extensionString = [NSString stringWithUTF8String:(char*)glGetString(GL_EXTENSIONS)]; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); }); NSArray* extensions = [extensionString componentsSeparatedByCharactersInSet: diff --git a/mediapipe/objc/MPPGraphTestBase.mm b/mediapipe/objc/MPPGraphTestBase.mm index b8c19a145..6b759d3a7 100644 --- a/mediapipe/objc/MPPGraphTestBase.mm +++ b/mediapipe/objc/MPPGraphTestBase.mm @@ -19,7 +19,7 @@ static UIImage* UIImageWithPixelBuffer(CVPixelBufferRef pixelBuffer) { CFHolder cgImage; - ::mediapipe::Status status = CreateCGImageFromCVPixelBuffer(pixelBuffer, &cgImage); + absl::Status status = CreateCGImageFromCVPixelBuffer(pixelBuffer, &cgImage); if (!status.ok()) { return nil; } @@ -441,7 +441,7 @@ static void EnsureOutputDirFor(NSString *outputFile) { for (NSString* inputStream in fileInputs) { UIImage* inputImage = [self testImageNamed:fileInputs[inputStream] extension:nil]; XCTAssertNotNil(inputImage); - ::mediapipe::Status status = + absl::Status status = CreateCVPixelBufferFromCGImage(inputImage.CGImage, &inputBuffers[inputStream.UTF8String]); XCTAssert(status.ok()); } @@ -449,8 +449,7 @@ static void EnsureOutputDirFor(NSString *outputFile) { UIImage* expectedImage = [self testImageNamed:expectedPath extension:nil]; XCTAssertNotNil(expectedImage); CFHolder expectedBuffer; - ::mediapipe::Status status = - CreateCVPixelBufferFromCGImage(expectedImage.CGImage, &expectedBuffer); + absl::Status status = CreateCVPixelBufferFromCGImage(expectedImage.CGImage, &expectedBuffer); XCTAssert(status.ok()); CVPixelBufferRef outputBuffer = [self runGraph:graph diff --git a/mediapipe/objc/MPPGraphTests.mm b/mediapipe/objc/MPPGraphTests.mm index e4562da38..7c1ea8e06 100644 --- a/mediapipe/objc/MPPGraphTests.mm +++ b/mediapipe/objc/MPPGraphTests.mm @@ -28,15 +28,14 @@ namespace mediapipe { class GrayscaleCalculator : public Calculator { public: - static ::mediapipe::Status FillExpectations(const CalculatorOptions& options, - PacketTypeSet* inputs, PacketTypeSet* outputs, - PacketTypeSet* input_side_packets) { + static absl::Status FillExpectations(const CalculatorOptions& options, PacketTypeSet* inputs, + PacketTypeSet* outputs, PacketTypeSet* input_side_packets) { inputs->Index(0).Set(); outputs->Index(0).Set(); - return ::util::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process() final { + absl::Status Process() final { const auto& input = Input()->Get(); int w = input.Width(); int h = input.Height(); @@ -49,7 +48,7 @@ class GrayscaleCalculator : public Calculator { NSCAssert(vErr == kvImageNoError, @"vImageRGBAToGray failed: %zd", vErr); Output()->Add(output.release(), InputTimestamp()); - return ::util::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(GrayscaleCalculator); @@ -58,41 +57,37 @@ REGISTER_CALCULATOR(GrayscaleCalculator); // if the video header is not present in the input stream. class VideoHeaderCalculator : public Calculator { public: - static ::mediapipe::Status FillExpectations(const CalculatorOptions& options, - PacketTypeSet* inputs, PacketTypeSet* outputs, - PacketTypeSet* input_side_packets) { + static absl::Status FillExpectations(const CalculatorOptions& options, PacketTypeSet* inputs, + PacketTypeSet* outputs, PacketTypeSet* input_side_packets) { inputs->Index(0).Set(); outputs->Index(0).Set(); - return ::util::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Open() final { + absl::Status Open() final { if (Input()->Header().IsEmpty()) { - return ::util::UnknownError("No video header present."); + return absl::UnknownError("No video header present."); } - return ::util::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process() final { + absl::Status Process() final { Output()->AddPacket(Input()->Value()); - return ::util::OkStatus(); + return absl::OkStatus(); } }; REGISTER_CALCULATOR(VideoHeaderCalculator); class ErrorCalculator : public Calculator { public: - static ::mediapipe::Status FillExpectations(const CalculatorOptions& options, - PacketTypeSet* inputs, PacketTypeSet* outputs, - PacketTypeSet* input_side_packets) { + static absl::Status FillExpectations(const CalculatorOptions& options, PacketTypeSet* inputs, + PacketTypeSet* outputs, PacketTypeSet* input_side_packets) { inputs->Index(0).SetAny(); outputs->Index(0).SetSameAs(&inputs->Index(0)); - return ::util::OkStatus(); + return absl::OkStatus(); } - ::mediapipe::Status Process() final { - return ::mediapipe::Status(absl::StatusCode::kUnknown, kExpectedError); - } + absl::Status Process() final { return absl::Status(absl::StatusCode::kUnknown, kExpectedError); } }; REGISTER_CALCULATOR(ErrorCalculator); @@ -127,7 +122,7 @@ REGISTER_CALCULATOR(ErrorCalculator); _graph = [[MPPGraph alloc] initWithGraphConfig:config]; [_graph addFrameOutputStream:"output_frames" outputPacketType:MPPPacketTypePixelBuffer]; CFHolder inputBuffer; - ::mediapipe::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &inputBuffer); + absl::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &inputBuffer); XCTAssert(status.ok()); CVPixelBufferRef outputBuffer = [self runGraph:_graph withPixelBuffer:*inputBuffer @@ -166,7 +161,7 @@ REGISTER_CALCULATOR(ErrorCalculator); [_graph addFrameOutputStream:"gray_frames" outputPacketType:MPPPacketTypeImageFrame]; CFHolder inputBuffer; - ::mediapipe::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &inputBuffer); + absl::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &inputBuffer); XCTAssert(status.ok()); WEAKIFY(self); @@ -204,7 +199,7 @@ REGISTER_CALCULATOR(ErrorCalculator); _graph = [[MPPGraph alloc] initWithGraphConfig:config]; [_graph addFrameOutputStream:"output_frames" outputPacketType:MPPPacketTypeImageFrame]; CFHolder inputBuffer; - ::mediapipe::Status status = CreateCVPixelBufferFromCGImage(grayImage.CGImage, &inputBuffer); + absl::Status status = CreateCVPixelBufferFromCGImage(grayImage.CGImage, &inputBuffer); XCTAssert(status.ok()); CVPixelBufferRef outputBuffer = [self runGraph:_graph withPixelBuffer:*inputBuffer @@ -222,8 +217,7 @@ REGISTER_CALCULATOR(ErrorCalculator); node->add_input_stream("input_frames"); node->add_output_stream("output_frames"); CFHolder srcPixelBuffer; - ::mediapipe::Status status = - CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &srcPixelBuffer); + absl::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &srcPixelBuffer); XCTAssert(status.ok()); _graph = [[MPPGraph alloc] initWithGraphConfig:config]; [_graph addFrameOutputStream:"output_frames" outputPacketType:MPPPacketTypeImageFrame]; @@ -290,7 +284,7 @@ REGISTER_CALCULATOR(ErrorCalculator); _graph = [[MPPGraph alloc] initWithGraphConfig:config]; [_graph addFrameOutputStream:"output_frames" outputPacketType:MPPPacketTypePixelBuffer]; CFHolder inputBuffer; - ::mediapipe::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &inputBuffer); + absl::Status status = CreateCVPixelBufferFromCGImage(_sourceImage.CGImage, &inputBuffer); XCTAssert(status.ok()); CVPixelBufferRef outputBuffer = [self runGraph:_graph withPixelBuffer:*inputBuffer diff --git a/mediapipe/objc/NSError+util_status.h b/mediapipe/objc/NSError+util_status.h index ebbc6fb6e..38e367997 100644 --- a/mediapipe/objc/NSError+util_status.h +++ b/mediapipe/objc/NSError+util_status.h @@ -16,33 +16,33 @@ #include "mediapipe/framework/port/status.h" -/// Error domain for ::mediapipe::Status errors. +/// Error domain for absl::Status errors. extern NSString *const kGUSGoogleUtilStatusErrorDomain; -/// Key for the ::mediapipe::Status wrapper in an NSError's user info dictionary. +/// Key for the absl::Status wrapper in an NSError's user info dictionary. extern NSString *const kGUSGoogleUtilStatusErrorKey; -/// This just wraps ::mediapipe::Status into an Objective-C object. +/// This just wraps absl::Status into an Objective-C object. @interface GUSUtilStatusWrapper : NSObject -@property(nonatomic)::mediapipe::Status status; +@property(nonatomic) absl::Status status; -+ (instancetype)wrapStatus:(const ::mediapipe::Status &)status; ++ (instancetype)wrapStatus:(const absl::Status &)status; @end -/// This category adds methods for generating NSError objects from ::mediapipe::Status +/// This category adds methods for generating NSError objects from absl::Status /// objects, and vice versa. @interface NSError (GUSGoogleUtilStatus) -/// Generates an NSError representing a ::mediapipe::Status. Note that NSError always -/// represents an error, so this should not be called with ::mediapipe::Status::OK. -+ (NSError *)gus_errorWithStatus:(const ::mediapipe::Status &)status; +/// Generates an NSError representing a absl::Status. Note that NSError always +/// represents an error, so this should not be called with absl::Status::OK. ++ (NSError *)gus_errorWithStatus:(const absl::Status &)status; -/// Returns a ::mediapipe::Status object representing an NSError. If the NSError was -/// generated from a ::mediapipe::Status, the ::mediapipe::Status returned is identical to +/// Returns a absl::Status object representing an NSError. If the NSError was +/// generated from a absl::Status, the absl::Status returned is identical to /// the original. Otherwise, this returns a status with code ::util::error::UNKNOWN /// and a message extracted from the NSError. -@property(nonatomic, readonly)::mediapipe::Status gus_status; // NOLINT(identifier-naming) +@property(nonatomic, readonly) absl::Status gus_status; // NOLINT(identifier-naming) @end diff --git a/mediapipe/objc/NSError+util_status.mm b/mediapipe/objc/NSError+util_status.mm index f8ba4802d..2b425fd44 100644 --- a/mediapipe/objc/NSError+util_status.mm +++ b/mediapipe/objc/NSError+util_status.mm @@ -16,11 +16,11 @@ @implementation GUSUtilStatusWrapper -+ (instancetype)wrapStatus:(const ::mediapipe::Status &)status { ++ (instancetype)wrapStatus:(const absl::Status &)status { return [[self alloc] initWithStatus:status]; } -- (instancetype)initWithStatus:(const ::mediapipe::Status &)status { +- (instancetype)initWithStatus:(const absl::Status &)status { self = [super init]; if (self) { _status = status; @@ -40,7 +40,7 @@ NSString *const kGUSGoogleUtilStatusErrorDomain = @"GoogleUtilStatusErrorDomain"; NSString *const kGUSGoogleUtilStatusErrorKey = @"GUSGoogleUtilStatusErrorKey"; -+ (NSError *)gus_errorWithStatus:(const ::mediapipe::Status &)status { ++ (NSError *)gus_errorWithStatus:(const absl::Status &)status { NSDictionary *userInfo = @{ NSLocalizedDescriptionKey : @(status.message().data()), kGUSGoogleUtilStatusErrorKey : [GUSUtilStatusWrapper wrapStatus:status], @@ -51,7 +51,7 @@ NSString *const kGUSGoogleUtilStatusErrorKey = @"GUSGoogleUtilStatusErrorKey"; return error; } -- (::mediapipe::Status)gus_status { +- (absl::Status)gus_status { NSString *domain = self.domain; if ([domain isEqual:kGUSGoogleUtilStatusErrorDomain]) { GUSUtilStatusWrapper *wrapper = self.userInfo[kGUSGoogleUtilStatusErrorKey]; @@ -63,7 +63,7 @@ NSString *const kGUSGoogleUtilStatusErrorKey = @"GUSGoogleUtilStatusErrorKey"; return ::util::PosixErrorToStatus(self.code, self.localizedDescription.UTF8String); #endif } - return ::mediapipe::Status(mediapipe::StatusCode::kUnknown, self.localizedDescription.UTF8String); + return absl::Status(absl::StatusCode::kUnknown, self.localizedDescription.UTF8String); } @end diff --git a/mediapipe/objc/util.cc b/mediapipe/objc/util.cc index bca9ff4be..e02b8aba7 100644 --- a/mediapipe/objc/util.cc +++ b/mediapipe/objc/util.cc @@ -249,20 +249,20 @@ void ReleaseMediaPipePacket(void* refcon, const void* base_address) { CVPixelBufferRef CreateCVPixelBufferForImageFramePacket( const mediapipe::Packet& image_frame_packet) { CFHolder buffer; - ::mediapipe::Status status = + absl::Status status = CreateCVPixelBufferForImageFramePacket(image_frame_packet, &buffer); MEDIAPIPE_CHECK_OK(status) << "Failed to create CVPixelBufferRef"; return (CVPixelBufferRef)CFRetain(*buffer); } -::mediapipe::Status CreateCVPixelBufferForImageFramePacket( +absl::Status CreateCVPixelBufferForImageFramePacket( const mediapipe::Packet& image_frame_packet, CFHolder* out_buffer) { return CreateCVPixelBufferForImageFramePacket(image_frame_packet, false, out_buffer); } -::mediapipe::Status CreateCVPixelBufferForImageFramePacket( +absl::Status CreateCVPixelBufferForImageFramePacket( const mediapipe::Packet& image_frame_packet, bool can_overwrite, CFHolder* out_buffer) { if (!out_buffer) { @@ -342,11 +342,11 @@ CVPixelBufferRef CreateCVPixelBufferForImageFramePacket( } *out_buffer = pixel_buffer; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status CreateCGImageFromCVPixelBuffer( - CVPixelBufferRef image_buffer, CFHolder* image) { +absl::Status CreateCGImageFromCVPixelBuffer(CVPixelBufferRef image_buffer, + CFHolder* image) { CVReturn status = CVPixelBufferLockBaseAddress(image_buffer, kCVPixelBufferLock_ReadOnly); RET_CHECK(status == kCVReturnSuccess) @@ -390,10 +390,10 @@ CVPixelBufferRef CreateCVPixelBufferForImageFramePacket( << "CVPixelBufferUnlockBaseAddress failed: " << status; *image = cg_image_holder; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status CreateCVPixelBufferFromCGImage( +absl::Status CreateCVPixelBufferFromCGImage( CGImageRef image, CFHolder* out_buffer) { size_t width = CGImageGetWidth(image); size_t height = CGImageGetHeight(image); @@ -428,7 +428,7 @@ CVPixelBufferRef CreateCVPixelBufferForImageFramePacket( << "CVPixelBufferUnlockBaseAddress failed: " << status; *out_buffer = pixel_buffer; - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } std::unique_ptr CreateImageFrameForCVPixelBuffer( diff --git a/mediapipe/objc/util.h b/mediapipe/objc/util.h index 2bb07500b..e499162f3 100644 --- a/mediapipe/objc/util.h +++ b/mediapipe/objc/util.h @@ -72,22 +72,22 @@ void ReleaseMediaPipePacket(void* refcon, const void* base_address); /// necessary to convert the data. This is done by creating a new buffer. /// If the optional can_overwrite parameter is true, the old buffer may be /// modified instead. -::mediapipe::Status CreateCVPixelBufferForImageFramePacket( +absl::Status CreateCVPixelBufferForImageFramePacket( const mediapipe::Packet& image_frame_packet, CFHolder* out_buffer); -::mediapipe::Status CreateCVPixelBufferForImageFramePacket( +absl::Status CreateCVPixelBufferForImageFramePacket( const mediapipe::Packet& image_frame_packet, bool can_overwrite, CFHolder* out_buffer); /// Creates a CVPixelBuffer with a copy of the contents of the CGImage. -::mediapipe::Status CreateCVPixelBufferFromCGImage( +absl::Status CreateCVPixelBufferFromCGImage( CGImageRef image, CFHolder* out_buffer); /// Creates a CGImage with a copy of the contents of the CVPixelBuffer. -::mediapipe::Status CreateCGImageFromCVPixelBuffer( - CVPixelBufferRef image_buffer, CFHolder* image); +absl::Status CreateCGImageFromCVPixelBuffer(CVPixelBufferRef image_buffer, + CFHolder* image); -/// DEPRECATED: use the version that returns ::mediapipe::Status instead. +/// DEPRECATED: use the version that returns absl::Status instead. CVPixelBufferRef CreateCVPixelBufferForImageFramePacket( const mediapipe::Packet& image_frame_packet); diff --git a/mediapipe/opensource_only/.bazelversion b/mediapipe/opensource_only/.bazelversion new file mode 100644 index 000000000..0b2eb36f5 --- /dev/null +++ b/mediapipe/opensource_only/.bazelversion @@ -0,0 +1 @@ +3.7.2 diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 59b20a2c7..3bd188de4 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -47,6 +47,8 @@ pybind_extension( "//mediapipe/framework/formats:classification_registration", "//mediapipe/framework/formats:detection_registration", "//mediapipe/framework/formats:landmark_registration", + "//mediapipe/framework/formats:rect_registration", + "//mediapipe/modules/objectron/calculators:annotation_registration", ], ) @@ -64,6 +66,7 @@ cc_library( "//mediapipe/modules/face_landmark:face_landmark_front_cpu", "//mediapipe/modules/hand_landmark:hand_landmark_tracking_cpu", "//mediapipe/modules/holistic_landmark:holistic_landmark_cpu", + "//mediapipe/modules/objectron:objectron_cpu", "//mediapipe/modules/palm_detection:palm_detection_cpu", "//mediapipe/modules/pose_detection:pose_detection_cpu", "//mediapipe/modules/pose_landmark:pose_landmark_by_roi_cpu", diff --git a/mediapipe/python/pybind/calculator_graph.cc b/mediapipe/python/pybind/calculator_graph.cc index 1a4ddd325..9878d384a 100644 --- a/mediapipe/python/pybind/calculator_graph.cc +++ b/mediapipe/python/pybind/calculator_graph.cc @@ -376,7 +376,7 @@ void CalculatorGraphSubmodule(pybind11::module* module) { calculator_graph.def( "get_combined_error_message", [](CalculatorGraph* self) { - mediapipe::Status error_status; + absl::Status error_status; if (self->GetCombinedErrors(&error_status) && !error_status.ok()) { return error_status.ToString(); } @@ -400,7 +400,7 @@ void CalculatorGraphSubmodule(pybind11::module* module) { // Acquire a mutex so that only one callback_fn can run at once. absl::MutexLock lock(&callback_mutex); callback_fn(stream_name, packet); - return mediapipe::OkStatus(); + return absl::OkStatus(); })); }, R"doc(Observe the named output stream. @@ -438,7 +438,7 @@ void CalculatorGraphSubmodule(pybind11::module* module) { [](CalculatorGraph* self, const std::string& packet_name) { auto status_or_packet = self->GetOutputSidePacket(packet_name); RaisePyErrorIfNotOk(status_or_packet.status()); - return status_or_packet.ValueOrDie(); + return status_or_packet.value(); }, R"doc(Get output side packet by name after the graph is done. diff --git a/mediapipe/python/pybind/packet_creator.cc b/mediapipe/python/pybind/packet_creator.cc index d990b3791..2212732cd 100644 --- a/mediapipe/python/pybind/packet_creator.cc +++ b/mediapipe/python/pybind/packet_creator.cc @@ -602,7 +602,7 @@ void InternalPacketCreators(pybind11::module* m) { "_create_proto", [](const std::string& type_name, const py::bytes& serialized_proto) { using packet_internal::HolderBase; - mediapipe::StatusOr> maybe_holder = + absl::StatusOr> maybe_holder = packet_internal::MessageHolderRegistry::CreateByName(type_name); if (!maybe_holder.ok()) { throw RaisePyError( @@ -612,7 +612,7 @@ void InternalPacketCreators(pybind11::module* m) { } // Creates a Packet with the concrete C++ payload type. std::unique_ptr message_holder = - std::move(maybe_holder).ValueOrDie(); + std::move(maybe_holder).value(); auto* copy = const_cast( message_holder->GetProtoMessageLite()); copy->ParseFromString(std::string(serialized_proto)); diff --git a/mediapipe/python/pybind/packet_getter.cc b/mediapipe/python/pybind/packet_getter.cc index 0287519e4..f88e48b4c 100644 --- a/mediapipe/python/pybind/packet_getter.cc +++ b/mediapipe/python/pybind/packet_getter.cc @@ -358,7 +358,7 @@ void InternalPacketGetters(pybind11::module* m) { [](Packet& packet) { auto proto_vector = packet.GetVectorOfProtoMessageLitePtrs(); RaisePyErrorIfNotOk(proto_vector.status()); - return proto_vector.ValueOrDie().size(); + return proto_vector.value().size(); }, py::return_value_policy::move); @@ -367,10 +367,10 @@ void InternalPacketGetters(pybind11::module* m) { [](Packet& packet) { auto proto_vector = packet.GetVectorOfProtoMessageLitePtrs(); RaisePyErrorIfNotOk(proto_vector.status()); - if (proto_vector.ValueOrDie().empty()) { + if (proto_vector.value().empty()) { return std::string(); } - return proto_vector.ValueOrDie()[0]->GetTypeName(); + return proto_vector.value()[0]->GetTypeName(); }, py::return_value_policy::move); @@ -391,10 +391,10 @@ void InternalPacketGetters(pybind11::module* m) { [](Packet& packet) { auto proto_vector = packet.GetVectorOfProtoMessageLitePtrs(); RaisePyErrorIfNotOk(proto_vector.status()); - int size = proto_vector.ValueOrDie().size(); + int size = proto_vector.value().size(); std::vector results; results.reserve(size); - for (const proto_ns::MessageLite* ptr : proto_vector.ValueOrDie()) { + for (const proto_ns::MessageLite* ptr : proto_vector.value()) { results.push_back(py::bytes(ptr->SerializeAsString())); } return results; diff --git a/mediapipe/python/pybind/util.h b/mediapipe/python/pybind/util.h index b84539e8f..099f75bd6 100644 --- a/mediapipe/python/pybind/util.h +++ b/mediapipe/python/pybind/util.h @@ -45,7 +45,7 @@ inline PyObject* StatusCodeToPyError(const ::absl::StatusCode& code) { } } -inline void RaisePyErrorIfNotOk(const mediapipe::Status& status) { +inline void RaisePyErrorIfNotOk(const absl::Status& status) { if (!status.ok()) { throw RaisePyError(StatusCodeToPyError(status.code()), status.message().data()); diff --git a/mediapipe/python/pybind/validated_graph_config.cc b/mediapipe/python/pybind/validated_graph_config.cc index bf0f81d2b..ee85de825 100644 --- a/mediapipe/python/pybind/validated_graph_config.cc +++ b/mediapipe/python/pybind/validated_graph_config.cc @@ -98,7 +98,7 @@ void ValidatedGraphConfigSubmodule(pybind11::module* module) { [](ValidatedGraphConfig& self, const std::string& stream_name) { auto status_or_type_name = self.RegisteredStreamTypeName(stream_name); RaisePyErrorIfNotOk(status_or_type_name.status()); - return status_or_type_name.ValueOrDie(); + return status_or_type_name.value(); }, R"doc(Return the registered type name of the specified stream if it can be determined. @@ -122,7 +122,7 @@ void ValidatedGraphConfigSubmodule(pybind11::module* module) { auto status_or_type_name = self.RegisteredSidePacketTypeName(side_packet_name); RaisePyErrorIfNotOk(status_or_type_name.status()); - return status_or_type_name.ValueOrDie(); + return status_or_type_name.value(); }, R"doc(Return the registered type name of the specified side packet if it can be determined. diff --git a/mediapipe/python/solution_base.py b/mediapipe/python/solution_base.py index 90c36560c..1bed3855e 100644 --- a/mediapipe/python/solution_base.py +++ b/mediapipe/python/solution_base.py @@ -45,6 +45,8 @@ from mediapipe.calculators.util import thresholding_calculator_pb2 from mediapipe.framework.formats import classification_pb2 from mediapipe.framework.formats import landmark_pb2 from mediapipe.framework.formats import rect_pb2 +from mediapipe.modules.objectron.calculators import annotation_data_pb2 +from mediapipe.modules.objectron.calculators import lift_2d_frame_annotation_to_3d_calculator_pb2 # pylint: enable=unused-import from mediapipe.python._framework_bindings import calculator_graph from mediapipe.python._framework_bindings import image_frame @@ -71,6 +73,9 @@ CALCULATOR_TO_OPTIONS = { 'TensorsToDetectionsCalculator': tensors_to_detections_calculator_pb2 .TensorsToDetectionsCalculatorOptions, + 'Lift2DFrameAnnotationTo3DCalculator': + lift_2d_frame_annotation_to_3d_calculator_pb2 + .Lift2DFrameAnnotationTo3DCalculatorOptions, } @@ -120,6 +125,8 @@ NAME_TO_TYPE: Mapping[str, '_PacketDataType'] = { _PacketDataType.PROTO, '::mediapipe::NormalizedLandmark': _PacketDataType.PROTO, + '::mediapipe::FrameAnnotation': + _PacketDataType.PROTO, '::mediapipe::Trigger': _PacketDataType.PROTO, '::mediapipe::Rect': @@ -157,15 +164,14 @@ class SolutionBase: shutdown. Example usage: - hand_tracker = solution_base.SolutionBase( - binary_graph_path='mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.binarypb', - side_inputs={'num_hands': 2}) - # Read an image and convert the BGR image to RGB. - input_image = cv2.cvtColor(cv2.imread('/tmp/hand.png'), COLOR_BGR2RGB) - results = hand_tracker.process(input_image) - print(results.palm_detections) - print(results.multi_hand_landmarks) - hand_tracker.close() + with solution_base.SolutionBase( + binary_graph_path='mediapipe/modules/hand_landmark/hand_landmark_tracking_cpu.binarypb', + side_inputs={'num_hands': 2}) as hand_tracker: + # Read an image and convert the BGR image to RGB. + input_image = cv2.cvtColor(cv2.imread('/tmp/hand.png'), COLOR_BGR2RGB) + results = hand_tracker.process(input_image) + print(results.palm_detections) + print(results.multi_hand_landmarks) """ def __init__( @@ -479,3 +485,11 @@ class SolutionBase: else: return getattr(packet_getter, 'get_' + packet_data_type.value)( output_packet) + + def __enter__(self): + """A "with" statement support.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Closes all the input sources and the graph.""" + self.close() diff --git a/mediapipe/python/solution_base_test.py b/mediapipe/python/solution_base_test.py index 35abca9f1..c24c4d5a7 100644 --- a/mediapipe/python/solution_base_test.py +++ b/mediapipe/python/solution_base_test.py @@ -105,14 +105,14 @@ class SolutionBaseTest(parameterized.TestCase): """ config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) - solution = solution_base.SolutionBase(graph_config=config_proto) - detection = detection_pb2.Detection() - text_format.Parse('score: 0.5', detection) - with self.assertRaisesRegex( - NotImplementedError, - 'SolutionBase can only process image data. PROTO_LIST type is not supported.' - ): - solution.process({'input_detections': detection}) + with solution_base.SolutionBase(graph_config=config_proto) as solution: + detection = detection_pb2.Detection() + text_format.Parse('score: 0.5', detection) + with self.assertRaisesRegex( + NotImplementedError, + 'SolutionBase can only process image data. PROTO_LIST type is not supported.' + ): + solution.process({'input_detections': detection}) def test_invalid_input_image_data(self): text_config = """ @@ -131,10 +131,10 @@ class SolutionBaseTest(parameterized.TestCase): """ config_proto = text_format.Parse(text_config, calculator_pb2.CalculatorGraphConfig()) - solution = solution_base.SolutionBase(graph_config=config_proto) - with self.assertRaisesRegex( - ValueError, 'Input image must contain three channel rgb data.'): - solution.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) + with solution_base.SolutionBase(graph_config=config_proto) as solution: + with self.assertRaisesRegex( + ValueError, 'Input image must contain three channel rgb data.'): + solution.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) @parameterized.named_parameters(('graph_without_side_packets', """ input_stream: 'image_in' @@ -272,15 +272,14 @@ class SolutionBaseTest(parameterized.TestCase): side_inputs=None, calculator_params=None): input_image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3) - solution = solution_base.SolutionBase( + with solution_base.SolutionBase( graph_config=config_proto, side_inputs=side_inputs, - calculator_params=calculator_params) - outputs = solution.process(input_image) + calculator_params=calculator_params) as solution: + outputs = solution.process(input_image) + outputs2 = solution.process({'image_in': input_image}) self.assertTrue(np.array_equal(input_image, outputs.image_out)) - outputs2 = solution.process({'image_in': input_image}) self.assertTrue(np.array_equal(input_image, outputs2.image_out)) - solution.close() if __name__ == '__main__': diff --git a/mediapipe/python/solutions/__init__.py b/mediapipe/python/solutions/__init__.py index 64875c029..8cd9af327 100644 --- a/mediapipe/python/solutions/__init__.py +++ b/mediapipe/python/solutions/__init__.py @@ -15,7 +15,9 @@ """MediaPipe Solutions Python API.""" import mediapipe.python.solutions.drawing_utils +import mediapipe.python.solutions.face_detection import mediapipe.python.solutions.face_mesh import mediapipe.python.solutions.hands import mediapipe.python.solutions.holistic +import mediapipe.python.solutions.objectron import mediapipe.python.solutions.pose diff --git a/mediapipe/python/solutions/drawing_utils.py b/mediapipe/python/solutions/drawing_utils.py index ae1e0c401..06936741a 100644 --- a/mediapipe/python/solutions/drawing_utils.py +++ b/mediapipe/python/solutions/drawing_utils.py @@ -21,6 +21,8 @@ import cv2 import dataclasses import numpy as np +from mediapipe.framework.formats import detection_pb2 +from mediapipe.framework.formats import location_data_pb2 from mediapipe.framework.formats import landmark_pb2 PRESENCE_THRESHOLD = 0.5 @@ -58,6 +60,57 @@ def _normalized_to_pixel_coordinates( return x_px, y_px +def draw_detection( + image: np.ndarray, + detection: detection_pb2.Detection, + keypoint_drawing_spec: DrawingSpec = DrawingSpec(color=RED_COLOR), + bbox_drawing_spec: DrawingSpec = DrawingSpec()): + """Draws the detction bounding box and keypoints on the image. + + Args: + image: A three channel RGB image represented as numpy ndarray. + detection: A detection proto message to be annotated on the image. + keypoint_drawing_spec: A DrawingSpec object that specifies the keypoints' + drawing settings such as color, line thickness, and circle radius. + bbox_drawing_spec: A DrawingSpec object that specifies the bounding box's + drawing settings such as color and line thickness. + + Raises: + ValueError: If one of the followings: + a) If the input image is not three channel RGB. + b) If the location data is not relative data. + """ + if not detection.location_data: + return + if image.shape[2] != RGB_CHANNELS: + raise ValueError('Input image must contain three channel rgb data.') + image_rows, image_cols, _ = image.shape + + location = detection.location_data + if location.format != location_data_pb2.LocationData.RELATIVE_BOUNDING_BOX: + raise ValueError( + 'LocationData must be relative for this drawing funtion to work.') + # Draws keypoints. + for keypoint in location.relative_keypoints: + keypoint_px = _normalized_to_pixel_coordinates(keypoint.x, keypoint.y, + image_cols, image_rows) + cv2.circle(image, keypoint_px, keypoint_drawing_spec.circle_radius, + keypoint_drawing_spec.color, keypoint_drawing_spec.thickness) + # Draws bounding box if exists. + if not location.HasField('relative_bounding_box'): + return + relative_bounding_box = location.relative_bounding_box + rect_start_point = _normalized_to_pixel_coordinates( + relative_bounding_box.xmin, relative_bounding_box.ymin, image_cols, + image_rows) + rect_end_point = _normalized_to_pixel_coordinates( + relative_bounding_box.xmin + relative_bounding_box.width, + relative_bounding_box.ymin + +relative_bounding_box.height, image_cols, + image_rows) + cv2.rectangle(image, rect_start_point, rect_end_point, + bbox_drawing_spec.color, bbox_drawing_spec.thickness) + + def draw_landmarks( image: np.ndarray, landmark_list: landmark_pb2.NormalizedLandmarkList, @@ -116,3 +169,63 @@ def draw_landmarks( for landmark_px in idx_to_coordinates.values(): cv2.circle(image, landmark_px, landmark_drawing_spec.circle_radius, landmark_drawing_spec.color, landmark_drawing_spec.thickness) + + +def draw_axis( + image: np.ndarray, + rotation: np.ndarray, + translation: np.ndarray, + focal_length: Tuple[float, float] = (1.0, 1.0), + principal_point: Tuple[float, float] = (0.0, 0.0), + axis_length: float = 0.1, + x_axis_drawing_spec: DrawingSpec = DrawingSpec(color=(0, 0, 255)), + y_axis_drawing_spec: DrawingSpec = DrawingSpec(color=(0, 128, 0)), + z_axis_drawing_spec: DrawingSpec = DrawingSpec(color=(255, 0, 0))): + """Draws the 3D axis on the image. + + Args: + image: A three channel RGB image represented as numpy ndarray. + rotation: Rotation matrix from object to camera coordinate frame. + translation: Translation vector from object to camera coordinate frame. + focal_length: camera focal length along x and y directions. + principal_point: camera principal point in x and y. + axis_length: length of the axis in the drawing. + x_axis_drawing_spec: A DrawingSpec object that specifies the x axis + drawing settings such as color, line thickness. + y_axis_drawing_spec: A DrawingSpec object that specifies the y axis + drawing settings such as color, line thickness. + z_axis_drawing_spec: A DrawingSpec object that specifies the z axis + drawing settings such as color, line thickness. + + Raises: + ValueError: If one of the followings: + a) If the input image is not three channel RGB. + """ + if image.shape[2] != RGB_CHANNELS: + raise ValueError('Input image must contain three channel rgb data.') + image_rows, image_cols, _ = image.shape + # Create axis points in camera coordinate frame. + axis_world = np.float32([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) + axis_cam = np.matmul(rotation, axis_length*axis_world.T).T + translation + x = axis_cam[..., 0] + y = axis_cam[..., 1] + z = axis_cam[..., 2] + # Project 3D points to NDC space. + fx, fy = focal_length + px, py = principal_point + x_ndc = -fx * x / z + px + y_ndc = -fy * y / z + py + # Convert from NDC space to image space. + x_im = np.int32((1 + x_ndc) * 0.5 * image_cols) + y_im = np.int32((1 - y_ndc) * 0.5 * image_rows) + # Draw xyz axis on the image. + origin = (x_im[0], y_im[0]) + x_axis = (x_im[1], y_im[1]) + y_axis = (x_im[2], y_im[2]) + z_axis = (x_im[3], y_im[3]) + image = cv2.arrowedLine(image, origin, x_axis, x_axis_drawing_spec.color, + x_axis_drawing_spec.thickness) + image = cv2.arrowedLine(image, origin, y_axis, y_axis_drawing_spec.color, + y_axis_drawing_spec.thickness) + image = cv2.arrowedLine(image, origin, z_axis, z_axis_drawing_spec.color, + z_axis_drawing_spec.thickness) diff --git a/mediapipe/python/solutions/drawing_utils_test.py b/mediapipe/python/solutions/drawing_utils_test.py index 0c5e21e01..6391aca80 100644 --- a/mediapipe/python/solutions/drawing_utils_test.py +++ b/mediapipe/python/solutions/drawing_utils_test.py @@ -21,11 +21,13 @@ import numpy as np from google.protobuf import text_format +from mediapipe.framework.formats import detection_pb2 from mediapipe.framework.formats import landmark_pb2 from mediapipe.python.solutions import drawing_utils +DEFAULT_BBOX_DRAWING_SPEC = drawing_utils.DrawingSpec() DEFAULT_CONNECTION_DRAWING_SPEC = drawing_utils.DrawingSpec() -DEFAULT_LANDMARK_DRAWING_SPEC = drawing_utils.DrawingSpec(color=(0, 0, 255)) +DEFAULT_CIRCLE_DRAWING_SPEC = drawing_utils.DrawingSpec(color=(0, 0, 255)) class DrawingUtilTest(parameterized.TestCase): @@ -35,6 +37,9 @@ class DrawingUtilTest(parameterized.TestCase): with self.assertRaisesRegex( ValueError, 'Input image must contain three channel rgb data.'): drawing_utils.draw_landmarks(image, landmark_pb2.NormalizedLandmarkList()) + with self.assertRaisesRegex( + ValueError, 'Input image must contain three channel rgb data.'): + drawing_utils.draw_detection(image, detection_pb2.Detection()) def test_invalid_connection(self): landmark_list = text_format.Parse( @@ -44,6 +49,46 @@ class DrawingUtilTest(parameterized.TestCase): with self.assertRaisesRegex(ValueError, 'Landmark index is out of range.'): drawing_utils.draw_landmarks(image, landmark_list, [(0, 2)]) + def test_unqualified_detection(self): + detection = text_format.Parse('location_data {format: GLOBAL}', + detection_pb2.Detection()) + image = np.arange(27, dtype=np.uint8).reshape(3, 3, 3) + with self.assertRaisesRegex(ValueError, 'LocationData must be relative'): + drawing_utils.draw_detection(image, detection) + + def test_draw_keypoints_only(self): + detection = text_format.Parse( + 'location_data {' + ' format: RELATIVE_BOUNDING_BOX' + ' relative_keypoints {x: 0 y: 1}' + ' relative_keypoints {x: 1 y: 0}}', detection_pb2.Detection()) + image = np.zeros((100, 100, 3), np.uint8) + expected_result = np.copy(image) + cv2.circle(expected_result, (0, 99), + DEFAULT_CIRCLE_DRAWING_SPEC.circle_radius, + DEFAULT_CIRCLE_DRAWING_SPEC.color, + DEFAULT_CIRCLE_DRAWING_SPEC.thickness) + cv2.circle(expected_result, (99, 0), + DEFAULT_CIRCLE_DRAWING_SPEC.circle_radius, + DEFAULT_CIRCLE_DRAWING_SPEC.color, + DEFAULT_CIRCLE_DRAWING_SPEC.thickness) + drawing_utils.draw_detection(image, detection) + np.testing.assert_array_equal(image, expected_result) + + def test_draw_bboxs_only(self): + detection = text_format.Parse( + 'location_data {' + ' format: RELATIVE_BOUNDING_BOX' + ' relative_bounding_box {xmin: 0 ymin: 0 width: 1 height: 1}}', + detection_pb2.Detection()) + image = np.zeros((100, 100, 3), np.uint8) + expected_result = np.copy(image) + cv2.rectangle(expected_result, (0, 0), (99, 99), + DEFAULT_BBOX_DRAWING_SPEC.color, + DEFAULT_BBOX_DRAWING_SPEC.thickness) + drawing_utils.draw_detection(image, detection) + np.testing.assert_array_equal(image, expected_result) + @parameterized.named_parameters( ('landmark_list_has_only_one_element', 'landmark {x: 0.1 y: 0.1}'), ('second_landmark_is_invisible', @@ -54,9 +99,9 @@ class DrawingUtilTest(parameterized.TestCase): image = np.zeros((100, 100, 3), np.uint8) expected_result = np.copy(image) cv2.circle(expected_result, (10, 10), - DEFAULT_LANDMARK_DRAWING_SPEC.circle_radius, - DEFAULT_LANDMARK_DRAWING_SPEC.color, - DEFAULT_LANDMARK_DRAWING_SPEC.thickness) + DEFAULT_CIRCLE_DRAWING_SPEC.circle_radius, + DEFAULT_CIRCLE_DRAWING_SPEC.color, + DEFAULT_CIRCLE_DRAWING_SPEC.thickness) drawing_utils.draw_landmarks(image, landmark_list) np.testing.assert_array_equal(image, expected_result) @@ -77,13 +122,13 @@ class DrawingUtilTest(parameterized.TestCase): DEFAULT_CONNECTION_DRAWING_SPEC.color, DEFAULT_CONNECTION_DRAWING_SPEC.thickness) cv2.circle(expected_result, start_point, - DEFAULT_LANDMARK_DRAWING_SPEC.circle_radius, - DEFAULT_LANDMARK_DRAWING_SPEC.color, - DEFAULT_LANDMARK_DRAWING_SPEC.thickness) + DEFAULT_CIRCLE_DRAWING_SPEC.circle_radius, + DEFAULT_CIRCLE_DRAWING_SPEC.color, + DEFAULT_CIRCLE_DRAWING_SPEC.thickness) cv2.circle(expected_result, end_point, - DEFAULT_LANDMARK_DRAWING_SPEC.circle_radius, - DEFAULT_LANDMARK_DRAWING_SPEC.color, - DEFAULT_LANDMARK_DRAWING_SPEC.thickness) + DEFAULT_CIRCLE_DRAWING_SPEC.circle_radius, + DEFAULT_CIRCLE_DRAWING_SPEC.color, + DEFAULT_CIRCLE_DRAWING_SPEC.thickness) drawing_utils.draw_landmarks( image=image, landmark_list=landmark_list, connections=[(0, 1)]) np.testing.assert_array_equal(image, expected_result) @@ -100,13 +145,13 @@ class DrawingUtilTest(parameterized.TestCase): DEFAULT_CONNECTION_DRAWING_SPEC.color, DEFAULT_CONNECTION_DRAWING_SPEC.thickness) cv2.circle(expected_result, start_point, - DEFAULT_LANDMARK_DRAWING_SPEC.circle_radius, - DEFAULT_LANDMARK_DRAWING_SPEC.color, - DEFAULT_LANDMARK_DRAWING_SPEC.thickness) + DEFAULT_CIRCLE_DRAWING_SPEC.circle_radius, + DEFAULT_CIRCLE_DRAWING_SPEC.color, + DEFAULT_CIRCLE_DRAWING_SPEC.thickness) cv2.circle(expected_result, end_point, - DEFAULT_LANDMARK_DRAWING_SPEC.circle_radius, - DEFAULT_LANDMARK_DRAWING_SPEC.color, - DEFAULT_LANDMARK_DRAWING_SPEC.thickness) + DEFAULT_CIRCLE_DRAWING_SPEC.circle_radius, + DEFAULT_CIRCLE_DRAWING_SPEC.color, + DEFAULT_CIRCLE_DRAWING_SPEC.thickness) drawing_utils.draw_landmarks( image=image, landmark_list=landmark_list, connections=[(0, 1)]) np.testing.assert_array_equal(image, expected_result) diff --git a/mediapipe/python/solutions/face_detection.py b/mediapipe/python/solutions/face_detection.py new file mode 100644 index 000000000..ef5eefdb3 --- /dev/null +++ b/mediapipe/python/solutions/face_detection.py @@ -0,0 +1,103 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe Face Detection.""" + +import enum +from typing import NamedTuple, Union + +import numpy as np +from mediapipe.framework.formats import detection_pb2 +from mediapipe.framework.formats import location_data_pb2 +# pylint: disable=unused-import +from mediapipe.calculators.tensor import image_to_tensor_calculator_pb2 +from mediapipe.calculators.tensor import inference_calculator_pb2 +from mediapipe.calculators.tensor import tensors_to_detections_calculator_pb2 +from mediapipe.calculators.tflite import ssd_anchors_calculator_pb2 +from mediapipe.calculators.util import non_max_suppression_calculator_pb2 +# pylint: enable=unused-import +from mediapipe.python.solution_base import SolutionBase + +BINARYPB_FILE_PATH = 'mediapipe/modules/face_detection/face_detection_front_cpu.binarypb' + + +def get_key_point( + detection: detection_pb2.Detection, key_point_enum: 'FaceKeyPoint' +) -> Union[None, location_data_pb2.LocationData.RelativeKeypoint]: + """A convenience method to return a face key point by the FaceKeyPoint type. + + Args: + detection: A detection proto message that contains face key points. + key_point_enum: A FaceKeyPoint type. + + Returns: + A RelativeKeypoint proto message. + """ + if not detection or not detection.location_data: + return None + return detection.location_data.relative_keypoints[key_point_enum] + + +class FaceKeyPoint(enum.IntEnum): + """The enum type of the six face detection key points.""" + RIGHT_EYE = 0 + LEFT_EYE = 1 + NOSE_TIP = 2 + MOUTH_CENTER = 3 + RIGHT_EAR_TRAGION = 4 + LEFT_EAR_TRAGION = 5 + + +class FaceDetection(SolutionBase): + """MediaPipe Face Detection. + + MediaPipe Face Detection processes an RGB image and returns a list of the + detected face location data. + + Please refer to + https://solutions.mediapipe.dev/face_detection#python-solution-api + for usage examples. + """ + + def __init__(self, min_detection_confidence=0.5): + """Initializes a MediaPipe Face Detection object. + + Args: + min_detection_confidence: Minimum confidence value ([0.0, 1.0]) for face + detection to be considered successful. See details in + https://solutions.mediapipe.dev/face_detection#min_detection_confidence. + """ + super().__init__( + binary_graph_path=BINARYPB_FILE_PATH, + calculator_params={ + 'facedetectionfrontcommon__TensorsToDetectionsCalculator.min_score_thresh': + min_detection_confidence, + }, + outputs=['detections']) + + def process(self, image: np.ndarray) -> NamedTuple: + """Processes an RGB image and returns a list of the detected face location data. + + Args: + image: An RGB image represented as a numpy ndarray. + + Raises: + RuntimeError: If the underlying graph throws any error. + ValueError: If the input image is not three channel RGB. + + Returns: + A NamedTuple object with a "detections" field that contains a list of the + detected face location data. + """ + + return super().process(input_data={'image': image}) diff --git a/mediapipe/python/solutions/face_detection_test.py b/mediapipe/python/solutions/face_detection_test.py new file mode 100644 index 000000000..86eb794c5 --- /dev/null +++ b/mediapipe/python/solutions/face_detection_test.py @@ -0,0 +1,67 @@ +# Copyright 2021 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for mediapipe.python.solutions.face_detection.""" + +import os + +from absl.testing import absltest +import cv2 +import numpy as np +import numpy.testing as npt + +# resources dependency +from mediapipe.python.solutions import face_detection as mp_faces + +TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' +EXPECTED_FACE_KEY_POINTS = [[182, 368], [186, 467], [236, 416], [284, 415], + [203, 310], [212, 521]] +DIFF_THRESHOLD = 10 # pixels + + +class FaceDetectionTest(absltest.TestCase): + + def test_invalid_image_shape(self): + with mp_faces.FaceDetection() as faces: + with self.assertRaisesRegex( + ValueError, 'Input image must contain three channel rgb data.'): + faces.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) + + def test_blank_image(self): + image = np.zeros([100, 100, 3], dtype=np.uint8) + image.fill(255) + with mp_faces.FaceDetection(min_detection_confidence=0.5) as faces: + results = faces.process(image) + self.assertIsNone(results.detections) + + def test_face(self): + image_path = os.path.join(os.path.dirname(__file__), 'testdata/face.jpg') + image = cv2.flip(cv2.imread(image_path), 1) + + with mp_faces.FaceDetection(min_detection_confidence=0.5) as faces: + for _ in range(5): + results = faces.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + location_data = results.detections[0].location_data + x = [keypoint.x for keypoint in location_data.relative_keypoints] + y = [keypoint.y for keypoint in location_data.relative_keypoints] + face_keypoints = np.transpose(np.stack((y, x))) * image.shape[0:2] + prediction_error = np.abs( + np.asarray(face_keypoints) - np.asarray(EXPECTED_FACE_KEY_POINTS)) + + self.assertLen(results.detections, 1) + self.assertLen(location_data.relative_keypoints, 6) + npt.assert_array_less(prediction_error, DIFF_THRESHOLD) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/python/solutions/face_mesh.py b/mediapipe/python/solutions/face_mesh.py index a100b8bc5..85b7efdcf 100644 --- a/mediapipe/python/solutions/face_mesh.py +++ b/mediapipe/python/solutions/face_mesh.py @@ -227,7 +227,7 @@ class FaceMesh(SolutionBase): image: An RGB image represented as a numpy ndarray. Raises: - RuntimeError: If the underlying graph occurs any error. + RuntimeError: If the underlying graph throws any error. ValueError: If the input image is not three channel RGB. Returns: diff --git a/mediapipe/python/solutions/face_mesh_test.py b/mediapipe/python/solutions/face_mesh_test.py index 6edb21b62..e53479821 100644 --- a/mediapipe/python/solutions/face_mesh_test.py +++ b/mediapipe/python/solutions/face_mesh_test.py @@ -26,7 +26,7 @@ import numpy.testing as npt from mediapipe.python.solutions import face_mesh as mp_faces TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' -DIFF_THRESHOLOD = 20 +DIFF_THRESHOLD = 20 # pixels EYE_INDICES_TO_LANDMARKS = { 33: [176, 350], 7: [177, 353], @@ -66,46 +66,42 @@ EYE_INDICES_TO_LANDMARKS = { class FaceMeshTest(parameterized.TestCase): def test_invalid_image_shape(self): - faces = mp_faces.FaceMesh() - with self.assertRaisesRegex( - ValueError, 'Input image must contain three channel rgb data.'): - faces.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) + with mp_faces.FaceMesh() as faces: + with self.assertRaisesRegex( + ValueError, 'Input image must contain three channel rgb data.'): + faces.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) def test_blank_image(self): - faces = mp_faces.FaceMesh() - image = np.zeros([100, 100, 3], dtype=np.uint8) - image.fill(255) - results = faces.process(image) - self.assertIsNone(results.multi_face_landmarks) - faces.close() + with mp_faces.FaceMesh() as faces: + image = np.zeros([100, 100, 3], dtype=np.uint8) + image.fill(255) + results = faces.process(image) + self.assertIsNone(results.multi_face_landmarks) @parameterized.named_parameters(('static_image_mode', True, 1), ('video_mode', False, 5)) def test_face(self, static_image_mode: bool, num_frames: int): image_path = os.path.join(os.path.dirname(__file__), 'testdata/face.jpg') - faces = mp_faces.FaceMesh( - static_image_mode=static_image_mode, min_detection_confidence=0.5) image = cv2.flip(cv2.imread(image_path), 1) - def process_one_frame(): - results = faces.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - multi_face_landmarks = [] - for landmarks in results.multi_face_landmarks: - self.assertLen(landmarks.landmark, 468) - x = [landmark.x for landmark in landmarks.landmark] - y = [landmark.y for landmark in landmarks.landmark] - face_landmarks = np.transpose(np.stack((y, x))) * image.shape[0:2] - multi_face_landmarks.append(face_landmarks) - self.assertLen(multi_face_landmarks, 1) - # Verify the eye landmarks are correct as sanity check. - for idx, gt_lds in EYE_INDICES_TO_LANDMARKS.items(): - prediction_error = np.abs( - np.asarray(multi_face_landmarks[0][idx]) - np.asarray(gt_lds)) - npt.assert_array_less(prediction_error, DIFF_THRESHOLOD) - - for _ in range(num_frames): - process_one_frame() - faces.close() + with mp_faces.FaceMesh( + static_image_mode=static_image_mode, + min_detection_confidence=0.5) as faces: + for _ in range(num_frames): + results = faces.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + multi_face_landmarks = [] + for landmarks in results.multi_face_landmarks: + self.assertLen(landmarks.landmark, 468) + x = [landmark.x for landmark in landmarks.landmark] + y = [landmark.y for landmark in landmarks.landmark] + face_landmarks = np.transpose(np.stack((y, x))) * image.shape[0:2] + multi_face_landmarks.append(face_landmarks) + self.assertLen(multi_face_landmarks, 1) + # Verify the eye landmarks are correct as sanity check. + for idx, gt_lds in EYE_INDICES_TO_LANDMARKS.items(): + prediction_error = np.abs( + np.asarray(multi_face_landmarks[0][idx]) - np.asarray(gt_lds)) + npt.assert_array_less(prediction_error, DIFF_THRESHOLD) if __name__ == '__main__': diff --git a/mediapipe/python/solutions/hands.py b/mediapipe/python/solutions/hands.py index c1f0c5e9c..15760ed75 100644 --- a/mediapipe/python/solutions/hands.py +++ b/mediapipe/python/solutions/hands.py @@ -151,7 +151,7 @@ class Hands(SolutionBase): image: An RGB image represented as a numpy ndarray. Raises: - RuntimeError: If the underlying graph occurs any error. + RuntimeError: If the underlying graph throws any error. ValueError: If the input image is not three channel RGB. Returns: diff --git a/mediapipe/python/solutions/hands_test.py b/mediapipe/python/solutions/hands_test.py index b6c67ca85..1ea4dc563 100644 --- a/mediapipe/python/solutions/hands_test.py +++ b/mediapipe/python/solutions/hands_test.py @@ -26,7 +26,7 @@ import numpy.testing as npt from mediapipe.python.solutions import hands as mp_hands TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' -DIFF_THRESHOLOD = 20 +DIFF_THRESHOLD = 20 # pixels EXPECTED_HAND_COORDINATES_PREDICTION = [[[332, 144], [323, 211], [286, 257], [237, 289], [203, 322], [216, 219], [138, 238], [90, 249], [51, 253], @@ -46,53 +46,48 @@ EXPECTED_HAND_COORDINATES_PREDICTION = [[[332, 144], [323, 211], [286, 257], class HandsTest(parameterized.TestCase): def test_invalid_image_shape(self): - hands = mp_hands.Hands() - with self.assertRaisesRegex( - ValueError, 'Input image must contain three channel rgb data.'): - hands.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) + with mp_hands.Hands() as hands: + with self.assertRaisesRegex( + ValueError, 'Input image must contain three channel rgb data.'): + hands.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) def test_blank_image(self): - hands = mp_hands.Hands() - image = np.zeros([100, 100, 3], dtype=np.uint8) - image.fill(255) - results = hands.process(image) - self.assertIsNone(results.multi_hand_landmarks) - self.assertIsNone(results.multi_handedness) - hands.close() + with mp_hands.Hands() as hands: + image = np.zeros([100, 100, 3], dtype=np.uint8) + image.fill(255) + results = hands.process(image) + self.assertIsNone(results.multi_hand_landmarks) + self.assertIsNone(results.multi_handedness) @parameterized.named_parameters(('static_image_mode', True, 1), ('video_mode', False, 5)) def test_multi_hands(self, static_image_mode, num_frames): image_path = os.path.join(os.path.dirname(__file__), 'testdata/hands.jpg') - hands = mp_hands.Hands( - static_image_mode=static_image_mode, - max_num_hands=2, - min_detection_confidence=0.5) image = cv2.flip(cv2.imread(image_path), 1) - def process_one_frame(): - results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - handedness = [ - handedness.classification[0].label - for handedness in results.multi_handedness - ] - self.assertLen(handedness, 2) - multi_hand_coordinates = [] - for landmarks in results.multi_hand_landmarks: - self.assertLen(landmarks.landmark, 21) - x = [landmark.x for landmark in landmarks.landmark] - y = [landmark.y for landmark in landmarks.landmark] - hand_coordinates = np.transpose(np.stack((y, x))) * image.shape[0:2] - multi_hand_coordinates.append(hand_coordinates) - self.assertLen(multi_hand_coordinates, 2) - prediction_error = np.abs( - np.asarray(multi_hand_coordinates) - - np.asarray(EXPECTED_HAND_COORDINATES_PREDICTION)) - npt.assert_array_less(prediction_error, DIFF_THRESHOLOD) - - for _ in range(num_frames): - process_one_frame() - hands.close() + with mp_hands.Hands( + static_image_mode=static_image_mode, + max_num_hands=2, + min_detection_confidence=0.5) as hands: + for _ in range(num_frames): + results = hands.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + handedness = [ + handedness.classification[0].label + for handedness in results.multi_handedness + ] + multi_hand_coordinates = [] + for landmarks in results.multi_hand_landmarks: + self.assertLen(landmarks.landmark, 21) + x = [landmark.x for landmark in landmarks.landmark] + y = [landmark.y for landmark in landmarks.landmark] + hand_coordinates = np.transpose(np.stack((y, x))) * image.shape[0:2] + multi_hand_coordinates.append(hand_coordinates) + self.assertLen(handedness, 2) + self.assertLen(multi_hand_coordinates, 2) + prediction_error = np.abs( + np.asarray(multi_hand_coordinates) - + np.asarray(EXPECTED_HAND_COORDINATES_PREDICTION)) + npt.assert_array_less(prediction_error, DIFF_THRESHOLD) if __name__ == '__main__': diff --git a/mediapipe/python/solutions/holistic.py b/mediapipe/python/solutions/holistic.py index 02b891222..898bb1562 100644 --- a/mediapipe/python/solutions/holistic.py +++ b/mediapipe/python/solutions/holistic.py @@ -41,6 +41,7 @@ from mediapipe.python.solutions.hands import HAND_CONNECTIONS from mediapipe.python.solutions.hands import HandLandmark from mediapipe.python.solutions.pose import POSE_CONNECTIONS from mediapipe.python.solutions.pose import PoseLandmark +from mediapipe.python.solutions.pose import UPPER_BODY_POSE_CONNECTIONS # pylint: enable=unused-import BINARYPB_FILE_PATH = 'mediapipe/modules/holistic_landmark/holistic_landmark_cpu.binarypb' @@ -111,7 +112,7 @@ class Holistic(SolutionBase): image: An RGB image represented as a numpy ndarray. Raises: - RuntimeError: If the underlying graph occurs any error. + RuntimeError: If the underlying graph throws any error. ValueError: If the input image is not three channel RGB. Returns: diff --git a/mediapipe/python/solutions/holistic_test.py b/mediapipe/python/solutions/holistic_test.py index 38c70abc7..cfd29e24c 100644 --- a/mediapipe/python/solutions/holistic_test.py +++ b/mediapipe/python/solutions/holistic_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for mediapipe.python.solutions.pose.""" -import math import os from absl.testing import absltest @@ -26,117 +25,118 @@ import numpy.testing as npt from mediapipe.python.solutions import holistic as mp_holistic TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' -POSE_DIFF_THRESHOLOD = 30 # pixels -HAND_DIFF_THRESHOLOD = 10 # pixels -EXPECTED_POSE_COORDINATES_PREDICTION = [[593, 645], [593, 626], [599, 621], - [605, 617], [575, 637], [569, 640], - [563, 643], [621, 616], [565, 652], - [617, 652], [595, 667], [714, 662], - [567, 749], [792, 559], [497, 844], - [844, 435], [407, 906], [866, 403], - [381, 921], [859, 392], [366, 922], - [850, 405], [381, 918], [707, 948], - [631, 940], [582, 1122], [599, 1097], - [495, 1277], [641, 1239], [485, 1300], - [658, 1257], [453, 1332], [626, 1308]] -EXPECTED_LEFT_HAND_COORDINATES_PREDICTION = [[843, 404], [862, 395], [876, 383], - [887, 369], [896, 359], [854, 367], - [868, 347], [879, 346], [885, 349], - [843, 362], [859, 341], [871, 340], - [878, 344], [837, 361], [849, 341], - [859, 338], [867, 339], [834, 361], - [841, 346], [848, 342], [854, 341]] -EXPECTED_RIGHT_HAND_COORDINATES_PREDICTION = [[391, 934], [371, - 930], [354, 930], - [340, 934], [328, - 939], [350, 938], - [339, 946], [347, - 951], [355, 952], - [356, 946], [346, - 955], [358, 956], - [366, 953], [361, - 952], [354, 959], - [364, 958], [372, - 954], [366, 957], - [359, 963], [364, 962], - [368, 960]] +POSE_DIFF_THRESHOLD = 30 # pixels +HAND_DIFF_THRESHOLD = 30 # pixels +EXPECTED_UPPER_BODY_LANDMARKS = np.array([[457, 289], [465, 278], [467, 278], + [470, 277], [461, 279], [461, 279], + [461, 279], [485, 277], [474, 278], + [468, 296], [463, 297], [542, 324], + [449, 327], [614, 321], [376, 318], + [680, 322], [312, 310], [697, 320], + [293, 305], [699, 314], [289, 302], + [693, 316], [296, 305], [515, 451], + [467, 453]]) +EXPECTED_FULL_BODY_LANDMARKS = np.array([[460, 287], [469, 277], [472, 276], + [475, 276], [464, 277], [463, 277], + [463, 276], [492, 277], [472, 277], + [471, 295], [465, 295], [542, 323], + [448, 318], [619, 319], [372, 313], + [695, 316], [296, 308], [717, 313], + [273, 304], [718, 304], [280, 298], + [709, 307], [289, 303], [521, 470], + [459, 466], [626, 533], [364, 500], + [704, 616], [347, 614], [710, 631], + [357, 633], [737, 625], [306, 639]]) +EXPECTED_LEFT_HAND_LANDMARKS = np.array([[698, 314], [712, 314], [721, 314], + [727, 314], [732, 313], [728, 309], + [738, 309], [745, 308], [751, 307], + [724, 310], [735, 309], [742, 309], + [747, 307], [719, 312], [727, 313], + [729, 312], [731, 311], [713, 315], + [717, 315], [719, 314], [719, 313]]) +EXPECTED_RIGHT_HAND_LANDMARKS = np.array([[293, 307], [284, 306], [277, 304], + [271, 303], [266, 303], [271, 302], + [261, 302], [254, 301], [247, 299], + [272, 303], [261, 303], [253, 301], + [245, 299], [275, 304], [266, 303], + [258, 302], [252, 300], [279, 305], + [273, 305], [268, 304], [263, 303]]) class PoseTest(parameterized.TestCase): - def _verify_output_landmarks(self, landmark_list, image_shape, num_landmarks, - expected_results, diff_thresholds): - self.assertLen(landmark_list.landmark, num_landmarks) - image_rows, image_cols, _ = image_shape - pose_coordinates = [(math.floor(landmark.x * image_cols), - math.floor(landmark.y * image_rows)) - for landmark in landmark_list.landmark] - prediction_error = np.abs( - np.asarray(pose_coordinates) - - np.asarray(expected_results[:num_landmarks])) - npt.assert_array_less(prediction_error, diff_thresholds) + def _landmarks_list_to_array(self, landmark_list, image_shape): + rows, cols, _ = image_shape + return np.asarray([(lmk.x * cols, lmk.y * rows) + for lmk in landmark_list.landmark]) + + def _assert_diff_less(self, array1, array2, threshold): + npt.assert_array_less(np.abs(array1 - array2), threshold) def test_invalid_image_shape(self): - holistic = mp_holistic.Holistic() - with self.assertRaisesRegex( - ValueError, 'Input image must contain three channel rgb data.'): - holistic.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) + with mp_holistic.Holistic() as holistic: + with self.assertRaisesRegex( + ValueError, 'Input image must contain three channel rgb data.'): + holistic.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) def test_blank_image(self): - holistic = mp_holistic.Holistic() - image = np.zeros([100, 100, 3], dtype=np.uint8) - image.fill(255) - results = holistic.process(image) - self.assertIsNone(results.pose_landmarks) - holistic.close() + with mp_holistic.Holistic() as holistic: + image = np.zeros([100, 100, 3], dtype=np.uint8) + image.fill(255) + results = holistic.process(image) + self.assertIsNone(results.pose_landmarks) @parameterized.named_parameters(('static_image_mode', True, 3), ('video_mode', False, 3)) def test_upper_body_model(self, static_image_mode, num_frames): image_path = os.path.join(os.path.dirname(__file__), 'testdata/pose.jpg') - holistic = mp_holistic.Holistic( - static_image_mode=static_image_mode, upper_body_only=True) - image = cv2.imread(image_path) - for _ in range(num_frames): - results = holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - self._verify_output_landmarks(results.pose_landmarks, image.shape, 25, - EXPECTED_POSE_COORDINATES_PREDICTION, - POSE_DIFF_THRESHOLOD) - self._verify_output_landmarks(results.left_hand_landmarks, image.shape, - 21, - EXPECTED_LEFT_HAND_COORDINATES_PREDICTION, - HAND_DIFF_THRESHOLOD) - self._verify_output_landmarks(results.right_hand_landmarks, image.shape, - 21, - EXPECTED_RIGHT_HAND_COORDINATES_PREDICTION, - HAND_DIFF_THRESHOLOD) - # TODO: Verify the correctness of the face landmarks. - self.assertLen(results.face_landmarks.landmark, 468) - holistic.close() + with mp_holistic.Holistic( + static_image_mode=static_image_mode, upper_body_only=True) as holistic: + image = cv2.imread(image_path) + for _ in range(num_frames): + results = holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + self._assert_diff_less( + self._landmarks_list_to_array(results.pose_landmarks, image.shape), + EXPECTED_UPPER_BODY_LANDMARKS, + POSE_DIFF_THRESHOLD) + self._assert_diff_less( + self._landmarks_list_to_array(results.left_hand_landmarks, + image.shape), + EXPECTED_LEFT_HAND_LANDMARKS, + HAND_DIFF_THRESHOLD) + self._assert_diff_less( + self._landmarks_list_to_array(results.right_hand_landmarks, + image.shape), + EXPECTED_RIGHT_HAND_LANDMARKS, + HAND_DIFF_THRESHOLD) + # TODO: Verify the correctness of the face landmarks. + self.assertLen(results.face_landmarks.landmark, 468) @parameterized.named_parameters(('static_image_mode', True, 3), ('video_mode', False, 3)) def test_full_body_model(self, static_image_mode, num_frames): image_path = os.path.join(os.path.dirname(__file__), 'testdata/pose.jpg') - holistic = mp_holistic.Holistic(static_image_mode=static_image_mode) image = cv2.imread(image_path) - for _ in range(num_frames): - results = holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - self._verify_output_landmarks(results.pose_landmarks, image.shape, 33, - EXPECTED_POSE_COORDINATES_PREDICTION, - POSE_DIFF_THRESHOLOD) - self._verify_output_landmarks(results.left_hand_landmarks, image.shape, - 21, - EXPECTED_LEFT_HAND_COORDINATES_PREDICTION, - HAND_DIFF_THRESHOLOD) - self._verify_output_landmarks(results.right_hand_landmarks, image.shape, - 21, - EXPECTED_RIGHT_HAND_COORDINATES_PREDICTION, - HAND_DIFF_THRESHOLOD) - # TODO: Verify the correctness of the face landmarks. - self.assertLen(results.face_landmarks.landmark, 468) - holistic.close() + with mp_holistic.Holistic(static_image_mode=static_image_mode) as holistic: + for _ in range(num_frames): + results = holistic.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + self._assert_diff_less( + self._landmarks_list_to_array(results.pose_landmarks, image.shape), + EXPECTED_FULL_BODY_LANDMARKS, + POSE_DIFF_THRESHOLD) + self._assert_diff_less( + self._landmarks_list_to_array(results.left_hand_landmarks, + image.shape), + EXPECTED_LEFT_HAND_LANDMARKS, + HAND_DIFF_THRESHOLD) + self._assert_diff_less( + self._landmarks_list_to_array(results.right_hand_landmarks, + image.shape), + EXPECTED_RIGHT_HAND_LANDMARKS, + HAND_DIFF_THRESHOLD) + # TODO: Verify the correctness of the face landmarks. + self.assertLen(results.face_landmarks.landmark, 468) if __name__ == '__main__': diff --git a/mediapipe/python/solutions/objectron.py b/mediapipe/python/solutions/objectron.py new file mode 100644 index 000000000..7802c814b --- /dev/null +++ b/mediapipe/python/solutions/objectron.py @@ -0,0 +1,278 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MediaPipe Objectron.""" + +import enum +from typing import List, Tuple, NamedTuple, Optional + +import attr +import numpy as np + +from mediapipe.calculators.core import constant_side_packet_calculator_pb2 +# pylint: disable=unused-import +from mediapipe.calculators.core import gate_calculator_pb2 +from mediapipe.calculators.core import split_vector_calculator_pb2 +from mediapipe.calculators.tensor import image_to_tensor_calculator_pb2 +from mediapipe.calculators.tensor import inference_calculator_pb2 +from mediapipe.calculators.tensor import tensors_to_detections_calculator_pb2 +from mediapipe.calculators.tensor import tensors_to_floats_calculator_pb2 +from mediapipe.calculators.tensor import tensors_to_landmarks_calculator_pb2 +from mediapipe.calculators.tflite import ssd_anchors_calculator_pb2 +from mediapipe.calculators.util import association_calculator_pb2 +from mediapipe.calculators.util import collection_has_min_size_calculator_pb2 +from mediapipe.calculators.util import detection_label_id_to_text_calculator_pb2 +from mediapipe.calculators.util import detections_to_rects_calculator_pb2 +from mediapipe.calculators.util import landmark_projection_calculator_pb2 +from mediapipe.calculators.util import local_file_contents_calculator_pb2 +from mediapipe.calculators.util import non_max_suppression_calculator_pb2 +from mediapipe.calculators.util import rect_transformation_calculator_pb2 +from mediapipe.calculators.util import thresholding_calculator_pb2 +from mediapipe.framework.formats import landmark_pb2 +from mediapipe.modules.objectron.calculators import annotation_data_pb2 +from mediapipe.modules.objectron.calculators import frame_annotation_to_rect_calculator_pb2 +from mediapipe.modules.objectron.calculators import lift_2d_frame_annotation_to_3d_calculator_pb2 +# pylint: enable=unused-import +from mediapipe.python.solution_base import SolutionBase + + +class BoxLandmark(enum.IntEnum): + """The 9 3D box landmarks.""" + # + # 3 + + + + + + + + 7 + # +\ +\ UP + # + \ + \ + # + \ + \ | + # + 4 + + + + + + + + 8 | y + # + + + + | + # + + + + | + # + + (0) + + .------- x + # + + + + \ + # 1 + + + + + + + + 5 + \ + # \ + \ + \ z + # \ + \ + \ + # \+ \+ + # 2 + + + + + + + + 6 + CENTER = 0 + BACK_BOTTOM_LEFT = 1 + FRONT_BOTTOM_LEFT = 2 + BACK_TOP_LEFT = 3 + FRONT_TOP_LEFT = 4 + BACK_BOTTOM_RIGHT = 5 + FRONT_BOTTOM_RIGHT = 6 + BACK_TOP_RIGHT = 7 + FRONT_TOP_RIGHT = 8 + +BINARYPB_FILE_PATH = 'mediapipe/modules/objectron/objectron_cpu.binarypb' +BOX_CONNECTIONS = frozenset([ + (BoxLandmark.BACK_BOTTOM_LEFT, BoxLandmark.FRONT_BOTTOM_LEFT), + (BoxLandmark.BACK_BOTTOM_LEFT, BoxLandmark.BACK_TOP_LEFT), + (BoxLandmark.BACK_BOTTOM_LEFT, BoxLandmark.BACK_BOTTOM_RIGHT), + (BoxLandmark.FRONT_BOTTOM_LEFT, BoxLandmark.FRONT_TOP_LEFT), + (BoxLandmark.FRONT_BOTTOM_LEFT, BoxLandmark.FRONT_BOTTOM_RIGHT), + (BoxLandmark.BACK_TOP_LEFT, BoxLandmark.FRONT_TOP_LEFT), + (BoxLandmark.BACK_TOP_LEFT, BoxLandmark.BACK_TOP_RIGHT), + (BoxLandmark.FRONT_TOP_LEFT, BoxLandmark.FRONT_TOP_RIGHT), + (BoxLandmark.BACK_BOTTOM_RIGHT, BoxLandmark.FRONT_BOTTOM_RIGHT), + (BoxLandmark.BACK_BOTTOM_RIGHT, BoxLandmark.BACK_TOP_RIGHT), + (BoxLandmark.FRONT_BOTTOM_RIGHT, BoxLandmark.FRONT_TOP_RIGHT), + (BoxLandmark.BACK_TOP_RIGHT, BoxLandmark.FRONT_TOP_RIGHT), +]) + + +@attr.s(auto_attribs=True) +class ObjectronModel(object): + model_path: str + label_name: str + + +@attr.s(auto_attribs=True, frozen=True) +class ShoeModel(ObjectronModel): + model_path: str = ('mediapipe/modules/objectron/' + 'object_detection_3d_sneakers.tflite') + label_name: str = 'Footwear' + + +@attr.s(auto_attribs=True, frozen=True) +class ChairModel(ObjectronModel): + model_path: str = ('mediapipe/modules/objectron/' + 'object_detection_3d_chair.tflite') + label_name: str = 'Chair' + + +@attr.s(auto_attribs=True, frozen=True) +class CameraModel(ObjectronModel): + model_path: str = ('mediapipe/modules/objectron/' + 'object_detection_3d_camera.tflite') + label_name: str = 'Camera' + + +@attr.s(auto_attribs=True, frozen=True) +class CupModel(ObjectronModel): + model_path: str = ('mediapipe/modules/objectron/' + 'object_detection_3d_cup.tflite') + label_name: str = 'Coffee cup, Mug' + +_MODEL_DICT = { + 'Shoe': ShoeModel(), + 'Chair': ChairModel(), + 'Cup': CupModel(), + 'Camera': CameraModel() +} + + +def GetModelByName(name: str) -> ObjectronModel: + if name not in _MODEL_DICT: + raise ValueError(f'{name} is not a valid model name for Objectron.') + return _MODEL_DICT[name] + + +@attr.s(auto_attribs=True) +class ObjectronOutputs(object): + landmarks_2d: landmark_pb2.NormalizedLandmarkList + landmarks_3d: landmark_pb2.LandmarkList + rotation: np.ndarray + translation: np.ndarray + scale: np.ndarray + + +class Objectron(SolutionBase): + """MediaPipe Objectron. + + MediaPipe Objectron processes an RGB image and returns the 3D box landmarks + and 2D rectangular bounding box of each detected object. + """ + + def __init__(self, + static_image_mode: bool = False, + max_num_objects: int = 5, + min_detection_confidence: float = 0.5, + min_tracking_confidence: float = 0.99, + model_name: str = 'Shoe', + focal_length: Tuple[float, float] = (1.0, 1.0), + principal_point: Tuple[float, float] = (0.0, 0.0), + image_size: Optional[Tuple[int, int]] = None, + ): + """Initializes a MediaPipe Objectron class. + + Args: + static_image_mode: Whether to treat the input images as a batch of static + and possibly unrelated images, or a video stream. + max_num_objects: Maximum number of objects to detect. + min_detection_confidence: Minimum confidence value ([0.0, 1.0]) for object + detection to be considered successful. + min_tracking_confidence: Minimum confidence value ([0.0, 1.0]) for the + box landmarks to be considered tracked successfully. + model_name: Name of model to use for predicting box landmarks, currently + support {'Shoe', 'Chair', 'Cup', 'Camera'}. + focal_length: Camera focal length `(fx, fy)`, by default is defined in NDC + space. To use focal length (fx_pixel, fy_pixel) in pixel space, users + should provide image_size = (image_width, image_height) to enable + conversions inside the API. + principal_point: Camera principal point (px, py), by default is defined in + NDC space. To use principal point (px_pixel, py_pixel) in pixel space, + users should provide image_size = (image_width, image_height) to enable + conversions inside the API. + image_size (Optional): size (image_width, image_height) of the input image + , ONLY needed when use focal_length and principal_point in pixel space. + """ + # Get Camera parameters. + fx, fy = focal_length + px, py = principal_point + if image_size is not None: + half_width = image_size[0] / 2.0 + half_height = image_size[1] / 2.0 + fx = fx / half_width + fy = fy / half_height + px = - (px - half_width) / half_width + py = - (py - half_height) / half_height + + # Create and init model. + model = GetModelByName(model_name) + super().__init__( + binary_graph_path=BINARYPB_FILE_PATH, + side_inputs={ + 'box_landmark_model_path': model.model_path, + 'allowed_labels': model.label_name, + 'max_num_objects': max_num_objects, + }, + calculator_params={ + 'ConstantSidePacketCalculator.packet': [ + constant_side_packet_calculator_pb2 + .ConstantSidePacketCalculatorOptions.ConstantSidePacket( + bool_value=not static_image_mode) + ], + ('objectdetectionoidv4subgraph' + '__TensorsToDetectionsCalculator.min_score_thresh'): + min_detection_confidence, + ('boxlandmarksubgraph__ThresholdingCalculator' + '.threshold'): + min_tracking_confidence, + ('Lift2DFrameAnnotationTo3DCalculator' + '.normalized_focal_x'): fx, + ('Lift2DFrameAnnotationTo3DCalculator' + '.normalized_focal_y'): fy, + ('Lift2DFrameAnnotationTo3DCalculator' + '.normalized_principal_point_x'): px, + ('Lift2DFrameAnnotationTo3DCalculator' + '.normalized_principal_point_y'): py, + }, + outputs=['detected_objects']) + + def process(self, image: np.ndarray) -> NamedTuple: + """Processes an RGB image and returns the box landmarks and rectangular bounding box of each detected object. + + Args: + image: An RGB image represented as a numpy ndarray. + + Raises: + RuntimeError: If the underlying graph throws any error. + ValueError: If the input image is not three channel RGB. + + Returns: + A NamedTuple object with a "detected_objects" field that contains a list + of detected 3D bounding boxes. Each detected box is represented as an + "ObjectronOutputs" instance. + """ + + results = super().process(input_data={'image': image}) + if results.detected_objects: + results.detected_objects = self._convert_format(results.detected_objects) + else: + results.detected_objects = None + return results + + def _convert_format( + self, + inputs: annotation_data_pb2.FrameAnnotation) -> List[ObjectronOutputs]: + new_outputs = list() + for annotation in inputs.annotations: + # Get 3d object pose. + rotation = np.reshape(np.array(annotation.rotation), (3, 3)) + translation = np.array(annotation.translation) + scale = np.array(annotation.scale) + # Get 2d/3d landmakrs. + landmarks_2d = landmark_pb2.NormalizedLandmarkList() + landmarks_3d = landmark_pb2.LandmarkList() + for keypoint in annotation.keypoints: + point_2d = keypoint.point_2d + landmarks_2d.landmark.add(x=point_2d.x, y=point_2d.y) + point_3d = keypoint.point_3d + landmarks_3d.landmark.add(x=point_3d.x, y=point_3d.y, z=point_3d.z) + + # Add to objectron outputs. + new_outputs.append(ObjectronOutputs(landmarks_2d, landmarks_3d, + rotation, translation, scale=scale)) + return new_outputs + diff --git a/mediapipe/python/solutions/objectron_test.py b/mediapipe/python/solutions/objectron_test.py new file mode 100644 index 000000000..79982eb15 --- /dev/null +++ b/mediapipe/python/solutions/objectron_test.py @@ -0,0 +1,81 @@ +# Copyright 2020 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for mediapipe.python.solutions.objectron.""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +import cv2 +import numpy as np +import numpy.testing as npt + +# resources dependency +from mediapipe.python.solutions import objectron as mp_objectron + +TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' +DIFF_THRESHOLD = 30 # pixels +EXPECTED_BOX_COORDINATES_PREDICTION = [[[236, 413], [408, 474], [135, 457], + [383, 505], [80, 478], [408, 345], + [130, 347], [384, 355], [72, 353]], + [[241, 206], [411, 279], [131, 280], + [392, 249], [78, 252], [412, 155], + [140, 178], [396, 105], [89, 137]]] + + +class ObjectronTest(parameterized.TestCase): + + def test_invalid_image_shape(self): + with mp_objectron.Objectron() as objectron: + with self.assertRaisesRegex( + ValueError, 'Input image must contain three channel rgb data.'): + objectron.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) + + def test_blank_image(self): + with mp_objectron.Objectron() as objectron: + image = np.zeros([100, 100, 3], dtype=np.uint8) + image.fill(255) + results = objectron.process(image) + self.assertIsNone(results.detected_objects) + + @parameterized.named_parameters(('static_image_mode', True, 1), + ('video_mode', False, 5)) + def test_multi_objects(self, static_image_mode, num_frames): + image_path = os.path.join(os.path.dirname(__file__), 'testdata/shoes.jpg') + image = cv2.imread(image_path) + + with mp_objectron.Objectron( + static_image_mode=static_image_mode, + max_num_objects=2, + min_detection_confidence=0.5) as objectron: + for _ in range(num_frames): + results = objectron.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + multi_box_coordinates = [] + for detected_object in results.detected_objects: + landmarks = detected_object.landmarks_2d + self.assertLen(landmarks.landmark, 9) + x = [landmark.x for landmark in landmarks.landmark] + y = [landmark.y for landmark in landmarks.landmark] + box_coordinates = np.transpose(np.stack((y, x))) * image.shape[0:2] + multi_box_coordinates.append(box_coordinates) + self.assertLen(multi_box_coordinates, 2) + prediction_error = np.abs( + np.asarray(multi_box_coordinates) - + np.asarray(EXPECTED_BOX_COORDINATES_PREDICTION)) + npt.assert_array_less(prediction_error, DIFF_THRESHOLD) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/python/solutions/pose.py b/mediapipe/python/solutions/pose.py index 98e73683b..c295bbe7a 100644 --- a/mediapipe/python/solutions/pose.py +++ b/mediapipe/python/solutions/pose.py @@ -77,7 +77,7 @@ class PoseLandmark(enum.IntEnum): RIGHT_FOOT_INDEX = 32 BINARYPB_FILE_PATH = 'mediapipe/modules/pose_landmark/pose_landmark_cpu.binarypb' -POSE_CONNECTIONS = frozenset([ +UPPER_BODY_POSE_CONNECTIONS = frozenset([ (PoseLandmark.NOSE, PoseLandmark.RIGHT_EYE_INNER), (PoseLandmark.RIGHT_EYE_INNER, PoseLandmark.RIGHT_EYE), (PoseLandmark.RIGHT_EYE, PoseLandmark.RIGHT_EYE_OUTER), @@ -103,18 +103,21 @@ POSE_CONNECTIONS = frozenset([ (PoseLandmark.RIGHT_SHOULDER, PoseLandmark.RIGHT_HIP), (PoseLandmark.LEFT_SHOULDER, PoseLandmark.LEFT_HIP), (PoseLandmark.RIGHT_HIP, PoseLandmark.LEFT_HIP), - (PoseLandmark.RIGHT_HIP, PoseLandmark.LEFT_HIP), - (PoseLandmark.RIGHT_HIP, PoseLandmark.RIGHT_KNEE), - (PoseLandmark.LEFT_HIP, PoseLandmark.LEFT_KNEE), - (PoseLandmark.RIGHT_KNEE, PoseLandmark.RIGHT_ANKLE), - (PoseLandmark.LEFT_KNEE, PoseLandmark.LEFT_ANKLE), - (PoseLandmark.RIGHT_ANKLE, PoseLandmark.RIGHT_HEEL), - (PoseLandmark.LEFT_ANKLE, PoseLandmark.LEFT_HEEL), - (PoseLandmark.RIGHT_HEEL, PoseLandmark.RIGHT_FOOT_INDEX), - (PoseLandmark.LEFT_HEEL, PoseLandmark.LEFT_FOOT_INDEX), - (PoseLandmark.RIGHT_ANKLE, PoseLandmark.RIGHT_FOOT_INDEX), - (PoseLandmark.LEFT_ANKLE, PoseLandmark.LEFT_FOOT_INDEX), ]) +POSE_CONNECTIONS = frozenset.union( + UPPER_BODY_POSE_CONNECTIONS, + frozenset([ + (PoseLandmark.RIGHT_HIP, PoseLandmark.RIGHT_KNEE), + (PoseLandmark.LEFT_HIP, PoseLandmark.LEFT_KNEE), + (PoseLandmark.RIGHT_KNEE, PoseLandmark.RIGHT_ANKLE), + (PoseLandmark.LEFT_KNEE, PoseLandmark.LEFT_ANKLE), + (PoseLandmark.RIGHT_ANKLE, PoseLandmark.RIGHT_HEEL), + (PoseLandmark.LEFT_ANKLE, PoseLandmark.LEFT_HEEL), + (PoseLandmark.RIGHT_HEEL, PoseLandmark.RIGHT_FOOT_INDEX), + (PoseLandmark.LEFT_HEEL, PoseLandmark.LEFT_FOOT_INDEX), + (PoseLandmark.RIGHT_ANKLE, PoseLandmark.RIGHT_FOOT_INDEX), + (PoseLandmark.LEFT_ANKLE, PoseLandmark.LEFT_FOOT_INDEX), + ])) class Pose(SolutionBase): @@ -178,7 +181,7 @@ class Pose(SolutionBase): image: An RGB image represented as a numpy ndarray. Raises: - RuntimeError: If the underlying graph occurs any error. + RuntimeError: If the underlying graph throws any error. ValueError: If the input image is not three channel RGB. Returns: diff --git a/mediapipe/python/solutions/pose_test.py b/mediapipe/python/solutions/pose_test.py index 38131c596..b15408b39 100644 --- a/mediapipe/python/solutions/pose_test.py +++ b/mediapipe/python/solutions/pose_test.py @@ -13,7 +13,6 @@ # limitations under the License. """Tests for mediapipe.python.solutions.pose.""" -import math import os from absl.testing import absltest @@ -26,71 +25,79 @@ import numpy.testing as npt from mediapipe.python.solutions import pose as mp_pose TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata' -DIFF_THRESHOLOD = 30 # pixels -EXPECTED_POSE_COORDINATES_PREDICTION = [[593, 645], [593, 626], [599, 621], - [605, 617], [575, 637], [569, 640], - [563, 643], [621, 616], [565, 652], - [617, 652], [595, 667], [714, 662], - [567, 749], [792, 559], [497, 844], - [844, 435], [407, 906], [866, 403], - [381, 921], [859, 392], [366, 922], - [850, 405], [381, 918], [707, 948], - [631, 940], [582, 1122], [599, 1097], - [495, 1277], [641, 1239], [485, 1300], - [658, 1257], [453, 1332], [626, 1308]] +DIFF_THRESHOLD = 30 # pixels +EXPECTED_UPPER_BODY_LANDMARKS = np.array([[457, 289], [465, 278], [467, 278], + [470, 277], [461, 279], [461, 279], + [461, 279], [485, 277], [474, 278], + [468, 296], [463, 297], [542, 324], + [449, 327], [614, 321], [376, 318], + [680, 322], [312, 310], [697, 320], + [293, 305], [699, 314], [289, 302], + [693, 316], [296, 305], [515, 451], + [467, 453]]) +EXPECTED_FULL_BODY_LANDMARKS = np.array([[460, 287], [469, 277], [472, 276], + [475, 276], [464, 277], [463, 277], + [463, 276], [492, 277], [472, 277], + [471, 295], [465, 295], [542, 323], + [448, 318], [619, 319], [372, 313], + [695, 316], [296, 308], [717, 313], + [273, 304], [718, 304], [280, 298], + [709, 307], [289, 303], [521, 470], + [459, 466], [626, 533], [364, 500], + [704, 616], [347, 614], [710, 631], + [357, 633], [737, 625], [306, 639]]) class PoseTest(parameterized.TestCase): - def _verify_output_landmarks(self, landmark_list, image_shape, num_landmarks): - self.assertLen(landmark_list.landmark, num_landmarks) - image_rows, image_cols, _ = image_shape - pose_coordinates = [(math.floor(landmark.x * image_cols), - math.floor(landmark.y * image_rows)) - for landmark in landmark_list.landmark] - prediction_error = np.abs( - np.asarray(pose_coordinates) - - np.asarray(EXPECTED_POSE_COORDINATES_PREDICTION[:num_landmarks])) - npt.assert_array_less(prediction_error, DIFF_THRESHOLOD) + def _landmarks_list_to_array(self, landmark_list, image_shape): + rows, cols, _ = image_shape + return np.asarray([(lmk.x * cols, lmk.y * rows) + for lmk in landmark_list.landmark]) + + def _assert_diff_less(self, array1, array2, threshold): + npt.assert_array_less(np.abs(array1 - array2), threshold) def test_invalid_image_shape(self): - pose = mp_pose.Pose() - with self.assertRaisesRegex( - ValueError, 'Input image must contain three channel rgb data.'): - pose.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) + with mp_pose.Pose() as pose: + with self.assertRaisesRegex( + ValueError, 'Input image must contain three channel rgb data.'): + pose.process(np.arange(36, dtype=np.uint8).reshape(3, 3, 4)) def test_blank_image(self): - pose = mp_pose.Pose() - image = np.zeros([100, 100, 3], dtype=np.uint8) - image.fill(255) - results = pose.process(image) - self.assertIsNone(results.pose_landmarks) - pose.close() + with mp_pose.Pose() as pose: + image = np.zeros([100, 100, 3], dtype=np.uint8) + image.fill(255) + results = pose.process(image) + self.assertIsNone(results.pose_landmarks) @parameterized.named_parameters(('static_image_mode', True, 3), ('video_mode', False, 3)) def test_upper_body_model(self, static_image_mode, num_frames): image_path = os.path.join(os.path.dirname(__file__), 'testdata/pose.jpg') - pose = mp_pose.Pose(static_image_mode=static_image_mode, - upper_body_only=True) - image = cv2.imread(image_path) - - for _ in range(num_frames): - results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - self._verify_output_landmarks(results.pose_landmarks, image.shape, 25) - pose.close() + with mp_pose.Pose( + static_image_mode=static_image_mode, upper_body_only=True) as pose: + image = cv2.imread(image_path) + for _ in range(num_frames): + results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + self._assert_diff_less( + self._landmarks_list_to_array(results.pose_landmarks, image.shape), + EXPECTED_UPPER_BODY_LANDMARKS, + DIFF_THRESHOLD) @parameterized.named_parameters(('static_image_mode', True, 3), ('video_mode', False, 3)) def test_full_body_model(self, static_image_mode, num_frames): image_path = os.path.join(os.path.dirname(__file__), 'testdata/pose.jpg') - pose = mp_pose.Pose(static_image_mode=static_image_mode) image = cv2.imread(image_path) - for _ in range(num_frames): - results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - self._verify_output_landmarks(results.pose_landmarks, image.shape, 33) - pose.close() + with mp_pose.Pose(static_image_mode=static_image_mode) as pose: + for _ in range(num_frames): + results = pose.process(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + self._assert_diff_less( + self._landmarks_list_to_array(results.pose_landmarks, image.shape), + EXPECTED_FULL_BODY_LANDMARKS, + DIFF_THRESHOLD) if __name__ == '__main__': diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index 9347fc009..d115dd087 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -78,7 +78,6 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "//mediapipe/framework/port:integral_types", - "//mediapipe/framework:port", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", ] + select({ @@ -168,13 +167,14 @@ cc_library( ], deps = [ "//mediapipe/framework/port:ret_check", + "//mediapipe/framework/port:singleton", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", + "//mediapipe/framework/port:file_helpers", "@com_google_absl//absl/strings", ] + select({ "//conditions:default": [ "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/port:file_helpers", "@com_google_absl//absl/flags:flag", ], "//mediapipe:android": [ @@ -184,7 +184,6 @@ cc_library( "//mediapipe:ios": [], "//mediapipe:macos": [ "//mediapipe/framework/deps:file_path", - "//mediapipe/framework/port:file_helpers", "@com_google_absl//absl/flags:flag", ], }), diff --git a/mediapipe/util/android/asset_manager_util.cc b/mediapipe/util/android/asset_manager_util.cc index 315078137..8b5803d64 100644 --- a/mediapipe/util/android/asset_manager_util.cc +++ b/mediapipe/util/android/asset_manager_util.cc @@ -95,16 +95,23 @@ bool AssetManager::InitializeFromActivity(JNIEnv* env, jobject activity, return InitializeFromContext(env, activity, cache_dir_path); } -bool AssetManager::FileExists(const std::string& filename) { +bool AssetManager::FileExists(const std::string& filename, bool* is_dir) { if (!asset_manager_) { LOG(ERROR) << "Asset manager was not initialized from JNI"; return false; } + auto safe_set_is_dir = [is_dir](bool is_dir_value) { + if (is_dir) { + *is_dir = is_dir_value; + } + }; + AAsset* asset = AAssetManager_open(asset_manager_, filename.c_str(), AASSET_MODE_RANDOM); if (asset != nullptr) { AAsset_close(asset); + safe_set_is_dir(false); return true; } @@ -117,6 +124,7 @@ bool AssetManager::FileExists(const std::string& filename) { // unusable (i.e. not considered a valid path). bool dir_exists = AAssetDir_getNextFileName(asset_dir) != nullptr; AAssetDir_close(asset_dir); + safe_set_is_dir(dir_exists); return dir_exists; } @@ -143,7 +151,7 @@ bool AssetManager::ReadFile(const std::string& filename, std::string* output) { return true; } -mediapipe::StatusOr AssetManager::CachedFileFromAsset( +absl::StatusOr AssetManager::CachedFileFromAsset( const std::string& asset_path) { RET_CHECK(cache_dir_path_.size()) << "asset manager not initialized"; @@ -170,8 +178,8 @@ mediapipe::StatusOr AssetManager::CachedFileFromAsset( return file_path; } -mediapipe::Status AssetManager::ReadContentUri(const std::string& content_uri, - std::string* output) { +absl::Status AssetManager::ReadContentUri(const std::string& content_uri, + std::string* output) { RET_CHECK(mediapipe::java::HasJavaVM()) << "JVM instance not set"; JNIEnv* env = mediapipe::java::GetJNIEnv(); RET_CHECK(env != nullptr) << "Unable to retrieve JNIEnv"; @@ -242,7 +250,7 @@ mediapipe::Status AssetManager::ReadContentUri(const std::string& content_uri, reinterpret_cast(&output->at(0))); RET_CHECK(!ExceptionPrintClear(env)) << "failed to copy array data"; - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/util/android/asset_manager_util.h b/mediapipe/util/android/asset_manager_util.h index 677d7c42c..2d2582c2e 100644 --- a/mediapipe/util/android/asset_manager_util.h +++ b/mediapipe/util/android/asset_manager_util.h @@ -65,16 +65,18 @@ class AssetManager { bool InitializeFromContext(JNIEnv* env, jobject context, const std::string& cache_dir_path); - // Checks if a file exists. Returns true on success, false otherwise. - bool FileExists(const std::string& filename); + // Checks if a file exists. Returns true on success, false otherwise. If it + // does exist, then 'is_dir' will be set to indicate whether the file is a + // directory. + bool FileExists(const std::string& filename, bool* is_dir = nullptr); // Reads a file into output. Returns true on success, false otherwise. bool ReadFile(const std::string& filename, std::string* output); // Reads the raw bytes referred to by the supplied content URI. Returns true // on success, false otherwise. - mediapipe::Status ReadContentUri(const std::string& content_uri, - std::string* output); + absl::Status ReadContentUri(const std::string& content_uri, + std::string* output); // Returns the path to the Android cache directory. Will be empty if // InitializeFromActivity has not been called. @@ -83,7 +85,7 @@ class AssetManager { // Caches the contents of the given asset as a file, and returns a path to // that file. This can be used to pass an asset to APIs that require a path // to a filesystem file. - ::mediapipe::StatusOr CachedFileFromAsset( + absl::StatusOr CachedFileFromAsset( const std::string& asset_path); private: diff --git a/mediapipe/util/android/file/base/BUILD b/mediapipe/util/android/file/base/BUILD index 6e5b2390a..f97bf2710 100644 --- a/mediapipe/util/android/file/base/BUILD +++ b/mediapipe/util/android/file/base/BUILD @@ -28,6 +28,7 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:status", "@com_google_absl//absl/base", diff --git a/mediapipe/util/android/file/base/filesystem.cc b/mediapipe/util/android/file/base/filesystem.cc index 53cdb16e1..aee010543 100644 --- a/mediapipe/util/android/file/base/filesystem.cc +++ b/mediapipe/util/android/file/base/filesystem.cc @@ -25,10 +25,10 @@ static_assert(sizeof(off_t) == 8, "Large file support is required"); namespace mediapipe { namespace file { -mediapipe::Status RecursivelyCreateDir(absl::string_view path, - const file::Options& options) { +absl::Status RecursivelyCreateDir(absl::string_view path, + const file::Options& options) { if (path.empty()) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } std::vector path_comp = absl::StrSplit(path, '/'); @@ -45,44 +45,41 @@ mediapipe::Status RecursivelyCreateDir(absl::string_view path, if (S_ISDIR(stat_buf.st_mode)) { continue; } - return mediapipe::Status(mediapipe::StatusCode::kInternal, - "Could not stat " + std::string(crpath)); + return absl::Status(absl::StatusCode::kInternal, + "Could not stat " + std::string(crpath)); } else { int mkval = mkdir(crpath, options.permissions()); if (mkval == -1) { - return mediapipe::Status(mediapipe::StatusCode::kInternal, - "Could not create " + std::string(crpath)); + return absl::Status(absl::StatusCode::kInternal, + "Could not create " + std::string(crpath)); } } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status Exists(absl::string_view path, const file::Options& ignored) { +absl::Status Exists(absl::string_view path, const file::Options& ignored) { struct stat64 stat_buf; int statval = lstat64(std::string(path).c_str(), &stat_buf); if (statval == 0) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { - return mediapipe::Status(mediapipe::StatusCode::kNotFound, - "Could not stat file."); + return absl::Status(absl::StatusCode::kNotFound, "Could not stat file."); } } -mediapipe::Status IsDirectory(absl::string_view path, - const file::Options& /*ignored*/) { +absl::Status IsDirectory(absl::string_view path, + const file::Options& /*ignored*/) { struct stat64 stat_buf; int statval = lstat64(std::string(path).c_str(), &stat_buf); bool is_dir = (statval == 0 && S_ISREG(stat_buf.st_mode)); if (is_dir) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } else if (statval != 0) { - return mediapipe::Status(mediapipe::StatusCode::kNotFound, - "File does not exists"); + return absl::Status(absl::StatusCode::kNotFound, "File does not exists"); } else { - return mediapipe::Status(mediapipe::StatusCode::kNotFound, - "Not a directory"); + return absl::Status(absl::StatusCode::kNotFound, "Not a directory"); } } diff --git a/mediapipe/util/android/file/base/filesystem.h b/mediapipe/util/android/file/base/filesystem.h index af86de70c..3e726f20d 100644 --- a/mediapipe/util/android/file/base/filesystem.h +++ b/mediapipe/util/android/file/base/filesystem.h @@ -22,13 +22,12 @@ namespace mediapipe { namespace file { -mediapipe::Status RecursivelyCreateDir(absl::string_view path, - const file::Options& options); +absl::Status RecursivelyCreateDir(absl::string_view path, + const file::Options& options); -mediapipe::Status Exists(absl::string_view path, const file::Options& options); +absl::Status Exists(absl::string_view path, const file::Options& options); -mediapipe::Status IsDirectory(absl::string_view path, - const file::Options& options); +absl::Status IsDirectory(absl::string_view path, const file::Options& options); } // namespace file. } // namespace mediapipe diff --git a/mediapipe/util/android/file/base/helpers.cc b/mediapipe/util/android/file/base/helpers.cc index 417069069..add4c9b36 100644 --- a/mediapipe/util/android/file/base/helpers.cc +++ b/mediapipe/util/android/file/base/helpers.cc @@ -21,6 +21,7 @@ #include +#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/status.h" namespace mediapipe { @@ -42,16 +43,15 @@ class FdCloser { } // namespace // Read contents of a file to a std::string. -mediapipe::Status GetContents(int fd, std::string* output) { +absl::Status GetContents(int fd, std::string* output) { // Determine the length of the file. struct stat buf; if (fstat(fd, &buf) != 0) { - return mediapipe::Status(mediapipe::StatusCode::kUnknown, - "Failed to get file status"); + return absl::Status(absl::StatusCode::kUnknown, + "Failed to get file status"); } if (buf.st_size < 0 || buf.st_size > SIZE_MAX) { - return mediapipe::Status(mediapipe::StatusCode::kInternal, - "Invalid file size"); + return absl::Status(absl::StatusCode::kInternal, "Invalid file size"); } size_t length = buf.st_size; @@ -61,62 +61,35 @@ mediapipe::Status GetContents(int fd, std::string* output) { while (length != 0) { const ssize_t nread = read(fd, output_ptr, length); if (nread <= 0) { - return mediapipe::Status(mediapipe::StatusCode::kUnknown, - "Failed to read file"); + return absl::Status(absl::StatusCode::kUnknown, "Failed to read file"); } output_ptr += nread; length -= nread; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Read contents of a file to a std::string. -mediapipe::Status GetContents(absl::string_view file_name, std::string* output, - const file::Options& /*options*/) { +absl::Status GetContents(absl::string_view file_name, std::string* output, + const file::Options& /*options*/) { int fd = open(std::string(file_name).c_str(), O_RDONLY); if (fd < 0) { - return mediapipe::Status(mediapipe::StatusCode::kUnknown, - "Failed to open file: " + std::string(file_name)); + return absl::Status(absl::StatusCode::kUnknown, + "Failed to open file: " + std::string(file_name)); } FdCloser closer(fd); return GetContents(fd, output); } -mediapipe::Status GetContents(absl::string_view file_name, - std::string* output) { +absl::Status GetContents(absl::string_view file_name, std::string* output) { return GetContents(file_name, output, file::Defaults()); } -mediapipe::Status SetContents(absl::string_view file_name, - absl::string_view content, - const file::Options& options) { - // Mode -rw-r--r-- - mode_t mode = S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH; - int fd = - open(std::string(file_name).c_str(), O_WRONLY | O_CREAT | O_TRUNC, mode); - if (fd < 0) { - return mediapipe::Status(mediapipe::StatusCode::kUnknown, - "Failed to open file: " + std::string(file_name)); - } - - int bytes_written = 0; - if (content.size() > 0) { - bytes_written = write(fd, content.data(), content.size()); - } - - close(fd); - if (bytes_written == content.size()) { - return mediapipe::OkStatus(); - } else { - return mediapipe::Status(mediapipe::StatusCode::kUnknown, - "Failed to write file"); - } -} - -mediapipe::Status SetContents(absl::string_view file_name, - absl::string_view content) { - return SetContents(file_name, content, file::Defaults()); +absl::Status SetContents(absl::string_view file_name, absl::string_view content, + const file::Options& options) { + // Options are currently ignored. + return SetContents(file_name, content); } } // namespace file diff --git a/mediapipe/util/android/file/base/helpers.h b/mediapipe/util/android/file/base/helpers.h index cd92cfc86..df61d423e 100644 --- a/mediapipe/util/android/file/base/helpers.h +++ b/mediapipe/util/android/file/base/helpers.h @@ -25,23 +25,22 @@ namespace mediapipe { namespace file { // Read contents of a file to a std::string. -mediapipe::Status GetContents(absl::string_view file_name, std::string* output, - const file::Options& options); +absl::Status GetContents(absl::string_view file_name, std::string* output, + const file::Options& options); // Read contents of a file to a std::string with default file options. -mediapipe::Status GetContents(absl::string_view file_name, std::string* output); +absl::Status GetContents(absl::string_view file_name, std::string* output); // Read contents of a file to a std::string from an open file descriptor. -mediapipe::Status GetContents(int fd, std::string* output); +absl::Status GetContents(int fd, std::string* output); // Write std::string to file. -mediapipe::Status SetContents(absl::string_view file_name, - absl::string_view content, - const file::Options& options); +absl::Status SetContents(absl::string_view file_name, absl::string_view content, + const file::Options& options); // Write std::string to file with default file options. -mediapipe::Status SetContents(absl::string_view file_name, - absl::string_view content); +absl::Status SetContents(absl::string_view file_name, + absl::string_view content); } // namespace file } // namespace mediapipe diff --git a/mediapipe/util/audio_decoder.cc b/mediapipe/util/audio_decoder.cc index 1aab2baf7..9ebc79aad 100644 --- a/mediapipe/util/audio_decoder.cc +++ b/mediapipe/util/audio_decoder.cc @@ -152,8 +152,7 @@ std::string AvErrorToString(int error) { } // Send a packet to the decoder. -mediapipe::Status SendPacket(const AVPacket& packet, - AVCodecContext* avcodec_ctx) { +absl::Status SendPacket(const AVPacket& packet, AVCodecContext* avcodec_ctx) { const int error = avcodec_send_packet(avcodec_ctx, &packet); if (error != 0 && error != AVERROR_EOF) { // Not consider AVERROR_EOF as an error because it can happen when more @@ -162,12 +161,12 @@ mediapipe::Status SendPacket(const AVPacket& packet, " (", AvErrorToString(error), "). Packet size: ", packet.size)); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Receive a decoded frame from the decoder. -mediapipe::Status ReceiveFrame(AVCodecContext* avcodec_ctx, AVFrame* frame, - bool* received) { +absl::Status ReceiveFrame(AVCodecContext* avcodec_ctx, AVFrame* frame, + bool* received) { const int error = avcodec_receive_frame(avcodec_ctx, frame); *received = error == 0; if (error != 0 && error != AVERROR_EOF && error != AVERROR(EAGAIN)) { @@ -177,13 +176,12 @@ mediapipe::Status ReceiveFrame(AVCodecContext* avcodec_ctx, AVFrame* frame, return UnknownError(absl::StrCat(" Failed to receive frame: error=", error, " (", AvErrorToString(error), ").")); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status LogStatus(const mediapipe::Status& status, - const AVCodecContext& avcodec_ctx, - const AVPacket& packet, - bool always_return_ok_status) { +absl::Status LogStatus(const absl::Status& status, + const AVCodecContext& avcodec_ctx, + const AVPacket& packet, bool always_return_ok_status) { if (status.ok()) { return status; } @@ -199,7 +197,7 @@ mediapipe::Status LogStatus(const mediapipe::Status& status, if (always_return_ok_status) { LOG(WARNING) << status.message(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { return status; } @@ -228,16 +226,16 @@ BasePacketProcessor::~BasePacketProcessor() { Close(); } bool BasePacketProcessor::HasData() { return !buffer_.empty(); } -mediapipe::Status BasePacketProcessor::GetData(Packet* packet) { +absl::Status BasePacketProcessor::GetData(Packet* packet) { CHECK(packet); CHECK(!buffer_.empty()); *packet = buffer_.front(); buffer_.pop_front(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status BasePacketProcessor::Flush() { +absl::Status BasePacketProcessor::Flush() { int64 last_num_frames_processed; do { std::unique_ptr av_packet(new AVPacket()); @@ -254,7 +252,7 @@ mediapipe::Status BasePacketProcessor::Flush() { } while (last_num_frames_processed != num_frames_processed_); flushed_ = true; - return mediapipe::OkStatus(); + return absl::OkStatus(); } void BasePacketProcessor::Close() { @@ -273,8 +271,8 @@ void BasePacketProcessor::Close() { } } -mediapipe::Status BasePacketProcessor::Decode(const AVPacket& packet, - bool ignore_decode_failures) { +absl::Status BasePacketProcessor::Decode(const AVPacket& packet, + bool ignore_decode_failures) { MP_RETURN_IF_ERROR(LogStatus(SendPacket(packet, avcodec_ctx_), *avcodec_ctx_, packet, ignore_decode_failures)); while (true) { @@ -290,7 +288,7 @@ mediapipe::Status BasePacketProcessor::Decode(const AVPacket& packet, break; } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } int64 BasePacketProcessor::CorrectPtsForRollover(int64 media_pts) { @@ -340,11 +338,11 @@ AudioPacketProcessor::AudioPacketProcessor(const AudioStreamOptions& options) DCHECK(absl::little_endian::IsLittleEndian()); } -mediapipe::Status AudioPacketProcessor::Open(int id, AVStream* stream) { +absl::Status AudioPacketProcessor::Open(int id, AVStream* stream) { id_ = id; avcodec_ = avcodec_find_decoder(stream->codecpar->codec_id); if (!avcodec_) { - return mediapipe::InvalidArgumentError("Failed to find codec"); + return absl::InvalidArgumentError("Failed to find codec"); } avcodec_ctx_ = avcodec_alloc_context3(avcodec_); avcodec_parameters_to_context(avcodec_ctx_, stream->codecpar); @@ -377,17 +375,17 @@ mediapipe::Status AudioPacketProcessor::Open(int id, AVStream* stream) { id_, num_channels_, sample_rate_, source_time_base_.num, source_time_base_.den); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AudioPacketProcessor::ValidateSampleFormat() { +absl::Status AudioPacketProcessor::ValidateSampleFormat() { switch (avcodec_ctx_->sample_fmt) { case AV_SAMPLE_FMT_S16: case AV_SAMPLE_FMT_S16P: case AV_SAMPLE_FMT_S32: case AV_SAMPLE_FMT_FLT: case AV_SAMPLE_FMT_FLTP: - return mediapipe::OkStatus(); + return absl::OkStatus(); default: return mediapipe::UnimplementedErrorBuilder(MEDIAPIPE_LOC) << "sample_fmt = " << avcodec_ctx_->sample_fmt; @@ -411,7 +409,7 @@ int64 AudioPacketProcessor::SampleNumberToMicroseconds( return av_rescale_q(sample_number, sample_time_base_, {1, 1000000}); } -mediapipe::Status AudioPacketProcessor::ProcessPacket(AVPacket* packet) { +absl::Status AudioPacketProcessor::ProcessPacket(AVPacket* packet) { CHECK(packet); if (flushed_) { return UnknownError( @@ -424,8 +422,7 @@ mediapipe::Status AudioPacketProcessor::ProcessPacket(AVPacket* packet) { return Decode(*packet, options_.ignore_decode_failures()); } -mediapipe::Status AudioPacketProcessor::ProcessDecodedFrame( - const AVPacket& packet) { +absl::Status AudioPacketProcessor::ProcessDecodedFrame(const AVPacket& packet) { RET_CHECK_EQ(decoded_frame_->channels, num_channels_); int buf_size_bytes = av_samples_get_buffer_size(nullptr, num_channels_, decoded_frame_->nb_samples, @@ -450,7 +447,8 @@ mediapipe::Status AudioPacketProcessor::ProcessDecodedFrame( SampleNumberToMicroseconds(expected_sample_number_); const int64 actual_us = TimestampToMicroseconds(pts); if (absl::Microseconds(std::abs(expected_us - actual_us)) > - absl::Seconds(FLAGS_media_decoder_allowed_audio_gap_merge)) { + absl::Seconds( + absl::GetFlag(FLAGS_media_decoder_allowed_audio_gap_merge))) { LOG(ERROR) << "The expected time based on how many samples we have seen (" << expected_us << " microseconds) no longer matches the time based " @@ -458,8 +456,8 @@ mediapipe::Status AudioPacketProcessor::ProcessDecodedFrame( << actual_us << " microseconds). The difference is more than " "--media_decoder_allowed_audio_gap_merge (" - << absl::FormatDuration(absl::Seconds( - FLAGS_media_decoder_allowed_audio_gap_merge)) + << absl::FormatDuration(absl::Seconds(absl::GetFlag( + FLAGS_media_decoder_allowed_audio_gap_merge))) << " microseconds). Resetting the timestamps to track what " "the audio stream is telling us."; expected_sample_number_ = TimestampToSampleNumber(pts); @@ -472,14 +470,14 @@ mediapipe::Status AudioPacketProcessor::ProcessDecodedFrame( data_ptr, buf_size_bytes)); ++num_frames_processed_; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AudioPacketProcessor::AddAudioDataToBuffer( +absl::Status AudioPacketProcessor::AddAudioDataToBuffer( const Timestamp output_timestamp, uint8* const* raw_audio, int buf_size_bytes) { if (buf_size_bytes == 0) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } if (buf_size_bytes % (num_channels_ * bytes_per_sample_) != 0) { @@ -568,15 +566,14 @@ mediapipe::Status AudioPacketProcessor::AddAudioDataToBuffer( } expected_sample_number_ += num_samples; - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AudioPacketProcessor::FillHeader( - TimeSeriesHeader* header) const { +absl::Status AudioPacketProcessor::FillHeader(TimeSeriesHeader* header) const { CHECK(header); header->set_sample_rate(sample_rate_); header->set_num_channels(num_channels_); - return mediapipe::OkStatus(); + return absl::OkStatus(); } int64 AudioPacketProcessor::MaybeCorrectPtsForRollover(int64 media_pts) { @@ -588,18 +585,18 @@ int64 AudioPacketProcessor::MaybeCorrectPtsForRollover(int64 media_pts) { AudioDecoder::AudioDecoder() { av_register_all(); } AudioDecoder::~AudioDecoder() { - mediapipe::Status status = Close(); + absl::Status status = Close(); if (!status.ok()) { LOG(ERROR) << "Encountered error while closing media file: " << status.message(); } } -mediapipe::Status AudioDecoder::Initialize( +absl::Status AudioDecoder::Initialize( const std::string& input_file, const mediapipe::AudioDecoderOptions options) { if (options.audio_stream().empty()) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( "At least one audio_stream must be defined in AudioDecoderOptions"); } std::map stream_index_to_audio_options_index; @@ -611,7 +608,7 @@ mediapipe::Status AudioDecoder::Initialize( } Cleanup> decoder_closer([this]() { - mediapipe::Status status = Close(); + absl::Status status = Close(); if (!status.ok()) { LOG(ERROR) << "Encountered error while closing media file: " << status.message(); @@ -620,12 +617,12 @@ mediapipe::Status AudioDecoder::Initialize( avformat_ctx_ = avformat_alloc_context(); if (avformat_open_input(&avformat_ctx_, input_file.c_str(), NULL, NULL) < 0) { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("Could not open file: ", input_file)); } if (avformat_find_stream_info(avformat_ctx_, NULL) < 0) { - return mediapipe::InvalidArgumentError(absl::StrCat( + return absl::InvalidArgumentError(absl::StrCat( "Could not find stream information of file: ", input_file)); } @@ -686,10 +683,10 @@ mediapipe::Status AudioDecoder::Initialize( is_first_packet_.resize(avformat_ctx_->nb_streams, true); decoder_closer.release(); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AudioDecoder::GetData(int* options_index, Packet* data) { +absl::Status AudioDecoder::GetData(int* options_index, Packet* data) { while (true) { for (auto& item : audio_processor_) { while (item.second && item.second->HasData()) { @@ -697,7 +694,7 @@ mediapipe::Status AudioDecoder::GetData(int* options_index, Packet* data) { is_first_packet_[item.first] = false; *options_index = FindOrDie(stream_id_to_audio_options_index_, item.first); - mediapipe::Status status = item.second->GetData(data); + absl::Status status = item.second->GetData(data); // Ignore packets which are out of the requested timestamp range. if (start_time_ != Timestamp::Unset()) { if (is_first_packet && data->Timestamp() > start_time_) { @@ -735,10 +732,10 @@ mediapipe::Status AudioDecoder::GetData(int* options_index, Packet* data) { } MP_RETURN_IF_ERROR(ProcessPacket()); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AudioDecoder::Close() { +absl::Status AudioDecoder::Close() { for (auto& item : audio_processor_) { if (item.second) { item.second->Close(); @@ -749,10 +746,10 @@ mediapipe::Status AudioDecoder::Close() { if (avformat_ctx_) { avformat_close_input(&avformat_ctx_); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AudioDecoder::FillAudioHeader( +absl::Status AudioDecoder::FillAudioHeader( const AudioStreamOptions& stream_option, TimeSeriesHeader* header) const { const std::unique_ptr* processor_ptr_ = FindOrNull( audio_processor_, @@ -760,10 +757,10 @@ mediapipe::Status AudioDecoder::FillAudioHeader( RET_CHECK(processor_ptr_ && *processor_ptr_) << "audio stream is not open."; MP_RETURN_IF_ERROR((*processor_ptr_)->FillHeader(header)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status AudioDecoder::ProcessPacket() { +absl::Status AudioDecoder::ProcessPacket() { std::unique_ptr av_packet(new AVPacket()); av_init_packet(av_packet.get()); av_packet->size = 0; @@ -785,14 +782,14 @@ mediapipe::Status AudioDecoder::ProcessPacket() { } else { VLOG(3) << "Ignoring packet for stream " << stream_id; } - return mediapipe::OkStatus(); + return absl::OkStatus(); } VLOG(1) << "Demuxing returned error (or EOF): " << AvErrorToString(ret); if (ret == AVERROR(EAGAIN)) { // EAGAIN is used to signify that the av_packet should be skipped // (maybe the demuxer is trying to re-sync). This definitely // occurs in the FLV and MpegT demuxers. - return mediapipe::OkStatus(); + return absl::OkStatus(); } // Unrecoverable demuxing error with details in avformat_ctx_->pb->error. @@ -819,8 +816,8 @@ mediapipe::Status AudioDecoder::ProcessPacket() { "Failed to read a frame: retval = $0 ($1)", ret, AvErrorToString(ret)); } -mediapipe::Status AudioDecoder::Flush() { - std::vector statuses; +absl::Status AudioDecoder::Flush() { + std::vector statuses; for (auto& item : audio_processor_) { if (item.second) { statuses.push_back(item.second->Flush()); diff --git a/mediapipe/util/audio_decoder.h b/mediapipe/util/audio_decoder.h index b0ae65e17..ae2ab3b33 100644 --- a/mediapipe/util/audio_decoder.h +++ b/mediapipe/util/audio_decoder.h @@ -50,10 +50,10 @@ class BasePacketProcessor { virtual ~BasePacketProcessor(); // Opens the codec. - virtual mediapipe::Status Open(int id, AVStream* stream) = 0; + virtual absl::Status Open(int id, AVStream* stream) = 0; // Processes a packet of data. Caller retains ownership of packet. - virtual mediapipe::Status ProcessPacket(AVPacket* packet) = 0; + virtual absl::Status ProcessPacket(AVPacket* packet) = 0; // Returns true if the processor has data immediately available // (without providing more data with ProcessPacket()). @@ -61,11 +61,11 @@ class BasePacketProcessor { // Fills packet with the next frame of data. Returns an empty packet // if there is nothing to return. - mediapipe::Status GetData(Packet* packet); + absl::Status GetData(Packet* packet); // Once no more AVPackets are available in the file, each stream must // be flushed to get any remaining frames which the codec is buffering. - mediapipe::Status Flush(); + absl::Status Flush(); // Closes the Processor, this does not close the file. You may not // call ProcessPacket() after calling Close(). Close() may be called @@ -74,11 +74,11 @@ class BasePacketProcessor { protected: // Decodes frames in a packet. - virtual mediapipe::Status Decode(const AVPacket& packet, - bool ignore_decode_failures); + virtual absl::Status Decode(const AVPacket& packet, + bool ignore_decode_failures); // Processes a decoded frame. - virtual mediapipe::Status ProcessDecodedFrame(const AVPacket& packet) = 0; + virtual absl::Status ProcessDecodedFrame(const AVPacket& packet) = 0; // Corrects the given PTS for MPEG PTS rollover. Assumed to be called with // the PTS of each frame in decode order. We detect a rollover whenever the @@ -132,17 +132,17 @@ class AudioPacketProcessor : public BasePacketProcessor { public: explicit AudioPacketProcessor(const AudioStreamOptions& options); - mediapipe::Status Open(int id, AVStream* stream) override; + absl::Status Open(int id, AVStream* stream) override; - mediapipe::Status ProcessPacket(AVPacket* packet) override; + absl::Status ProcessPacket(AVPacket* packet) override; - mediapipe::Status FillHeader(TimeSeriesHeader* header) const; + absl::Status FillHeader(TimeSeriesHeader* header) const; private: // Appends audio in buffer(s) to the output buffer (buffer_). - mediapipe::Status AddAudioDataToBuffer(const Timestamp output_timestamp, - uint8* const* raw_audio, - int buf_size_bytes); + absl::Status AddAudioDataToBuffer(const Timestamp output_timestamp, + uint8* const* raw_audio, + int buf_size_bytes); // Converts a number of samples into an approximate stream timestamp value. int64 SampleNumberToTimestamp(const int64 sample_number); @@ -154,11 +154,11 @@ class AudioPacketProcessor : public BasePacketProcessor { // Returns an error if the sample format in avformat_ctx_.sample_format // is not supported. - mediapipe::Status ValidateSampleFormat(); + absl::Status ValidateSampleFormat(); // Processes a decoded audio frame. audio_frame_ must have been filled // with the frame before calling this function. - mediapipe::Status ProcessDecodedFrame(const AVPacket& packet) override; + absl::Status ProcessDecodedFrame(const AVPacket& packet) override; // Corrects PTS for rollover if correction is enabled. int64 MaybeCorrectPtsForRollover(int64 media_pts); @@ -194,19 +194,19 @@ class AudioDecoder { AudioDecoder(); ~AudioDecoder(); - mediapipe::Status Initialize(const std::string& input_file, - const mediapipe::AudioDecoderOptions options); + absl::Status Initialize(const std::string& input_file, + const mediapipe::AudioDecoderOptions options); - mediapipe::Status GetData(int* options_index, Packet* data); + absl::Status GetData(int* options_index, Packet* data); - mediapipe::Status Close(); + absl::Status Close(); - mediapipe::Status FillAudioHeader(const AudioStreamOptions& stream_option, - TimeSeriesHeader* header) const; + absl::Status FillAudioHeader(const AudioStreamOptions& stream_option, + TimeSeriesHeader* header) const; private: - mediapipe::Status ProcessPacket(); - mediapipe::Status Flush(); + absl::Status ProcessPacket(); + absl::Status Flush(); std::map stream_id_to_audio_options_index_; std::map stream_index_to_stream_id_; diff --git a/mediapipe/util/cpu_util.cc b/mediapipe/util/cpu_util.cc index 99d1315dd..33e0dacde 100644 --- a/mediapipe/util/cpu_util.cc +++ b/mediapipe/util/cpu_util.cc @@ -38,18 +38,18 @@ namespace { constexpr uint32 kBufferLength = 64; -mediapipe::StatusOr GetFilePath(int cpu) { +absl::StatusOr GetFilePath(int cpu) { return absl::Substitute( "/sys/devices/system/cpu/cpu$0/cpufreq/cpuinfo_max_freq", cpu); } -mediapipe::StatusOr GetCpuMaxFrequency(int cpu) { +absl::StatusOr GetCpuMaxFrequency(int cpu) { auto path_or_status = GetFilePath(cpu); if (!path_or_status.ok()) { return path_or_status.status(); } std::ifstream file; - file.open(path_or_status.ValueOrDie()); + file.open(path_or_status.value()); if (file.is_open()) { char buffer[kBufferLength]; file.getline(buffer, kBufferLength); @@ -58,12 +58,12 @@ mediapipe::StatusOr GetCpuMaxFrequency(int cpu) { if (absl::SimpleAtoi(buffer, &frequency)) { return frequency; } else { - return mediapipe::InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("Invalid frequency: ", buffer)); } } else { - return mediapipe::NotFoundError( - absl::StrCat("Couldn't read ", path_or_status.ValueOrDie())); + return absl::NotFoundError( + absl::StrCat("Couldn't read ", path_or_status.value())); } } @@ -72,7 +72,7 @@ std::set InferLowerOrHigherCoreIds(bool lower) { for (int cpu = 0; cpu < NumCPUCores(); ++cpu) { auto freq_or_status = GetCpuMaxFrequency(cpu); if (freq_or_status.ok()) { - cpu_freq_pairs.push_back({cpu, freq_or_status.ValueOrDie()}); + cpu_freq_pairs.push_back({cpu, freq_or_status.value()}); } } if (cpu_freq_pairs.empty()) { diff --git a/mediapipe/util/header_util.cc b/mediapipe/util/header_util.cc index a4db9ea5f..ee9946acb 100644 --- a/mediapipe/util/header_util.cc +++ b/mediapipe/util/header_util.cc @@ -19,8 +19,8 @@ namespace mediapipe { -mediapipe::Status CopyInputHeadersToOutputs(const InputStreamSet& inputs, - const OutputStreamSet& outputs) { +absl::Status CopyInputHeadersToOutputs(const InputStreamSet& inputs, + const OutputStreamSet& outputs) { for (auto id = inputs.BeginId(); id < inputs.EndId(); ++id) { std::pair tag_index = inputs.TagAndIndexFromId(id); auto output_id = outputs.GetId(tag_index.first, tag_index.second); @@ -29,11 +29,11 @@ mediapipe::Status CopyInputHeadersToOutputs(const InputStreamSet& inputs, } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } -mediapipe::Status CopyInputHeadersToOutputs(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs) { +absl::Status CopyInputHeadersToOutputs(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs) { for (auto id = inputs.BeginId(); id < inputs.EndId(); ++id) { std::pair tag_index = inputs.TagAndIndexFromId(id); auto output_id = outputs->GetId(tag_index.first, tag_index.second); @@ -42,7 +42,7 @@ mediapipe::Status CopyInputHeadersToOutputs(const InputStreamShardSet& inputs, } } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/util/header_util.h b/mediapipe/util/header_util.h index a11dd78ef..abe1a35ec 100644 --- a/mediapipe/util/header_util.h +++ b/mediapipe/util/header_util.h @@ -22,11 +22,11 @@ namespace mediapipe { // Copies headers from |inputs| into |outputs| respectively. The size of // |inputs| and |outputs| must be equal. -mediapipe::Status CopyInputHeadersToOutputs(const InputStreamSet& inputs, - const OutputStreamSet& outputs); +absl::Status CopyInputHeadersToOutputs(const InputStreamSet& inputs, + const OutputStreamSet& outputs); -mediapipe::Status CopyInputHeadersToOutputs(const InputStreamShardSet& inputs, - OutputStreamShardSet* outputs); +absl::Status CopyInputHeadersToOutputs(const InputStreamShardSet& inputs, + OutputStreamShardSet* outputs); } // namespace mediapipe diff --git a/mediapipe/util/resource_util.cc b/mediapipe/util/resource_util.cc index 206667ebe..87659b7f0 100644 --- a/mediapipe/util/resource_util.cc +++ b/mediapipe/util/resource_util.cc @@ -27,14 +27,13 @@ ABSL_FLAG( namespace mediapipe { -mediapipe::StatusOr PathToResourceAsFile(const std::string& path) { - return mediapipe::file::JoinPath(FLAGS_resource_root_dir.CurrentValue(), +absl::StatusOr PathToResourceAsFile(const std::string& path) { + return mediapipe::file::JoinPath(absl::GetFlag(FLAGS_resource_root_dir), path); } -mediapipe::Status GetResourceContents(const std::string& path, - std::string* output, - bool read_as_binary) { +absl::Status GetResourceContents(const std::string& path, std::string* output, + bool read_as_binary) { return mediapipe::file::GetContents(path, output, read_as_binary); } diff --git a/mediapipe/util/resource_util.h b/mediapipe/util/resource_util.h index 92aabf49a..c870900e2 100644 --- a/mediapipe/util/resource_util.h +++ b/mediapipe/util/resource_util.h @@ -39,13 +39,12 @@ namespace mediapipe { // accepts file paths. Code that can access data as a stream or as a buffer // should read from an asset directly on Android; an API for this will be // provided later. TODO. -mediapipe::StatusOr PathToResourceAsFile(const std::string& path); +absl::StatusOr PathToResourceAsFile(const std::string& path); // Reads the entire contents of a resource. The search path is as in // PathToResourceAsFile. -mediapipe::Status GetResourceContents(const std::string& path, - std::string* output, - bool read_as_binary = true); +absl::Status GetResourceContents(const std::string& path, std::string* output, + bool read_as_binary = true); } // namespace mediapipe diff --git a/mediapipe/util/resource_util_android.cc b/mediapipe/util/resource_util_android.cc index 4c45fbe98..323c31c02 100644 --- a/mediapipe/util/resource_util_android.cc +++ b/mediapipe/util/resource_util_android.cc @@ -15,6 +15,7 @@ #include #include "absl/strings/match.h" +#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/singleton.h" #include "mediapipe/util/android/asset_manager_util.h" @@ -24,13 +25,13 @@ namespace mediapipe { namespace { -mediapipe::StatusOr PathToResourceAsFileInternal( +absl::StatusOr PathToResourceAsFileInternal( const std::string& path) { return Singleton::get()->CachedFileFromAsset(path); } } // namespace -mediapipe::StatusOr PathToResourceAsFile(const std::string& path) { +absl::StatusOr PathToResourceAsFile(const std::string& path) { // Return full path. if (absl::StartsWith(path, "/")) { return path; @@ -51,14 +52,24 @@ mediapipe::StatusOr PathToResourceAsFile(const std::string& path) { CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path. auto base_name = path.substr(last_slash_idx + 1); auto status_or_path = PathToResourceAsFileInternal(base_name); - if (status_or_path.ok()) LOG(INFO) << "Successfully loaded: " << base_name; - return status_or_path; + if (status_or_path.ok()) { + LOG(INFO) << "Successfully loaded: " << base_name; + return status_or_path; + } } + + // Try the test environment. + absl::string_view workspace = "mediapipe"; + auto test_path = file::JoinPath(std::getenv("TEST_SRCDIR"), workspace, path); + if (file::Exists(test_path).ok()) { + return test_path; + } + + return path; } -mediapipe::Status GetResourceContents(const std::string& path, - std::string* output, - bool read_as_binary) { +absl::Status GetResourceContents(const std::string& path, std::string* output, + bool read_as_binary) { if (!read_as_binary) { LOG(WARNING) << "Setting \"read_as_binary\" to false is a no-op on Android."; @@ -70,12 +81,12 @@ mediapipe::Status GetResourceContents(const std::string& path, if (absl::StartsWith(path, "content://")) { MP_RETURN_IF_ERROR( Singleton::get()->ReadContentUri(path, output)); - return mediapipe::OkStatus(); + return absl::OkStatus(); } RET_CHECK(Singleton::get()->ReadFile(path, output)) << "could not read asset: " << path; - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/util/resource_util_apple.cc b/mediapipe/util/resource_util_apple.cc index e5dcc12e1..1750c67e4 100644 --- a/mediapipe/util/resource_util_apple.cc +++ b/mediapipe/util/resource_util_apple.cc @@ -18,13 +18,14 @@ #include #include "absl/strings/match.h" +#include "mediapipe/framework/port/file_helpers.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/resource_util.h" namespace mediapipe { namespace { -mediapipe::StatusOr PathToResourceAsFileInternal( +absl::StatusOr PathToResourceAsFileInternal( const std::string& path) { NSString* ns_path = [NSString stringWithUTF8String:path.c_str()]; Class mediapipeGraphClass = NSClassFromString(@"MPPGraph"); @@ -39,7 +40,7 @@ mediapipe::StatusOr PathToResourceAsFileInternal( } } // namespace -mediapipe::StatusOr PathToResourceAsFile(const std::string& path) { +absl::StatusOr PathToResourceAsFile(const std::string& path) { // Return full path. if (absl::StartsWith(path, "/")) { return path; @@ -60,14 +61,30 @@ mediapipe::StatusOr PathToResourceAsFile(const std::string& path) { CHECK_NE(last_slash_idx, std::string::npos); // Make sure it's a path. auto base_name = path.substr(last_slash_idx + 1); auto status_or_path = PathToResourceAsFileInternal(base_name); - if (status_or_path.ok()) LOG(INFO) << "Successfully loaded: " << base_name; - return status_or_path; + if (status_or_path.ok()) { + LOG(INFO) << "Successfully loaded: " << base_name; + return status_or_path; + } } + + // Try the test environment. + { + absl::string_view workspace = "mediapipe"; + auto test_path = + file::JoinPath(std::getenv("TEST_SRCDIR"), workspace, path); + if ([[NSFileManager defaultManager] + fileExistsAtPath:[NSString + stringWithUTF8String:test_path.c_str()]]) { + LOG(INFO) << "Successfully loaded: " << test_path; + return test_path; + } + } + + return path; } -mediapipe::Status GetResourceContents(const std::string& path, - std::string* output, - bool read_as_binary) { +absl::Status GetResourceContents(const std::string& path, std::string* output, + bool read_as_binary) { if (!read_as_binary) { LOG(WARNING) << "Setting \"read_as_binary\" to false is a no-op on ios."; } @@ -77,7 +94,7 @@ mediapipe::Status GetResourceContents(const std::string& path, std::stringstream buffer; buffer << input_file.rdbuf(); buffer.str().swap(*output); - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/util/sequence/README.md b/mediapipe/util/sequence/README.md index 4af003092..e3e017290 100644 --- a/mediapipe/util/sequence/README.md +++ b/mediapipe/util/sequence/README.md @@ -548,3 +548,21 @@ where the STFT is computed over audio at some other rate. |`PREFIX/feature/packet_rate`|context float|`set_feature_packet_rate` / `SetFeaturePacketRate`|The number of packets per second.| |`PREFIX/feature/audio_sample_rate`|context float|`set_feature_audio_sample_rate` / `SetFeatureAudioSampleRate`|The sample rate of the original audio for derived features.| +### Keys related to text, captions, and ASR +Text features may be timed with the media such as captions or automatic +speech recognition results, or may be descriptions. This collection of keys +should be used for many, very short text features. For a few, longer segments +please use the Segment keys in the context as described above. As always, +prefixes can be used to store different types of text such as automated and +ground truth transcripts. + +| key | type | python call / c++ call | description | +|-----|------|------------------------|-------------| +|`text/language`|context bytes|`set_text_langage` / `SetTextLanguage`|The language for the corresponding text.| +|`text/context/content`|context bytes|`set_text_context_content` / `SetTextContextContent`|Storage for large blocks of text in the context.| +|`text/content`|feature list bytes|`add_text_content` / `AddTextContent`|One (or a few) text tokens that occur at one timestamp.| +|`text/timestamp`|feature list int|`add_text_timestamp` / `AddTextTimestamp`|When a text token occurs in microseconds.| +|`text/duration`|feature list int|`add_text_duration` / `SetTextDuration`|The duration in microseconds for the corresponding text tokens.| +|`text/confidence`|feature list float|`add_text_confidence` / `AddTextConfidence`|How likely the text is correct.| +|`text/embedding`|feautre list float list|`add_text_embedding` / `AddTextEmbedding`|A floating point vector for the corresponding text token.| +|`text/token/id`|feature list int|`add_text_token_id` / `AddTextTokenId`|An integer id for the corresponding text token.| diff --git a/mediapipe/util/sequence/media_sequence.cc b/mediapipe/util/sequence/media_sequence.cc index 88f73771e..4e705d114 100644 --- a/mediapipe/util/sequence/media_sequence.cc +++ b/mediapipe/util/sequence/media_sequence.cc @@ -85,10 +85,10 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { // "segment/start/index" and "segment/end/index" by finding the closest // timestamps in the "image/timestamp" FeatureList if image timestamps are // present. -::mediapipe::Status ReconcileAnnotationIndicesByImageTimestamps( +absl::Status ReconcileAnnotationIndicesByImageTimestamps( tensorflow::SequenceExample* sequence) { if (GetImageTimestampSize(*sequence) == 0) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } int index; @@ -118,15 +118,15 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { } SetSegmentEndIndex(end_indices, sequence); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Sets the values of "image/format", "image/channels", "image/height", // "image/width", and "image/frame_rate" based image metadata and timestamps. -::mediapipe::Status ReconcileMetadataImages( - const std::string& prefix, tensorflow::SequenceExample* sequence) { +absl::Status ReconcileMetadataImages(const std::string& prefix, + tensorflow::SequenceExample* sequence) { if (GetImageEncodedSize(prefix, *sequence) == 0) { - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } std::string format; int height, width, channels; @@ -144,7 +144,7 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { GetImageTimestampAt(prefix, *sequence, 1)); SetImageFrameRate(prefix, rate, sequence); } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Sets the values of "feature/${TAG}/dimensions", and @@ -152,7 +152,7 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { // dimensions are already present as a context feature, this method verifies // the number of elements in the feature. Otherwise, it will write the // dimensions as a 1D vector with the number of elements. -::mediapipe::Status ReconcileMetadataFeatureFloats( +absl::Status ReconcileMetadataFeatureFloats( tensorflow::SequenceExample* sequence) { // Loop through all keys and see if they contain "/feature/floats" // If so, check dimensions and set rate. @@ -182,7 +182,7 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { } } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } // Go through all bounding box annotations and move the annotation to the @@ -190,7 +190,7 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { // nothing. If two or more annotations are closest to the same frame, then only // the closest annotation is saved. This matches the behavior of downsampling // images streams in time. -::mediapipe::Status ReconcileMetadataBoxAnnotations( +absl::Status ReconcileMetadataBoxAnnotations( const std::string& prefix, tensorflow::SequenceExample* sequence) { int num_bboxes = GetBBoxTimestampSize(prefix, *sequence); int num_frames = GetImageTimestampSize(*sequence); @@ -355,10 +355,10 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { } } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } -::mediapipe::Status ReconcileMetadataRegionAnnotations( +absl::Status ReconcileMetadataRegionAnnotations( tensorflow::SequenceExample* sequence) { // Copy keys for fixed iteration order while updating feature_lists. std::vector key_ptrs; @@ -376,7 +376,7 @@ float TimestampsToRate(int64 first_timestamp, int64 second_timestamp) { RET_CHECK_OK(ReconcileMetadataBoxAnnotations(prefix, sequence)); } } - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace @@ -393,6 +393,7 @@ std::vector<::mediapipe::Location> GetBBoxAt( const auto& ymins = GetBBoxYMinAt(prefix, sequence, index); const auto& xmaxs = GetBBoxXMaxAt(prefix, sequence, index); const auto& ymaxs = GetBBoxYMaxAt(prefix, sequence, index); + bboxes.reserve(xmins.size()); for (int i = 0; i < xmins.size(); ++i) { bboxes.push_back(::mediapipe::Location::CreateRelativeBBoxLocation( xmins[i], ymins[i], xmaxs[i] - xmins[i], ymaxs[i] - ymins[i])); @@ -537,9 +538,9 @@ void AddAudioAsFeature(const std::string& prefix, .Swap(value_list); } -::mediapipe::Status ReconcileMetadata(bool reconcile_bbox_annotations, - bool reconcile_region_annotations, - tensorflow::SequenceExample* sequence) { +absl::Status ReconcileMetadata(bool reconcile_bbox_annotations, + bool reconcile_region_annotations, + tensorflow::SequenceExample* sequence) { RET_CHECK_OK(ReconcileAnnotationIndicesByImageTimestamps(sequence)); RET_CHECK_OK(ReconcileMetadataImages("", sequence)); RET_CHECK_OK(ReconcileMetadataImages(kForwardFlowPrefix, sequence)); @@ -553,7 +554,7 @@ void AddAudioAsFeature(const std::string& prefix, RET_CHECK_OK(ReconcileMetadataRegionAnnotations(sequence)); } // audio is always reconciled in the framework. - return ::mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediasequence diff --git a/mediapipe/util/sequence/media_sequence.h b/mediapipe/util/sequence/media_sequence.h index 81e18656e..c8be33e37 100644 --- a/mediapipe/util/sequence/media_sequence.h +++ b/mediapipe/util/sequence/media_sequence.h @@ -614,6 +614,36 @@ PREFIXED_IMAGE(ForwardFlow, kForwardFlowPrefix); PREFIXED_IMAGE(ClassSegmentation, kClassSegmentationPrefix); PREFIXED_IMAGE(InstanceSegmentation, kInstanceSegmentationPrefix); +// ************************** TEXT **************************************** +// Context keys: +// Which language text tokens are likely to be in. +const char kTextLanguageKey[] = "text/language"; +// A large block of text that applies to the media. +const char kTextContextContentKey[] = "text/context/content"; + +// Feature list keys: +// The text contents for a given time. +const char kTextContentKey[] = "text/content"; +// The start time for the text becoming relevant. +const char kTextTimestampKey[] = "text/timestamp"; +// The duration where the text is relevant. +const char kTextDurationKey[] = "text/duration"; +// The confidence that this is the correct text. +const char kTextConfidenceKey[] = "text/confidence"; +// A floating point embedding corresponding to the text. +const char kTextEmbeddingKey[] = "text/embedding"; +// An integer id corresponding to the text. +const char kTextTokenIdKey[] = "text/token/id"; + +BYTES_CONTEXT_FEATURE(TextLanguage, kTextLanguageKey); +BYTES_CONTEXT_FEATURE(TextContextContent, kTextContextContentKey); +BYTES_FEATURE_LIST(TextContent, kTextContentKey); +INT64_FEATURE_LIST(TextTimestamp, kTextTimestampKey); +INT64_FEATURE_LIST(TextDuration, kTextDurationKey); +FLOAT_FEATURE_LIST(TextConfidence, kTextConfidenceKey); +VECTOR_FLOAT_FEATURE_LIST(TextEmbedding, kTextEmbeddingKey); +INT64_FEATURE_LIST(TextTokenId, kTextTokenIdKey); + // *********************** FEATURES ************************************* // Context keys: // The dimensions of the feature. @@ -691,9 +721,9 @@ PREFIXED_FLOAT_CONTEXT_FEATURE(FeatureAudioSampleRate, // code verifies the number of elements matches the dimensions. // Reconciling bounding box annotations is optional because will remove // annotations if the sequence rate is lower than the annotation rate. -::mediapipe::Status ReconcileMetadata(bool reconcile_bbox_annotations, - bool reconcile_region_annotations, - tensorflow::SequenceExample* sequence); +absl::Status ReconcileMetadata(bool reconcile_bbox_annotations, + bool reconcile_region_annotations, + tensorflow::SequenceExample* sequence); } // namespace mediasequence } // namespace mediapipe diff --git a/mediapipe/util/sequence/media_sequence.py b/mediapipe/util/sequence/media_sequence.py index 034fd8937..abee3b5e6 100644 --- a/mediapipe/util/sequence/media_sequence.py +++ b/mediapipe/util/sequence/media_sequence.py @@ -149,13 +149,12 @@ identify common storage patterns (e.g. storing an image along with the height and width) under different names (e.g. storing a left and right image in a stereo pair.) An example creating functions such as add_left_image_encoded that adds a string under the key "LEFT/image/encoded" - add_left_image_encoded = functools.partial(add_image_encoded, prefix="LEFT") + add_left_image_encoded = msu.function_with_default(add_image_encoded, "LEFT") """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools import numpy as np from mediapipe.util.sequence import media_sequence_util msu = media_sequence_util @@ -463,39 +462,39 @@ def _create_region_with_prefix(name, prefix): # pylint: enable=undefined-variable msu.add_functions_to_module({ "get_" + name + "_at": - functools.partial(get_prefixed_bbox_at, prefix=prefix), + msu.function_with_default(get_prefixed_bbox_at, prefix), "add_" + name: - functools.partial(add_prefixed_bbox, prefix=prefix), + msu.function_with_default(add_prefixed_bbox, prefix), "get_" + name + "_size": - functools.partial(get_prefixed_bbox_size, prefix=prefix), + msu.function_with_default(get_prefixed_bbox_size, prefix), "has_" + name: - functools.partial(has_prefixed_bbox, prefix=prefix), + msu.function_with_default(has_prefixed_bbox, prefix), "clear_" + name: - functools.partial(clear_prefixed_bbox, prefix=prefix), + msu.function_with_default(clear_prefixed_bbox, prefix), }, module_dict=globals()) msu.add_functions_to_module({ "get_" + name + "_point_at": - functools.partial(get_prefixed_point_at, prefix=prefix), + msu.function_with_default(get_prefixed_point_at, prefix), "add_" + name + "_point": - functools.partial(add_prefixed_point, prefix=prefix), + msu.function_with_default(add_prefixed_point, prefix), "get_" + name + "_point_size": - functools.partial(get_prefixed_point_size, prefix=prefix), + msu.function_with_default(get_prefixed_point_size, prefix), "has_" + name + "_point": - functools.partial(has_prefixed_point, prefix=prefix), + msu.function_with_default(has_prefixed_point, prefix), "clear_" + name + "_point": - functools.partial(clear_prefixed_point, prefix=prefix), + msu.function_with_default(clear_prefixed_point, prefix), }, module_dict=globals()) msu.add_functions_to_module({ "get_" + name + "_3d_point_at": - functools.partial(get_prefixed_3d_point_at, prefix=prefix), + msu.function_with_default(get_prefixed_3d_point_at, prefix), "add_" + name + "_3d_point": - functools.partial(add_prefixed_3d_point, prefix=prefix), + msu.function_with_default(add_prefixed_3d_point, prefix), "get_" + name + "_3d_point_size": - functools.partial(get_prefixed_3d_point_size, prefix=prefix), + msu.function_with_default(get_prefixed_3d_point_size, prefix), "has_" + name + "_3d_point": - functools.partial(has_prefixed_3d_point, prefix=prefix), + msu.function_with_default(has_prefixed_3d_point, prefix), "clear_" + name + "_3d_point": - functools.partial(clear_prefixed_3d_point, prefix=prefix), + msu.function_with_default(clear_prefixed_3d_point, prefix), }, module_dict=globals()) @@ -580,6 +579,42 @@ _create_image_with_prefix("forward_flow", FORWARD_FLOW_PREFIX) _create_image_with_prefix("class_segmentation", CLASS_SEGMENTATION_PREFIX) _create_image_with_prefix("instance_segmentation", INSTANCE_SEGMENTATION_PREFIX) +################################## TEXT ################################# +# Which language text tokens are likely to be in. +TEXT_LANGUAGE_KEY = "text/language" +# A large block of text that applies to the media. +TEXT_CONTEXT_CONTENT_KEY = "text/context/content" + +# The text contents for a given time. +TEXT_CONTENT_KEY = "text/content" +# The start time for the text becoming relevant. +TEXT_TIMESTAMP_KEY = "text/timestamp" +# The duration where the text is relevant. +TEXT_DURATION_KEY = "text/duration" +# The confidence that this is the correct text. +TEXT_CONFIDENCE_KEY = "text/confidence" +# A floating point embedding corresponding to the text. +TEXT_EMBEDDING_KEY = "text/embedding" +# An integer id corresponding to the text. +TEXT_TOKEN_ID_KEY = "text/token/id" + +msu.create_bytes_context_feature( + "text_language", TEXT_LANGUAGE_KEY, module_dict=globals()) +msu.create_bytes_context_feature( + "text_context_content", TEXT_CONTEXT_CONTENT_KEY, module_dict=globals()) +msu.create_bytes_feature_list( + "text_content", TEXT_CONTENT_KEY, module_dict=globals()) +msu.create_int_feature_list( + "text_timestamp", TEXT_TIMESTAMP_KEY, module_dict=globals()) +msu.create_int_feature_list( + "text_duration", TEXT_DURATION_KEY, module_dict=globals()) +msu.create_float_feature_list( + "text_confidence", TEXT_CONFIDENCE_KEY, module_dict=globals()) +msu.create_float_list_feature_list( + "text_embedding", TEXT_EMBEDDING_KEY, module_dict=globals()) +msu.create_int_feature_list( + "text_token_id", TEXT_TOKEN_ID_KEY, module_dict=globals()) + ################################## FEATURES ################################# # The dimensions of the feature. FEATURE_DIMENSIONS_KEY = "feature/dimensions" diff --git a/mediapipe/util/sequence/media_sequence_test.cc b/mediapipe/util/sequence/media_sequence_test.cc index d22128747..c4e482600 100644 --- a/mediapipe/util/sequence/media_sequence_test.cc +++ b/mediapipe/util/sequence/media_sequence_test.cc @@ -640,6 +640,90 @@ TEST(MediaSequenceTest, RoundTripOpticalFlowTimestamp) { ASSERT_EQ(GetForwardFlowTimestampSize(sequence), 0); } +TEST(MediaSequenceTest, RoundTripTextLanguage) { + tensorflow::SequenceExample sequence; + ASSERT_FALSE(HasTextLanguage(sequence)); + SetTextLanguage("test", &sequence); + ASSERT_TRUE(HasTextLanguage(sequence)); + ASSERT_EQ("test", GetTextLanguage(sequence)); + ClearTextLanguage(&sequence); + ASSERT_FALSE(HasTextLanguage(sequence)); +} + +TEST(MediaSequenceTest, RoundTripTextContextContent) { + tensorflow::SequenceExample sequence; + ASSERT_FALSE(HasTextContextContent(sequence)); + SetTextContextContent("test", &sequence); + ASSERT_TRUE(HasTextContextContent(sequence)); + ASSERT_EQ("test", GetTextContextContent(sequence)); + ClearTextContextContent(&sequence); + ASSERT_FALSE(HasTextContextContent(sequence)); +} + +TEST(MediaSequenceTest, RoundTripTextContent) { + tensorflow::SequenceExample sequence; + std::vector text = {"test", "again"}; + for (int i = 0; i < text.size(); ++i) { + AddTextContent(text[i], &sequence); + ASSERT_EQ(GetTextContentSize(sequence), i + 1); + ASSERT_EQ(GetTextContentAt(sequence, i), text[i]); + } + ClearTextContent(&sequence); + ASSERT_EQ(GetTextContentSize(sequence), 0); +} + +TEST(MediaSequenceTest, RoundTripTextDuration) { + tensorflow::SequenceExample sequence; + std::vector timestamps = {4, 7}; + for (int i = 0; i < timestamps.size(); ++i) { + AddTextTimestamp(timestamps[i], &sequence); + ASSERT_EQ(GetTextTimestampSize(sequence), i + 1); + ASSERT_EQ(GetTextTimestampAt(sequence, i), timestamps[i]); + } + ClearTextTimestamp(&sequence); + ASSERT_EQ(GetTextTimestampSize(sequence), 0); +} + +TEST(MediaSequenceTest, RoundTripTextConfidence) { + tensorflow::SequenceExample sequence; + std::vector confidence = {0.25, 1.0}; + for (int i = 0; i < confidence.size(); ++i) { + AddTextConfidence(confidence[i], &sequence); + ASSERT_EQ(GetTextConfidenceSize(sequence), i + 1); + ASSERT_EQ(GetTextConfidenceAt(sequence, i), confidence[i]); + } + ClearTextConfidence(&sequence); + ASSERT_EQ(GetTextConfidenceSize(sequence), 0); +} + +TEST(MediaSequenceTest, RoundTripTextEmbedding) { + tensorflow::SequenceExample sequence; + int num_features = 3; + int num_floats_in_feature = 4; + for (int i = 0; i < num_features; ++i) { + std::vector vf(num_floats_in_feature, 2 << i); + AddTextEmbedding(vf, &sequence); + ASSERT_EQ(GetTextEmbeddingSize(sequence), i + 1); + for (float value : GetTextEmbeddingAt(sequence, i)) { + ASSERT_EQ(value, 2 << i); + } + } + ClearTextEmbedding(&sequence); + ASSERT_EQ(GetTextEmbeddingSize(sequence), 0); +} + +TEST(MediaSequenceTest, RoundTripTextTokenId) { + tensorflow::SequenceExample sequence; + std::vector ids = {4, 7}; + for (int i = 0; i < ids.size(); ++i) { + AddTextTokenId(ids[i], &sequence); + ASSERT_EQ(GetTextTokenIdSize(sequence), i + 1); + ASSERT_EQ(GetTextTokenIdAt(sequence, i), ids[i]); + } + ClearTextTokenId(&sequence); + ASSERT_EQ(GetTextTokenIdSize(sequence), 0); +} + TEST(MediaSequenceTest, ReconcileMetadataOnEmptySequence) { tensorflow::SequenceExample sequence; MP_ASSERT_OK(ReconcileMetadata(true, false, &sequence)); diff --git a/mediapipe/util/sequence/media_sequence_test.py b/mediapipe/util/sequence/media_sequence_test.py index 0c6ff9be7..9a282ed2d 100644 --- a/mediapipe/util/sequence/media_sequence_test.py +++ b/mediapipe/util/sequence/media_sequence_test.py @@ -123,6 +123,14 @@ class MediaSequenceTest(tf.test.TestCase): ms.add_bbox_embedding_floats((0.47, 0.49), example) ms.add_bbox_embedding_encoded((b"text", b"stings"), example) ms.add_bbox_embedding_confidence((0.47, 0.49), example) + ms.set_text_language(b"test", example) + ms.set_text_context_content(b"text", example) + ms.add_text_content(b"one", example) + ms.add_text_timestamp(47, example) + ms.add_text_confidence(0.47, example) + ms.add_text_duration(47, example) + ms.add_text_token_id(47, example) + ms.add_text_embedding((0.47, 0.49), example) def test_bbox_round_trip(self): example = tf.train.SequenceExample() @@ -149,6 +157,18 @@ class MediaSequenceTest(tf.test.TestCase): ms.clear_bbox_point(example) self.assertEqual(0, ms.get_bbox_point_size(example)) + def test_prefixed_point_round_trip(self): + example = tf.train.SequenceExample() + points = np.array([[0.1, 0.2], + [0.5, 0.6]]) + ms.add_bbox_point(points, example, "test") + ms.add_bbox_point(points, example, "test") + self.assertEqual(2, ms.get_bbox_point_size(example, "test")) + self.assertAllClose(points, ms.get_bbox_point_at(0, example, "test")) + self.assertTrue(ms.has_bbox_point(example, "test")) + ms.clear_bbox_point(example, "test") + self.assertEqual(0, ms.get_bbox_point_size(example, "test")) + def test_3d_point_round_trip(self): example = tf.train.SequenceExample() points = np.array([[0.1, 0.2, 0.3], diff --git a/mediapipe/util/sequence/media_sequence_util.py b/mediapipe/util/sequence/media_sequence_util.py index c1cf2115f..adf71f62c 100644 --- a/mediapipe/util/sequence/media_sequence_util.py +++ b/mediapipe/util/sequence/media_sequence_util.py @@ -22,10 +22,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - +import types import tensorflow.compat.v1 as tf +def function_with_default(f, default): + """Creates a new function with a default last parameter.""" + return types.FunctionType(f.__code__, f.__globals__, f.__name__, + (default,), f.__closure__) + + def add_functions_to_module(function_dict, module_dict=None): """Adds functions to another module. diff --git a/mediapipe/util/tensor_to_detection.cc b/mediapipe/util/tensor_to_detection.cc index 57f03e3b6..4326067bc 100644 --- a/mediapipe/util/tensor_to_detection.cc +++ b/mediapipe/util/tensor_to_detection.cc @@ -201,7 +201,7 @@ Status TensorsToDetections(const ::tensorflow::Tensor& num_detections, } detections->emplace_back(detection); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } } // namespace mediapipe diff --git a/mediapipe/util/tensor_to_detection.h b/mediapipe/util/tensor_to_detection.h index 767e639b2..3b87b8981 100644 --- a/mediapipe/util/tensor_to_detection.h +++ b/mediapipe/util/tensor_to_detection.h @@ -28,21 +28,23 @@ Detection TensorToDetection( const ::tensorflow::TTypes::Vec& box, float score, const ::absl::variant& class_label); -mediapipe::Status TensorsToDetections( - const ::tensorflow::Tensor& num_detections, - const ::tensorflow::Tensor& boxes, const ::tensorflow::Tensor& scores, - const ::tensorflow::Tensor& classes, - const std::map& label_map, - std::vector* detections); +absl::Status TensorsToDetections(const ::tensorflow::Tensor& num_detections, + const ::tensorflow::Tensor& boxes, + const ::tensorflow::Tensor& scores, + const ::tensorflow::Tensor& classes, + const std::map& label_map, + std::vector* detections); // Use this version if keypoints or masks are available. -mediapipe::Status TensorsToDetections( - const ::tensorflow::Tensor& num_detections, - const ::tensorflow::Tensor& boxes, const ::tensorflow::Tensor& scores, - const ::tensorflow::Tensor& classes, const ::tensorflow::Tensor& keypoints, - const ::tensorflow::Tensor& masks, float mask_threshold, - const std::map& label_map, - std::vector* detections); +absl::Status TensorsToDetections(const ::tensorflow::Tensor& num_detections, + const ::tensorflow::Tensor& boxes, + const ::tensorflow::Tensor& scores, + const ::tensorflow::Tensor& classes, + const ::tensorflow::Tensor& keypoints, + const ::tensorflow::Tensor& masks, + float mask_threshold, + const std::map& label_map, + std::vector* detections); } // namespace mediapipe #endif // MEDIAPIPE_TENSORFLOW_UTIL_TENSOR_TO_DETECTION_H_ diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index 6e33cd181..4d66bbe21 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -104,7 +104,7 @@ cc_library( srcs = ["tflite_model_loader.cc"], hdrs = ["tflite_model_loader.h"], deps = [ - "//mediapipe/framework:packet", + "//mediapipe/framework/api2:packet", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", diff --git a/mediapipe/util/tflite/tflite_gpu_runner.cc b/mediapipe/util/tflite/tflite_gpu_runner.cc index a4d97acbf..c77c09524 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.cc +++ b/mediapipe/util/tflite/tflite_gpu_runner.cc @@ -83,7 +83,7 @@ ObjectDef GetSSBOObjectDef(int channels) { } // namespace -mediapipe::Status TFLiteGPURunner::InitializeWithModel( +absl::Status TFLiteGPURunner::InitializeWithModel( const tflite::FlatBufferModel& flatbuffer, const tflite::OpResolver& op_resolver) { // GraphFloat32 is created twice because, when OpenCL and OpenGL backends are @@ -111,23 +111,23 @@ mediapipe::Status TFLiteGPURunner::InitializeWithModel( return absl::OkStatus(); } -mediapipe::StatusOr TFLiteGPURunner::GetInputElements(int id) { +absl::StatusOr TFLiteGPURunner::GetInputElements(int id) { if (id >= input_shapes_.size()) { - return mediapipe::InternalError("Wrong input tensor id."); + return absl::InternalError("Wrong input tensor id."); } else { return input_shapes_[id].DimensionsProduct(); } } -mediapipe::StatusOr TFLiteGPURunner::GetOutputElements(int id) { +absl::StatusOr TFLiteGPURunner::GetOutputElements(int id) { if (id >= output_shapes_.size()) { - return mediapipe::InternalError("Wrong output tensor id."); + return absl::InternalError("Wrong output tensor id."); } else { return output_shapes_[id].DimensionsProduct(); } } -mediapipe::Status TFLiteGPURunner::Build() { +absl::Status TFLiteGPURunner::Build() { // 1. Prepare inference builder. std::unique_ptr builder; // By default, we try CL first & fall back to GL if that fails. @@ -164,23 +164,23 @@ mediapipe::Status TFLiteGPURunner::Build() { return builder->Build(&runner_); } -mediapipe::Status TFLiteGPURunner::BindSSBOToInputTensor(GLuint ssbo_id, - int input_id) { +absl::Status TFLiteGPURunner::BindSSBOToInputTensor(GLuint ssbo_id, + int input_id) { OpenGlBuffer buffer; buffer.id = ssbo_id; return runner_->SetInputObject(input_id, std::move(buffer)); } -mediapipe::Status TFLiteGPURunner::BindSSBOToOutputTensor(GLuint ssbo_id, - int output_id) { +absl::Status TFLiteGPURunner::BindSSBOToOutputTensor(GLuint ssbo_id, + int output_id) { OpenGlBuffer buffer; buffer.id = ssbo_id; return runner_->SetOutputObject(output_id, std::move(buffer)); } -mediapipe::Status TFLiteGPURunner::Invoke() { return runner_->Run(); } +absl::Status TFLiteGPURunner::Invoke() { return runner_->Run(); } -mediapipe::Status TFLiteGPURunner::InitializeOpenGL( +absl::Status TFLiteGPURunner::InitializeOpenGL( std::unique_ptr* builder) { gl::InferenceEnvironmentOptions env_options; gl::InferenceEnvironmentProperties properties; diff --git a/mediapipe/util/tflite/tflite_gpu_runner.h b/mediapipe/util/tflite/tflite_gpu_runner.h index 389e33b94..d88556e55 100644 --- a/mediapipe/util/tflite/tflite_gpu_runner.h +++ b/mediapipe/util/tflite/tflite_gpu_runner.h @@ -53,24 +53,23 @@ class TFLiteGPURunner { explicit TFLiteGPURunner(const InferenceOptions& options) : options_(options) {} - mediapipe::Status InitializeWithModel( - const tflite::FlatBufferModel& flatbuffer, - const tflite::OpResolver& op_resolver); + absl::Status InitializeWithModel(const tflite::FlatBufferModel& flatbuffer, + const tflite::OpResolver& op_resolver); void ForceOpenGL() { opengl_is_forced_ = true; } void ForceOpenCL() { opencl_is_forced_ = true; } - mediapipe::Status BindSSBOToInputTensor(GLuint ssbo_id, int input_id); - mediapipe::Status BindSSBOToOutputTensor(GLuint ssbo_id, int output_id); + absl::Status BindSSBOToInputTensor(GLuint ssbo_id, int input_id); + absl::Status BindSSBOToOutputTensor(GLuint ssbo_id, int output_id); int inputs_size() const { return input_shapes_.size(); } int outputs_size() const { return output_shapes_.size(); } - mediapipe::StatusOr GetInputElements(int id); - mediapipe::StatusOr GetOutputElements(int id); + absl::StatusOr GetInputElements(int id); + absl::StatusOr GetOutputElements(int id); - mediapipe::Status Build(); - mediapipe::Status Invoke(); + absl::Status Build(); + absl::Status Invoke(); std::vector GetInputShapes() { return input_shapes_; } std::vector GetOutputShapes() { return output_shapes_; } @@ -93,10 +92,8 @@ class TFLiteGPURunner { #endif private: - mediapipe::Status InitializeOpenGL( - std::unique_ptr* builder); - mediapipe::Status InitializeOpenCL( - std::unique_ptr* builder); + absl::Status InitializeOpenGL(std::unique_ptr* builder); + absl::Status InitializeOpenCL(std::unique_ptr* builder); InferenceOptions options_; std::unique_ptr gl_environment_; diff --git a/mediapipe/util/tflite/tflite_model_loader.cc b/mediapipe/util/tflite/tflite_model_loader.cc index 941d08ef4..7a27b1ea3 100644 --- a/mediapipe/util/tflite/tflite_model_loader.cc +++ b/mediapipe/util/tflite/tflite_model_loader.cc @@ -19,7 +19,7 @@ namespace mediapipe { -mediapipe::StatusOr TfLiteModelLoader::LoadFromPath( +absl::StatusOr> TfLiteModelLoader::LoadFromPath( const std::string& path) { std::string model_path = path; @@ -27,8 +27,8 @@ mediapipe::StatusOr TfLiteModelLoader::LoadFromPath( auto model = tflite::FlatBufferModel::BuildFromFile(model_path.c_str()); RET_CHECK(model) << "Failed to load model from path " << model_path; - return MakePacket(TfLiteModelPtr( - model.release(), [](tflite::FlatBufferModel* model) { delete model; })); + return api2::MakePacket( + model.release(), [](tflite::FlatBufferModel* model) { delete model; }); } } // namespace mediapipe diff --git a/mediapipe/util/tflite/tflite_model_loader.h b/mediapipe/util/tflite/tflite_model_loader.h index 5e759649e..c1baf128a 100644 --- a/mediapipe/util/tflite/tflite_model_loader.h +++ b/mediapipe/util/tflite/tflite_model_loader.h @@ -15,7 +15,7 @@ #ifndef MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_ #define MEDIAPIPE_UTIL_TFLITE_TFLITE_MODEL_LOADER_H_ -#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/statusor.h" #include "tensorflow/lite/model.h" @@ -30,7 +30,8 @@ class TfLiteModelLoader { public: // Returns a Packet containing a TfLiteModelPtr, pointing to a model loaded // from the specified file path. - static mediapipe::StatusOr LoadFromPath(const std::string& path); + static absl::StatusOr> LoadFromPath( + const std::string& path); }; } // namespace mediapipe diff --git a/mediapipe/util/time_series_test_util.h b/mediapipe/util/time_series_test_util.h index 2ac08206a..bf1a0a461 100644 --- a/mediapipe/util/time_series_test_util.h +++ b/mediapipe/util/time_series_test_util.h @@ -308,7 +308,7 @@ class TimeSeriesCalculatorTest : public ::testing::Test { AppendInputPacket(payload, Timestamp(timestamp), input_tag); } - mediapipe::Status RunGraph() { return runner_->Run(); } + absl::Status RunGraph() { return runner_->Run(); } bool HasInputHeader(const size_t input_index = 0) const { return input(input_index) diff --git a/mediapipe/util/time_series_util.cc b/mediapipe/util/time_series_util.cc index 35e89108a..1e20daa59 100644 --- a/mediapipe/util/time_series_util.cc +++ b/mediapipe/util/time_series_util.cc @@ -62,10 +62,10 @@ bool LogWarningIfTimestampIsInconsistent(const Timestamp& current_timestamp, } } -mediapipe::Status IsTimeSeriesHeaderValid(const TimeSeriesHeader& header) { +absl::Status IsTimeSeriesHeaderValid(const TimeSeriesHeader& header) { if (header.has_sample_rate() && header.sample_rate() >= 0 && header.has_num_channels() && header.num_channels() >= 0) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } else { std::string error_message = "TimeSeriesHeader is missing necessary fields: " @@ -77,8 +77,8 @@ mediapipe::Status IsTimeSeriesHeaderValid(const TimeSeriesHeader& header) { } } -mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, - TimeSeriesHeader* header) { +absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, + TimeSeriesHeader* header) { CHECK(header); if (header_packet.IsEmpty()) { return tool::StatusFail("No header found."); @@ -90,7 +90,7 @@ mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, return IsTimeSeriesHeaderValid(*header); } -mediapipe::Status FillMultiStreamTimeSeriesHeaderIfValid( +absl::Status FillMultiStreamTimeSeriesHeaderIfValid( const Packet& header_packet, MultiStreamTimeSeriesHeader* header) { CHECK(header); if (header_packet.IsEmpty()) { @@ -107,8 +107,8 @@ mediapipe::Status FillMultiStreamTimeSeriesHeaderIfValid( return IsTimeSeriesHeaderValid(header->time_series_header()); } -mediapipe::Status IsMatrixShapeConsistentWithHeader( - const Matrix& matrix, const TimeSeriesHeader& header) { +absl::Status IsMatrixShapeConsistentWithHeader(const Matrix& matrix, + const TimeSeriesHeader& header) { if (header.has_num_samples() && matrix.cols() != header.num_samples()) { return tool::StatusInvalid(absl::StrCat( "Matrix size is inconsistent with header. Expected ", @@ -119,7 +119,7 @@ mediapipe::Status IsMatrixShapeConsistentWithHeader( "Matrix size is inconsistent with header. Expected ", header.num_channels(), " rows, but found ", matrix.rows())); } - return mediapipe::OkStatus(); + return absl::OkStatus(); } int64 SecondsToSamples(double time_in_seconds, double sample_rate) { diff --git a/mediapipe/util/time_series_util.h b/mediapipe/util/time_series_util.h index 292749552..a6a5911a6 100644 --- a/mediapipe/util/time_series_util.h +++ b/mediapipe/util/time_series_util.h @@ -43,27 +43,27 @@ bool LogWarningIfTimestampIsInconsistent(const Timestamp& current_timestamp, int64 cumulative_samples, double sample_rate); -// Returns mediapipe::status::OK if the header is valid. Otherwise, returns a +// Returns absl::Status::OK if the header is valid. Otherwise, returns a // Status object with an error message. -mediapipe::Status IsTimeSeriesHeaderValid(const TimeSeriesHeader& header); +absl::Status IsTimeSeriesHeaderValid(const TimeSeriesHeader& header); -// Fills header and returns mediapipe::status::OK if the header is non-empty and +// Fills header and returns absl::Status::OK if the header is non-empty and // valid. Otherwise, returns a Status object with an error message. -mediapipe::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, - TimeSeriesHeader* header); +absl::Status FillTimeSeriesHeaderIfValid(const Packet& header_packet, + TimeSeriesHeader* header); -// Fills header and returns mediapipe::status::OK if the header contains a +// Fills header and returns absl::Status::OK if the header contains a // non-empty and valid TimeSeriesHeader. Otherwise, returns a Status object with // an error message. -mediapipe::Status FillMultiStreamTimeSeriesHeaderIfValid( +absl::Status FillMultiStreamTimeSeriesHeaderIfValid( const Packet& header_packet, MultiStreamTimeSeriesHeader* header); -// Returnsmediapipe::Status::OK iff options contains an extension of type +// Returnsabsl::Status::OK iff options contains an extension of type // OptionsClass. template -mediapipe::Status HasOptionsExtension(const CalculatorOptions& options) { +absl::Status HasOptionsExtension(const CalculatorOptions& options) { if (options.HasExtension(OptionsClass::ext)) { - return mediapipe::OkStatus(); + return absl::OkStatus(); } std::string error_message = "Options proto does not contain extension "; absl::StrAppend(&error_message, @@ -72,16 +72,16 @@ mediapipe::Status HasOptionsExtension(const CalculatorOptions& options) { // Avoid lite proto APIs on mobile targets. absl::StrAppend(&error_message, " : ", options.DebugString()); #endif - return mediapipe::InvalidArgumentError(error_message); + return absl::InvalidArgumentError(error_message); } -// Returnsmediapipe::Status::OK if the shape of 'matrix' is consistent +// Returnsabsl::Status::OK if the shape of 'matrix' is consistent // with the num_samples and num_channels fields present in 'header'. // The corresponding matrix dimensions of unset header fields are // ignored, so e.g. an empty header (which is not valid according to // FillTimeSeriesHeaderIfValid) is considered consistent with any matrix. -mediapipe::Status IsMatrixShapeConsistentWithHeader( - const Matrix& matrix, const TimeSeriesHeader& header); +absl::Status IsMatrixShapeConsistentWithHeader(const Matrix& matrix, + const TimeSeriesHeader& header); template void FillOptionsExtensionOrDie(const CalculatorOptions& options, diff --git a/mediapipe/util/tracking/BUILD b/mediapipe/util/tracking/BUILD index f1114e86f..db9b004f9 100644 --- a/mediapipe/util/tracking/BUILD +++ b/mediapipe/util/tracking/BUILD @@ -319,6 +319,7 @@ cc_library( "//mediapipe/framework/port:logging", "//mediapipe/framework/port:vector", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", ], ) diff --git a/mediapipe/util/tracking/box_tracker.cc b/mediapipe/util/tracking/box_tracker.cc index 2111eac4b..7c5cd2b94 100644 --- a/mediapipe/util/tracking/box_tracker.cc +++ b/mediapipe/util/tracking/box_tracker.cc @@ -667,7 +667,7 @@ bool BoxTracker::WaitForChunkFile(int id, int checkpoint, } } - usleep(wait_time_msec * 1000); + absl::SleepFor(absl::Milliseconds(wait_time_msec)); total_wait_msec += wait_time_msec; struct stat tmp; diff --git a/mediapipe/util/tracking/motion_estimation.cc b/mediapipe/util/tracking/motion_estimation.cc index 18e74392d..e06acf1d1 100644 --- a/mediapipe/util/tracking/motion_estimation.cc +++ b/mediapipe/util/tracking/motion_estimation.cc @@ -5120,6 +5120,7 @@ bool MotionEstimation::MixtureHomographyFromFeature( MixtureHomography norm_model; // Initialize with identity. + norm_model.mutable_model()->Reserve(num_mixtures); for (int k = 0; k < num_mixtures; ++k) { norm_model.add_model(); } diff --git a/mediapipe/util/tracking/parallel_invoker_forbid_mixed.cc b/mediapipe/util/tracking/parallel_invoker_forbid_mixed.cc index 383ad53a3..4768eab09 100644 --- a/mediapipe/util/tracking/parallel_invoker_forbid_mixed.cc +++ b/mediapipe/util/tracking/parallel_invoker_forbid_mixed.cc @@ -15,14 +15,10 @@ // Guard to ensure clients do not link against both, // single and parallel version. #ifdef PARALLEL_INVOKER_ACTIVE -int LinkageAgainstBothSingleAndParallelStabilizationVersionsDetected() { - return 0; -} +int LinkageAgainstBothSingleAndParallelTrackingVersionsDetected() { return 0; } #endif // PARALLEL_INVOKER_ACTIVE #ifdef PARALLEL_INVOKER_INACTIVE -int LinkageAgainstBothSingleAndParallelStabilizationVersionsDetected() { - return 1; -} +int LinkageAgainstBothSingleAndParallelTrackingVersionsDetected() { return 1; } #endif // PARALLEL_INVOKER_INACTIVE diff --git a/mediapipe/util/tracking/region_flow.cc b/mediapipe/util/tracking/region_flow.cc index 1012f3f65..cdd6bcd88 100644 --- a/mediapipe/util/tracking/region_flow.cc +++ b/mediapipe/util/tracking/region_flow.cc @@ -21,6 +21,7 @@ #include #include "absl/container/node_hash_map.h" +#include "absl/container/node_hash_set.h" #include "absl/strings/str_cat.h" #include "mediapipe/framework/port/integral_types.h" #include "mediapipe/util/tracking/measure_time.h" @@ -580,7 +581,7 @@ void LongFeatureStream::AddFeatures(const RegionFlowFeatureList& feature_list, } // Record id of each track that is present in the current feature_list. - std::unordered_set present_tracks; + absl::node_hash_set present_tracks; for (auto feature : feature_list.feature()) { // Copy feature. if (feature.track_id() < 0) { LOG_IF(WARNING, []() { diff --git a/mediapipe/util/tracking/region_flow_computation_test.cc b/mediapipe/util/tracking/region_flow_computation_test.cc index 66319990c..91437681a 100644 --- a/mediapipe/util/tracking/region_flow_computation_test.cc +++ b/mediapipe/util/tracking/region_flow_computation_test.cc @@ -115,7 +115,7 @@ void RegionFlowComputationTest::MakeMovie( // First generate random positions. int seed = 900913; // google. - if (FLAGS_time_seed) { + if (absl::GetFlag(FLAGS_time_seed)) { seed = ToUnixMillis(absl::Now()) % (1 << 16); LOG(INFO) << "Using time seed: " << seed; } diff --git a/mediapipe/util/tracking/tracking.cc b/mediapipe/util/tracking/tracking.cc index 72245c0df..8d0afb08f 100644 --- a/mediapipe/util/tracking/tracking.cc +++ b/mediapipe/util/tracking/tracking.cc @@ -1652,7 +1652,8 @@ bool MotionBox::GetVectorsAndWeights( vectors->push_back(&motion_vectors[k]); - auto is_close_to_test_vector = [test_vector](const Vector2_f v) -> bool { + auto is_close_to_test_vector = [test_vector, + kSqProximity](const Vector2_f v) -> bool { return (v - test_vector.pos).Norm2() < kSqProximity; }; diff --git a/requirements.txt b/requirements.txt index cee4e454a..37cad28fe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ absl-py +attrs dataclasses numpy == 1.19.3 opencv-python diff --git a/setup.py b/setup.py index 61848de4f..5acad7bdb 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,8 @@ SUBDIR_INIT_PY_FILES = [ os.path.join(MP_ROOT_PATH, 'mediapipe/calculators/__init__.py'), os.path.join(MP_ROOT_PATH, 'mediapipe/modules/__init__.py'), os.path.join(MP_ROOT_PATH, - 'mediapipe/modules/holistic_landmark/__init__.py') + 'mediapipe/modules/holistic_landmark/__init__.py'), + os.path.join(MP_ROOT_PATH, 'mediapipe/modules/objectron/__init__.py') ] if not os.path.exists(ROOT_INIT_PY): open(ROOT_INIT_PY, 'w').close() @@ -219,9 +220,10 @@ class BuildBinaryGraphs(build.build): def run(self): _check_bazel() binary_graphs = [ + 'face_detection/face_detection_front_cpu', 'face_landmark/face_landmark_front_cpu', 'hand_landmark/hand_landmark_tracking_cpu', - 'holistic_landmark/holistic_landmark_cpu', + 'holistic_landmark/holistic_landmark_cpu', 'objectron/objectron_cpu', 'pose_landmark/pose_landmark_cpu' ] for binary_graph in binary_graphs: