diff --git a/LICENSE b/LICENSE index 261eeb9e9..0e03e3911 100644 --- a/LICENSE +++ b/LICENSE @@ -199,3 +199,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + +=========================================================================== +For files under tasks/cc/text/language_detector/custom_ops/utils/utf/ +=========================================================================== +/* + * The authors of this software are Rob Pike and Ken Thompson. + * Copyright (c) 2002 by Lucent Technologies. + * Permission to use, copy, modify, and distribute this software for any + * purpose without fee is hereby granted, provided that this entire notice + * is included in all copies of any software which is or includes a copy + * or modification of this software and in all copies of the supporting + * documentation for such software. + * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED + * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY + * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY + * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. + */ diff --git a/WORKSPACE b/WORKSPACE index 10f0c1ac5..17e96c0b2 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -270,7 +270,7 @@ new_local_repository( # For local MacOS builds, the path should point to an opencv@3 installation. # If you edit the path here, you will also need to update the corresponding # prefix in "opencv_macos.BUILD". - path = "/usr/local", + path = "/usr/local", # e.g. /usr/local/Cellar for HomeBrew ) new_local_repository( @@ -499,8 +499,8 @@ cc_crosstool(name = "crosstool") # Node dependencies http_archive( name = "build_bazel_rules_nodejs", - sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda", - urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"], + sha256 = "94070eff79305be05b7699207fbac5d2608054dd53e6109f7d00d923919ff45a", + urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.8.2/rules_nodejs-5.8.2.tar.gz"], ) load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies") @@ -543,3 +543,43 @@ external_files() load("@//third_party:wasm_files.bzl", "wasm_files") wasm_files() + +# Halide + +new_local_repository( + name = "halide", + build_file = "@//third_party/halide:BUILD.bazel", + path = "third_party/halide" +) + +http_archive( + name = "linux_halide", + sha256 = "f62b2914823d6e33d18693f5b74484f274523bf5402ce51988e24393d123b375", + strip_prefix = "Halide-15.0.0-x86-64-linux", + urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-linux-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"], + build_file = "@//third_party:halide.BUILD", +) + +http_archive( + name = "macos_x86_64_halide", + sha256 = "3d832aed942080ea89aa832462c68fbb906f3055c440b7b6d35093d7c52f6aab", + strip_prefix = "Halide-15.0.0-x86-64-osx", + urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-osx-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"], + build_file = "@//third_party:halide.BUILD", +) + +http_archive( + name = "macos_arm_64_halide", + sha256 = "b1fad3c9810122b187303d7031d9e35fb43761f345d18cc4492c00ed5877f641", + strip_prefix = "Halide-15.0.0-arm-64-osx", + urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-arm-64-osx-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"], + build_file = "@//third_party:halide.BUILD", +) + +http_archive( + name = "windows_halide", + sha256 = "5acf6fe161dd375856a2b43f4bb0a32815ba958b0585ee312c44e008aa7b0b64", + strip_prefix = "Halide-15.0.0-x86-64-windows", + urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-windows-d7651f4b32f9dbd764f243134001f7554378d62d.zip"], + build_file = "@//third_party:halide.BUILD", +) diff --git a/docs/framework_concepts/synchronization.md b/docs/framework_concepts/synchronization.md index e35e1032d..e12d077a7 100644 --- a/docs/framework_concepts/synchronization.md +++ b/docs/framework_concepts/synchronization.md @@ -113,14 +113,14 @@ Warning: On the other hand, it is not guaranteed that an input packet will always be available for all streams. To explain how it works, we need to introduce the definition of a settled -timestamp. We say that a timestamp in a stream is *settled* if it lower than the -timestamp bound. In other words, a timestamp is settled for a stream once the -state of the input at that timestamp is irrevocably known: either there is a +timestamp. We say that a timestamp in a stream is *settled* if it is lower than +the timestamp bound. In other words, a timestamp is settled for a stream once +the state of the input at that timestamp is irrevocably known: either there is a packet, or there is the certainty that a packet with that timestamp will not arrive. Note: For this reason, MediaPipe also allows a stream producer to explicitly -advance the timestamp bound farther that what the last packet implies, i.e. to +advance the timestamp bound farther than what the last packet implies, i.e. to provide a tighter bound. This can allow the downstream nodes to settle their inputs sooner. diff --git a/docs/solutions/models.md b/docs/solutions/models.md index 1172f2cfc..c45aefa44 100644 --- a/docs/solutions/models.md +++ b/docs/solutions/models.md @@ -108,6 +108,8 @@ one over the other. * [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite) * [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite) +* [TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip) +* [Model information](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md) ### [Objectron](https://google.github.io/mediapipe/solutions/objectron) diff --git a/docs/solutions/object_detection.md b/docs/solutions/object_detection.md index f08d0a928..ef7db8671 100644 --- a/docs/solutions/object_detection.md +++ b/docs/solutions/object_detection.md @@ -118,9 +118,9 @@ on how to build MediaPipe examples. * With a TensorFlow Model This uses the - [TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model) + [TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip) ( see also - [model info](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md)), + [model info](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md)), and the pipeline is implemented in this [graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt). diff --git a/docs/solutions/object_detection_saved_model.md b/docs/solutions/object_detection_saved_model.md new file mode 100644 index 000000000..6acac0a1b --- /dev/null +++ b/docs/solutions/object_detection_saved_model.md @@ -0,0 +1,62 @@ +## TensorFlow/TFLite Object Detection Model + +### TensorFlow model + +The model is trained on [MSCOCO 2014](http://cocodataset.org) dataset using [TensorFlow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection). It is a MobileNetV2-based SSD model with 0.5 depth multiplier. Detailed training configuration is in the provided `pipeline.config`. The model is a relatively compact model which has `0.171 mAP` to achieve real-time performance on mobile devices. You can compare it with other models from the [TensorFlow detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md). + + +### TFLite model + +The TFLite model is converted from the TensorFlow above. The steps needed to convert the model are similar to [this tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) with minor modifications. Assuming now we have a trained TensorFlow model which includes the checkpoint files and the training configuration file, for example the files provided in this repo: + + * `model.ckpt.index` + * `model.ckpt.meta` + * `model.ckpt.data-00000-of-00001` + * `pipeline.config` + +Make sure you have installed these [python libraries](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1.md). Then to get the frozen graph, run the `export_tflite_ssd_graph.py` script from the `models/research` directory with this command: + +```bash +$ PATH_TO_MODEL=path/to/the/model +$ bazel run object_detection:export_tflite_ssd_graph -- \ + --pipeline_config_path ${PATH_TO_MODEL}/pipeline.config \ + --trained_checkpoint_prefix ${PATH_TO_MODEL}/model.ckpt \ + --output_directory ${PATH_TO_MODEL} \ + --add_postprocessing_op=False +``` + +The exported model contains two files: + + * `tflite_graph.pb` + * `tflite_graph.pbtxt` + +The difference between this step and the one in [the tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) is that we set `add_postprocessing_op` to False. In MediaPipe, we have provided all the calculators needed for post-processing such that we can exclude the custom TFLite ops for post-processing in the original graph, e.g., non-maximum suppression. This enables the flexibility to integrate with different post-processing algorithms and implementations. + +Optional: You can install and use the [graph tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms) to inspect the input/output of the exported model: + +```bash +$ bazel run graph_transforms:summarize_graph -- \ + --in_graph=${PATH_TO_MODEL}/tflite_graph.pb +``` + +You should be able to see the input image size of the model is 320x320 and the outputs of the model are: + + * `raw_outputs/box_encodings` + * `raw_outputs/class_predictions` + +The last step is to convert the model to TFLite. You can look at [this guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md) for more detail. For this example, you just need to run: + +```bash +$ tflite_convert -- \ + --graph_def_file=${PATH_TO_MODEL}/tflite_graph.pb \ + --output_file=${PATH_TO_MODEL}/model.tflite \ + --input_format=TENSORFLOW_GRAPHDEF \ + --output_format=TFLITE \ + --inference_type=FLOAT \ + --input_shapes=1,320,320,3 \ + --input_arrays=normalized_input_image_tensor \ + --output_arrays=raw_outputs/box_encodings,raw_outputs/class_predictions + +``` + +Now you have the TFLite model `model.tflite` ready to use with MediaPipe Object Detection graphs. Please see the examples for more detail. diff --git a/docs/solutions/pose.md b/docs/solutions/pose.md index 3226ddc2f..ce0197ebd 100644 --- a/docs/solutions/pose.md +++ b/docs/solutions/pose.md @@ -269,6 +269,7 @@ Supported configuration options: ```python import cv2 import mediapipe as mp +import numpy as np mp_drawing = mp.solutions.drawing_utils mp_drawing_styles = mp.solutions.drawing_styles mp_pose = mp.solutions.pose diff --git a/mediapipe/calculators/image/BUILD b/mediapipe/calculators/image/BUILD index 18a1d60ae..7b94a031f 100644 --- a/mediapipe/calculators/image/BUILD +++ b/mediapipe/calculators/image/BUILD @@ -156,6 +156,7 @@ cc_library( "//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:status", "//mediapipe/framework/port:vector", + "//mediapipe/framework/port:opencv_imgproc", ] + select({ "//mediapipe/gpu:disable_gpu": [], "//conditions:default": [ @@ -168,6 +169,25 @@ cc_library( alwayslink = 1, ) +cc_test( + name = "set_alpha_calculator_test", + srcs = ["set_alpha_calculator_test.cc"], + deps = [ + ":set_alpha_calculator", + ":set_alpha_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/framework/port:opencv_imgproc", + "//mediapipe/framework/port:parse_text_proto", + "//mediapipe/framework/port:status", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "bilateral_filter_calculator", srcs = ["bilateral_filter_calculator.cc"], @@ -748,6 +768,7 @@ cc_test( "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png", + "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png", ], tags = ["desktop_only_test"], diff --git a/mediapipe/calculators/image/affine_transformation.h b/mediapipe/calculators/image/affine_transformation.h index 40793e7a1..3e40e46dc 100644 --- a/mediapipe/calculators/image/affine_transformation.h +++ b/mediapipe/calculators/image/affine_transformation.h @@ -29,6 +29,9 @@ class AffineTransformation { // pixels will be calculated. enum class BorderMode { kZero, kReplicate }; + // Pixel sampling interpolation method. + enum class Interpolation { kLinear, kCubic }; + struct Size { int width; int height; diff --git a/mediapipe/calculators/image/affine_transformation_runner_gl.cc b/mediapipe/calculators/image/affine_transformation_runner_gl.cc index 361dfc902..006416916 100644 --- a/mediapipe/calculators/image/affine_transformation_runner_gl.cc +++ b/mediapipe/calculators/image/affine_transformation_runner_gl.cc @@ -77,8 +77,11 @@ class GlTextureWarpAffineRunner std::unique_ptr> { public: GlTextureWarpAffineRunner(std::shared_ptr gl_helper, - GpuOrigin::Mode gpu_origin) - : gl_helper_(gl_helper), gpu_origin_(gpu_origin) {} + GpuOrigin::Mode gpu_origin, + AffineTransformation::Interpolation interpolation) + : gl_helper_(gl_helper), + gpu_origin_(gpu_origin), + interpolation_(interpolation) {} absl::Status Init() { return gl_helper_->RunInGlContext([this]() -> absl::Status { const GLint attr_location[kNumAttributes] = { @@ -103,28 +106,83 @@ class GlTextureWarpAffineRunner } )"; + // TODO Move bicubic code to common shared place. constexpr GLchar kFragShader[] = R"( - DEFAULT_PRECISION(highp, float) - in vec2 sample_coordinate; - uniform sampler2D input_texture; + DEFAULT_PRECISION(highp, float) - #ifdef GL_ES - #define fragColor gl_FragColor - #else - out vec4 fragColor; - #endif // defined(GL_ES); + in vec2 sample_coordinate; + uniform sampler2D input_texture; + uniform vec2 input_size; - void main() { - vec4 color = texture2D(input_texture, sample_coordinate); - #ifdef CUSTOM_ZERO_BORDER_MODE - float out_of_bounds = - float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 || - sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0); - color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds); - #endif // defined(CUSTOM_ZERO_BORDER_MODE) - fragColor = color; - } - )"; + #ifdef GL_ES + #define fragColor gl_FragColor + #else + out vec4 fragColor; + #endif // defined(GL_ES); + + #ifdef CUBIC_INTERPOLATION + vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) { + const vec2 halve = vec2(0.5,0.5); + const vec2 one = vec2(1.0,1.0); + const vec2 two = vec2(2.0,2.0); + const vec2 three = vec2(3.0,3.0); + const vec2 six = vec2(6.0,6.0); + + // Calculate the fraction and integer. + tex_coord = tex_coord * tex_size - halve; + vec2 frac = fract(tex_coord); + vec2 index = tex_coord - frac + halve; + + // Calculate weights for Catmull-Rom filter. + vec2 w0 = frac * (-halve + frac * (one - halve * frac)); + vec2 w1 = one + frac * frac * (-(two+halve) + three/two * frac); + vec2 w2 = frac * (halve + frac * (two - three/two * frac)); + vec2 w3 = frac * frac * (-halve + halve * frac); + + // Calculate weights to take advantage of bilinear texture lookup. + vec2 w12 = w1 + w2; + vec2 offset12 = w2 / (w1 + w2); + + vec2 index_tl = index - one; + vec2 index_br = index + two; + vec2 index_eq = index + offset12; + + index_tl /= tex_size; + index_br /= tex_size; + index_eq /= tex_size; + + // 9 texture lookup and linear blending. + vec4 color = vec4(0.0); + color += texture2D(tex, vec2(index_tl.x, index_tl.y)) * w0.x * w0.y; + color += texture2D(tex, vec2(index_eq.x, index_tl.y)) * w12.x *w0.y; + color += texture2D(tex, vec2(index_br.x, index_tl.y)) * w3.x * w0.y; + + color += texture2D(tex, vec2(index_tl.x, index_eq.y)) * w0.x * w12.y; + color += texture2D(tex, vec2(index_eq.x, index_eq.y)) * w12.x *w12.y; + color += texture2D(tex, vec2(index_br.x, index_eq.y)) * w3.x * w12.y; + + color += texture2D(tex, vec2(index_tl.x, index_br.y)) * w0.x * w3.y; + color += texture2D(tex, vec2(index_eq.x, index_br.y)) * w12.x *w3.y; + color += texture2D(tex, vec2(index_br.x, index_br.y)) * w3.x * w3.y; + return color; + } + #else + vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) { + return texture2D(tex, tex_coord); + } + #endif // defined(CUBIC_INTERPOLATION) + + void main() { + vec4 color = sample(input_texture, sample_coordinate, input_size); + #ifdef CUSTOM_ZERO_BORDER_MODE + float out_of_bounds = + float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 || + sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0); + color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds); + #endif // defined(CUSTOM_ZERO_BORDER_MODE) + fragColor = color; + } + )"; // Create program and set parameters. auto create_fn = [&](const std::string& vs, @@ -137,14 +195,28 @@ class GlTextureWarpAffineRunner glUseProgram(program); glUniform1i(glGetUniformLocation(program, "input_texture"), 1); GLint matrix_id = glGetUniformLocation(program, "transform_matrix"); - return Program{.id = program, .matrix_id = matrix_id}; + GLint size_id = glGetUniformLocation(program, "input_size"); + return Program{ + .id = program, .matrix_id = matrix_id, .size_id = size_id}; }; const std::string vert_src = absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader); - const std::string frag_src = absl::StrCat( - mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader); + std::string interpolation_def; + switch (interpolation_) { + case AffineTransformation::Interpolation::kCubic: + interpolation_def = R"( + #define CUBIC_INTERPOLATION + )"; + break; + case AffineTransformation::Interpolation::kLinear: + break; + } + + const std::string frag_src = + absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble, + interpolation_def, kFragShader); ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src)); @@ -152,9 +224,9 @@ class GlTextureWarpAffineRunner std::string custom_zero_border_mode_def = R"( #define CUSTOM_ZERO_BORDER_MODE )"; - const std::string frag_custom_zero_src = - absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble, - custom_zero_border_mode_def, kFragShader); + const std::string frag_custom_zero_src = absl::StrCat( + mediapipe::kMediaPipeFragmentShaderPreamble, + custom_zero_border_mode_def, interpolation_def, kFragShader); return create_fn(vert_src, frag_custom_zero_src); }; #if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED @@ -256,6 +328,7 @@ class GlTextureWarpAffineRunner } glUseProgram(program->id); + // uniforms Eigen::Matrix eigen_mat(matrix.data()); if (IsMatrixVerticalFlipNeeded(gpu_origin_)) { // @matrix describes affine transformation in terms of TOP LEFT origin, so @@ -275,6 +348,10 @@ class GlTextureWarpAffineRunner eigen_mat.transposeInPlace(); glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data()); + if (interpolation_ == AffineTransformation::Interpolation::kCubic) { + glUniform2f(program->size_id, texture.width(), texture.height()); + } + // vao glBindVertexArray(vao_); @@ -327,6 +404,7 @@ class GlTextureWarpAffineRunner struct Program { GLuint id; GLint matrix_id; + GLint size_id; }; std::shared_ptr gl_helper_; GpuOrigin::Mode gpu_origin_; @@ -335,6 +413,8 @@ class GlTextureWarpAffineRunner Program program_; std::optional program_custom_zero_; GLuint framebuffer_ = 0; + AffineTransformation::Interpolation interpolation_ = + AffineTransformation::Interpolation::kLinear; }; #undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED @@ -344,9 +424,10 @@ class GlTextureWarpAffineRunner absl::StatusOr>>> CreateAffineTransformationGlRunner( - std::shared_ptr gl_helper, GpuOrigin::Mode gpu_origin) { - auto runner = - absl::make_unique(gl_helper, gpu_origin); + std::shared_ptr gl_helper, GpuOrigin::Mode gpu_origin, + AffineTransformation::Interpolation interpolation) { + auto runner = absl::make_unique( + gl_helper, gpu_origin, interpolation); MP_RETURN_IF_ERROR(runner->Init()); return runner; } diff --git a/mediapipe/calculators/image/affine_transformation_runner_gl.h b/mediapipe/calculators/image/affine_transformation_runner_gl.h index 677e0720d..826c7b5c1 100644 --- a/mediapipe/calculators/image/affine_transformation_runner_gl.h +++ b/mediapipe/calculators/image/affine_transformation_runner_gl.h @@ -29,7 +29,8 @@ absl::StatusOr>>> CreateAffineTransformationGlRunner( std::shared_ptr gl_helper, - mediapipe::GpuOrigin::Mode gpu_origin); + mediapipe::GpuOrigin::Mode gpu_origin, + AffineTransformation::Interpolation interpolation); } // namespace mediapipe diff --git a/mediapipe/calculators/image/affine_transformation_runner_opencv.cc b/mediapipe/calculators/image/affine_transformation_runner_opencv.cc index 46026a987..c43d73ff7 100644 --- a/mediapipe/calculators/image/affine_transformation_runner_opencv.cc +++ b/mediapipe/calculators/image/affine_transformation_runner_opencv.cc @@ -39,9 +39,22 @@ cv::BorderTypes GetBorderModeForOpenCv( } } +int GetInterpolationForOpenCv( + AffineTransformation::Interpolation interpolation) { + switch (interpolation) { + case AffineTransformation::Interpolation::kLinear: + return cv::INTER_LINEAR; + case AffineTransformation::Interpolation::kCubic: + return cv::INTER_CUBIC; + } +} + class OpenCvRunner : public AffineTransformation::Runner { public: + OpenCvRunner(AffineTransformation::Interpolation interpolation) + : interpolation_(GetInterpolationForOpenCv(interpolation)) {} + absl::StatusOr Run( const ImageFrame& input, const std::array& matrix, const AffineTransformation::Size& size, @@ -142,19 +155,23 @@ class OpenCvRunner cv::warpAffine(in_mat, out_mat, cv_affine_transform, cv::Size(out_mat.cols, out_mat.rows), - /*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP, + /*flags=*/interpolation_ | cv::WARP_INVERSE_MAP, GetBorderModeForOpenCv(border_mode)); return out_image; } + + private: + int interpolation_ = cv::INTER_LINEAR; }; } // namespace absl::StatusOr< std::unique_ptr>> -CreateAffineTransformationOpenCvRunner() { - return absl::make_unique(); +CreateAffineTransformationOpenCvRunner( + AffineTransformation::Interpolation interpolation) { + return absl::make_unique(interpolation); } } // namespace mediapipe diff --git a/mediapipe/calculators/image/affine_transformation_runner_opencv.h b/mediapipe/calculators/image/affine_transformation_runner_opencv.h index 200281c95..6de48d4cf 100644 --- a/mediapipe/calculators/image/affine_transformation_runner_opencv.h +++ b/mediapipe/calculators/image/affine_transformation_runner_opencv.h @@ -25,7 +25,8 @@ namespace mediapipe { absl::StatusOr< std::unique_ptr>> -CreateAffineTransformationOpenCvRunner(); +CreateAffineTransformationOpenCvRunner( + AffineTransformation::Interpolation interpolation); } // namespace mediapipe diff --git a/mediapipe/calculators/image/image_clone_calculator.cc b/mediapipe/calculators/image/image_clone_calculator.cc index 1e76848b1..6660d55b4 100644 --- a/mediapipe/calculators/image/image_clone_calculator.cc +++ b/mediapipe/calculators/image/image_clone_calculator.cc @@ -81,7 +81,8 @@ class ImageCloneCalculator : public Node { absl::Status Process(CalculatorContext* cc) override { std::unique_ptr output; const auto& input = *kIn(cc); - if (input.UsesGpu()) { + bool input_on_gpu = input.UsesGpu(); + if (input_on_gpu) { #if !MEDIAPIPE_DISABLE_GPU // Create an output Image that co-owns the underlying texture buffer as // the input Image. @@ -97,15 +98,15 @@ class ImageCloneCalculator : public Node { // Image. This ensures a correct life span of the shared pixel data. output = std::make_unique(std::make_unique( input.image_format(), input.width(), input.height(), input.step(), - const_cast(input.GetImageFrameSharedPtr()->PixelData()), - [packet_copy_ptr](uint8*) { delete packet_copy_ptr; })); + const_cast(input.GetImageFrameSharedPtr()->PixelData()), + [packet_copy_ptr](uint8_t*) { delete packet_copy_ptr; })); } - if (output_on_gpu_) { + if (output_on_gpu_ && !input_on_gpu) { #if !MEDIAPIPE_DISABLE_GPU gpu_helper_.RunInGlContext([&output]() { output->ConvertToGpu(); }); #endif // !MEDIAPIPE_DISABLE_GPU - } else { + } else if (!output_on_gpu_ && input_on_gpu) { output->ConvertToCpu(); } kOut(cc).Send(std::move(output)); diff --git a/mediapipe/calculators/image/set_alpha_calculator.cc b/mediapipe/calculators/image/set_alpha_calculator.cc index ea29278de..e20621e8d 100644 --- a/mediapipe/calculators/image/set_alpha_calculator.cc +++ b/mediapipe/calculators/image/set_alpha_calculator.cc @@ -22,6 +22,7 @@ #include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/vector.h" @@ -53,24 +54,16 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; // range of [0, 1). Only the first channel of Alpha is used. Input & output Mat // must be uchar. template -absl::Status MergeRGBA8Image(const cv::Mat input_mat, const cv::Mat& alpha_mat, - cv::Mat& output_mat) { - RET_CHECK_EQ(input_mat.rows, alpha_mat.rows); - RET_CHECK_EQ(input_mat.cols, alpha_mat.cols); - RET_CHECK_EQ(input_mat.rows, output_mat.rows); - RET_CHECK_EQ(input_mat.cols, output_mat.cols); +absl::Status CopyAlphaImage(const cv::Mat& alpha_mat, cv::Mat& output_mat) { + RET_CHECK_EQ(output_mat.rows, alpha_mat.rows); + RET_CHECK_EQ(output_mat.cols, alpha_mat.cols); for (int i = 0; i < output_mat.rows; ++i) { - const uchar* in_ptr = input_mat.ptr(i); const AlphaType* alpha_ptr = alpha_mat.ptr(i); uchar* out_ptr = output_mat.ptr(i); for (int j = 0; j < output_mat.cols; ++j) { const int out_idx = j * kNumChannelsRGBA; - const int in_idx = j * input_mat.channels(); const int alpha_idx = j * alpha_mat.channels(); - out_ptr[out_idx + 0] = in_ptr[in_idx + 0]; - out_ptr[out_idx + 1] = in_ptr[in_idx + 1]; - out_ptr[out_idx + 2] = in_ptr[in_idx + 2]; if constexpr (std::is_same::value) { out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask } else { @@ -273,7 +266,7 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) { // Setup source image const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get(); - const cv::Mat input_mat = mediapipe::formats::MatView(&input_frame); + const cv::Mat input_mat = formats::MatView(&input_frame); if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) { LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported"; } @@ -281,38 +274,38 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) { // Setup destination image auto output_frame = absl::make_unique( ImageFormat::SRGBA, input_mat.cols, input_mat.rows); - cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get()); + cv::Mat output_mat = formats::MatView(output_frame.get()); const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) && !cc->Inputs().Tag(kInputAlphaTag).IsEmpty(); const bool use_alpha_mask = alpha_value_ < 0 && has_alpha_mask; - // Setup alpha image and Update image in CPU. + // Copy rgb part of the image in CPU + if (input_mat.channels() == 3) { + cv::cvtColor(input_mat, output_mat, cv::COLOR_RGB2RGBA); + } else { + input_mat.copyTo(output_mat); + } + + // Setup alpha image in CPU. if (use_alpha_mask) { const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get(); - cv::Mat alpha_mat = mediapipe::formats::MatView(&alpha_mask); + cv::Mat alpha_mat = formats::MatView(&alpha_mask); const bool alpha_is_float = CV_MAT_DEPTH(alpha_mat.type()) == CV_32F; RET_CHECK(alpha_is_float || CV_MAT_DEPTH(alpha_mat.type()) == CV_8U); if (alpha_is_float) { - MP_RETURN_IF_ERROR( - MergeRGBA8Image(input_mat, alpha_mat, output_mat)); + MP_RETURN_IF_ERROR(CopyAlphaImage(alpha_mat, output_mat)); } else { - MP_RETURN_IF_ERROR( - MergeRGBA8Image(input_mat, alpha_mat, output_mat)); + MP_RETURN_IF_ERROR(CopyAlphaImage(alpha_mat, output_mat)); } } else { const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f); for (int i = 0; i < output_mat.rows; ++i) { - const uchar* in_ptr = input_mat.ptr(i); uchar* out_ptr = output_mat.ptr(i); for (int j = 0; j < output_mat.cols; ++j) { const int out_idx = j * kNumChannelsRGBA; - const int in_idx = j * input_mat.channels(); - out_ptr[out_idx + 0] = in_ptr[in_idx + 0]; - out_ptr[out_idx + 1] = in_ptr[in_idx + 1]; - out_ptr[out_idx + 2] = in_ptr[in_idx + 2]; out_ptr[out_idx + 3] = alpha_value; // use value from options } } diff --git a/mediapipe/calculators/image/set_alpha_calculator_test.cc b/mediapipe/calculators/image/set_alpha_calculator_test.cc new file mode 100644 index 000000000..cb2352d08 --- /dev/null +++ b/mediapipe/calculators/image/set_alpha_calculator_test.cc @@ -0,0 +1,156 @@ +#include + +#include "mediapipe/calculators/image/set_alpha_calculator.pb.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgproc_inc.h" +#include "mediapipe/framework/port/parse_text_proto.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "testing/base/public/benchmark.h" + +namespace mediapipe { + +namespace { + +constexpr int input_width = 100; +constexpr int input_height = 100; + +std::unique_ptr GetInputFrame(int width, int height, int channel) { + const int total_size = width * height * channel; + + ImageFormat::Format image_format; + if (channel == 4) { + image_format = ImageFormat::SRGBA; + } else if (channel == 3) { + image_format = ImageFormat::SRGB; + } else { + image_format = ImageFormat::GRAY8; + } + + auto input_frame = std::make_unique(image_format, width, height, + /*alignment_boundary =*/1); + for (int i = 0; i < total_size; ++i) { + input_frame->MutablePixelData()[i] = i % 256; + } + return input_frame; +} + +// Test SetAlphaCalculator with RGB IMAGE input. +TEST(SetAlphaCalculatorTest, CpuRgb) { + auto calculator_node = ParseTextProtoOrDie( + R"pb( + calculator: "SetAlphaCalculator" + input_stream: "IMAGE:input_frames" + input_stream: "ALPHA:masks" + output_stream: "IMAGE:output_frames" + )pb"); + CalculatorRunner runner(calculator_node); + + // Input frames. + const auto input_frame = GetInputFrame(input_width, input_height, 3); + const auto mask_frame = GetInputFrame(input_width, input_height, 1); + auto input_frame_packet = MakePacket(std::move(*input_frame)); + auto mask_frame_packet = MakePacket(std::move(*mask_frame)); + runner.MutableInputs()->Tag("IMAGE").packets.push_back( + input_frame_packet.At(Timestamp(1))); + runner.MutableInputs()->Tag("ALPHA").packets.push_back( + mask_frame_packet.At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()); + + const auto& outputs = runner.Outputs(); + EXPECT_EQ(outputs.NumEntries(), 1); + const auto& output_image = outputs.Tag("IMAGE").packets[0].Get(); + + // Generate ground truth (expected_mat). + const auto image = GetInputFrame(input_width, input_height, 3); + const auto input_mat = formats::MatView(image.get()); + const auto mask = GetInputFrame(input_width, input_height, 1); + const auto mask_mat = formats::MatView(mask.get()); + const std::array input_mats = {input_mat, mask_mat}; + cv::Mat expected_mat(input_width, input_height, CV_8UC4); + cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 3, 3}); + + cv::Mat output_mat = formats::MatView(&output_image); + double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF); + EXPECT_FLOAT_EQ(max_diff, 0); +} // TEST + +// Test SetAlphaCalculator with RGBA IMAGE input. +TEST(SetAlphaCalculatorTest, CpuRgba) { + auto calculator_node = ParseTextProtoOrDie( + R"pb( + calculator: "SetAlphaCalculator" + input_stream: "IMAGE:input_frames" + input_stream: "ALPHA:masks" + output_stream: "IMAGE:output_frames" + )pb"); + CalculatorRunner runner(calculator_node); + + // Input frames. + const auto input_frame = GetInputFrame(input_width, input_height, 4); + const auto mask_frame = GetInputFrame(input_width, input_height, 1); + auto input_frame_packet = MakePacket(std::move(*input_frame)); + auto mask_frame_packet = MakePacket(std::move(*mask_frame)); + runner.MutableInputs()->Tag("IMAGE").packets.push_back( + input_frame_packet.At(Timestamp(1))); + runner.MutableInputs()->Tag("ALPHA").packets.push_back( + mask_frame_packet.At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()); + + const auto& outputs = runner.Outputs(); + EXPECT_EQ(outputs.NumEntries(), 1); + const auto& output_image = outputs.Tag("IMAGE").packets[0].Get(); + + // Generate ground truth (expected_mat). + const auto image = GetInputFrame(input_width, input_height, 4); + const auto input_mat = formats::MatView(image.get()); + const auto mask = GetInputFrame(input_width, input_height, 1); + const auto mask_mat = formats::MatView(mask.get()); + const std::array input_mats = {input_mat, mask_mat}; + cv::Mat expected_mat(input_width, input_height, CV_8UC4); + cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 4, 3}); + + cv::Mat output_mat = formats::MatView(&output_image); + double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF); + EXPECT_FLOAT_EQ(max_diff, 0); +} // TEST + +static void BM_SetAlpha3ChannelImage(benchmark::State& state) { + auto calculator_node = ParseTextProtoOrDie( + R"pb( + calculator: "SetAlphaCalculator" + input_stream: "IMAGE:input_frames" + input_stream: "ALPHA:masks" + output_stream: "IMAGE:output_frames" + )pb"); + CalculatorRunner runner(calculator_node); + + // Input frames. + const auto input_frame = GetInputFrame(input_width, input_height, 3); + const auto mask_frame = GetInputFrame(input_width, input_height, 1); + auto input_frame_packet = MakePacket(std::move(*input_frame)); + auto mask_frame_packet = MakePacket(std::move(*mask_frame)); + runner.MutableInputs()->Tag("IMAGE").packets.push_back( + input_frame_packet.At(Timestamp(1))); + runner.MutableInputs()->Tag("ALPHA").packets.push_back( + mask_frame_packet.At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()); + const auto& outputs = runner.Outputs(); + ASSERT_EQ(1, outputs.NumEntries()); + + for (const auto _ : state) { + MP_ASSERT_OK(runner.Run()); + } +} + +BENCHMARK(BM_SetAlpha3ChannelImage); + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/calculators/image/warp_affine_calculator.cc b/mediapipe/calculators/image/warp_affine_calculator.cc index 615d1697c..388701773 100644 --- a/mediapipe/calculators/image/warp_affine_calculator.cc +++ b/mediapipe/calculators/image/warp_affine_calculator.cc @@ -53,6 +53,17 @@ AffineTransformation::BorderMode GetBorderMode( } } +AffineTransformation::Interpolation GetInterpolation( + mediapipe::WarpAffineCalculatorOptions::Interpolation interpolation) { + switch (interpolation) { + case mediapipe::WarpAffineCalculatorOptions::INTER_UNSPECIFIED: + case mediapipe::WarpAffineCalculatorOptions::INTER_LINEAR: + return AffineTransformation::Interpolation::kLinear; + case mediapipe::WarpAffineCalculatorOptions::INTER_CUBIC: + return AffineTransformation::Interpolation::kCubic; + } +} + template class WarpAffineRunnerHolder {}; @@ -61,16 +72,22 @@ template <> class WarpAffineRunnerHolder { public: using RunnerType = AffineTransformation::Runner; - absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); } + absl::Status Open(CalculatorContext* cc) { + interpolation_ = GetInterpolation( + cc->Options().interpolation()); + return absl::OkStatus(); + } absl::StatusOr GetRunner() { if (!runner_) { - ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner()); + ASSIGN_OR_RETURN(runner_, + CreateAffineTransformationOpenCvRunner(interpolation_)); } return runner_.get(); } private: std::unique_ptr runner_; + AffineTransformation::Interpolation interpolation_; }; #endif // !MEDIAPIPE_DISABLE_OPENCV @@ -85,12 +102,14 @@ class WarpAffineRunnerHolder { gpu_origin_ = cc->Options().gpu_origin(); gl_helper_ = std::make_shared(); + interpolation_ = GetInterpolation( + cc->Options().interpolation()); return gl_helper_->Open(cc); } absl::StatusOr GetRunner() { if (!runner_) { - ASSIGN_OR_RETURN( - runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_)); + ASSIGN_OR_RETURN(runner_, CreateAffineTransformationGlRunner( + gl_helper_, gpu_origin_, interpolation_)); } return runner_.get(); } @@ -99,6 +118,7 @@ class WarpAffineRunnerHolder { mediapipe::GpuOrigin::Mode gpu_origin_; std::shared_ptr gl_helper_; std::unique_ptr runner_; + AffineTransformation::Interpolation interpolation_; }; #endif // !MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/calculators/image/warp_affine_calculator.proto b/mediapipe/calculators/image/warp_affine_calculator.proto index 20e6c1b07..b68f71ac3 100644 --- a/mediapipe/calculators/image/warp_affine_calculator.proto +++ b/mediapipe/calculators/image/warp_affine_calculator.proto @@ -31,6 +31,13 @@ message WarpAffineCalculatorOptions { BORDER_REPLICATE = 2; } + // Pixel sampling interpolation methods. See @interpolation. + enum Interpolation { + INTER_UNSPECIFIED = 0; + INTER_LINEAR = 1; + INTER_CUBIC = 2; + } + // Pixel extrapolation method. // When converting image to tensor it may happen that tensor needs to read // pixels outside image boundaries. Border mode helps to specify how such @@ -43,4 +50,10 @@ message WarpAffineCalculatorOptions { // to be flipped vertically as tensors are expected to start at top. // (DEFAULT or unset interpreted as CONVENTIONAL.) optional GpuOrigin.Mode gpu_origin = 2; + + // Sampling method for neighboring pixels. + // INTER_LINEAR (bilinear) linearly interpolates from the nearest 4 neighbors. + // INTER_CUBIC (bicubic) interpolates a small neighborhood with cubic weights. + // INTER_UNSPECIFIED or unset interpreted as INTER_LINEAR. + optional Interpolation interpolation = 3; } diff --git a/mediapipe/calculators/image/warp_affine_calculator_test.cc b/mediapipe/calculators/image/warp_affine_calculator_test.cc index 959912cc9..b911b66fd 100644 --- a/mediapipe/calculators/image/warp_affine_calculator_test.cc +++ b/mediapipe/calculators/image/warp_affine_calculator_test.cc @@ -63,7 +63,8 @@ void RunTest(const std::string& graph_text, const std::string& tag, const cv::Mat& input, cv::Mat expected_result, float similarity_threshold, std::array matrix, int out_width, int out_height, - absl::optional border_mode) { + std::optional border_mode, + std::optional interpolation) { std::string border_mode_str; if (border_mode) { switch (*border_mode) { @@ -75,8 +76,20 @@ void RunTest(const std::string& graph_text, const std::string& tag, break; } } + std::string interpolation_str; + if (interpolation) { + switch (*interpolation) { + case AffineTransformation::Interpolation::kLinear: + interpolation_str = "interpolation: INTER_LINEAR"; + break; + case AffineTransformation::Interpolation::kCubic: + interpolation_str = "interpolation: INTER_CUBIC"; + break; + } + } auto graph_config = mediapipe::ParseTextProtoOrDie( - absl::Substitute(graph_text, /*$0=*/border_mode_str)); + absl::Substitute(graph_text, /*$0=*/border_mode_str, + /*$1=*/interpolation_str)); std::vector output_packets; tool::AddVectorSink("output_image", &graph_config, &output_packets); @@ -132,7 +145,8 @@ struct SimilarityConfig { void RunTest(cv::Mat input, cv::Mat expected_result, const SimilarityConfig& similarity, std::array matrix, int out_width, int out_height, - absl::optional border_mode) { + std::optional border_mode, + std::optional interpolation) { RunTest(R"( input_stream: "input_image" input_stream: "output_size" @@ -146,12 +160,13 @@ void RunTest(cv::Mat input, cv::Mat expected_result, options { [mediapipe.WarpAffineCalculatorOptions.ext] { $0 # border mode + $1 # interpolation } } } )", "cpu", input, expected_result, similarity.threshold_on_cpu, matrix, - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); RunTest(R"( input_stream: "input_image" @@ -171,6 +186,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, options { [mediapipe.WarpAffineCalculatorOptions.ext] { $0 # border mode + $1 # interpolation } } } @@ -181,7 +197,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, } )", "cpu_image", input, expected_result, similarity.threshold_on_cpu, - matrix, out_width, out_height, border_mode); + matrix, out_width, out_height, border_mode, interpolation); RunTest(R"( input_stream: "input_image" @@ -201,6 +217,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, options { [mediapipe.WarpAffineCalculatorOptions.ext] { $0 # border mode + $1 # interpolation gpu_origin: TOP_LEFT } } @@ -212,7 +229,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, } )", "gpu", input, expected_result, similarity.threshold_on_gpu, matrix, - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); RunTest(R"( input_stream: "input_image" @@ -237,6 +254,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, options { [mediapipe.WarpAffineCalculatorOptions.ext] { $0 # border mode + $1 # interpolation gpu_origin: TOP_LEFT } } @@ -253,7 +271,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, } )", "gpu_image", input, expected_result, similarity.threshold_on_gpu, - matrix, out_width, out_height, border_mode); + matrix, out_width, out_height, border_mode, interpolation); } std::array GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi, @@ -287,10 +305,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) { int out_height = 256; bool keep_aspect_ratio = true; std::optional border_mode = {}; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) { @@ -312,10 +331,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) { bool keep_aspect_ratio = true; std::optional border_mode = AffineTransformation::BorderMode::kZero; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) { @@ -337,10 +357,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) { bool keep_aspect_ratio = true; std::optional border_mode = AffineTransformation::BorderMode::kReplicate; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) { @@ -362,10 +383,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) { bool keep_aspect_ratio = true; std::optional border_mode = AffineTransformation::BorderMode::kZero; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) { @@ -386,10 +408,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) { bool keep_aspect_ratio = false; std::optional border_mode = AffineTransformation::BorderMode::kReplicate; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) { @@ -411,10 +434,38 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) { bool keep_aspect_ratio = false; std::optional border_mode = AffineTransformation::BorderMode::kZero; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); +} + +TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZeroInterpCubic) { + mediapipe::NormalizedRect roi; + roi.set_x_center(0.65f); + roi.set_y_center(0.4f); + roi.set_width(0.5f); + roi.set_height(0.5f); + roi.set_rotation(M_PI * -45.0f / 180.0f); + auto input = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/input.jpg"); + auto expected_output = GetRgb( + "/mediapipe/calculators/" + "tensor/testdata/image_to_tensor/" + "medium_sub_rect_with_rotation_border_zero_interp_cubic.png"); + int out_width = 256; + int out_height = 256; + bool keep_aspect_ratio = false; + std::optional border_mode = + AffineTransformation::BorderMode::kZero; + std::optional interpolation = + AffineTransformation::Interpolation::kCubic; + RunTest(input, expected_output, + {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.78}, + GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, LargeSubRect) { @@ -435,10 +486,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRect) { bool keep_aspect_ratio = false; std::optional border_mode = AffineTransformation::BorderMode::kReplicate; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) { @@ -459,10 +511,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) { bool keep_aspect_ratio = false; std::optional border_mode = AffineTransformation::BorderMode::kZero; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) { @@ -483,10 +536,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) { bool keep_aspect_ratio = true; std::optional border_mode = AffineTransformation::BorderMode::kReplicate; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) { @@ -508,10 +562,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) { bool keep_aspect_ratio = true; std::optional border_mode = AffineTransformation::BorderMode::kZero; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) { @@ -532,10 +587,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) { int out_height = 128; bool keep_aspect_ratio = true; std::optional border_mode = {}; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { @@ -557,10 +613,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { bool keep_aspect_ratio = true; std::optional border_mode = AffineTransformation::BorderMode::kZero; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, NoOp) { @@ -581,10 +638,11 @@ TEST(WarpAffineCalculatorTest, NoOp) { bool keep_aspect_ratio = true; std::optional border_mode = AffineTransformation::BorderMode::kReplicate; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } TEST(WarpAffineCalculatorTest, NoOpBorderZero) { @@ -605,10 +663,11 @@ TEST(WarpAffineCalculatorTest, NoOpBorderZero) { bool keep_aspect_ratio = true; std::optional border_mode = AffineTransformation::BorderMode::kZero; + std::optional interpolation = {}; RunTest(input, expected_output, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), - out_width, out_height, border_mode); + out_width, out_height, border_mode, interpolation); } } // namespace diff --git a/mediapipe/calculators/tensor/BUILD b/mediapipe/calculators/tensor/BUILD index 7a29c3af8..a76b75494 100644 --- a/mediapipe/calculators/tensor/BUILD +++ b/mediapipe/calculators/tensor/BUILD @@ -997,17 +997,20 @@ cc_library( ":image_to_tensor_converter_gl_buffer", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_service", ], "//mediapipe:apple": [ ":image_to_tensor_converter_metal", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_service", ], "//conditions:default": [ ":image_to_tensor_converter_gl_buffer", "//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gpu_buffer", + "//mediapipe/gpu:gpu_service", ], }), ) @@ -1045,6 +1048,10 @@ cc_test( ":image_to_tensor_calculator", ":image_to_tensor_converter", ":image_to_tensor_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_runner", "//mediapipe/framework/deps:file_path", @@ -1061,11 +1068,10 @@ cc_test( "//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:parse_text_proto", "//mediapipe/util:image_test_utils", - "@com_google_absl//absl/flags:flag", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - ], + ] + select({ + "//mediapipe:apple": [], + "//conditions:default": ["//mediapipe/gpu:gl_context"], + }), ) cc_library( diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc index 499b497b0..d15d35086 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator.cc @@ -45,9 +45,11 @@ #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h" #include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_service.h" #else #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h" #include "mediapipe/gpu/gl_calculator_helper.h" +#include "mediapipe/gpu/gpu_service.h" #endif // MEDIAPIPE_METAL_ENABLED #endif // !MEDIAPIPE_DISABLE_GPU @@ -147,7 +149,7 @@ class ImageToTensorCalculator : public Node { #if MEDIAPIPE_METAL_ENABLED MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); #else - MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); + cc->UseService(kGpuService).Optional(); #endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_DISABLE_GPU diff --git a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc index ed7d93886..3795b1fa0 100644 --- a/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc +++ b/mediapipe/calculators/tensor/image_to_tensor_calculator_test.cc @@ -41,6 +41,10 @@ #include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/util/image_test_utils.h" +#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED +#include "mediapipe/gpu/gl_context.h" +#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED + namespace mediapipe { namespace { @@ -507,5 +511,79 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeAndUseInputImageDims) { /*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt, /*keep_aspect=*/false, BorderMode::kZero, roi); } + +TEST(ImageToTensorCalculatorTest, CanBeUsedWithoutGpuServiceSet) { + auto graph_config = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input_image" + node { + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:input_image" + output_stream: "TENSORS:tensor" + options { + [mediapipe.ImageToTensorCalculatorOptions.ext] { + output_tensor_float_range { min: 0.0f max: 1.0f } + } + } + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization()); + MP_ASSERT_OK(graph.StartRun({})); + auto image_frame = + std::make_shared(ImageFormat::SRGBA, 128, 256, 4); + Image image = Image(std::move(image_frame)); + Packet packet = MakePacket(std::move(image)); + MP_ASSERT_OK( + graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1)))); + MP_ASSERT_OK(graph.WaitUntilIdle()); + MP_ASSERT_OK(graph.CloseAllPacketSources()); + MP_ASSERT_OK(graph.WaitUntilDone()); +} + +#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED + +TEST(ImageToTensorCalculatorTest, + FailsGracefullyWhenGpuServiceNeededButNotAvailable) { + auto graph_config = + mediapipe::ParseTextProtoOrDie(R"pb( + input_stream: "input_image" + node { + calculator: "ImageToTensorCalculator" + input_stream: "IMAGE:input_image" + output_stream: "TENSORS:tensor" + options { + [mediapipe.ImageToTensorCalculatorOptions.ext] { + output_tensor_float_range { min: 0.0f max: 1.0f } + } + } + } + )pb"); + CalculatorGraph graph; + MP_ASSERT_OK(graph.Initialize(graph_config)); + MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization()); + MP_ASSERT_OK(graph.StartRun({})); + + MP_ASSERT_OK_AND_ASSIGN(auto context, + GlContext::Create(nullptr, /*create_thread=*/true)); + Packet packet; + context->Run([&packet]() { + auto image_frame = + std::make_shared(ImageFormat::SRGBA, 128, 256, 4); + Image image = Image(std::move(image_frame)); + // Ensure image is available on GPU to force ImageToTensorCalculator to + // run on GPU. + ASSERT_TRUE(image.ConvertToGpu()); + packet = MakePacket(std::move(image)); + }); + MP_ASSERT_OK( + graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1)))); + EXPECT_THAT(graph.WaitUntilIdle(), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("GPU service not available"))); +} +#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED + } // namespace } // namespace mediapipe diff --git a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc index 8fd55efa7..8aee46185 100644 --- a/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc +++ b/mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc @@ -141,7 +141,7 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process( } // Run inference. { - MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc); + MEDIAPIPE_PROFILING(GPU_TASK_INVOKE_ADVANCED, cc); return tflite_gpu_runner_->Invoke(); } })); diff --git a/mediapipe/calculators/tensor/tensors_readback_calculator.proto b/mediapipe/calculators/tensor/tensors_readback_calculator.proto new file mode 100644 index 000000000..3c6611f1f --- /dev/null +++ b/mediapipe/calculators/tensor/tensors_readback_calculator.proto @@ -0,0 +1,41 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The option proto for the TensorsReadbackCalculator. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; + +message TensorsReadbackCalculatorOptions { + extend mediapipe.CalculatorOptions { + optional TensorsReadbackCalculatorOptions ext = 514750372; + } + + // Expected shapes of the input tensors. + // The calculator uses these shape to build the GPU programs during + // initialization, and check the actual tensor shapes against the expected + // shapes during runtime. + // Batch size of the tensor is set to be 1. `TensorShape` here can be C, WC, + // or HWC. + // For example {dims: 1 dims: 2} represents a tensor with batch_size = 1, + // width = 1, and num_channels = 2. + message TensorShape { + repeated int32 dims = 1 [packed = true]; + } + // tensor_shape specifies the shape of each input tensors. + repeated TensorShape tensor_shape = 1; +} diff --git a/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png new file mode 100644 index 000000000..8d2e266a9 Binary files /dev/null and b/mediapipe/calculators/tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png differ diff --git a/mediapipe/calculators/tflite/BUILD b/mediapipe/calculators/tflite/BUILD index 9dcb0f733..b5b01d937 100644 --- a/mediapipe/calculators/tflite/BUILD +++ b/mediapipe/calculators/tflite/BUILD @@ -14,6 +14,7 @@ # load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") +load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") load("@bazel_skylib//lib:selects.bzl", "selects") licenses(["notice"]) @@ -312,15 +313,19 @@ cc_library( alwayslink = 1, ) -cc_library( +# TODO: Re-evaluate which of these libraries we can avoid making +# cc_library_with_tflite and can be changed back to cc_library. +cc_library_with_tflite( name = "tflite_model_calculator", srcs = ["tflite_model_calculator.cc"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite:framework_stable", + ], deps = [ "//mediapipe/framework:calculator_framework", "//mediapipe/framework:packet", "//mediapipe/framework/port:ret_check", "@com_google_absl//absl/status", - "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", ], alwayslink = 1, ) diff --git a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc index 3a5ae8282..950d742a9 100644 --- a/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.cc @@ -66,7 +66,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase { } else { cc->OutputSidePackets() .Index(0) - .Set(); + .Set(); } return absl::OkStatus(); } @@ -77,7 +77,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase { const TfLiteCustomOpResolverCalculatorOptions& options = cc->Options(); - std::unique_ptr op_resolver; + std::unique_ptr op_resolver; if (options.use_gpu()) { op_resolver = absl::make_unique(); } else { diff --git a/mediapipe/calculators/tflite/tflite_model_calculator.cc b/mediapipe/calculators/tflite/tflite_model_calculator.cc index 435ea0127..a8a85ed78 100644 --- a/mediapipe/calculators/tflite/tflite_model_calculator.cc +++ b/mediapipe/calculators/tflite/tflite_model_calculator.cc @@ -21,7 +21,7 @@ #include "mediapipe/framework/packet.h" #include "mediapipe/framework/port/ret_check.h" #include "tensorflow/lite/allocation.h" -#include "tensorflow/lite/core/shims/cc/model.h" +#include "tensorflow/lite/model.h" namespace mediapipe { @@ -82,7 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase { } if (cc->InputSidePackets().HasTag("MODEL_FD")) { -#ifdef ABSL_HAVE_MMAP +#if defined(ABSL_HAVE_MMAP) && !TFLITE_WITH_STABLE_ABI model_packet = cc->InputSidePackets().Tag("MODEL_FD"); const auto& model_fd = model_packet.Get>(); diff --git a/mediapipe/calculators/util/BUILD b/mediapipe/calculators/util/BUILD index 6ac60d2c1..710a60d8a 100644 --- a/mediapipe/calculators/util/BUILD +++ b/mediapipe/calculators/util/BUILD @@ -1270,6 +1270,50 @@ cc_library( alwayslink = 1, ) +mediapipe_proto_library( + name = "flat_color_image_calculator_proto", + srcs = ["flat_color_image_calculator.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/util:color_proto", + ], +) + +cc_library( + name = "flat_color_image_calculator", + srcs = ["flat_color_image_calculator.cc"], + deps = [ + ":flat_color_image_calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:node", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/formats:image_frame_opencv", + "//mediapipe/framework/port:opencv_core", + "//mediapipe/util:color_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_test( + name = "flat_color_image_calculator_test", + srcs = ["flat_color_image_calculator_test.cc"], + deps = [ + ":flat_color_image_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_runner", + "//mediapipe/framework:packet", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_frame", + "//mediapipe/framework/port:gtest", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/util:color_cc_proto", + ], +) + cc_library( name = "from_image_calculator", srcs = ["from_image_calculator.cc"], diff --git a/mediapipe/calculators/util/flat_color_image_calculator.cc b/mediapipe/calculators/util/flat_color_image_calculator.cc new file mode 100644 index 000000000..71d3582c5 --- /dev/null +++ b/mediapipe/calculators/util/flat_color_image_calculator.cc @@ -0,0 +1,138 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" +#include "mediapipe/framework/api2/node.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/util/color.pb.h" + +namespace mediapipe { + +namespace { + +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Node; +using ::mediapipe::api2::Output; +} // namespace + +// A calculator for generating an image filled with a single color. +// +// Inputs: +// IMAGE (Image, optional) +// If provided, the output will have the same size +// COLOR (Color proto, optional) +// Color to paint the output with. Takes precedence over the equivalent +// calculator options. +// +// Outputs: +// IMAGE (Image) +// Image filled with the requested color. +// +// Example useage: +// node { +// calculator: "FlatColorImageCalculator" +// input_stream: "IMAGE:image" +// input_stream: "COLOR:color" +// output_stream: "IMAGE:blank_image" +// options { +// [mediapipe.FlatColorImageCalculatorOptions.ext] { +// color: { +// r: 255 +// g: 255 +// b: 255 +// } +// } +// } +// } + +class FlatColorImageCalculator : public Node { + public: + static constexpr Input::Optional kInImage{"IMAGE"}; + static constexpr Input::Optional kInColor{"COLOR"}; + static constexpr Output kOutImage{"IMAGE"}; + + MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage); + + static absl::Status UpdateContract(CalculatorContract* cc) { + const auto& options = cc->Options(); + + RET_CHECK(kInImage(cc).IsConnected() ^ + (options.has_output_height() || options.has_output_width())) + << "Either set IMAGE input stream, or set through options"; + RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color()) + << "Either set COLOR input stream, or set through options"; + + return absl::OkStatus(); + } + + absl::Status Open(CalculatorContext* cc) override; + absl::Status Process(CalculatorContext* cc) override; + + private: + bool use_dimension_from_option_ = false; + bool use_color_from_option_ = false; +}; +MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator); + +absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) { + use_dimension_from_option_ = !kInImage(cc).IsConnected(); + use_color_from_option_ = !kInColor(cc).IsConnected(); + return absl::OkStatus(); +} + +absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) { + const auto& options = cc->Options(); + + int output_height = -1; + int output_width = -1; + if (use_dimension_from_option_) { + output_height = options.output_height(); + output_width = options.output_width(); + } else if (!kInImage(cc).IsEmpty()) { + const Image& input_image = kInImage(cc).Get(); + output_height = input_image.height(); + output_width = input_image.width(); + } else { + return absl::OkStatus(); + } + + Color color; + if (use_color_from_option_) { + color = options.color(); + } else if (!kInColor(cc).IsEmpty()) { + color = kInColor(cc).Get(); + } else { + return absl::OkStatus(); + } + + auto output_frame = std::make_shared(ImageFormat::SRGB, + output_width, output_height); + cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get()); + + output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b())); + + kOutImage(cc).Send(Image(output_frame)); + + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/mediapipe/calculators/util/flat_color_image_calculator.proto b/mediapipe/calculators/util/flat_color_image_calculator.proto new file mode 100644 index 000000000..183bc796e --- /dev/null +++ b/mediapipe/calculators/util/flat_color_image_calculator.proto @@ -0,0 +1,32 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/util/color.proto"; + +message FlatColorImageCalculatorOptions { + extend CalculatorOptions { + optional FlatColorImageCalculatorOptions ext = 515548435; + } + + // Output dimensions. + optional int32 output_width = 1; + optional int32 output_height = 2; + // The color to fill with in the output image. + optional Color color = 3; +} diff --git a/mediapipe/calculators/util/flat_color_image_calculator_test.cc b/mediapipe/calculators/util/flat_color_image_calculator_test.cc new file mode 100644 index 000000000..53c6de1b1 --- /dev/null +++ b/mediapipe/calculators/util/flat_color_image_calculator_test.cc @@ -0,0 +1,210 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_runner.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/util/color.pb.h" + +namespace mediapipe { +namespace { + +using ::testing::HasSubstr; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kColorTag[] = "COLOR"; +constexpr int kImageWidth = 256; +constexpr int kImageHeight = 256; + +TEST(FlatColorImageCalculatorTest, SpecifyColorThroughOptions) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + input_stream: "IMAGE:image" + output_stream: "IMAGE:out_image" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + color: { + r: 100, + g: 200, + b: 255, + } + } + } + )pb"); + + auto image_frame = std::make_shared(ImageFormat::SRGB, + kImageWidth, kImageHeight); + + for (int ts = 0; ts < 3; ++ts) { + runner.MutableInputs()->Tag(kImageTag).packets.push_back( + MakePacket(image_frame).At(Timestamp(ts))); + } + MP_ASSERT_OK(runner.Run()); + + const auto& outputs = runner.Outputs().Tag(kImageTag).packets; + ASSERT_EQ(outputs.size(), 3); + + for (const auto& packet : outputs) { + const auto& image = packet.Get(); + EXPECT_EQ(image.width(), kImageWidth); + EXPECT_EQ(image.height(), kImageHeight); + auto image_frame = image.GetImageFrameSharedPtr(); + auto* pixel_data = image_frame->PixelData(); + EXPECT_EQ(pixel_data[0], 100); + EXPECT_EQ(pixel_data[1], 200); + EXPECT_EQ(pixel_data[2], 255); + } +} + +TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + input_stream: "COLOR:color" + output_stream: "IMAGE:out_image" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + output_width: 7, + output_height: 13, + } + } + )pb"); + + Color color; + color.set_r(0); + color.set_g(5); + color.set_b(0); + + for (int ts = 0; ts < 3; ++ts) { + runner.MutableInputs()->Tag(kColorTag).packets.push_back( + MakePacket(color).At(Timestamp(ts))); + } + MP_ASSERT_OK(runner.Run()); + + const auto& outputs = runner.Outputs().Tag(kImageTag).packets; + ASSERT_EQ(outputs.size(), 3); + + for (const auto& packet : outputs) { + const auto& image = packet.Get(); + EXPECT_EQ(image.width(), 7); + EXPECT_EQ(image.height(), 13); + auto image_frame = image.GetImageFrameSharedPtr(); + const uint8_t* pixel_data = image_frame->PixelData(); + EXPECT_EQ(pixel_data[0], 0); + EXPECT_EQ(pixel_data[1], 5); + EXPECT_EQ(pixel_data[2], 0); + } +} + +TEST(FlatColorImageCalculatorTest, FailureMissingDimension) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + input_stream: "COLOR:color" + output_stream: "IMAGE:out_image" + )pb"); + + Color color; + color.set_r(0); + color.set_g(5); + color.set_b(0); + + for (int ts = 0; ts < 3; ++ts) { + runner.MutableInputs()->Tag(kColorTag).packets.push_back( + MakePacket(color).At(Timestamp(ts))); + } + ASSERT_THAT(runner.Run().message(), + HasSubstr("Either set IMAGE input stream")); +} + +TEST(FlatColorImageCalculatorTest, FailureMissingColor) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + input_stream: "IMAGE:image" + output_stream: "IMAGE:out_image" + )pb"); + + auto image_frame = std::make_shared(ImageFormat::SRGB, + kImageWidth, kImageHeight); + + for (int ts = 0; ts < 3; ++ts) { + runner.MutableInputs()->Tag(kImageTag).packets.push_back( + MakePacket(image_frame).At(Timestamp(ts))); + } + ASSERT_THAT(runner.Run().message(), + HasSubstr("Either set COLOR input stream")); +} + +TEST(FlatColorImageCalculatorTest, FailureDuplicateDimension) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + input_stream: "IMAGE:image" + input_stream: "COLOR:color" + output_stream: "IMAGE:out_image" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + output_width: 7, + output_height: 13, + } + } + )pb"); + + auto image_frame = std::make_shared(ImageFormat::SRGB, + kImageWidth, kImageHeight); + + for (int ts = 0; ts < 3; ++ts) { + runner.MutableInputs()->Tag(kImageTag).packets.push_back( + MakePacket(image_frame).At(Timestamp(ts))); + } + ASSERT_THAT(runner.Run().message(), + HasSubstr("Either set IMAGE input stream")); +} + +TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) { + CalculatorRunner runner(R"pb( + calculator: "FlatColorImageCalculator" + input_stream: "IMAGE:image" + input_stream: "COLOR:color" + output_stream: "IMAGE:out_image" + options { + [mediapipe.FlatColorImageCalculatorOptions.ext] { + color: { + r: 100, + g: 200, + b: 255, + } + } + } + )pb"); + + Color color; + color.set_r(0); + color.set_g(5); + color.set_b(0); + + for (int ts = 0; ts < 3; ++ts) { + runner.MutableInputs()->Tag(kColorTag).packets.push_back( + MakePacket(color).At(Timestamp(ts))); + } + ASSERT_THAT(runner.Run().message(), + HasSubstr("Either set COLOR input stream")); +} + +} // namespace +} // namespace mediapipe diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar index e708b1c02..943f0cbfa 100644 Binary files a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar and b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.jar differ diff --git a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties index 070cb702f..508322917 100644 --- a/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties +++ b/mediapipe/examples/android/solutions/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip +networkTimeout=10000 zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/mediapipe/examples/android/solutions/gradlew b/mediapipe/examples/android/solutions/gradlew index 4f906e0c8..65dcd68d6 100755 --- a/mediapipe/examples/android/solutions/gradlew +++ b/mediapipe/examples/android/solutions/gradlew @@ -1,7 +1,7 @@ -#!/usr/bin/env sh +#!/bin/sh # -# Copyright 2015 the original author or authors. +# Copyright © 2015-2021 the original authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,67 +17,101 @@ # ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## # Attempt to set APP_HOME -# Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi -done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null -APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum warn () { echo "$*" -} +} >&2 die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar @@ -87,9 +121,9 @@ CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -98,7 +132,7 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" + JAVACMD=java which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the @@ -106,80 +140,105 @@ location of your Java installation." fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi -fi - -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi - -# For Cygwin or MSYS, switch paths to Windows format before running java -if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi - # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" - fi - i=`expr $i + 1` - done - case $i in - 0) set -- ;; - 1) set -- "$args0" ;; - 2) set -- "$args0" "$args1" ;; - 3) set -- "$args0" "$args1" "$args2" ;; - 4) set -- "$args0" "$args1" "$args2" "$args3" ;; - 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" esac fi -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=`save "$@"` +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' exec "$JAVACMD" "$@" diff --git a/mediapipe/examples/android/solutions/gradlew.bat b/mediapipe/examples/android/solutions/gradlew.bat index ac1b06f93..93e3f59f1 100755 --- a/mediapipe/examples/android/solutions/gradlew.bat +++ b/mediapipe/examples/android/solutions/gradlew.bat @@ -1,89 +1,92 @@ -@rem -@rem Copyright 2015 the original author or authors. -@rem -@rem Licensed under the Apache License, Version 2.0 (the "License"); -@rem you may not use this file except in compliance with the License. -@rem You may obtain a copy of the License at -@rem -@rem https://www.apache.org/licenses/LICENSE-2.0 -@rem -@rem Unless required by applicable law or agreed to in writing, software -@rem distributed under the License is distributed on an "AS IS" BASIS, -@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -@rem See the License for the specific language governing permissions and -@rem limitations under the License. -@rem - -@if "%DEBUG%" == "" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Resolve any "." and ".." in APP_HOME to make it shorter. -for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto execute - -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto execute - -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:execute -@rem Setup the command line - -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar - - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* - -:end -@rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc index 787baa370..84b229d80 100644 --- a/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc +++ b/mediapipe/examples/desktop/autoflip/quality/padding_effect_generator_test.cc @@ -138,7 +138,23 @@ void TestWithAspectRatio(const double aspect_ratio, std::string result_image; MP_ASSERT_OK( mediapipe::file::GetContents(result_string_path, &result_image)); - EXPECT_EQ(result_image, output_string); + if (result_image != output_string) { + // There may be slight differences due to the way the JPEG was encoded or + // the OpenCV version used to generate the reference files. Compare + // pixel-by-pixel using the Peak Signal-to-Noise Ratio instead. + cv::Mat result_mat = + cv::imdecode(cv::Mat(1, result_image.size(), CV_8UC1, + const_cast(result_image.data())), + cv::IMREAD_UNCHANGED); + cv::Mat output_mat = + cv::imdecode(cv::Mat(1, output_string.size(), CV_8UC1, + const_cast(output_string.data())), + cv::IMREAD_UNCHANGED); + ASSERT_EQ(result_mat.rows, output_mat.rows); + ASSERT_EQ(result_mat.cols, output_mat.cols); + ASSERT_EQ(result_mat.type(), output_mat.type()); + EXPECT_GT(cv::PSNR(result_mat, output_mat), 45.0); + } } else { std::string output_string_path = mediapipe::file::JoinPath( absl::GetFlag(FLAGS_output_folder), diff --git a/mediapipe/examples/desktop/autoflip/quality/testdata/result_0.3.jpg b/mediapipe/examples/desktop/autoflip/quality/testdata/result_0.3.jpg index 53ebcf770..3602046e6 100644 Binary files a/mediapipe/examples/desktop/autoflip/quality/testdata/result_0.3.jpg and b/mediapipe/examples/desktop/autoflip/quality/testdata/result_0.3.jpg differ diff --git a/mediapipe/examples/desktop/autoflip/quality/testdata/result_0.6.jpg b/mediapipe/examples/desktop/autoflip/quality/testdata/result_0.6.jpg index 2ffde6739..16842d8b7 100644 Binary files a/mediapipe/examples/desktop/autoflip/quality/testdata/result_0.6.jpg and b/mediapipe/examples/desktop/autoflip/quality/testdata/result_0.6.jpg differ diff --git a/mediapipe/examples/desktop/autoflip/quality/testdata/result_1.jpg b/mediapipe/examples/desktop/autoflip/quality/testdata/result_1.jpg index c03cd4610..741e078e9 100644 Binary files a/mediapipe/examples/desktop/autoflip/quality/testdata/result_1.jpg and b/mediapipe/examples/desktop/autoflip/quality/testdata/result_1.jpg differ diff --git a/mediapipe/examples/desktop/autoflip/quality/testdata/result_3.4.jpg b/mediapipe/examples/desktop/autoflip/quality/testdata/result_3.4.jpg index 5ec4ea6ec..4efbe7da2 100644 Binary files a/mediapipe/examples/desktop/autoflip/quality/testdata/result_3.4.jpg and b/mediapipe/examples/desktop/autoflip/quality/testdata/result_3.4.jpg differ diff --git a/mediapipe/examples/ios/helloworld/BUILD b/mediapipe/examples/ios/helloworld/BUILD index aed0c35a5..6bfcfaaef 100644 --- a/mediapipe/examples/ios/helloworld/BUILD +++ b/mediapipe/examples/ios/helloworld/BUILD @@ -56,5 +56,6 @@ objc_library( deps = [ "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", "//mediapipe/graphs/edge_detection:mobile_calculators", + "//third_party/apple_frameworks:Metal", ], ) diff --git a/mediapipe/framework/calculator_graph.cc b/mediapipe/framework/calculator_graph.cc index d803d7141..b49930b7a 100644 --- a/mediapipe/framework/calculator_graph.cc +++ b/mediapipe/framework/calculator_graph.cc @@ -631,7 +631,13 @@ absl::Status CalculatorGraph::PrepareServices() { for (const auto& [key, request] : node->Contract().ServiceRequests()) { auto packet = service_manager_.GetServicePacket(request.Service()); if (!packet.IsEmpty()) continue; - auto packet_or = request.Service().CreateDefaultObject(); + absl::StatusOr packet_or; + if (allow_service_default_initialization_) { + packet_or = request.Service().CreateDefaultObject(); + } else { + packet_or = absl::FailedPreconditionError( + "Service default initialization is disallowed."); + } if (packet_or.ok()) { MP_RETURN_IF_ERROR(service_manager_.SetServicePacket( request.Service(), std::move(packet_or).value())); diff --git a/mediapipe/framework/calculator_graph.h b/mediapipe/framework/calculator_graph.h index 93dbfe8dc..354694e39 100644 --- a/mediapipe/framework/calculator_graph.h +++ b/mediapipe/framework/calculator_graph.h @@ -405,6 +405,34 @@ class CalculatorGraph { return service_manager_.GetServiceObject(service); } + // Disallows/disables default initialization of MediaPipe graph services. + // + // IMPORTANT: MediaPipe graph serices, essentially a graph-level singletons, + // are designed in the way, so they may provide default initialization. For + // example, this allows to run OpenGL processing wihtin the graph without + // provinging a praticular OpenGL context as it can be provided by + // default-initializable `kGpuService`. (One caveat here, you may still need + // to initialize it manually to share graph context with external context.) + // + // Even if calculators require some service optionally + // (`calculator_contract->UseService(kSomeService).Optional()`), it will be + // still initialized if it allows default initialization. + // + // So far, in rare cases, this may be unwanted and strict control of what + // services are allowed in the graph can be achieved by calling this method, + // following `SetServiceObject` call for services which are allowed in the + // graph. + // + // Recommendation: do not use unless you have to (for example, default + // initialization has side effects) + // + // NOTE: must be called before `StartRun`/`Run`, where services are checked + // and can be default-initialized. + absl::Status DisallowServiceDefaultInitialization() { + allow_service_default_initialization_ = false; + return absl::OkStatus(); + } + // Sets a service object, essentially a graph-level singleton, which can be // accessed by calculators and subgraphs without requiring an explicit // connection. @@ -644,6 +672,9 @@ class CalculatorGraph { // Object to manage graph services. GraphServiceManager service_manager_; + // Indicates whether service default initialization is allowed. + bool allow_service_default_initialization_ = true; + // Vector of errors encountered while running graph. Always use RecordError() // to add an error to this vector. std::vector errors_ ABSL_GUARDED_BY(error_mutex_); diff --git a/mediapipe/framework/calculator_profile.proto b/mediapipe/framework/calculator_profile.proto index d86162ea5..0b5498c4e 100644 --- a/mediapipe/framework/calculator_profile.proto +++ b/mediapipe/framework/calculator_profile.proto @@ -136,6 +136,8 @@ message GraphTrace { GPU_TASK_INVOKE = 16; TPU_TASK_INVOKE = 17; CPU_TASK_INVOKE = 18; + GPU_TASK_INVOKE_ADVANCED = 19; + TPU_TASK_INVOKE_ASYNC = 20; } // The timing for one packet set being processed at one caclulator node. diff --git a/mediapipe/framework/profiler/BUILD b/mediapipe/framework/profiler/BUILD index 6184ed45b..38d72b265 100644 --- a/mediapipe/framework/profiler/BUILD +++ b/mediapipe/framework/profiler/BUILD @@ -315,11 +315,11 @@ cc_library( visibility = ["//visibility:private"], deps = [ "@com_google_absl//absl/flags:flag", + "//mediapipe/framework/deps:file_path", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/port:statusor", "//mediapipe/framework/port:status", - "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:statusor", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", @@ -328,8 +328,8 @@ cc_library( "//mediapipe/framework/port:file_helpers", ], "//mediapipe:android": [ - "//mediapipe/java/com/google/mediapipe/framework/jni:jni_util", "//mediapipe/framework/port:file_helpers", + "//mediapipe/java/com/google/mediapipe/framework/jni:jni_util", ], "//mediapipe:apple": [ "//mediapipe/framework/port:file_helpers", diff --git a/mediapipe/framework/profiler/trace_buffer.h b/mediapipe/framework/profiler/trace_buffer.h index b44d8f0bf..b5e2d9994 100644 --- a/mediapipe/framework/profiler/trace_buffer.h +++ b/mediapipe/framework/profiler/trace_buffer.h @@ -112,6 +112,10 @@ struct TraceEvent { static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE; static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE; static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE; + static constexpr EventType GPU_TASK_INVOKE_ADVANCED = + GraphTrace::GPU_TASK_INVOKE_ADVANCED; + static constexpr EventType TPU_TASK_INVOKE_ASYNC = + GraphTrace::TPU_TASK_INVOKE_ASYNC; }; // Packet trace log buffer. diff --git a/mediapipe/framework/profiler/trace_builder.cc b/mediapipe/framework/profiler/trace_builder.cc index ce5bf1e25..9c3661ffe 100644 --- a/mediapipe/framework/profiler/trace_builder.cc +++ b/mediapipe/framework/profiler/trace_builder.cc @@ -57,7 +57,6 @@ struct hash { namespace mediapipe { namespace { - void BasicTraceEventTypes(TraceEventRegistry* result) { // The initializer arguments below are: event_type, description, // is_packet_event, is_stream_event, id_event_data. @@ -84,6 +83,15 @@ void BasicTraceEventTypes(TraceEventRegistry* result) { "A time measured by GPU clock and by CPU clock.", true, false}, {TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.", true, true, false}, + + {TraceEvent::GPU_TASK_INVOKE, "CPU timing for initiating a GPU task."}, + {TraceEvent::TPU_TASK_INVOKE, "CPU timing for initiating a TPU task."}, + {TraceEvent::CPU_TASK_INVOKE, "CPU timing for initiating a CPU task."}, + {TraceEvent::GPU_TASK_INVOKE_ADVANCED, + "CPU timing for initiating a GPU task bypassing the TFLite " + "interpreter."}, + {TraceEvent::TPU_TASK_INVOKE_ASYNC, + "CPU timing for async initiation of a TPU task."}, }; for (const TraceEventType& t : basic_types) { (*result)[t.event_type()] = t; diff --git a/mediapipe/framework/tool/BUILD b/mediapipe/framework/tool/BUILD index c34194fbf..56ca0dc65 100644 --- a/mediapipe/framework/tool/BUILD +++ b/mediapipe/framework/tool/BUILD @@ -77,7 +77,6 @@ mediapipe_proto_library( name = "calculator_graph_template_proto", srcs = ["calculator_graph_template.proto"], def_options_lib = False, - def_py_proto = False, visibility = ["//visibility:public"], deps = [ "//mediapipe/framework:calculator_options_proto", diff --git a/mediapipe/framework/tool/mediapipe_proto.bzl b/mediapipe/framework/tool/mediapipe_proto.bzl index 598bc08f2..01888c0d0 100644 --- a/mediapipe/framework/tool/mediapipe_proto.bzl +++ b/mediapipe/framework/tool/mediapipe_proto.bzl @@ -204,7 +204,7 @@ def rewrite_mediapipe_proto(name, rewrite_proto, source_proto, **kwargs): 'import public "' + join_path + '";', ) rewrite_ref = SubsituteCommand( - r"mediapipe\\.(" + rewrite_message_regex + ")", + r"mediapipe\.(" + rewrite_message_regex + ")", r"mediapipe.\\1", ) rewrite_objc = SubsituteCommand( @@ -284,7 +284,7 @@ def mediapipe_proto_library( def_jspb_proto: define the jspb_proto_library target def_go_proto: define the go_proto_library target def_options_lib: define the mediapipe_options_library target - def_rewrite: define a sibbling mediapipe_proto_library with package "mediapipe" + def_rewrite: define a sibling mediapipe_proto_library with package "mediapipe" """ mediapipe_proto_library_impl( diff --git a/mediapipe/framework/tool/subgraph_expansion.cc b/mediapipe/framework/tool/subgraph_expansion.cc index 9f81153f1..dcd055f59 100644 --- a/mediapipe/framework/tool/subgraph_expansion.cc +++ b/mediapipe/framework/tool/subgraph_expansion.cc @@ -183,12 +183,13 @@ absl::Status FindCorrespondingStreams( // name, calculator, input_stream, output_stream, input_side_packet, // output_side_packet, options. // All other fields are only applicable to calculators. +// TODO: Check whether executor is not set in the subgraph node +// after this issues is properly solved. absl::Status ValidateSubgraphFields( const CalculatorGraphConfig::Node& subgraph_node) { if (subgraph_node.source_layer() || subgraph_node.buffer_size_hint() || subgraph_node.has_output_stream_handler() || - subgraph_node.input_stream_info_size() != 0 || - !subgraph_node.executor().empty()) { + subgraph_node.input_stream_info_size() != 0) { return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) << "Subgraph \"" << subgraph_node.name() << "\" has a field that is only applicable to calculators."; diff --git a/mediapipe/gpu/BUILD b/mediapipe/gpu/BUILD index 83948226a..9b538a7f2 100644 --- a/mediapipe/gpu/BUILD +++ b/mediapipe/gpu/BUILD @@ -272,14 +272,6 @@ selects.config_setting_group( ], ) -selects.config_setting_group( - name = "platform_ios_without_gpu", - match_all = [ - ":disable_gpu", - "//mediapipe:ios", - ], -) - selects.config_setting_group( name = "platform_macos_with_gpu", match_all = [ @@ -296,32 +288,33 @@ cc_library( deps = [ ":gpu_buffer_format", ":gpu_buffer_storage", + ":gpu_buffer_storage_image_frame", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", - ":gpu_buffer_storage_image_frame", ] + select({ "//conditions:default": [ - ":gl_texture_view", ":gl_texture_buffer", + ":gl_texture_view", ], ":platform_ios_with_gpu": [ ":gl_texture_view", ":gpu_buffer_storage_cv_pixel_buffer", - "//mediapipe/objc:util", "//mediapipe/objc:CFHolder", ], ":platform_macos_with_gpu": [ - "//mediapipe/objc:CFHolder", - ":gl_texture_view", ":gl_texture_buffer", - ], - ":platform_ios_without_gpu": [ - "//mediapipe/objc:util", + ":gl_texture_view", + "//mediapipe/objc:CFHolder", ], ":disable_gpu": [], + }) + select({ + "//conditions:default": [], + "//mediapipe:ios": [ + "//mediapipe/objc:util", + ], }), ) @@ -331,9 +324,9 @@ cc_library( hdrs = ["gpu_buffer_format.h"], visibility = ["//visibility:public"], deps = [ - "//mediapipe/framework/formats:image_format_cc_proto", "@com_google_absl//absl/container:flat_hash_map", "//mediapipe/framework/deps:no_destructor", + "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/port:logging", ] + select({ "//conditions:default": [ @@ -474,6 +467,7 @@ cc_library( "//mediapipe/framework/formats:frame_buffer", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:yuv_image", + "//mediapipe/util/frame_buffer:frame_buffer_util", "//third_party/libyuv", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -619,22 +613,22 @@ cc_library( }), visibility = ["//visibility:private"], deps = [ - ":gl_context_options_cc_proto", - ":graph_support", - "//mediapipe/framework:calculator_context", - "//mediapipe/framework:executor", - "//mediapipe/framework:calculator_node", - "//mediapipe/framework/port:ret_check", - "//mediapipe/framework/deps:no_destructor", ":gl_base", ":gl_context", + ":gl_context_options_cc_proto", ":gpu_buffer_multi_pool", ":gpu_shared_data_header", + ":graph_support", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_node", + "//mediapipe/framework:executor", + "//mediapipe/framework/deps:no_destructor", + "//mediapipe/framework/port:ret_check", ] + select({ "//conditions:default": [], "//mediapipe:apple": [ - ":metal_shared_resources", ":cv_texture_cache_manager", + ":metal_shared_resources", ], }), ) @@ -703,13 +697,13 @@ cc_library( ":gpu_buffer", ":gpu_shared_data_header", ":multi_pool", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", "//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_node", "//mediapipe/framework/port:logging", "//mediapipe/util:resource_cache", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", ] + select({ "//conditions:default": [ ":gl_texture_buffer", @@ -725,9 +719,9 @@ cc_library( "//mediapipe:macos": [ ":cv_pixel_buffer_pool_wrapper", ":cv_texture_cache_manager", - ":pixel_buffer_pool_util", ":gl_texture_buffer", ":gl_texture_buffer_pool", + ":pixel_buffer_pool_util", ], }), ) @@ -795,31 +789,31 @@ cc_library( ":gpu_buffer", ":gpu_buffer_format", ":gpu_buffer_multi_pool", - ":gpu_shared_data_internal", ":gpu_service", + ":gpu_shared_data_internal", ":graph_support", ":image_frame_view", ":shader_util", - "//mediapipe/framework:calculator_framework", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/synchronization", "//mediapipe/framework:calculator_context", - "//mediapipe/framework:calculator_node", "//mediapipe/framework:calculator_contract", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_node", "//mediapipe/framework:demangle", "//mediapipe/framework:legacy_calculator_support", "//mediapipe/framework:packet", "//mediapipe/framework:packet_set", "//mediapipe/framework:packet_type", "//mediapipe/framework:timestamp", + "//mediapipe/framework/deps:registration", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:logging", + "//mediapipe/framework/port:map_util", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/synchronization", - "//mediapipe/framework/deps:registration", - "//mediapipe/framework/port:map_util", ] + select({ "//conditions:default": [ ], @@ -918,8 +912,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":gl_calculator_helper", - ":gpu_buffer_storage_image_frame", - "//mediapipe/framework/api2:node", "//mediapipe/framework:calculator_framework", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/port:status", @@ -941,7 +933,7 @@ mediapipe_proto_library( ], ) -proto_library( +mediapipe_proto_library( name = "gl_scaler_calculator_proto", srcs = ["gl_scaler_calculator.proto"], visibility = ["//visibility:public"], @@ -951,17 +943,6 @@ proto_library( ], ) -mediapipe_cc_proto_library( - name = "gl_scaler_calculator_cc_proto", - srcs = ["gl_scaler_calculator.proto"], - cc_deps = [ - ":scale_mode_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - ], - visibility = ["//visibility:public"], - deps = [":gl_scaler_calculator_proto"], -) - cc_library( name = "gl_scaler_calculator", srcs = ["gl_scaler_calculator.cc"], diff --git a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc index c67fb0c62..2a8331db8 100644 --- a/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc +++ b/mediapipe/gpu/image_frame_to_gpu_buffer_calculator.cc @@ -12,63 +12,73 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/gpu/gl_calculator_helper.h" +#ifdef __APPLE__ +#include "mediapipe/objc/util.h" +#endif + namespace mediapipe { -namespace api2 { -class ImageFrameToGpuBufferCalculator - : public RegisteredNode { +// Convert ImageFrame to GpuBuffer. +class ImageFrameToGpuBufferCalculator : public CalculatorBase { public: - static constexpr Input kIn{""}; - static constexpr Output kOut{""}; + ImageFrameToGpuBufferCalculator() {} - MEDIAPIPE_NODE_INTERFACE(ImageFrameToGpuBufferCalculator, kIn, kOut); - - static absl::Status UpdateContract(CalculatorContract* cc); + static absl::Status GetContract(CalculatorContract* cc); absl::Status Open(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override; private: +#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER GlCalculatorHelper helper_; +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER }; +REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator); // static -absl::Status ImageFrameToGpuBufferCalculator::UpdateContract( +absl::Status ImageFrameToGpuBufferCalculator::GetContract( CalculatorContract* cc) { + cc->Inputs().Index(0).Set(); + cc->Outputs().Index(0).Set(); // Note: we call this method even on platforms where we don't use the helper, // to ensure the calculator's contract is the same. In particular, the helper // enables support for the legacy side packet, which several graphs still use. - return GlCalculatorHelper::UpdateContract(cc); + MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); + return absl::OkStatus(); } absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) { + // Inform the framework that we always output at the same timestamp + // as we receive a packet at. + cc->SetOffset(TimestampDiff(0)); +#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER MP_RETURN_IF_ERROR(helper_.Open(cc)); +#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER return absl::OkStatus(); } absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) { - auto image_frame = std::const_pointer_cast( - mediapipe::SharedPtrWithPacket(kIn(cc).packet())); - auto gpu_buffer = api2::MakePacket( - std::make_shared( - std::move(image_frame))) - .At(cc->InputTimestamp()); - // This calculator's behavior has been to do the texture upload eagerly, and - // some graphs may rely on running this on a separate GL context to avoid - // blocking another context with the read operation. So let's request GPU - // access here to ensure that the behavior stays the same. - // TODO: have a better way to do this, or defer until later. - helper_.RunInGlContext( - [&gpu_buffer] { auto view = gpu_buffer->GetReadView(0); }); - kOut(cc).Send(std::move(gpu_buffer)); +#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER + CFHolder buffer; + MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket( + cc->Inputs().Index(0).Value(), &buffer)); + cc->Outputs().Index(0).Add(new GpuBuffer(buffer), cc->InputTimestamp()); +#else + const auto& input = cc->Inputs().Index(0).Get(); + helper_.RunInGlContext([this, &input, &cc]() { + auto src = helper_.CreateSourceTexture(input); + auto output = src.GetFrame(); + glFlush(); + cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); + src.Release(); + }); +#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER return absl::OkStatus(); } -} // namespace api2 } // namespace mediapipe diff --git a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java index 6a2c97b94..a253a8289 100644 --- a/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java +++ b/mediapipe/java/com/google/mediapipe/framework/GraphTextureFrame.java @@ -66,7 +66,8 @@ public class GraphTextureFrame implements TextureFrame { if (nativeBufferHandle == 0) { return 0; } - if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) { + long contextHandle = nativeGetCurrentExternalContextHandle(); + if (contextHandle != 0 && activeConsumerContextHandleSet.add(contextHandle)) { // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using // PacketGetter.getTextureFrameDeferredSync(). if (deferredSync) { @@ -116,7 +117,14 @@ public class GraphTextureFrame implements TextureFrame { GlSyncToken consumerToken = null; // Note that this remove should be moved to the other overload of release when b/68808951 is // addressed. - if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) { + final long contextHandle = nativeGetCurrentExternalContextHandle(); + if (contextHandle == 0 && !activeConsumerContextHandleSet.isEmpty()) { + logger.atWarning().log( + "GraphTextureFrame is being released on non GL thread while having active consumers," + + " which may lead to external / internal GL contexts synchronization issues."); + } + + if (contextHandle != 0 && activeConsumerContextHandleSet.remove(contextHandle)) { consumerToken = new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); } @@ -169,7 +177,9 @@ public class GraphTextureFrame implements TextureFrame { private native void nativeReleaseBuffer(long nativeHandle); private native int nativeGetTextureName(long nativeHandle); + private native int nativeGetWidth(long nativeHandle); + private native int nativeGetHeight(long nativeHandle); private native void nativeGpuWait(long nativeHandle); diff --git a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD index 4540f63a6..fa9ccffe9 100644 --- a/mediapipe/java/com/google/mediapipe/framework/jni/BUILD +++ b/mediapipe/java/com/google/mediapipe/framework/jni/BUILD @@ -30,11 +30,11 @@ cc_library( "compat_jni.cc", "graph.cc", "graph_jni.cc", + "graph_profiler_jni.cc", "graph_service_jni.cc", "packet_context_jni.cc", "packet_creator_jni.cc", "packet_getter_jni.cc", - "graph_profiler_jni.cc", ] + select({ "//conditions:default": [], "//mediapipe:android": [ @@ -54,11 +54,11 @@ cc_library( "compat_jni.h", "graph.h", "graph_jni.h", + "graph_profiler_jni.h", "graph_service_jni.h", "packet_context_jni.h", "packet_creator_jni.h", "packet_getter_jni.h", - "graph_profiler_jni.h", ] + select({ "//conditions:default": [], "//mediapipe:android": [ @@ -84,40 +84,40 @@ cc_library( deps = [ ":class_registry", ":jni_util", - "//mediapipe/framework/formats:image_format_cc_proto", - "//mediapipe/framework/formats:time_series_header_cc_proto", - "//mediapipe/framework:calculator_cc_proto", - "//mediapipe/framework:calculator_profile_cc_proto", - "//mediapipe/framework:calculator_framework", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@eigen_archive//:eigen3", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_profile_cc_proto", "//mediapipe/framework:camera_intrinsics", "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:video_stream_header", - "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", - "//mediapipe/framework/tool:name_util", - "//mediapipe/framework/tool:executor_util", "//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:logging", - "//mediapipe/framework/port:threadpool", "//mediapipe/framework/port:singleton", "//mediapipe/framework/port:status", + "//mediapipe/framework/port:threadpool", + "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", + "//mediapipe/framework/tool:executor_util", + "//mediapipe/framework/tool:name_util", ] + select({ "//conditions:default": [ "//mediapipe/framework/port:file_helpers", ], "//mediapipe:android": [ - "//mediapipe/util/android/file/base", "//mediapipe/util/android:asset_manager_util", + "//mediapipe/util/android/file/base", ], }) + select({ "//conditions:default": [ - "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_calculator_helper", + "//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_surface_sink_calculator", "//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_shared_data_internal", @@ -153,9 +153,9 @@ cc_library( srcs = ["class_registry.cc"], hdrs = ["class_registry.h"], deps = [ + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/container:node_hash_map", ] + select({ "//conditions:default": [ ], @@ -172,9 +172,9 @@ cc_library( ":class_registry", ":loose_headers", ":mediapipe_framework_jni", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/container:node_hash_map", "//mediapipe/framework/port:logging", ] + select({ "//conditions:default": [ diff --git a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl index 645e8b722..1b80744e8 100644 --- a/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl +++ b/mediapipe/java/com/google/mediapipe/mediapipe_aar.bzl @@ -357,6 +357,22 @@ def mediapipe_java_proto_srcs(name = ""): target = "//mediapipe/framework/formats:rect_java_proto_lite", src_out = "com/google/mediapipe/formats/proto/RectProto.java", )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/util:color_java_proto_lite", + src_out = "com/google/mediapipe/util/proto/ColorProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/util:label_map_java_proto_lite", + src_out = "com/google/mediapipe/util/proto/LabelMapProto.java", + )) + + proto_src_list.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/util:render_data_java_proto_lite", + src_out = "com/google/mediapipe/util/proto/RenderDataProto.java", + )) + return proto_src_list def mediapipe_logging_java_proto_srcs(name = ""): diff --git a/mediapipe/model_maker/models/face_stylizer/BUILD b/mediapipe/model_maker/models/face_stylizer/BUILD new file mode 100644 index 000000000..74ca71554 --- /dev/null +++ b/mediapipe/model_maker/models/face_stylizer/BUILD @@ -0,0 +1,24 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe/model_maker/python/vision/face_stylizer:__subpackages__"]) + +filegroup( + name = "models", + srcs = glob([ + "**", + ]), +) diff --git a/mediapipe/model_maker/python/core/data/dataset.py b/mediapipe/model_maker/python/core/data/dataset.py index a92b05c0d..113969384 100644 --- a/mediapipe/model_maker/python/core/data/dataset.py +++ b/mediapipe/model_maker/python/core/data/dataset.py @@ -18,7 +18,7 @@ from __future__ import division from __future__ import print_function import functools -from typing import Callable, Optional, Tuple, TypeVar +from typing import Any, Callable, Optional, Tuple, TypeVar # Dependency imports import tensorflow as tf @@ -66,12 +66,14 @@ class Dataset(object): """ return self._size - def gen_tf_dataset(self, - batch_size: int = 1, - is_training: bool = False, - shuffle: bool = False, - preprocess: Optional[Callable[..., bool]] = None, - drop_remainder: bool = False) -> tf.data.Dataset: + def gen_tf_dataset( + self, + batch_size: int = 1, + is_training: bool = False, + shuffle: bool = False, + preprocess: Optional[Callable[..., Any]] = None, + drop_remainder: bool = False, + ) -> tf.data.Dataset: """Generates a batched tf.data.Dataset for training/evaluation. Args: diff --git a/mediapipe/model_maker/python/core/tasks/classifier.py b/mediapipe/model_maker/python/core/tasks/classifier.py index abcfff835..bfe0f027f 100644 --- a/mediapipe/model_maker/python/core/tasks/classifier.py +++ b/mediapipe/model_maker/python/core/tasks/classifier.py @@ -48,11 +48,13 @@ class Classifier(custom_model.CustomModel): self._hparams: hp.BaseHParams = None self._history: tf.keras.callbacks.History = None - def _train_model(self, - train_data: classification_ds.ClassificationDataset, - validation_data: classification_ds.ClassificationDataset, - preprocessor: Optional[Callable[..., bool]] = None, - checkpoint_path: Optional[str] = None): + def _train_model( + self, + train_data: classification_ds.ClassificationDataset, + validation_data: classification_ds.ClassificationDataset, + preprocessor: Optional[Callable[..., Any]] = None, + checkpoint_path: Optional[str] = None, + ): """Trains the classifier model. Compiles and fits the tf.keras `_model` and records the `_history`. diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index 69a8654ec..7a0b8fcf0 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -115,9 +115,11 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None, def convert_to_tflite( model: tf.keras.Model, quantization_config: Optional[quantization.QuantizationConfig] = None, - supported_ops: Tuple[tf.lite.OpsSet, - ...] = (tf.lite.OpsSet.TFLITE_BUILTINS,), - preprocess: Optional[Callable[..., bool]] = None) -> bytearray: + supported_ops: Tuple[tf.lite.OpsSet, ...] = ( + tf.lite.OpsSet.TFLITE_BUILTINS, + ), + preprocess: Optional[Callable[..., Any]] = None, +) -> bytearray: """Converts the input Keras model to TFLite format. Args: diff --git a/mediapipe/model_maker/python/vision/face_stylizer/BUILD b/mediapipe/model_maker/python/vision/face_stylizer/BUILD new file mode 100644 index 000000000..804511540 --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/BUILD @@ -0,0 +1,48 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Placeholder for internal Python strict test compatibility macro. +# Placeholder for internal Python strict library and test compatibility macro. + +licenses(["notice"]) + +package(default_visibility = ["//mediapipe:__subpackages__"]) + +filegroup( + name = "testdata", + srcs = glob([ + "testdata/**", + ]), +) + +py_library( + name = "dataset", + srcs = ["dataset.py"], + deps = [ + "//mediapipe/model_maker/python/core/data:classification_dataset", + "//mediapipe/model_maker/python/vision/core:image_utils", + ], +) + +py_test( + name = "dataset_test", + srcs = ["dataset_test.py"], + data = [ + ":testdata", + ], + deps = [ + ":dataset", + "//mediapipe/tasks/python/test:test_utils", + ], +) diff --git a/mediapipe/model_maker/python/vision/face_stylizer/__init__.py b/mediapipe/model_maker/python/vision/face_stylizer/__init__.py new file mode 100644 index 000000000..e935c0c76 --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe Model Maker Python Public API For Face Stylization.""" diff --git a/mediapipe/model_maker/python/vision/face_stylizer/dataset.py b/mediapipe/model_maker/python/vision/face_stylizer/dataset.py new file mode 100644 index 000000000..b6c85d6f3 --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/dataset.py @@ -0,0 +1,98 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Face stylizer dataset library.""" + +import logging +import os + +import tensorflow as tf + +from mediapipe.model_maker.python.core.data import classification_dataset +from mediapipe.model_maker.python.vision.core import image_utils + + +# TODO: Change to a unlabeled dataset if it makes sense. +class Dataset(classification_dataset.ClassificationDataset): + """Dataset library for face stylizer fine tuning.""" + + @classmethod + def from_folder( + cls, dirname: str + ) -> classification_dataset.ClassificationDataset: + """Loads images from the given directory. + + The style image dataset directory is expected to contain one subdirectory + whose name represents the label of the style. There can be one or multiple + images of the same style in that subdirectory. Supported input image formats + include 'jpg', 'jpeg', 'png'. + + Args: + dirname: Name of the directory containing the image files. + + Returns: + Dataset containing images and labels and other related info. + Raises: + ValueError: if the input data directory is empty. + """ + data_root = os.path.abspath(dirname) + + # Assumes the image data of the same label are in the same subdirectory, + # gets image path and label names. + all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*')) + all_image_size = len(all_image_paths) + if all_image_size == 0: + raise ValueError('Invalid input data directory') + if not any( + fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths + ): + raise ValueError('No images found under given directory') + + label_names = sorted( + name + for name in os.listdir(data_root) + if os.path.isdir(os.path.join(data_root, name)) + ) + all_label_size = len(label_names) + index_by_label = dict( + (name, index) for index, name in enumerate(label_names) + ) + # Get the style label from the subdirectory name. + all_image_labels = [ + index_by_label[os.path.basename(os.path.dirname(path))] + for path in all_image_paths + ] + + path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths) + + image_ds = path_ds.map( + image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE + ) + + # Load label + label_ds = tf.data.Dataset.from_tensor_slices( + tf.cast(all_image_labels, tf.int64) + ) + + # Create a dataset of (image, label) pairs + image_label_ds = tf.data.Dataset.zip((image_ds, label_ds)) + + logging.info( + 'Load images dataset with size: %d, num_label: %d, labels: %s.', + all_image_size, + all_label_size, + ', '.join(label_names), + ) + return Dataset( + dataset=image_label_ds, size=all_image_size, label_names=label_names + ) diff --git a/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py b/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py new file mode 100644 index 000000000..a8af222d4 --- /dev/null +++ b/mediapipe/model_maker/python/vision/face_stylizer/dataset_test.py @@ -0,0 +1,48 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from mediapipe.model_maker.python.vision.face_stylizer import dataset +from mediapipe.tasks.python.test import test_utils + + +class DatasetTest(tf.test.TestCase): + + def setUp(self): + super().setUp() + # TODO: Replace the stylize image dataset with licensed images. + self._test_data_dirname = 'testdata' + + def test_from_folder(self): + input_data_dir = test_utils.get_test_data_path(self._test_data_dirname) + data = dataset.Dataset.from_folder(dirname=input_data_dir) + self.assertEqual(data.num_classes, 2) + self.assertEqual(data.label_names, ['cartoon', 'sketch']) + self.assertLen(data, 2) + + def test_from_folder_raise_value_error_for_invalid_path(self): + with self.assertRaisesRegex(ValueError, 'Invalid input data directory'): + dataset.Dataset.from_folder(dirname='invalid') + + def test_from_folder_raise_value_error_for_valid_no_data_path(self): + input_data_dir = test_utils.get_test_data_path('face_stylizer') + with self.assertRaisesRegex( + ValueError, 'No images found under given directory' + ): + dataset.Dataset.from_folder(dirname=input_data_dir) + + +if __name__ == '__main__': + tf.test.main() diff --git a/mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png b/mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png new file mode 100644 index 000000000..87e9d3d8d Binary files /dev/null and b/mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png differ diff --git a/mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png b/mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png new file mode 100644 index 000000000..169192c96 Binary files /dev/null and b/mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png differ diff --git a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py index ce167df93..ad2f211f5 100644 --- a/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py +++ b/mediapipe/model_maker/python/vision/gesture_recognizer/gesture_recognizer_test.py @@ -15,6 +15,7 @@ import io import os import tempfile +import unittest from unittest import mock as unittest_mock import zipfile @@ -31,6 +32,7 @@ _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdat tf.keras.backend.experimental.enable_tf_random_generator() +@unittest.skip('b/273818271') class GestureRecognizerTest(tf.test.TestCase): def _load_data(self): @@ -72,8 +74,10 @@ class GestureRecognizerTest(tf.test.TestCase): self._test_accuracy(model) + @unittest.skip('b/273818271') @unittest_mock.patch.object( - tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense) + tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense + ) def test_gesture_recognizer_model_layer_widths(self, mock_dense): layer_widths = [64, 32] mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths) @@ -143,12 +147,14 @@ class GestureRecognizerTest(tf.test.TestCase): hyperparameters, 'HParams', autospec=True, - return_value=gesture_recognizer.HParams(epochs=1)) + return_value=gesture_recognizer.HParams(epochs=1), + ) @unittest_mock.patch.object( model_options, 'GestureRecognizerModelOptions', autospec=True, - return_value=gesture_recognizer.ModelOptions()) + return_value=gesture_recognizer.ModelOptions(), + ) def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options( self, mock_hparams, mock_model_options): options = gesture_recognizer.GestureRecognizerOptions() diff --git a/mediapipe/model_maker/python/vision/image_classifier/model_spec.py b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py index ef44f86e6..d46cafe6b 100644 --- a/mediapipe/model_maker/python/vision/image_classifier/model_spec.py +++ b/mediapipe/model_maker/python/vision/image_classifier/model_spec.py @@ -28,7 +28,7 @@ class ModelSpec(object): uri: str, input_image_shape: Optional[List[int]] = None, name: str = ''): - """Initializes a new instance of the `ImageModelSpec` class. + """Initializes a new instance of the image classifier `ModelSpec` class. Args: uri: str, URI to the pretrained model. diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index d7e4a950f..29e5426e0 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -5,4 +5,4 @@ opencv-python tensorflow>=2.10 tensorflow-datasets tensorflow-hub -tf-models-official>=2.10.1 +tf-models-official>=2.11.4 diff --git a/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc b/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc index db0f27484..29f4c79d2 100644 --- a/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc +++ b/mediapipe/modules/objectron/calculators/filter_detection_calculator.cc @@ -37,6 +37,7 @@ constexpr char kDetectionTag[] = "DETECTION"; constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kLabelsTag[] = "LABELS"; constexpr char kLabelsCsvTag[] = "LABELS_CSV"; +constexpr char kLabelMapTag[] = "LABEL_MAP"; using mediapipe::RE2; using Detections = std::vector; @@ -151,6 +152,11 @@ absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) { if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) { cc->InputSidePackets().Tag(kLabelsCsvTag).Set(); } + if (cc->InputSidePackets().HasTag(kLabelMapTag)) { + cc->InputSidePackets() + .Tag(kLabelMapTag) + .Set>>(); + } return absl::OkStatus(); } @@ -158,7 +164,8 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) { cc->SetOffset(TimestampDiff(0)); options_ = cc->Options(); limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) || - cc->InputSidePackets().HasTag(kLabelsCsvTag); + cc->InputSidePackets().HasTag(kLabelsCsvTag) || + cc->InputSidePackets().HasTag(kLabelMapTag); if (limit_labels_) { Strings allowlist_labels; if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) { @@ -168,8 +175,16 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) { for (auto& e : allowlist_labels) { absl::StripAsciiWhitespace(&e); } - } else { + } else if (cc->InputSidePackets().HasTag(kLabelsTag)) { allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get(); + } else if (cc->InputSidePackets().HasTag(kLabelMapTag)) { + auto label_map = cc->InputSidePackets() + .Tag(kLabelMapTag) + .Get>>() + .get(); + for (const auto& [_, v] : *label_map) { + allowlist_labels.push_back(v); + } } allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end()); } diff --git a/mediapipe/modules/objectron/calculators/filter_detection_calculator_test.cc b/mediapipe/modules/objectron/calculators/filter_detection_calculator_test.cc index 958fe4c54..10e750a49 100644 --- a/mediapipe/modules/objectron/calculators/filter_detection_calculator_test.cc +++ b/mediapipe/modules/objectron/calculators/filter_detection_calculator_test.cc @@ -67,5 +67,68 @@ TEST(FilterDetectionCalculatorTest, DetectionFilterTest) { )); } +TEST(FilterDetectionCalculatorTest, DetectionFilterLabelMapTest) { + auto runner = std::make_unique( + ParseTextProtoOrDie(R"pb( + calculator: "FilterDetectionCalculator" + input_stream: "DETECTION:input" + input_side_packet: "LABEL_MAP:input_map" + output_stream: "DETECTION:output" + options { + [mediapipe.FilterDetectionCalculatorOptions.ext]: { min_score: 0.6 } + } + )pb")); + + runner->MutableInputs()->Tag("DETECTION").packets = { + MakePacket(ParseTextProtoOrDie(R"pb( + label: "a" + label: "b" + label: "c" + label: "d" + score: 1 + score: 0.8 + score: 0.3 + score: 0.9 + )pb")) + .At(Timestamp(20)), + MakePacket(ParseTextProtoOrDie(R"pb( + label: "a" + label: "b" + label: "c" + label: "e" + score: 0.6 + score: 0.4 + score: 0.2 + score: 0.7 + )pb")) + .At(Timestamp(40)), + }; + + auto label_map = std::make_unique>(); + (*label_map)[0] = "a"; + (*label_map)[1] = "b"; + (*label_map)[2] = "c"; + runner->MutableSidePackets()->Tag("LABEL_MAP") = + AdoptAsUniquePtr(label_map.release()); + + // Run graph. + MP_ASSERT_OK(runner->Run()); + + // Check output. + EXPECT_THAT( + runner->Outputs().Tag("DETECTION").packets, + ElementsAre(PacketContainsTimestampAndPayload( + Eq(Timestamp(20)), + EqualsProto(R"pb( + label: "a" label: "b" score: 1 score: 0.8 + )pb")), // Packet 1 at timestamp 20. + PacketContainsTimestampAndPayload( + Eq(Timestamp(40)), + EqualsProto(R"pb( + label: "a" score: 0.6 + )pb")) // Packet 2 at timestamp 40. + )); +} + } // namespace } // namespace mediapipe diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 9d5ea26ad..aacb4bfcb 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -57,6 +57,7 @@ pybind_extension( "//mediapipe/framework/formats:landmark_registration", "//mediapipe/framework/formats:rect_registration", "//mediapipe/modules/objectron/calculators:annotation_registration", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_registration", ], ) @@ -95,6 +96,8 @@ cc_library( "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph", + "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", + "//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph", ] + select({ # TODO: Build text_classifier_graph and text_embedder_graph on Windows. "//mediapipe:windows": [], diff --git a/mediapipe/tasks/cc/common.h b/mediapipe/tasks/cc/common.h index 1295177df..70892c5cd 100644 --- a/mediapipe/tasks/cc/common.h +++ b/mediapipe/tasks/cc/common.h @@ -30,7 +30,7 @@ constexpr absl::string_view kMediaPipeTasksPayload = "MediaPipeTasksStatus"; // // At runtime, such codes are meant to be attached (where applicable) to a // `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and -// stringifed error code as value (aka payload). This logic is encapsulated in +// stringified error code as value (aka payload). This logic is encapsulated in // the `CreateStatusWithPayload` helper below for convenience. // // The returned status includes: diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc b/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc index 5867be49b..2f53ff2d5 100644 --- a/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources.cc @@ -51,12 +51,11 @@ ModelAssetBundleResources::Create( auto model_bundle_resources = absl::WrapUnique( new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file))); MP_RETURN_IF_ERROR( - model_bundle_resources->ExtractModelFilesFromExternalFileProto()); + model_bundle_resources->ExtractFilesFromExternalFileProto()); return model_bundle_resources; } -absl::Status -ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() { +absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() { if (model_asset_bundle_file_->has_file_name()) { // If the model asset bundle file name is a relative path, searches the file // in a platform-specific location and returns the absolute path on success. @@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() { model_asset_bundle_file_handler_->GetFileContent().data(); size_t buffer_size = model_asset_bundle_file_handler_->GetFileContent().size(); - return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, - &model_files_); + return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_); } -absl::StatusOr ModelAssetBundleResources::GetModelFile( +absl::StatusOr ModelAssetBundleResources::GetFile( const std::string& filename) const { - auto it = model_files_.find(filename); - if (it == model_files_.end()) { - auto model_files = ListModelFiles(); - std::string all_model_files = - absl::StrJoin(model_files.begin(), model_files.end(), ", "); + auto it = files_.find(filename); + if (it == files_.end()) { + auto files = ListFiles(); + std::string all_files = absl::StrJoin(files.begin(), files.end(), ", "); return CreateStatusWithPayload( StatusCode::kNotFound, - absl::StrFormat("No model file with name: %s. All model files in the " - "model asset bundle are: %s.", - filename, all_model_files), + absl::StrFormat("No file with name: %s. All files in the model asset " + "bundle are: %s.", + filename, all_files), MediaPipeTasksStatus::kFileNotFoundError); } return it->second; } -std::vector ModelAssetBundleResources::ListModelFiles() const { - std::vector model_names; - for (const auto& [model_name, _] : model_files_) { - model_names.push_back(model_name); +std::vector ModelAssetBundleResources::ListFiles() const { + std::vector file_names; + for (const auto& [file_name, _] : files_) { + file_names.push_back(file_name); } - return model_names; + return file_names; } } // namespace core diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources.h b/mediapipe/tasks/cc/core/model_asset_bundle_resources.h index 61474d3ad..02d989d4b 100644 --- a/mediapipe/tasks/cc/core/model_asset_bundle_resources.h +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources.h @@ -28,8 +28,8 @@ namespace core { // The mediapipe task model asset bundle resources class. // A ModelAssetBundleResources object, created from an external file proto, // contains model asset bundle related resources and the method to extract the -// tflite models or model asset bundles for the mediapipe sub-tasks. As the -// resources are owned by the ModelAssetBundleResources object +// tflite models, resource files or model asset bundles for the mediapipe +// sub-tasks. As the resources are owned by the ModelAssetBundleResources object // callers must keep ModelAssetBundleResources alive while using any of the // resources. class ModelAssetBundleResources { @@ -50,14 +50,13 @@ class ModelAssetBundleResources { // Returns the model asset bundle resources tag. std::string GetTag() const { return tag_; } - // Gets the contents of the model file (either tflite model file or model - // bundle file) with the provided name. An error is returned if there is no - // such model file. - absl::StatusOr GetModelFile( - const std::string& filename) const; + // Gets the contents of the model file (either tflite model file, resource + // file or model bundle file) with the provided name. An error is returned if + // there is no such model file. + absl::StatusOr GetFile(const std::string& filename) const; - // Lists all the model file names in the model asset model. - std::vector ListModelFiles() const; + // Lists all the file names in the model asset model. + std::vector ListFiles() const; private: // Constructor. @@ -65,9 +64,9 @@ class ModelAssetBundleResources { const std::string& tag, std::unique_ptr model_asset_bundle_file); - // Extracts the model files (either tflite model file or model bundle file) - // from the external file proto. - absl::Status ExtractModelFilesFromExternalFileProto(); + // Extracts the model files (either tflite model file, resource file or model + // bundle file) from the external file proto. + absl::Status ExtractFilesFromExternalFileProto(); // The model asset bundle resources tag. const std::string tag_; @@ -78,11 +77,11 @@ class ModelAssetBundleResources { // The ExternalFileHandler for the model asset bundle. std::unique_ptr model_asset_bundle_file_handler_; - // The model files bundled in model asset bundle, as a map with the filename + // The files bundled in model asset bundle, as a map with the filename // (corresponding to a basename, e.g. "hand_detector.tflite") as key and - // a pointer to the file contents as value. Each model file can be either - // a TFLite model file or a model bundle file for sub-task. - absl::flat_hash_map model_files_; + // a pointer to the file contents as value. Each file can be either a TFLite + // model file, resource file or a model bundle file for sub-task. + absl::flat_hash_map files_; }; } // namespace core diff --git a/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc b/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc index 359deef91..85a94ccc7 100644 --- a/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc +++ b/mediapipe/tasks/cc/core/model_asset_bundle_resources_test.cc @@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") - .status()); + model_bundle_resources->GetFile("dummy_hand_landmarker.task").status()); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite") .status()); } @@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") - .status()); + model_bundle_resources->GetFile("dummy_hand_landmarker.task").status()); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite") .status()); } @@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") - .status()); + model_bundle_resources->GetFile("dummy_hand_landmarker.task").status()); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite") .status()); } #endif // _WIN32 @@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") - .status()); + model_bundle_resources->GetFile("dummy_hand_landmarker.task").status()); MP_EXPECT_OK( - model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") + model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite") .status()); } @@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); auto status_or_model_bundle_file = - model_bundle_resources->GetModelFile("dummy_hand_landmarker.task"); + model_bundle_resources->GetFile("dummy_hand_landmarker.task"); MP_EXPECT_OK(status_or_model_bundle_file.status()); // Creates sub-task model asset bundle resources. @@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(hand_landmaker_model_file))); MP_EXPECT_OK(hand_landmaker_model_bundle_resources - ->GetModelFile("dummy_hand_detector.tflite") + ->GetFile("dummy_hand_detector.tflite") .status()); MP_EXPECT_OK(hand_landmaker_model_bundle_resources - ->GetModelFile("dummy_hand_landmarker.tflite") + ->GetFile("dummy_hand_landmarker.tflite") .status()); } @@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) { ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); auto status_or_model_bundle_file = - model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite"); + model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite"); MP_EXPECT_OK(status_or_model_bundle_file.status()); // Verify tflite model works. @@ -200,12 +196,12 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) { auto model_bundle_resources, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); - auto status = model_bundle_resources->GetModelFile("not_found.task").status(); + auto status = model_bundle_resources->GetFile("not_found.task").status(); EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); - EXPECT_THAT(status.message(), - testing::HasSubstr( - "No model file with name: not_found.task. All model files in " - "the model asset bundle are: ")); + EXPECT_THAT( + status.message(), + testing::HasSubstr("No file with name: not_found.task. All files in " + "the model asset bundle are: ")); EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload), testing::Optional(absl::Cord( absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError)))); @@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) { auto model_bundle_resources, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, std::move(model_file))); - auto model_files = model_bundle_resources->ListModelFiles(); + auto model_files = model_bundle_resources->ListFiles(); std::vector expected_model_files = { "dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"}; std::sort(model_files.begin(), model_files.end()); diff --git a/mediapipe/tasks/cc/core/model_resources_calculator.cc b/mediapipe/tasks/cc/core/model_resources_calculator.cc index 72a7b33a3..d5c8cd502 100644 --- a/mediapipe/tasks/cc/core/model_resources_calculator.cc +++ b/mediapipe/tasks/cc/core/model_resources_calculator.cc @@ -77,9 +77,11 @@ class ModelResourcesCalculator : public api2::Node { if (options.has_model_file()) { RET_CHECK(options.model_file().has_file_content() || options.model_file().has_file_descriptor_meta() || - options.model_file().has_file_name()) + options.model_file().has_file_name() || + options.model_file().has_file_pointer_meta()) << "'model_file' must specify at least one of " - "'file_content', 'file_descriptor_meta', or 'file_name'"; + "'file_content', 'file_descriptor_meta', 'file_name', or " + "'file_pointer_meta'"; } return absl::OkStatus(); } diff --git a/mediapipe/tasks/cc/core/model_resources_calculator_test.cc b/mediapipe/tasks/cc/core/model_resources_calculator_test.cc index 58659c77d..83134a8c7 100644 --- a/mediapipe/tasks/cc/core/model_resources_calculator_test.cc +++ b/mediapipe/tasks/cc/core/model_resources_calculator_test.cc @@ -179,9 +179,9 @@ TEST_F(ModelResourcesCalculatorTest, EmptyExternalFileProto) { auto status = graph.Initialize(graph_config); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - testing::HasSubstr( - "'model_file' must specify at least one of " - "'file_content', 'file_descriptor_meta', or 'file_name'")); + testing::HasSubstr("'model_file' must specify at least one of " + "'file_content', 'file_descriptor_meta', " + "'file_name', or 'file_pointer_meta'")); } TEST_F(ModelResourcesCalculatorTest, GraphServiceNotAvailable) { diff --git a/mediapipe/tasks/cc/core/model_task_graph.cc b/mediapipe/tasks/cc/core/model_task_graph.cc index 0cb556ec2..653c6b9ff 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.cc +++ b/mediapipe/tasks/cc/core/model_task_graph.cc @@ -138,7 +138,7 @@ class InferenceSubgraph : public Subgraph { delegate.mutable_tflite()->CopyFrom(acceleration.tflite()); break; case Acceleration::DELEGATE_NOT_SET: - // Deafult inference calculator setting. + // Default inference calculator setting. break; } return delegate; diff --git a/mediapipe/tasks/cc/core/model_task_graph.h b/mediapipe/tasks/cc/core/model_task_graph.h index 3068b2c46..aa864c9fc 100644 --- a/mediapipe/tasks/cc/core/model_task_graph.h +++ b/mediapipe/tasks/cc/core/model_task_graph.h @@ -124,10 +124,10 @@ class ModelTaskGraph : public Subgraph { // Inserts a mediapipe task inference subgraph into the provided // GraphBuilder. The returned node provides the following interfaces to the // the rest of the graph: - // - a tensor vector (std::vector) input stream with tag + // - a tensor vector (std::vector) input stream with tag // "TENSORS", representing the input tensors to be consumed by the // inference engine. - // - a tensor vector (std::vector) output stream with tag + // - a tensor vector (std::vector) output stream with tag // "TENSORS", representing the output tensors generated by the inference // engine. // - a MetadataExtractor output side packet with tag "METADATA_EXTRACTOR". diff --git a/mediapipe/tasks/cc/core/task_runner.cc b/mediapipe/tasks/cc/core/task_runner.cc index 9a87551e7..fc933d547 100644 --- a/mediapipe/tasks/cc/core/task_runner.cc +++ b/mediapipe/tasks/cc/core/task_runner.cc @@ -301,7 +301,7 @@ absl::Status TaskRunner::Close() { } is_running_ = false; MP_RETURN_IF_ERROR( - AddPayload(graph_.CloseAllInputStreams(), "Fail to close intput streams", + AddPayload(graph_.CloseAllInputStreams(), "Fail to close input streams", MediaPipeTasksStatus::kRunnerFailsToCloseError)); MP_RETURN_IF_ERROR(AddPayload( graph_.WaitUntilDone(), "Fail to shutdown the MediaPipe graph.", diff --git a/mediapipe/tasks/cc/core/task_runner.h b/mediapipe/tasks/cc/core/task_runner.h index 0d049c782..8123a45aa 100644 --- a/mediapipe/tasks/cc/core/task_runner.h +++ b/mediapipe/tasks/cc/core/task_runner.h @@ -65,7 +65,7 @@ class TaskRunner { // Creates the task runner with a CalculatorGraphConfig proto. // If a tflite op resolver object is provided, the task runner will take // it as the global op resolver for all models running within this task. - // The op resolver's owernship will be transferred into the pipeleine runner. + // The op resolver's ownership will be transferred into the pipeleine runner. // When a user-defined PacketsCallback is provided, clients must use the // asynchronous method, Send(), to provide the input packets. If the packets // callback is absent, clients must use the synchronous method, Process(), to @@ -84,7 +84,7 @@ class TaskRunner { // frames from a video file and an audio file. The call blocks the current // thread until a failure status or a successful result is returned. // If the input packets have no timestamp, an internal timestamp will be - // assigend per invocation. Otherwise, when the timestamp is set in the + // assigned per invocation. Otherwise, when the timestamp is set in the // input packets, the caller must ensure that the input packet timestamps are // greater than the timestamps of the previous invocation. This method is // thread-unsafe and it is the caller's responsibility to synchronize access diff --git a/mediapipe/tasks/cc/metadata/metadata_populator.h b/mediapipe/tasks/cc/metadata/metadata_populator.h index 024ad785f..c0554f704 100644 --- a/mediapipe/tasks/cc/metadata/metadata_populator.h +++ b/mediapipe/tasks/cc/metadata/metadata_populator.h @@ -64,7 +64,7 @@ class ModelMetadataPopulator { // Loads associated files into the TFLite FlatBuffer model. The input is a map // of {filename, file contents}. // - // Warning: this method removes any previoulsy present associated files. + // Warning: this method removes any previously present associated files. // Calling this method multiple time removes any associated files from // previous calls, so this method should usually be called only once. void LoadAssociatedFiles( diff --git a/mediapipe/tasks/cc/metadata/metadata_version.cc b/mediapipe/tasks/cc/metadata/metadata_version.cc index 923d3aa56..7e2414dd5 100644 --- a/mediapipe/tasks/cc/metadata/metadata_version.cc +++ b/mediapipe/tasks/cc/metadata/metadata_version.cc @@ -213,7 +213,7 @@ void UpdateMinimumVersionForTable(const tflite::Content* table, Version* min_version) { if (table == nullptr) return; - // Checks the ContenProperties field. + // Checks the ContentProperties field. if (table->content_properties_type() == ContentProperties_AudioProperties) { UpdateMinimumVersion( GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties), diff --git a/mediapipe/tasks/cc/metadata/python/metadata_version.cc b/mediapipe/tasks/cc/metadata/python/metadata_version.cc index 860a00e4f..e3072bc9e 100644 --- a/mediapipe/tasks/cc/metadata/python/metadata_version.cc +++ b/mediapipe/tasks/cc/metadata/python/metadata_version.cc @@ -31,8 +31,8 @@ PYBIND11_MODULE(_pywrap_metadata_version, m) { // Using pybind11 type conversions to convert between Python and native // C++ types. There are other options to provide access to native Python types - // in C++ and vice versa. See the pybind 11 instrcution [1] for more details. - // Type converstions is recommended by pybind11, though the main downside + // in C++ and vice versa. See the pybind 11 instruction [1] for more details. + // Type conversions is recommended by pybind11, though the main downside // is that a copy of the data must be made on every Python to C++ transition: // this is needed since the C++ and Python versions of the same type generally // won’t have the same memory layout. diff --git a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc index 273c91685..32ff51482 100644 --- a/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc +++ b/mediapipe/tasks/cc/metadata/tests/metadata_version_test.cc @@ -79,7 +79,7 @@ TEST(MetadataVersionTest, auto metadata = metadata_builder.Finish(); FinishModelMetadataBuffer(builder, metadata); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -100,7 +100,7 @@ TEST(MetadataVersionTest, auto metadata = metadata_builder.Finish(); builder.Finish(metadata); - // Gets the mimimum metadata parser version and triggers error. + // Gets the minimum metadata parser version and triggers error. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -121,7 +121,7 @@ TEST(MetadataVersionTest, metadata_builder.add_associated_files(associated_files); FinishModelMetadataBuffer(builder, metadata_builder.Finish()); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -147,7 +147,7 @@ TEST(MetadataVersionTest, metadata_builder.add_subgraph_metadata(subgraphs); FinishModelMetadataBuffer(builder, metadata_builder.Finish()); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -172,7 +172,7 @@ TEST(MetadataVersionTest, std::vector>{tensor_builder.Finish()}); CreateModelWithMetadata(tensors, builder); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -203,7 +203,7 @@ TEST(MetadataVersionTest, metadata_builder.add_subgraph_metadata(subgraphs); FinishModelMetadataBuffer(builder, metadata_builder.Finish()); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -234,7 +234,7 @@ TEST(MetadataVersionTest, metadata_builder.add_subgraph_metadata(subgraphs); FinishModelMetadataBuffer(builder, metadata_builder.Finish()); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -265,7 +265,7 @@ TEST(MetadataVersionTest, metadata_builder.add_subgraph_metadata(subgraphs); FinishModelMetadataBuffer(builder, metadata_builder.Finish()); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -294,7 +294,7 @@ TEST(MetadataVersionTest, std::vector>{tensor_builder.Finish()}); CreateModelWithMetadata(tensors, builder); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -323,7 +323,7 @@ TEST(MetadataVersionTest, std::vector>{tensor_builder.Finish()}); CreateModelWithMetadata(tensors, builder); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -348,7 +348,7 @@ TEST(MetadataVersionTest, metadata_builder.add_subgraph_metadata(subgraphs); FinishModelMetadataBuffer(builder, metadata_builder.Finish()); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -373,7 +373,7 @@ TEST(MetadataVersionTest, metadata_builder.add_subgraph_metadata(subgraphs); FinishModelMetadataBuffer(builder, metadata_builder.Finish()); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -404,7 +404,7 @@ TEST(MetadataVersionTest, metadata_builder.add_subgraph_metadata(subgraphs); FinishModelMetadataBuffer(builder, metadata_builder.Finish()); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -431,7 +431,7 @@ TEST(MetadataVersionTest, std::vector>{tensor_builder.Finish()}); CreateModelWithMetadata(tensors, builder); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -453,7 +453,7 @@ TEST(MetadataVersionTest, metadata_builder.add_associated_files(associated_files); FinishModelMetadataBuffer(builder, metadata_builder.Finish()); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -476,7 +476,7 @@ TEST(MetadataVersionTest, metadata_builder.add_associated_files(associated_files); FinishModelMetadataBuffer(builder, metadata_builder.Finish()); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), @@ -504,7 +504,7 @@ TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) { metadata_builder.add_subgraph_metadata(subgraphs); FinishModelMetadataBuffer(builder, metadata_builder.Finish()); - // Gets the mimimum metadata parser version. + // Gets the minimum metadata parser version. std::string min_version; EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), builder.GetSize(), &min_version), diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD index 5e7c5afa5..090f528ef 100644 --- a/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/BUILD @@ -42,3 +42,36 @@ cc_test( "@org_tensorflow//tensorflow/lite/kernels:test_util", ], ) + +cc_library( + name = "ngram_hash", + srcs = ["ngram_hash.cc"], + hdrs = ["ngram_hash.h"], + copts = tflite_copts(), + deps = [ + "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils", + "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:kernel_util", + ], + alwayslink = 1, +) + +cc_test( + name = "ngram_hash_test", + srcs = ["ngram_hash_test.cc"], + deps = [ + ":ngram_hash", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur", + "@com_google_absl//absl/types:optional", + "@flatbuffers", + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite:string_util", + "@org_tensorflow//tensorflow/lite/c:common", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + "@org_tensorflow//tensorflow/lite/kernels:test_util", + ], +) diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc new file mode 100644 index 000000000..738fa1128 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.cc @@ -0,0 +1,264 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" + +#include +#include +#include + +#include "flatbuffers/flexbuffers.h" +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h" +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite::ops::custom { + +namespace ngram_op { + +namespace { + +using ::flexbuffers::GetRoot; +using ::flexbuffers::Map; +using ::flexbuffers::TypedVector; +using ::mediapipe::tasks::text::language_detector::custom_ops:: + LowercaseUnicodeStr; +using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize; +using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput; +using ::mediapipe::tasks::text::language_detector::custom_ops::hash:: + MurmurHash64WithSeed; +using ::tflite::GetString; +using ::tflite::StringRef; + +constexpr int kInputMessage = 0; +constexpr int kOutputLabel = 0; +constexpr int kDefaultMaxSplits = 128; + +// This op takes in a string, finds the character ngrams for it and then +// maps each of these ngrams to an index using the specified vocabulary sizes. + +// Input(s): +// - input: Input string. +// - seeds: Seed for the random number generator. +// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would +// be interpreted as generating unigrams, bigrams, and trigrams. +// - vocab_sizes: Size of the vocabulary for each of the ngram features +// respectively. The op would generate vocab ids to be less than or equal to +// the vocab size. The index 0 implies an invalid ngram. +// - max_splits: Maximum number of tokens in the output. If this is unset, the +// limit is `kDefaultMaxSplits`. +// - lower_case_input: If this is set to true, the input string would be +// lower-cased before any processing. + +// Output(s): +// - output: A tensor of size [number of ngrams, number of tokens + 2], +// where 2 tokens are reserved for the padding. If `max_splits` is set, this +// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`. + +// Helper class used for pre-processing the input. +class NGramHashParams { + public: + NGramHashParams(const uint64_t seed, const std::vector& ngram_lengths, + const std::vector& vocab_sizes, int max_splits, + bool lower_case_input) + : seed_(seed), + ngram_lengths_(ngram_lengths), + vocab_sizes_(vocab_sizes), + max_splits_(max_splits), + lower_case_input_(lower_case_input) {} + + TfLiteStatus PreprocessInput(const TfLiteTensor* input_t, + TfLiteContext* context) { + if (input_t->bytes == 0) { + context->ReportError(context, "Empty input not supported."); + return kTfLiteError; + } + + // Do sanity checks on the input. + if (ngram_lengths_.empty()) { + context->ReportError(context, "`ngram_lengths` must be non-empty."); + return kTfLiteError; + } + + if (vocab_sizes_.empty()) { + context->ReportError(context, "`vocab_sizes` must be non-empty."); + return kTfLiteError; + } + + if (ngram_lengths_.size() != vocab_sizes_.size()) { + context->ReportError( + context, + "Sizes of `ngram_lengths` and `vocab_sizes` must be the same."); + return kTfLiteError; + } + + if (max_splits_ <= 0) { + context->ReportError(context, "`max_splits` must be > 0."); + return kTfLiteError; + } + + // Obtain and tokenize the input. + StringRef inputref = GetString(input_t, /*string_index=*/0); + if (lower_case_input_) { + std::string lower_cased_str; + LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str); + + tokenized_output_ = + Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_, + /*exclude_nonalphaspace_tokens=*/true); + } else { + tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_, + /*exclude_nonalphaspace_tokens=*/true); + } + return kTfLiteOk; + } + uint64_t GetSeed() const { return seed_; } + + int GetNumTokens() const { return tokenized_output_.tokens.size(); } + + int GetNumNGrams() const { return ngram_lengths_.size(); } + + std::vector GetNGramLengths() const { return ngram_lengths_; } + + std::vector GetVocabSizes() const { return vocab_sizes_; } + + const TokenizedOutput& GetTokenizedOutput() const { + return tokenized_output_; + } + + TokenizedOutput tokenized_output_; + + private: + const uint64_t seed_; + std::vector ngram_lengths_; + std::vector vocab_sizes_; + const int max_splits_; + const bool lower_case_input_; +}; + +// Convert the TypedVector into a regular std::vector. +std::vector GetIntVector(TypedVector typed_vec) { + std::vector vec(typed_vec.size()); + for (int j = 0; j < typed_vec.size(); j++) { + vec[j] = typed_vec[j].AsInt32(); + } + return vec; +} + +void GetNGramHashIndices(NGramHashParams* params, int32_t* data) { + const int max_unicode_length = params->GetNumTokens(); + const auto ngram_lengths = params->GetNGramLengths(); + const auto vocab_sizes = params->GetVocabSizes(); + const auto& tokenized_output = params->GetTokenizedOutput(); + const auto seed = params->GetSeed(); + + // Compute for each ngram. + for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) { + const int vocab_size = vocab_sizes[ngram]; + const int ngram_length = ngram_lengths[ngram]; + + // Compute for each token within the input. + for (int start = 0; start < tokenized_output.tokens.size(); start++) { + // Compute the number of bytes for the ngram starting at the given + // token. + int num_bytes = 0; + for (int i = start; + i < tokenized_output.tokens.size() && i < (start + ngram_length); + i++) { + num_bytes += tokenized_output.tokens[i].second; + } + + // Compute the hash for the ngram starting at the token. + const auto str_hash = MurmurHash64WithSeed( + tokenized_output.str.c_str() + tokenized_output.tokens[start].first, + num_bytes, seed); + + // Map the hash to an index in the vocab. + data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1; + } + } +} + +} // namespace + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + const uint8_t* buffer_t = reinterpret_cast(buffer); + const Map& m = GetRoot(buffer_t, length).AsMap(); + + const uint64_t seed = m["seed"].AsUInt64(); + const std::vector ngram_lengths = + GetIntVector(m["ngram_lengths"].AsTypedVector()); + const std::vector vocab_sizes = + GetIntVector(m["vocab_sizes"].AsTypedVector()); + const int max_splits = + m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32(); + const bool lowercase_input = + m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool(); + + return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits, + lowercase_input); +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, kOutputLabel); + TF_LITE_ENSURE(context, output != nullptr); + SetTensorToDynamic(output); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + NGramHashParams* params = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_OK( + context, + params->PreprocessInput(GetInput(context, node, kInputMessage), context)); + + TfLiteTensor* output = GetOutput(context, node, kOutputLabel); + TF_LITE_ENSURE(context, output != nullptr); + if (IsDynamicTensor(output)) { + TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); + output_size->data[0] = 1; + output_size->data[1] = params->GetNumNGrams(); + output_size->data[2] = params->GetNumTokens(); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + } else { + context->ReportError(context, "Output must by dynamic."); + return kTfLiteError; + } + + if (output->type == kTfLiteInt32) { + GetNGramHashIndices(params, output->data.i32); + } else { + context->ReportError(context, "Output type must be Int32."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace ngram_op + +TfLiteRegistration* Register_NGRAM_HASH() { + static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free, + ngram_op::Resize, ngram_op::Eval}; + return &r; +} + +} // namespace tflite::ops::custom diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h new file mode 100644 index 000000000..a061357bd --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h @@ -0,0 +1,27 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ + +#include "tensorflow/lite/kernels/register.h" + +namespace tflite::ops::custom { + +TfLiteRegistration* Register_NGRAM_HASH(); + +} // namespace tflite::ops::custom + +#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc new file mode 100644 index 000000000..28d2dea6e --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash_test.cc @@ -0,0 +1,313 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h" + +#include +#include +#include +#include + +#include "absl/types/optional.h" +#include "flatbuffers/flexbuffers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" + +namespace tflite::ops::custom { +namespace { + +using ::flexbuffers::Builder; +using ::mediapipe::tasks::text::language_detector::custom_ops::hash:: + MurmurHash64WithSeed; +using ::testing::ElementsAreArray; +using ::testing::Message; + +// Helper class for testing the op. +class NGramHashModel : public SingleOpModel { + public: + explicit NGramHashModel(const uint64_t seed, + const std::vector& ngram_lengths, + const std::vector& vocab_sizes, + const absl::optional max_splits = std::nullopt) { + // Setup the model inputs. + Builder fbb; + size_t start = fbb.StartMap(); + fbb.UInt("seed", seed); + { + size_t start = fbb.StartVector("ngram_lengths"); + for (const int& ngram_len : ngram_lengths) { + fbb.Int(ngram_len); + } + fbb.EndVector(start, /*typed=*/true, /*fixed=*/false); + } + { + size_t start = fbb.StartVector("vocab_sizes"); + for (const int& vocab_size : vocab_sizes) { + fbb.Int(vocab_size); + } + fbb.EndVector(start, /*typed=*/true, /*fixed=*/false); + } + if (max_splits) { + fbb.Int("max_splits", *max_splits); + } + fbb.EndMap(start); + fbb.Finish(); + output_ = AddOutput({TensorType_INT32, {}}); + SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH); + BuildInterpreter({GetShape(input_)}); + } + + void SetupInputTensor(const std::string& input) { + PopulateStringTensor(input_, {input}); + CHECK(interpreter_->AllocateTensors() == kTfLiteOk) + << "Cannot allocate tensors"; + } + + void Invoke(const std::string& input) { + SetupInputTensor(input); + CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk); + } + + TfLiteStatus InvokeUnchecked(const std::string& input) { + SetupInputTensor(input); + return SingleOpModel::Invoke(); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_ = AddInput(TensorType_STRING); + int output_; +}; + +TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) { + // Checks that the op returns the expected value when the input is sane. + // Also checks that when `max_splits` is not specified, the entire string is + // tokenized. + const uint64_t kSeed = 123; + const std::vector vocab_sizes({100, 200}); + std::vector ngram_lengths({1, 2}); + const std::vector testcase_inputs({ + "hi", + "wow", + "!", + "HI", + }); + + // A hash function that maps the given string to an index in the embedding + // table denoted by `vocab_idx`. + auto hash = [vocab_sizes](std::string str, const int vocab_idx) { + const auto hash_value = + MurmurHash64WithSeed(str.c_str(), str.size(), kSeed); + return static_cast((hash_value % vocab_sizes[vocab_idx]) + 1); + }; + const std::vector> expected_testcase_outputs( + {{ + // Unigram & Bigram output for "hi". + hash("^", 0), + hash("h", 0), + hash("i", 0), + hash("$", 0), + hash("^h", 1), + hash("hi", 1), + hash("i$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow". + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("ow", 1), + hash("w$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "!" (which will get replaced by " "). + hash("^", 0), + hash(" ", 0), + hash("$", 0), + hash("^ ", 1), + hash(" $", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "HI" (which will get lower-cased). + hash("^", 0), + hash("h", 0), + hash("i", 0), + hash("$", 0), + hash("^h", 1), + hash("hi", 1), + hash("i$", 1), + hash("$", 1), + }}); + + NGramHashModel m(kSeed, ngram_lengths, vocab_sizes); + for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) { + const string& testcase_input = testcase_inputs[test_idx]; + m.Invoke(testcase_input); + SCOPED_TRACE(Message() << "Where the testcases' input is: " + << testcase_input); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(expected_testcase_outputs[test_idx])); + EXPECT_THAT(m.GetOutputShape(), + ElementsAreArray( + {/*batch_size=*/1, static_cast(ngram_lengths.size()), + static_cast(testcase_input.size()) + /*padding*/ 2})); + } +} + +TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) { + // Checks that the op returns the expected value when the input is correct + // when `max_splits` is specified. + const uint64_t kSeed = 123; + const std::vector vocab_sizes({100, 200}); + std::vector ngram_lengths({1, 2}); + + const std::string testcase_input = "wow"; + const std::vector max_splits({2, 3, 4, 5, 6}); + + // A hash function that maps the given string to an index in the embedding + // table denoted by `vocab_idx`. + auto hash = [vocab_sizes](std::string str, const int vocab_idx) { + const auto hash_value = + MurmurHash64WithSeed(str.c_str(), str.size(), kSeed); + return static_cast((hash_value % vocab_sizes[vocab_idx]) + 1); + }; + + const std::vector> expected_testcase_outputs( + {{ + // Unigram & Bigram output for "wow", when `max_splits` == 2. + // We cannot include any of the actual tokens, since `max_splits` + // only allows enough space for the delimiters. + hash("^", 0), + hash("$", 0), + hash("^$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 3. + // We can start to include some tokens from the input string. + hash("^", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("w$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 4. + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("o$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 5. + // We can include the full input string. + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("ow", 1), + hash("w$", 1), + hash("$", 1), + }, + { + // Unigram & Bigram output for "wow", when `max_splits` == 6. + // `max_splits` is more than the full input string. + hash("^", 0), + hash("w", 0), + hash("o", 0), + hash("w", 0), + hash("$", 0), + hash("^w", 1), + hash("wo", 1), + hash("ow", 1), + hash("w$", 1), + hash("$", 1), + }}); + + for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) { + const int testcase_max_splits = max_splits[test_idx]; + NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits); + m.Invoke(testcase_input); + SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(expected_testcase_outputs[test_idx])); + EXPECT_THAT( + m.GetOutputShape(), + ElementsAreArray( + {/*batch_size=*/1, static_cast(ngram_lengths.size()), + std::min( + // Longest possible tokenization when using the entire + // input. + static_cast(testcase_input.size()) + /*padding*/ 2, + // Longest possible string when the `max_splits` value + // is < testcase_input.size() + 2 for padding. + testcase_max_splits)})); + } +} + +TEST(NGramHashTest, InvalidMaxSplitsValue) { + // Check that the op errors out when given an invalid max splits value. + const std::vector invalid_max_splits({0, -1, -5, -100}); + for (const int max_splits : invalid_max_splits) { + NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200}, + /*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits); + EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); + } +} + +TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) { + // Check that the op errors out when ngram lengths and vocab sizes mistmatch. + { + NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300}, + /*vocab_sizes=*/{1, 2}); + EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); + } + { + NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200}, + /*vocab_sizes=*/{1, 2, 3}); + EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError); + } +} + +} // namespace +} // namespace tflite::ops::custom diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/BUILD new file mode 100644 index 000000000..9f2fe298a --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/BUILD @@ -0,0 +1,42 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "ngram_hash_ops_utils", + srcs = [ + "ngram_hash_ops_utils.cc", + ], + hdrs = [ + "ngram_hash_ops_utils.h", + ], + deps = [ + "//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf", + ], +) + +cc_test( + name = "ngram_hash_ops_utils_test", + size = "small", + srcs = [ + "ngram_hash_ops_utils_test.cc", + ], + deps = [ + ":ngram_hash_ops_utils", + "//mediapipe/framework/port:gtest_main", + ], +) diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/BUILD new file mode 100644 index 000000000..86b659245 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/BUILD @@ -0,0 +1,38 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "murmur", + srcs = ["murmur.cc"], + hdrs = ["murmur.h"], + deps = [ + "//mediapipe/framework/port:integral_types", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:endian", + ], +) + +cc_test( + name = "murmur_test", + srcs = ["murmur_test.cc"], + deps = [ + ":murmur", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:integral_types", + ], +) diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.cc new file mode 100644 index 000000000..75dd161bf --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.cc @@ -0,0 +1,95 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Forked from a library written by Austin Appelby and Jyrki Alakuijala. +// Original copyright message below. +// Copyright 2009 Google Inc. All Rights Reserved. +// Author: aappleby@google.com (Austin Appleby) +// jyrki@google.com (Jyrki Alakuijala) + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h" + +#include + +#include "absl/base/internal/endian.h" +#include "absl/base/optimization.h" +#include "mediapipe/framework/port/integral_types.h" + +namespace mediapipe::tasks::text::language_detector::custom_ops::hash { + +namespace { + +using ::absl::little_endian::Load64; + +// Murmur 2.0 multiplication constant. +static const uint64_t kMul = 0xc6a4a7935bd1e995ULL; + +// We need to mix some of the bits that get propagated and mixed into the +// high bits by multiplication back into the low bits. 17 last bits get +// a more efficiently mixed with this. +inline uint64_t ShiftMix(uint64_t val) { return val ^ (val >> 47); } + +// Accumulate 8 bytes into 64-bit Murmur hash +inline uint64_t MurmurStep(uint64_t hash, uint64_t data) { + hash ^= ShiftMix(data * kMul) * kMul; + hash *= kMul; + return hash; +} + +// Build a uint64 from 1-8 bytes. +// 8 * len least significant bits are loaded from the memory with +// LittleEndian order. The 64 - 8 * len most significant bits are +// set all to 0. +// In latex-friendly words, this function returns: +// $\sum_{i=0}^{len-1} p[i] 256^{i}$, where p[i] is unsigned. +// +// This function is equivalent to: +// uint64 val = 0; +// memcpy(&val, p, len); +// return ToHost64(val); +// +// The caller needs to guarantee that 0 <= len <= 8. +uint64_t Load64VariableLength(const void* const p, int len) { + ABSL_ASSUME(len >= 0 && len <= 8); + uint64_t val = 0; + const uint8_t* const src = static_cast(p); + for (int i = 0; i < len; ++i) { + val |= static_cast(src[i]) << (8 * i); + } + return val; +} + +} // namespace + +unsigned long long MurmurHash64WithSeed(const char* buf, // NOLINT + const size_t len, const uint64_t seed) { + // Let's remove the bytes not divisible by the sizeof(uint64). + // This allows the inner loop to process the data as 64 bit integers. + const size_t len_aligned = len & ~0x7; + const char* const end = buf + len_aligned; + uint64_t hash = seed ^ (len * kMul); + for (const char* p = buf; p != end; p += 8) { + hash = MurmurStep(hash, Load64(p)); + } + if ((len & 0x7) != 0) { + const uint64_t data = Load64VariableLength(end, len & 0x7); + hash ^= data; + hash *= kMul; + } + hash = ShiftMix(hash) * kMul; + hash = ShiftMix(hash); + return hash; +} + +} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h new file mode 100644 index 000000000..abcb41a6b --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Forked from a library written by Austin Appelby and Jyrki Alakuijala. +// Original copyright message below. +// Copyright 2009 Google Inc. All Rights Reserved. +// Author: aappleby@google.com (Austin Appelby) +// jyrki@google.com (Jyrki Alakuijala) +// +// MurmurHash is a fast multiplication and shifting based algorithm, +// based on Austin Appleby's MurmurHash 2.0 algorithm. + +#ifndef UTIL_HASH_MURMUR_H_ +#define UTIL_HASH_MURMUR_H_ + +#include +#include // for size_t. + +#include + +#include "mediapipe/framework/port/integral_types.h" + +namespace mediapipe::tasks::text::language_detector::custom_ops::hash { + +// Hash function for a byte array. Has a seed which allows this hash function to +// be used in algorithms that need a family of parameterized hash functions. +// e.g. Minhash. +unsigned long long MurmurHash64WithSeed(const char* buf, size_t len, // NOLINT + uint64_t seed); +} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash + +#endif // UTIL_HASH_MURMUR_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur_test.cc new file mode 100644 index 000000000..6658965bf --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Forked from a test library written by Jyrki Alakuijala. +// Original copyright message below. +// Copyright 2009 Google Inc. All Rights Reserved. +// Author: jyrki@google.com (Jyrki Alakuijala) +// +// Tests for the fast hashing algorithm based on Austin Appleby's +// MurmurHash 2.0 algorithm. See http://murmurhash.googlepages.com/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h" + +#include + +#include +#include + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/integral_types.h" + +namespace mediapipe::tasks::text::language_detector::custom_ops::hash { + +TEST(Murmur, EmptyData64) { + EXPECT_EQ(uint64_t{0}, MurmurHash64WithSeed(nullptr, uint64_t{0}, 0)); +} + +TEST(Murmur, VaryWithDifferentSeeds) { + // While in theory different seeds could return the same + // hash for the same data this is unlikely. + char data1 = 'x'; + EXPECT_NE(MurmurHash64WithSeed(&data1, 1, 100), + MurmurHash64WithSeed(&data1, 1, 101)); +} + +// Hashes don't change. +TEST(Murmur, Idempotence) { + const char data[] = "deadbeef"; + const size_t dlen = strlen(data); + + for (int i = 0; i < 10; i++) { + EXPECT_EQ(MurmurHash64WithSeed(data, dlen, i), + MurmurHash64WithSeed(data, dlen, i)); + } + + const char next_data[] = "deadbeef000---"; + const size_t next_dlen = strlen(next_data); + + for (int i = 0; i < 10; i++) { + EXPECT_EQ(MurmurHash64WithSeed(next_data, next_dlen, i), + MurmurHash64WithSeed(next_data, next_dlen, i)); + } +} +} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.cc new file mode 100644 index 000000000..f1ad71fc1 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.cc @@ -0,0 +1,96 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h" + +#include +#include +#include + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h" + +namespace mediapipe::tasks::text::language_detector::custom_ops { + +TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens, + bool exclude_nonalphaspace_tokens) { + const std::string kPrefix = "^"; + const std::string kSuffix = "$"; + const std::string kReplacementToken = " "; + + TokenizedOutput output; + + size_t token_start = 0; + output.str.reserve(len + 2); + output.tokens.reserve(len + 2); + + output.str.append(kPrefix); + output.tokens.push_back(std::make_pair(token_start, kPrefix.size())); + token_start += kPrefix.size(); + + Rune token; + for (int i = 0; i < len && output.tokens.size() + 1 < max_tokens;) { + // Use the standard UTF-8 library to find the next token. + size_t bytes_read = utf_charntorune(&token, input_str + i, len - i); + + // Stop processing, if we can't read any more tokens, or we have reached + // maximum allowed tokens, allocating one token for the suffix. + if (bytes_read == 0) { + break; + } + + // If `exclude_nonalphaspace_tokens` is set to true, and the token is not + // alphanumeric, replace it with a replacement token. + if (exclude_nonalphaspace_tokens && !utf_isalpharune(token)) { + output.str.append(kReplacementToken); + output.tokens.push_back( + std::make_pair(token_start, kReplacementToken.size())); + token_start += kReplacementToken.size(); + i += bytes_read; + continue; + } + + // Append the token in the output string, and note its position and the + // number of bytes that token consumed. + output.str.append(input_str + i, bytes_read); + output.tokens.push_back(std::make_pair(token_start, bytes_read)); + token_start += bytes_read; + i += bytes_read; + } + output.str.append(kSuffix); + output.tokens.push_back(std::make_pair(token_start, kSuffix.size())); + token_start += kSuffix.size(); + + return output; +} + +void LowercaseUnicodeStr(const char* input_str, int len, + std::string* output_str) { + for (int i = 0; i < len;) { + Rune token; + + // Tokenize the given string, and get the appropriate lowercase token. + size_t bytes_read = utf_charntorune(&token, input_str + i, len - i); + token = utf_isalpharune(token) ? utf_tolowerrune(token) : token; + + // Write back the token to the output string. + char token_buf[UTFmax]; + size_t bytes_to_write = utf_runetochar(token_buf, &token); + output_str->append(token_buf, bytes_to_write); + + i += bytes_read; + } +} + +} // namespace mediapipe::tasks::text::language_detector::custom_ops diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h new file mode 100644 index 000000000..9a80554c8 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h @@ -0,0 +1,56 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_ + +#include +#include +#include + +namespace mediapipe::tasks::text::language_detector::custom_ops { + +struct TokenizedOutput { + // The processed string (with necessary prefix, suffix, skipped tokens, etc.). + std::string str; + + // This vector contains pairs, where each pair has two members. The first + // denoting the starting index of the token in the `str` string, and the + // second denoting the length of that token in bytes. + std::vector> tokens; +}; + +// Tokenizes the given input string on Unicode token boundaries, with a maximum +// of `max_tokens` tokens. +// +// If `exclude_nonalphaspace_tokens` is enabled, the tokenization ignores +// non-alphanumeric tokens, and replaces them with a replacement token (" "). +// +// The method returns the output in the `TokenizedOutput` struct, which stores +// both, the processed input string, and the indices and sizes of each token +// within that string. +TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens, + bool exclude_nonalphaspace_tokens); + +// Converts the given unicode string (`input_str`) with the specified length +// (`len`) to a lowercase string. +// +// The method populates the lowercased string in `output_str`. +void LowercaseUnicodeStr(const char* input_str, int len, + std::string* output_str); + +} // namespace mediapipe::tasks::text::language_detector::custom_ops + +#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_ diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils_test.cc b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils_test.cc new file mode 100644 index 000000000..d22af1c95 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils_test.cc @@ -0,0 +1,135 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h" + +#include + +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +namespace mediapipe::tasks::text::language_detector::custom_ops { + +namespace { + +using ::testing::Values; + +std::string ReconstructStringFromTokens(TokenizedOutput output) { + std::string reconstructed_str; + for (int i = 0; i < output.tokens.size(); i++) { + reconstructed_str.append( + output.str.c_str() + output.tokens[i].first, + output.str.c_str() + output.tokens[i].first + output.tokens[i].second); + } + return reconstructed_str; +} + +struct TokenizeTestParams { + std::string input_str; + size_t max_tokens; + bool exclude_nonalphaspace_tokens; + std::string expected_output_str; +}; + +class TokenizeParameterizedTest + : public ::testing::Test, + public testing::WithParamInterface {}; + +TEST_P(TokenizeParameterizedTest, Tokenize) { + // Checks that the Tokenize method returns the expected value. + const TokenizeTestParams params = TokenizeParameterizedTest::GetParam(); + const TokenizedOutput output = Tokenize( + /*input_str=*/params.input_str.c_str(), + /*len=*/params.input_str.size(), + /*max_tokens=*/params.max_tokens, + /*exclude_nonalphaspace_tokens=*/params.exclude_nonalphaspace_tokens); + + // The output string should have the necessary prefixes, and the "!" token + // should have been replaced with a " ". + EXPECT_EQ(output.str, params.expected_output_str); + EXPECT_EQ(ReconstructStringFromTokens(output), params.expected_output_str); +} + +INSTANTIATE_TEST_SUITE_P( + TokenizeParameterizedTests, TokenizeParameterizedTest, + Values( + // Test including non-alphanumeric characters. + TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100, + /*exclude_alphanonspace=*/false, + /*expected_output_str=*/"^hi!$"}), + // Test not including non-alphanumeric characters. + TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100, + /*exclude_alphanonspace=*/true, + /*expected_output_str=*/"^hi $"}), + // Test with a maximum of 3 tokens. + TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/3, + /*exclude_alphanonspace=*/true, + /*expected_output_str=*/"^h$"}), + // Test with non-latin characters. + TokenizeTestParams({/*input_str=*/"ありがと", /*max_tokens=*/100, + /*exclude_alphanonspace=*/true, + /*expected_output_str=*/"^ありがと$"}))); + +TEST(LowercaseUnicodeTest, TestLowercaseUnicode) { + { + // Check that the method is a no-op when the string is lowercase. + std::string input_str = "hello"; + std::string output_str; + LowercaseUnicodeStr( + /*input_str=*/input_str.c_str(), + /*len=*/input_str.size(), + /*output_str=*/&output_str); + + EXPECT_EQ(output_str, "hello"); + } + { + // Check that the method has uppercase characters. + std::string input_str = "hElLo"; + std::string output_str; + LowercaseUnicodeStr( + /*input_str=*/input_str.c_str(), + /*len=*/input_str.size(), + /*output_str=*/&output_str); + + EXPECT_EQ(output_str, "hello"); + } + { + // Check that the method works with non-latin scripts. + // Cyrillic has the concept of cases, so it should change the input. + std::string input_str = "БЙп"; + std::string output_str; + LowercaseUnicodeStr( + /*input_str=*/input_str.c_str(), + /*len=*/input_str.size(), + /*output_str=*/&output_str); + + EXPECT_EQ(output_str, "бйп"); + } + { + // Check that the method works with non-latin scripts. + // Japanese doesn't have the concept of cases, so it should not change. + std::string input_str = "ありがと"; + std::string output_str; + LowercaseUnicodeStr( + /*input_str=*/input_str.c_str(), + /*len=*/input_str.size(), + /*output_str=*/&output_str); + + EXPECT_EQ(output_str, "ありがと"); + } +} + +} // namespace +} // namespace mediapipe::tasks::text::language_detector::custom_ops diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/BUILD b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/BUILD new file mode 100644 index 000000000..a71845305 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/BUILD @@ -0,0 +1,27 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +cc_library( + name = "utf", + srcs = [ + "rune.c", + "runetype.c", + "runetypebody.h", + ], + hdrs = ["utf.h"], +) diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/rune.c b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/rune.c new file mode 100644 index 000000000..b74450f44 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/rune.c @@ -0,0 +1,233 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Forked from a library written by Rob Pike and Ken Thompson. Original +// copyright message below. +/* + * The authors of this software are Rob Pike and Ken Thompson. + * Copyright (c) 2002 by Lucent Technologies. + * Permission to use, copy, modify, and distribute this software for any + * purpose without fee is hereby granted, provided that this entire notice + * is included in all copies of any software which is or includes a copy + * or modification of this software and in all copies of the supporting + * documentation for such software. + * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED + * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY + * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY + * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. + */ +#include +#include +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h" + +enum +{ + Bit1 = 7, + Bitx = 6, + Bit2 = 5, + Bit3 = 4, + Bit4 = 3, + Bit5 = 2, + + T1 = ((1<<(Bit1+1))-1) ^ 0xFF, /* 0000 0000 */ + Tx = ((1<<(Bitx+1))-1) ^ 0xFF, /* 1000 0000 */ + T2 = ((1<<(Bit2+1))-1) ^ 0xFF, /* 1100 0000 */ + T3 = ((1<<(Bit3+1))-1) ^ 0xFF, /* 1110 0000 */ + T4 = ((1<<(Bit4+1))-1) ^ 0xFF, /* 1111 0000 */ + T5 = ((1<<(Bit5+1))-1) ^ 0xFF, /* 1111 1000 */ + + Rune1 = (1<<(Bit1+0*Bitx))-1, /* 0000 0000 0111 1111 */ + Rune2 = (1<<(Bit2+1*Bitx))-1, /* 0000 0111 1111 1111 */ + Rune3 = (1<<(Bit3+2*Bitx))-1, /* 1111 1111 1111 1111 */ + Rune4 = (1<<(Bit4+3*Bitx))-1, + /* 0001 1111 1111 1111 1111 1111 */ + + Maskx = (1< T1 + */ + c = *(uchar*)str; + if(c < Tx) { + *rune = c; + return 1; + } + + // If we can't read more than one character we must stop + if(length <= 1) { + goto badlen; + } + + /* + * two character sequence (11-bit value) + * 0080-07FF => T2 Tx + */ + c1 = *(uchar*)(str+1) ^ Tx; + if(c1 & Testx) + goto bad; + if(c < T3) { + if(c < T2) + goto bad; + l = ((c << Bitx) | c1) & Rune2; + if(l <= Rune1) + goto bad; + *rune = l; + return 2; + } + + // If we can't read more than two characters we must stop + if(length <= 2) { + goto badlen; + } + + /* + * three character sequence (16-bit value) + * 0800-FFFF => T3 Tx Tx + */ + c2 = *(uchar*)(str+2) ^ Tx; + if(c2 & Testx) + goto bad; + if(c < T4) { + l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3; + if(l <= Rune2) + goto bad; + *rune = l; + return 3; + } + + if (length <= 3) + goto badlen; + + /* + * four character sequence (21-bit value) + * 10000-1FFFFF => T4 Tx Tx Tx + */ + c3 = *(uchar*)(str+3) ^ Tx; + if (c3 & Testx) + goto bad; + if (c < T5) { + l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4; + if (l <= Rune3) + goto bad; + if (l > Runemax) + goto bad; + *rune = l; + return 4; + } + + // Support for 5-byte or longer UTF-8 would go here, but + // since we don't have that, we'll just fall through to bad. + + /* + * bad decoding + */ +bad: + *rune = Bad; + return 1; +badlen: + *rune = Bad; + return 0; + +} + +int +utf_runetochar(char *str, const Rune *rune) +{ + /* Runes are signed, so convert to unsigned for range check. */ + unsigned long c; + + /* + * one character sequence + * 00000-0007F => 00-7F + */ + c = *rune; + if(c <= Rune1) { + str[0] = c; + return 1; + } + + /* + * two character sequence + * 0080-07FF => T2 Tx + */ + if(c <= Rune2) { + str[0] = T2 | (c >> 1*Bitx); + str[1] = Tx | (c & Maskx); + return 2; + } + + /* + * If the Rune is out of range, convert it to the error rune. + * Do this test here because the error rune encodes to three bytes. + * Doing it earlier would duplicate work, since an out of range + * Rune wouldn't have fit in one or two bytes. + */ + if (c > Runemax) + c = Runeerror; + + /* + * three character sequence + * 0800-FFFF => T3 Tx Tx + */ + if (c <= Rune3) { + str[0] = T3 | (c >> 2*Bitx); + str[1] = Tx | ((c >> 1*Bitx) & Maskx); + str[2] = Tx | (c & Maskx); + return 3; + } + + /* + * four character sequence (21-bit value) + * 10000-1FFFFF => T4 Tx Tx Tx + */ + str[0] = T4 | (c >> 3*Bitx); + str[1] = Tx | ((c >> 2*Bitx) & Maskx); + str[2] = Tx | ((c >> 1*Bitx) & Maskx); + str[3] = Tx | (c & Maskx); + return 4; +} diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetype.c b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetype.c new file mode 100644 index 000000000..1dd8abdbd --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetype.c @@ -0,0 +1,54 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Forked from a library written by Rob Pike and Ken Thompson. Original +// copyright message below. +/* + * The authors of this software are Rob Pike and Ken Thompson. + * Copyright (c) 2002 by Lucent Technologies. + * Permission to use, copy, modify, and distribute this software for any + * purpose without fee is hereby granted, provided that this entire notice + * is included in all copies of any software which is or includes a copy + * or modification of this software and in all copies of the supporting + * documentation for such software. + * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED + * WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY + * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY + * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. + */ +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h" + +static +Rune* +rbsearch(Rune c, Rune *t, int n, int ne) +{ + Rune *p; + int m; + + while(n > 1) { + m = n >> 1; + p = t + m*ne; + if(c >= p[0]) { + t = p; + n = n-m; + } else + n = m; + } + if(n && c >= t[0]) + return t; + return 0; +} + +#define RUNETYPEBODY +#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h" diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h new file mode 100644 index 000000000..66d1dfc19 --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h @@ -0,0 +1,212 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef RUNETYPEBODY + +static Rune __isalphar[] = { + 0x0041, 0x005a, 0x0061, 0x007a, 0x00c0, 0x00d6, 0x00d8, 0x00f6, + 0x00f8, 0x02c1, 0x02c6, 0x02d1, 0x02e0, 0x02e4, 0x0370, 0x0374, + 0x0376, 0x0377, 0x037a, 0x037d, 0x0388, 0x038a, 0x038e, 0x03a1, + 0x03a3, 0x03f5, 0x03f7, 0x0481, 0x048a, 0x0527, 0x0531, 0x0556, + 0x0561, 0x0587, 0x05d0, 0x05ea, 0x05f0, 0x05f2, 0x0620, 0x064a, + 0x066e, 0x066f, 0x0671, 0x06d3, 0x06e5, 0x06e6, 0x06ee, 0x06ef, + 0x06fa, 0x06fc, 0x0712, 0x072f, 0x074d, 0x07a5, 0x07ca, 0x07ea, + 0x07f4, 0x07f5, 0x0800, 0x0815, 0x0840, 0x0858, 0x08a2, 0x08ac, + 0x0904, 0x0939, 0x0958, 0x0961, 0x0971, 0x0977, 0x0979, 0x097f, + 0x0985, 0x098c, 0x098f, 0x0990, 0x0993, 0x09a8, 0x09aa, 0x09b0, + 0x09b6, 0x09b9, 0x09dc, 0x09dd, 0x09df, 0x09e1, 0x09f0, 0x09f1, + 0x0a05, 0x0a0a, 0x0a0f, 0x0a10, 0x0a13, 0x0a28, 0x0a2a, 0x0a30, + 0x0a32, 0x0a33, 0x0a35, 0x0a36, 0x0a38, 0x0a39, 0x0a59, 0x0a5c, + 0x0a72, 0x0a74, 0x0a85, 0x0a8d, 0x0a8f, 0x0a91, 0x0a93, 0x0aa8, + 0x0aaa, 0x0ab0, 0x0ab2, 0x0ab3, 0x0ab5, 0x0ab9, 0x0ae0, 0x0ae1, + 0x0b05, 0x0b0c, 0x0b0f, 0x0b10, 0x0b13, 0x0b28, 0x0b2a, 0x0b30, + 0x0b32, 0x0b33, 0x0b35, 0x0b39, 0x0b5c, 0x0b5d, 0x0b5f, 0x0b61, + 0x0b85, 0x0b8a, 0x0b8e, 0x0b90, 0x0b92, 0x0b95, 0x0b99, 0x0b9a, + 0x0b9e, 0x0b9f, 0x0ba3, 0x0ba4, 0x0ba8, 0x0baa, 0x0bae, 0x0bb9, + 0x0c05, 0x0c0c, 0x0c0e, 0x0c10, 0x0c12, 0x0c28, 0x0c2a, 0x0c33, + 0x0c35, 0x0c39, 0x0c58, 0x0c59, 0x0c60, 0x0c61, 0x0c85, 0x0c8c, + 0x0c8e, 0x0c90, 0x0c92, 0x0ca8, 0x0caa, 0x0cb3, 0x0cb5, 0x0cb9, + 0x0ce0, 0x0ce1, 0x0cf1, 0x0cf2, 0x0d05, 0x0d0c, 0x0d0e, 0x0d10, + 0x0d12, 0x0d3a, 0x0d60, 0x0d61, 0x0d7a, 0x0d7f, 0x0d85, 0x0d96, + 0x0d9a, 0x0db1, 0x0db3, 0x0dbb, 0x0dc0, 0x0dc6, 0x0e01, 0x0e30, + 0x0e32, 0x0e33, 0x0e40, 0x0e46, 0x0e81, 0x0e82, 0x0e87, 0x0e88, + 0x0e94, 0x0e97, 0x0e99, 0x0e9f, 0x0ea1, 0x0ea3, 0x0eaa, 0x0eab, + 0x0ead, 0x0eb0, 0x0eb2, 0x0eb3, 0x0ec0, 0x0ec4, 0x0edc, 0x0edf, + 0x0f40, 0x0f47, 0x0f49, 0x0f6c, 0x0f88, 0x0f8c, 0x1000, 0x102a, + 0x1050, 0x1055, 0x105a, 0x105d, 0x1065, 0x1066, 0x106e, 0x1070, + 0x1075, 0x1081, 0x10a0, 0x10c5, 0x10d0, 0x10fa, 0x10fc, 0x1248, + 0x124a, 0x124d, 0x1250, 0x1256, 0x125a, 0x125d, 0x1260, 0x1288, + 0x128a, 0x128d, 0x1290, 0x12b0, 0x12b2, 0x12b5, 0x12b8, 0x12be, + 0x12c2, 0x12c5, 0x12c8, 0x12d6, 0x12d8, 0x1310, 0x1312, 0x1315, + 0x1318, 0x135a, 0x1380, 0x138f, 0x13a0, 0x13f4, 0x1401, 0x166c, + 0x166f, 0x167f, 0x1681, 0x169a, 0x16a0, 0x16ea, 0x1700, 0x170c, + 0x170e, 0x1711, 0x1720, 0x1731, 0x1740, 0x1751, 0x1760, 0x176c, + 0x176e, 0x1770, 0x1780, 0x17b3, 0x1820, 0x1877, 0x1880, 0x18a8, + 0x18b0, 0x18f5, 0x1900, 0x191c, 0x1950, 0x196d, 0x1970, 0x1974, + 0x1980, 0x19ab, 0x19c1, 0x19c7, 0x1a00, 0x1a16, 0x1a20, 0x1a54, + 0x1b05, 0x1b33, 0x1b45, 0x1b4b, 0x1b83, 0x1ba0, 0x1bae, 0x1baf, + 0x1bba, 0x1be5, 0x1c00, 0x1c23, 0x1c4d, 0x1c4f, 0x1c5a, 0x1c7d, + 0x1ce9, 0x1cec, 0x1cee, 0x1cf1, 0x1cf5, 0x1cf6, 0x1d00, 0x1dbf, + 0x1e00, 0x1f15, 0x1f18, 0x1f1d, 0x1f20, 0x1f45, 0x1f48, 0x1f4d, + 0x1f50, 0x1f57, 0x1f5f, 0x1f7d, 0x1f80, 0x1fb4, 0x1fb6, 0x1fbc, + 0x1fc2, 0x1fc4, 0x1fc6, 0x1fcc, 0x1fd0, 0x1fd3, 0x1fd6, 0x1fdb, + 0x1fe0, 0x1fec, 0x1ff2, 0x1ff4, 0x1ff6, 0x1ffc, 0x2090, 0x209c, + 0x210a, 0x2113, 0x2119, 0x211d, 0x212a, 0x212d, 0x212f, 0x2139, + 0x213c, 0x213f, 0x2145, 0x2149, 0x2183, 0x2184, 0x2c00, 0x2c2e, + 0x2c30, 0x2c5e, 0x2c60, 0x2ce4, 0x2ceb, 0x2cee, 0x2cf2, 0x2cf3, + 0x2d00, 0x2d25, 0x2d30, 0x2d67, 0x2d80, 0x2d96, 0x2da0, 0x2da6, + 0x2da8, 0x2dae, 0x2db0, 0x2db6, 0x2db8, 0x2dbe, 0x2dc0, 0x2dc6, + 0x2dc8, 0x2dce, 0x2dd0, 0x2dd6, 0x2dd8, 0x2dde, 0x3005, 0x3006, + 0x3031, 0x3035, 0x303b, 0x303c, 0x3041, 0x3096, 0x309d, 0x309f, + 0x30a1, 0x30fa, 0x30fc, 0x30ff, 0x3105, 0x312d, 0x3131, 0x318e, + 0x31a0, 0x31ba, 0x31f0, 0x31ff, 0x3400, 0x4db5, 0x4e00, 0x9fcc, + 0xa000, 0xa48c, 0xa4d0, 0xa4fd, 0xa500, 0xa60c, 0xa610, 0xa61f, + 0xa62a, 0xa62b, 0xa640, 0xa66e, 0xa67f, 0xa697, 0xa6a0, 0xa6e5, + 0xa717, 0xa71f, 0xa722, 0xa788, 0xa78b, 0xa78e, 0xa790, 0xa793, + 0xa7a0, 0xa7aa, 0xa7f8, 0xa801, 0xa803, 0xa805, 0xa807, 0xa80a, + 0xa80c, 0xa822, 0xa840, 0xa873, 0xa882, 0xa8b3, 0xa8f2, 0xa8f7, + 0xa90a, 0xa925, 0xa930, 0xa946, 0xa960, 0xa97c, 0xa984, 0xa9b2, + 0xaa00, 0xaa28, 0xaa40, 0xaa42, 0xaa44, 0xaa4b, 0xaa60, 0xaa76, + 0xaa80, 0xaaaf, 0xaab5, 0xaab6, 0xaab9, 0xaabd, 0xaadb, 0xaadd, + 0xaae0, 0xaaea, 0xaaf2, 0xaaf4, 0xab01, 0xab06, 0xab09, 0xab0e, + 0xab11, 0xab16, 0xab20, 0xab26, 0xab28, 0xab2e, 0xabc0, 0xabe2, + 0xac00, 0xd7a3, 0xd7b0, 0xd7c6, 0xd7cb, 0xd7fb, 0xf900, 0xfa6d, + 0xfa70, 0xfad9, 0xfb00, 0xfb06, 0xfb13, 0xfb17, 0xfb1f, 0xfb28, + 0xfb2a, 0xfb36, 0xfb38, 0xfb3c, 0xfb40, 0xfb41, 0xfb43, 0xfb44, + 0xfb46, 0xfbb1, 0xfbd3, 0xfd3d, 0xfd50, 0xfd8f, 0xfd92, 0xfdc7, + 0xfdf0, 0xfdfb, 0xfe70, 0xfe74, 0xfe76, 0xfefc, 0xff21, 0xff3a, + 0xff41, 0xff5a, 0xff66, 0xffbe, 0xffc2, 0xffc7, 0xffca, 0xffcf, + 0xffd2, 0xffd7, 0xffda, 0xffdc, 0x10000, 0x1000b, 0x1000d, 0x10026, + 0x10028, 0x1003a, 0x1003c, 0x1003d, 0x1003f, 0x1004d, 0x10050, 0x1005d, + 0x10080, 0x100fa, 0x10280, 0x1029c, 0x102a0, 0x102d0, 0x10300, 0x1031e, + 0x10330, 0x10340, 0x10342, 0x10349, 0x10380, 0x1039d, 0x103a0, 0x103c3, + 0x103c8, 0x103cf, 0x10400, 0x1049d, 0x10800, 0x10805, 0x1080a, 0x10835, + 0x10837, 0x10838, 0x1083f, 0x10855, 0x10900, 0x10915, 0x10920, 0x10939, + 0x10980, 0x109b7, 0x109be, 0x109bf, 0x10a10, 0x10a13, 0x10a15, 0x10a17, + 0x10a19, 0x10a33, 0x10a60, 0x10a7c, 0x10b00, 0x10b35, 0x10b40, 0x10b55, + 0x10b60, 0x10b72, 0x10c00, 0x10c48, 0x11003, 0x11037, 0x11083, 0x110af, + 0x110d0, 0x110e8, 0x11103, 0x11126, 0x11183, 0x111b2, 0x111c1, 0x111c4, + 0x11680, 0x116aa, 0x12000, 0x1236e, 0x13000, 0x1342e, 0x16800, 0x16a38, + 0x16f00, 0x16f44, 0x16f93, 0x16f9f, 0x1b000, 0x1b001, 0x1d400, 0x1d454, + 0x1d456, 0x1d49c, 0x1d49e, 0x1d49f, 0x1d4a5, 0x1d4a6, 0x1d4a9, 0x1d4ac, + 0x1d4ae, 0x1d4b9, 0x1d4bd, 0x1d4c3, 0x1d4c5, 0x1d505, 0x1d507, 0x1d50a, + 0x1d50d, 0x1d514, 0x1d516, 0x1d51c, 0x1d51e, 0x1d539, 0x1d53b, 0x1d53e, + 0x1d540, 0x1d544, 0x1d54a, 0x1d550, 0x1d552, 0x1d6a5, 0x1d6a8, 0x1d6c0, + 0x1d6c2, 0x1d6da, 0x1d6dc, 0x1d6fa, 0x1d6fc, 0x1d714, 0x1d716, 0x1d734, + 0x1d736, 0x1d74e, 0x1d750, 0x1d76e, 0x1d770, 0x1d788, 0x1d78a, 0x1d7a8, + 0x1d7aa, 0x1d7c2, 0x1d7c4, 0x1d7cb, 0x1ee00, 0x1ee03, 0x1ee05, 0x1ee1f, + 0x1ee21, 0x1ee22, 0x1ee29, 0x1ee32, 0x1ee34, 0x1ee37, 0x1ee4d, 0x1ee4f, + 0x1ee51, 0x1ee52, 0x1ee61, 0x1ee62, 0x1ee67, 0x1ee6a, 0x1ee6c, 0x1ee72, + 0x1ee74, 0x1ee77, 0x1ee79, 0x1ee7c, 0x1ee80, 0x1ee89, 0x1ee8b, 0x1ee9b, + 0x1eea1, 0x1eea3, 0x1eea5, 0x1eea9, 0x1eeab, 0x1eebb, 0x20000, 0x2a6d6, + 0x2a700, 0x2b734, 0x2b740, 0x2b81d, 0x2f800, 0x2fa1d, +}; + +static Rune __isalphas[] = { + 0x00aa, 0x00b5, 0x00ba, 0x02ec, 0x02ee, 0x0386, 0x038c, 0x0559, + 0x06d5, 0x06ff, 0x0710, 0x07b1, 0x07fa, 0x081a, 0x0824, 0x0828, + 0x08a0, 0x093d, 0x0950, 0x09b2, 0x09bd, 0x09ce, 0x0a5e, 0x0abd, + 0x0ad0, 0x0b3d, 0x0b71, 0x0b83, 0x0b9c, 0x0bd0, 0x0c3d, 0x0cbd, + 0x0cde, 0x0d3d, 0x0d4e, 0x0dbd, 0x0e84, 0x0e8a, 0x0e8d, 0x0ea5, + 0x0ea7, 0x0ebd, 0x0ec6, 0x0f00, 0x103f, 0x1061, 0x108e, 0x10c7, + 0x10cd, 0x1258, 0x12c0, 0x17d7, 0x17dc, 0x18aa, 0x1aa7, 0x1f59, + 0x1f5b, 0x1f5d, 0x1fbe, 0x2071, 0x207f, 0x2102, 0x2107, 0x2115, + 0x2124, 0x2126, 0x2128, 0x214e, 0x2d27, 0x2d2d, 0x2d6f, 0x2e2f, + 0xa8fb, 0xa9cf, 0xaa7a, 0xaab1, 0xaac0, 0xaac2, 0xfb1d, 0xfb3e, + 0x10808, 0x1083c, 0x10a00, 0x16f50, 0x1d4a2, 0x1d4bb, 0x1d546, 0x1ee24, + 0x1ee27, 0x1ee39, 0x1ee3b, 0x1ee42, 0x1ee47, 0x1ee49, 0x1ee4b, 0x1ee54, + 0x1ee57, 0x1ee59, 0x1ee5b, 0x1ee5d, 0x1ee5f, 0x1ee64, 0x1ee7e, +}; + +int utf_isalpharune(Rune c) { + Rune *p; + + p = rbsearch(c, __isalphar, nelem(__isalphar) / 2, 2); + if (p && c >= p[0] && c <= p[1]) return 1; + p = rbsearch(c, __isalphas, nelem(__isalphas), 1); + if (p && c == p[0]) return 1; + return 0; +} + +static Rune __tolowerr[] = { + 0x0041, 0x005a, 1048608, 0x00c0, 0x00d6, 1048608, 0x00d8, 0x00de, 1048608, + 0x0189, 0x018a, 1048781, 0x01b1, 0x01b2, 1048793, 0x0388, 0x038a, 1048613, + 0x038e, 0x038f, 1048639, 0x0391, 0x03a1, 1048608, 0x03a3, 0x03ab, 1048608, + 0x03fd, 0x03ff, 1048446, 0x0400, 0x040f, 1048656, 0x0410, 0x042f, 1048608, + 0x0531, 0x0556, 1048624, 0x10a0, 0x10c5, 1055840, 0x1f08, 0x1f0f, 1048568, + 0x1f18, 0x1f1d, 1048568, 0x1f28, 0x1f2f, 1048568, 0x1f38, 0x1f3f, 1048568, + 0x1f48, 0x1f4d, 1048568, 0x1f68, 0x1f6f, 1048568, 0x1f88, 0x1f8f, 1048568, + 0x1f98, 0x1f9f, 1048568, 0x1fa8, 0x1faf, 1048568, 0x1fb8, 0x1fb9, 1048568, + 0x1fba, 0x1fbb, 1048502, 0x1fc8, 0x1fcb, 1048490, 0x1fd8, 0x1fd9, 1048568, + 0x1fda, 0x1fdb, 1048476, 0x1fe8, 0x1fe9, 1048568, 0x1fea, 0x1feb, 1048464, + 0x1ff8, 0x1ff9, 1048448, 0x1ffa, 0x1ffb, 1048450, 0x2160, 0x216f, 1048592, + 0x24b6, 0x24cf, 1048602, 0x2c00, 0x2c2e, 1048624, 0x2c7e, 0x2c7f, 1037761, + 0xff21, 0xff3a, 1048608, 0x10400, 0x10427, 1048616, +}; + +static Rune __tolowerp[] = { + 0x0100, 0x012e, 1048577, 0x0132, 0x0136, 1048577, 0x0139, 0x0147, 1048577, + 0x014a, 0x0176, 1048577, 0x017b, 0x017d, 1048577, 0x01a2, 0x01a4, 1048577, + 0x01b3, 0x01b5, 1048577, 0x01cd, 0x01db, 1048577, 0x01de, 0x01ee, 1048577, + 0x01f8, 0x021e, 1048577, 0x0222, 0x0232, 1048577, 0x0248, 0x024e, 1048577, + 0x0370, 0x0372, 1048577, 0x03d8, 0x03ee, 1048577, 0x0460, 0x0480, 1048577, + 0x048a, 0x04be, 1048577, 0x04c3, 0x04cd, 1048577, 0x04d0, 0x0526, 1048577, + 0x1e00, 0x1e94, 1048577, 0x1ea0, 0x1efe, 1048577, 0x1f59, 0x1f5f, 1048568, + 0x2c67, 0x2c6b, 1048577, 0x2c80, 0x2ce2, 1048577, 0x2ceb, 0x2ced, 1048577, + 0xa640, 0xa66c, 1048577, 0xa680, 0xa696, 1048577, 0xa722, 0xa72e, 1048577, + 0xa732, 0xa76e, 1048577, 0xa779, 0xa77b, 1048577, 0xa780, 0xa786, 1048577, + 0xa790, 0xa792, 1048577, 0xa7a0, 0xa7a8, 1048577, +}; + +static Rune __tolowers[] = { + 0x0130, 1048377, 0x0178, 1048455, 0x0179, 1048577, 0x0181, 1048786, + 0x0182, 1048577, 0x0184, 1048577, 0x0186, 1048782, 0x0187, 1048577, + 0x018b, 1048577, 0x018e, 1048655, 0x018f, 1048778, 0x0190, 1048779, + 0x0191, 1048577, 0x0193, 1048781, 0x0194, 1048783, 0x0196, 1048787, + 0x0197, 1048785, 0x0198, 1048577, 0x019c, 1048787, 0x019d, 1048789, + 0x019f, 1048790, 0x01a0, 1048577, 0x01a6, 1048794, 0x01a7, 1048577, + 0x01a9, 1048794, 0x01ac, 1048577, 0x01ae, 1048794, 0x01af, 1048577, + 0x01b7, 1048795, 0x01b8, 1048577, 0x01bc, 1048577, 0x01c4, 1048578, + 0x01c5, 1048577, 0x01c7, 1048578, 0x01c8, 1048577, 0x01ca, 1048578, + 0x01cb, 1048577, 0x01f1, 1048578, 0x01f2, 1048577, 0x01f4, 1048577, + 0x01f6, 1048479, 0x01f7, 1048520, 0x0220, 1048446, 0x023a, 1059371, + 0x023b, 1048577, 0x023d, 1048413, 0x023e, 1059368, 0x0241, 1048577, + 0x0243, 1048381, 0x0244, 1048645, 0x0245, 1048647, 0x0246, 1048577, + 0x0376, 1048577, 0x0386, 1048614, 0x038c, 1048640, 0x03cf, 1048584, + 0x03f4, 1048516, 0x03f7, 1048577, 0x03f9, 1048569, 0x03fa, 1048577, + 0x04c0, 1048591, 0x04c1, 1048577, 0x10c7, 1055840, 0x10cd, 1055840, + 0x1e9e, 1040961, 0x1fbc, 1048567, 0x1fcc, 1048567, 0x1fec, 1048569, + 0x1ffc, 1048567, 0x2126, 1041059, 0x212a, 1040193, 0x212b, 1040314, + 0x2132, 1048604, 0x2183, 1048577, 0x2c60, 1048577, 0x2c62, 1037833, + 0x2c63, 1044762, 0x2c64, 1037849, 0x2c6d, 1037796, 0x2c6e, 1037827, + 0x2c6f, 1037793, 0x2c70, 1037794, 0x2c72, 1048577, 0x2c75, 1048577, + 0x2cf2, 1048577, 0xa77d, 1013244, 0xa77e, 1048577, 0xa78b, 1048577, + 0xa78d, 1006296, 0xa7aa, 1006268, +}; + +Rune utf_tolowerrune(Rune c) { + Rune *p; + + p = rbsearch(c, __tolowerr, nelem(__tolowerr) / 3, 3); + if (p && c >= p[0] && c <= p[1]) return c + p[2] - 1048576; + p = rbsearch(c, __tolowerp, nelem(__tolowerp) / 3, 3); + if (p && c >= p[0] && c <= p[1] && !((c - p[0]) & 1)) + return c + p[2] - 1048576; + p = rbsearch(c, __tolowers, nelem(__tolowers) / 2, 2); + if (p && c == p[0]) return c + p[1] - 1048576; + return c; +} + +#endif diff --git a/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h new file mode 100644 index 000000000..24d9b9dbe --- /dev/null +++ b/mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h @@ -0,0 +1,98 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Fork of several UTF utils originally written by Rob Pike and Ken Thompson. +#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_ +#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_ 1 + +#include + +// Code-point values in Unicode 4.0 are 21 bits wide. +typedef signed int Rune; + +#define uchar _utfuchar + +typedef unsigned char uchar; + +#define nelem(x) (sizeof(x) / sizeof((x)[0])) + +enum { + UTFmax = 4, // maximum bytes per rune + Runeerror = 0xFFFD, // decoding error in UTF + Runemax = 0x10FFFF, // maximum rune value +}; + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * rune routines + */ + +/* + * These routines were written by Rob Pike and Ken Thompson + * and first appeared in Plan 9. + * SEE ALSO + * utf (7) + * tcs (1) + */ + +// utf_runetochar copies (encodes) one rune, pointed to by r, to at most +// UTFmax bytes starting at s and returns the number of bytes generated. + +int utf_runetochar(char* s, const Rune* r); + +// utf_charntorune copies (decodes) at most UTFmax bytes starting at `str` to +// one rune, pointed to by `rune`, access at most `length` bytes of `str`, and +// returns the number of bytes consumed. +// If the UTF sequence is incomplete within n bytes, +// utf_charntorune will set *r to Runeerror and return 0. If it is complete +// but not in UTF format, it will set *r to Runeerror and return 1. +// +// Added 2004-09-24 by Wei-Hwa Huang + +int utf_charntorune(Rune* rune, const char* str, int length); + +// Unicode defines some characters as letters and +// specifies three cases: upper, lower, and title. Mappings among the +// cases are also defined, although they are not exhaustive: some +// upper case letters have no lower case mapping, and so on. Unicode +// also defines several character properties, a subset of which are +// checked by these routines. These routines are based on Unicode +// version 3.0.0. +// +// NOTE: The routines are implemented in C, so isalpharrune returns 0 for false +// and 1 for true. +// +// utf_tolowerrune is the Unicode case mapping. It returns the character +// unchanged if it has no defined mapping. + +Rune utf_tolowerrune(Rune r); + +// utf_isalpharune tests for Unicode letters; this includes ideographs in +// addition to alphabetic characters. + +int utf_isalpharune(Rune r); + +// (The comments in this file were copied from the manpage files rune.3, +// isalpharune.3, and runestrcat.3. Some formatting changes were also made +// to conform to Google style. /JRM 11/11/05) + +#ifdef __cplusplus +} +#endif + +#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_ diff --git a/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer_test.cc b/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer_test.cc index a42719446..88afabe1e 100644 --- a/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer_test.cc +++ b/mediapipe/tasks/cc/text/tokenizers/sentencepiece_tokenizer_test.cc @@ -34,7 +34,7 @@ constexpr char kTestSPModelPath[] = std::unique_ptr CreateSentencePieceTokenizer( absl::string_view model_path) { - // We are using `LoadBinaryContent()` instead of loading the model direclty + // We are using `LoadBinaryContent()` instead of loading the model directly // via `SentencePieceTokenizer` so that the file can be located on Windows std::string buffer = LoadBinaryContent(kTestSPModelPath); return absl::make_unique(buffer.data(), diff --git a/mediapipe/tasks/cc/vision/face_detector/face_detector.h b/mediapipe/tasks/cc/vision/face_detector/face_detector.h index 78715528f..ae485819d 100644 --- a/mediapipe/tasks/cc/vision/face_detector/face_detector.h +++ b/mediapipe/tasks/cc/vision/face_detector/face_detector.h @@ -74,7 +74,7 @@ class FaceDetector : core::BaseVisionTaskApi { // three running modes: // 1) Image mode for detecting faces on single image inputs. Users // provide mediapipe::Image to the `Detect` method, and will receive the - // deteced face detection results as the return value. + // detected face detection results as the return value. // 2) Video mode for detecting faces on the decoded frames of a // video. Users call `DetectForVideo` method, and will receive the detected // face detection results as the return value. diff --git a/mediapipe/tasks/cc/vision/face_geometry/BUILD b/mediapipe/tasks/cc/vision/face_geometry/BUILD index 265b0dc9e..6bd9912b2 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/BUILD +++ b/mediapipe/tasks/cc/vision/face_geometry/BUILD @@ -19,9 +19,6 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) cc_library( name = "face_geometry_from_landmarks_graph", srcs = ["face_geometry_from_landmarks_graph.cc"], - data = [ - "//mediapipe/tasks/cc/vision/face_geometry/data:geometry_pipeline_metadata_landmarks", - ], deps = [ "//mediapipe/calculators/core:begin_loop_calculator", "//mediapipe/calculators/core:end_loop_calculator", @@ -39,6 +36,7 @@ cc_library( "//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_cc_proto", "//mediapipe/util:graph_builder_utils", "@com_google_absl//absl/status:statusor", ], diff --git a/mediapipe/tasks/cc/vision/face_geometry/calculators/BUILD b/mediapipe/tasks/cc/vision/face_geometry/calculators/BUILD index b134c81f4..3f2833f3b 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/face_geometry/calculators/BUILD @@ -45,6 +45,7 @@ mediapipe_proto_library( srcs = ["geometry_pipeline_calculator.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", + "//mediapipe/tasks/cc/core/proto:external_file_proto", ], ) @@ -59,6 +60,9 @@ cc_library( "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", + "//mediapipe/tasks/cc:common", + "//mediapipe/tasks/cc/core:external_file_handler", + "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline", "//mediapipe/tasks/cc/vision/face_geometry/libs:validation_utils", "//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto", @@ -66,6 +70,7 @@ cc_library( "//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto", "//mediapipe/util:resource_util", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings:str_format", ], alwayslink = 1, ) diff --git a/mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.cc b/mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.cc index d6082e62d..9dead7289 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.cc +++ b/mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.cc @@ -18,12 +18,16 @@ #include #include "absl/memory/memory.h" +#include "absl/strings/str_format.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/statusor.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/core/external_file_handler.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.h" #include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h" @@ -39,13 +43,50 @@ static constexpr char kEnvironmentTag[] = "ENVIRONMENT"; static constexpr char kImageSizeTag[] = "IMAGE_SIZE"; static constexpr char kMultiFaceGeometryTag[] = "MULTI_FACE_GEOMETRY"; static constexpr char kMultiFaceLandmarksTag[] = "MULTI_FACE_LANDMARKS"; +static constexpr char kFaceGeometryTag[] = "FACE_GEOMETRY"; +static constexpr char kFaceLandmarksTag[] = "FACE_LANDMARKS"; using ::mediapipe::tasks::vision::face_geometry::proto::Environment; using ::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry; using ::mediapipe::tasks::vision::face_geometry::proto:: GeometryPipelineMetadata; -// A calculator that renders a visual effect for multiple faces. +absl::Status SanityCheck(CalculatorContract* cc) { + if (!(cc->Inputs().HasTag(kFaceLandmarksTag) ^ + cc->Inputs().HasTag(kMultiFaceLandmarksTag))) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Only one of %s and %s can be set at a time.", + kFaceLandmarksTag, kMultiFaceLandmarksTag)); + } + if (!(cc->Outputs().HasTag(kFaceGeometryTag) ^ + cc->Outputs().HasTag(kMultiFaceGeometryTag))) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat("Only one of %s and %s can be set at a time.", + kFaceGeometryTag, kMultiFaceGeometryTag)); + } + if (cc->Inputs().HasTag(kFaceLandmarksTag) != + cc->Outputs().HasTag(kFaceGeometryTag)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "%s and %s must both be set or neither be set and a time.", + kFaceLandmarksTag, kFaceGeometryTag)); + } + if (cc->Inputs().HasTag(kMultiFaceLandmarksTag) != + cc->Outputs().HasTag(kMultiFaceGeometryTag)) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrFormat( + "%s and %s must both be set or neither be set and a time.", + kMultiFaceLandmarksTag, kMultiFaceGeometryTag)); + } + return absl::OkStatus(); +} + +// A calculator that renders a visual effect for multiple faces. Support single +// face landmarks or multiple face landmarks. // // Inputs: // IMAGE_SIZE (`std::pair`, required): @@ -56,8 +97,12 @@ using ::mediapipe::tasks::vision::face_geometry::proto:: // ratio. If used as-is, the resulting face geometry visualization should be // happening on a frame with the same ratio as well. // -// MULTI_FACE_LANDMARKS (`std::vector`, required): -// A vector of face landmark lists. +// MULTI_FACE_LANDMARKS (`std::vector`, optional): +// A vector of face landmark lists. If connected, the output stream +// MULTI_FACE_GEOMETRY must be connected. +// FACE_LANDMARKS (NormalizedLandmarkList, optional): +// A NormalizedLandmarkList of single face landmark lists. If connected, the +// output stream FACE_GEOMETRY must be connected. // // Input side packets: // ENVIRONMENT (`proto::Environment`, required) @@ -65,12 +110,14 @@ using ::mediapipe::tasks::vision::face_geometry::proto:: // as well as virtual camera parameters. // // Output: -// MULTI_FACE_GEOMETRY (`std::vector`, required): -// A vector of face geometry data. +// MULTI_FACE_GEOMETRY (`std::vector`, optional): +// A vector of face geometry data if MULTI_FACE_LANDMARKS is connected . +// FACE_GEOMETRY (FaceGeometry, optional): +// A FaceGeometry of the face landmarks if FACE_LANDMARKS is connected. // // Options: -// metadata_path (`string`, optional): -// Defines a path for the geometry pipeline metadata file. +// metadata_file (`ExternalFile`, optional): +// Defines an ExternalFile for the geometry pipeline metadata file. // // The geometry pipeline metadata file format must be the binary // `GeometryPipelineMetadata` proto. @@ -79,13 +126,21 @@ class GeometryPipelineCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { cc->InputSidePackets().Tag(kEnvironmentTag).Set(); + MP_RETURN_IF_ERROR(SanityCheck(cc)); cc->Inputs().Tag(kImageSizeTag).Set>(); - cc->Inputs() - .Tag(kMultiFaceLandmarksTag) - .Set>(); - cc->Outputs().Tag(kMultiFaceGeometryTag).Set>(); - - return absl::OkStatus(); + if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) { + cc->Inputs() + .Tag(kMultiFaceLandmarksTag) + .Set>(); + cc->Outputs().Tag(kMultiFaceGeometryTag).Set>(); + return absl::OkStatus(); + } else { + cc->Inputs() + .Tag(kFaceLandmarksTag) + .Set(); + cc->Outputs().Tag(kFaceGeometryTag).Set(); + return absl::OkStatus(); + } } absl::Status Open(CalculatorContext* cc) override { @@ -95,7 +150,7 @@ class GeometryPipelineCalculator : public CalculatorBase { ASSIGN_OR_RETURN( GeometryPipelineMetadata metadata, - ReadMetadataFromFile(options.metadata_path()), + ReadMetadataFromFile(options.metadata_file()), _ << "Failed to read the geometry pipeline metadata from file!"); MP_RETURN_IF_ERROR(ValidateGeometryPipelineMetadata(metadata)) @@ -110,41 +165,69 @@ class GeometryPipelineCalculator : public CalculatorBase { ASSIGN_OR_RETURN(geometry_pipeline_, CreateGeometryPipeline(environment, metadata), _ << "Failed to create a geometry pipeline!"); - return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) override { - // Both the `IMAGE_SIZE` and the `MULTI_FACE_LANDMARKS` streams are required - // to have a non-empty packet. In case this requirement is not met, there's - // nothing to be processed at the current timestamp. - if (cc->Inputs().Tag(kImageSizeTag).IsEmpty() || - cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty()) { + // Both the `IMAGE_SIZE` and either the `FACE_LANDMARKS` or + // `MULTI_FACE_LANDMARKS` streams are required to have a non-empty packet. + // In case this requirement is not met, there's nothing to be processed at + // the current timestamp and we return early (checked here and below). + if (cc->Inputs().Tag(kImageSizeTag).IsEmpty()) { return absl::OkStatus(); } const auto& image_size = cc->Inputs().Tag(kImageSizeTag).Get>(); - const auto& multi_face_landmarks = - cc->Inputs() - .Tag(kMultiFaceLandmarksTag) - .Get>(); - auto multi_face_geometry = absl::make_unique>(); + if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) { + if (cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty()) { + return absl::OkStatus(); + } - ASSIGN_OR_RETURN( - *multi_face_geometry, - geometry_pipeline_->EstimateFaceGeometry( - multi_face_landmarks, // - /*frame_width*/ image_size.first, - /*frame_height*/ image_size.second), - _ << "Failed to estimate face geometry for multiple faces!"); + const auto& multi_face_landmarks = + cc->Inputs() + .Tag(kMultiFaceLandmarksTag) + .Get>(); - cc->Outputs() - .Tag(kMultiFaceGeometryTag) - .AddPacket(mediapipe::Adopt>( - multi_face_geometry.release()) - .At(cc->InputTimestamp())); + auto multi_face_geometry = absl::make_unique>(); + + ASSIGN_OR_RETURN( + *multi_face_geometry, + geometry_pipeline_->EstimateFaceGeometry( + multi_face_landmarks, // + /*frame_width*/ image_size.first, + /*frame_height*/ image_size.second), + _ << "Failed to estimate face geometry for multiple faces!"); + + cc->Outputs() + .Tag(kMultiFaceGeometryTag) + .AddPacket(mediapipe::Adopt>( + multi_face_geometry.release()) + .At(cc->InputTimestamp())); + } else if (cc->Inputs().HasTag(kFaceLandmarksTag)) { + if (cc->Inputs().Tag(kFaceLandmarksTag).IsEmpty()) { + return absl::OkStatus(); + } + + const auto& face_landmarks = + cc->Inputs() + .Tag(kFaceLandmarksTag) + .Get(); + + ASSIGN_OR_RETURN( + std::vector multi_face_geometry, + geometry_pipeline_->EstimateFaceGeometry( + {face_landmarks}, // + /*frame_width*/ image_size.first, + /*frame_height*/ image_size.second), + _ << "Failed to estimate face geometry for multiple faces!"); + + cc->Outputs() + .Tag(kFaceGeometryTag) + .AddPacket(mediapipe::MakePacket(multi_face_geometry[0]) + .At(cc->InputTimestamp())); + } return absl::OkStatus(); } @@ -155,32 +238,19 @@ class GeometryPipelineCalculator : public CalculatorBase { private: static absl::StatusOr ReadMetadataFromFile( - const std::string& metadata_path) { - ASSIGN_OR_RETURN(std::string metadata_blob, - ReadContentBlobFromFile(metadata_path), - _ << "Failed to read a metadata blob from file!"); + const core::proto::ExternalFile& metadata_file) { + ASSIGN_OR_RETURN( + const auto file_handler, + core::ExternalFileHandler::CreateFromExternalFile(&metadata_file)); GeometryPipelineMetadata metadata; - RET_CHECK(metadata.ParseFromString(metadata_blob)) + RET_CHECK( + metadata.ParseFromString(std::string(file_handler->GetFileContent()))) << "Failed to parse a metadata proto from a binary blob!"; return metadata; } - static absl::StatusOr ReadContentBlobFromFile( - const std::string& unresolved_path) { - ASSIGN_OR_RETURN(std::string resolved_path, - mediapipe::PathToResourceAsFile(unresolved_path), - _ << "Failed to resolve path! Path = " << unresolved_path); - - std::string content_blob; - MP_RETURN_IF_ERROR( - mediapipe::GetResourceContents(resolved_path, &content_blob)) - << "Failed to read content blob! Resolved path = " << resolved_path; - - return content_blob; - } - std::unique_ptr geometry_pipeline_; }; diff --git a/mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.proto b/mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.proto index afcc20a13..a748cdf8b 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.proto +++ b/mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.proto @@ -17,11 +17,12 @@ syntax = "proto2"; package mediapipe.tasks.vision.face_geometry; import "mediapipe/framework/calculator_options.proto"; +import "mediapipe/tasks/cc/core/proto/external_file.proto"; message FaceGeometryPipelineCalculatorOptions { extend mediapipe.CalculatorOptions { optional FaceGeometryPipelineCalculatorOptions ext = 512499200; } - optional string metadata_path = 1; + optional core.proto.ExternalFile metadata_file = 1; } diff --git a/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph.cc b/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph.cc index 08b3d1bf4..8c69a31fd 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph.cc +++ b/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" +#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.pb.h" #include "mediapipe/util/graph_builder_utils.h" namespace mediapipe::tasks::vision::face_geometry { @@ -49,10 +50,6 @@ constexpr char kIterableTag[] = "ITERABLE"; constexpr char kBatchEndTag[] = "BATCH_END"; constexpr char kItemTag[] = "ITEM"; -constexpr char kGeometryPipelineMetadataPath[] = - "mediapipe/tasks/cc/vision/face_geometry/data/" - "geometry_pipeline_metadata_landmarks.binarypb"; - struct FaceGeometryOuts { Stream> multi_face_geometry; }; @@ -127,6 +124,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph { } ASSIGN_OR_RETURN(auto outs, BuildFaceGeometryFromLandmarksGraph( + *sc->MutableOptions(), graph.In(kFaceLandmarksTag) .Cast>(), graph.In(kImageSizeTag).Cast>(), @@ -138,6 +136,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph { private: absl::StatusOr BuildFaceGeometryFromLandmarksGraph( + proto::FaceGeometryGraphOptions& graph_options, Stream> multi_face_landmarks, Stream> image_size, std::optional> environment, Graph& graph) { @@ -185,7 +184,8 @@ class FaceGeometryFromLandmarksGraph : public Subgraph { "mediapipe.tasks.vision.face_geometry.FaceGeometryPipelineCalculator"); auto& geometry_pipeline_options = geometry_pipeline.GetOptions(); - geometry_pipeline_options.set_metadata_path(kGeometryPipelineMetadataPath); + geometry_pipeline_options.Swap( + graph_options.mutable_geometry_pipeline_options()); image_size >> geometry_pipeline.In(kImageSizeTag); multi_face_landmarks_no_iris >> geometry_pipeline.In(kMultiFaceLandmarksTag); diff --git a/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph_test.cc b/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph_test.cc index df935135d..74baff5d8 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph_test.cc +++ b/mediapipe/tasks/cc/vision/face_geometry/face_geometry_from_landmarks_graph_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_runner.h" @@ -31,6 +32,7 @@ limitations under the License. #include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/tool/sink.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" @@ -49,6 +51,9 @@ constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; constexpr char kFaceLandmarksFileName[] = "face_blendshapes_in_landmarks.prototxt"; constexpr char kFaceGeometryFileName[] = "face_geometry_expected_out.pbtxt"; +constexpr char kGeometryPipelineMetadataPath[] = + "mediapipe/tasks/cc/vision/face_geometry/data/" + "geometry_pipeline_metadata_landmarks.binarypb"; std::vector GetLandmarks(absl::string_view filename) { NormalizedLandmarkList landmarks; @@ -89,17 +94,25 @@ void MakeInputPacketsAndRunGraph(CalculatorGraph& graph) { TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie< - CalculatorGraphConfig>(R"pb( - input_stream: "FACE_LANDMARKS:face_landmarks" - input_stream: "IMAGE_SIZE:image_size" - output_stream: "FACE_GEOMETRY:face_geometry" - node { - calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph" - input_stream: "FACE_LANDMARKS:face_landmarks" - input_stream: "IMAGE_SIZE:image_size" - output_stream: "FACE_GEOMETRY:face_geometry" - } - )pb"); + CalculatorGraphConfig>(absl::Substitute( + R"pb( + input_stream: "FACE_LANDMARKS:face_landmarks" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "FACE_GEOMETRY:face_geometry" + node { + calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph" + input_stream: "FACE_LANDMARKS:face_landmarks" + input_stream: "IMAGE_SIZE:image_size" + output_stream: "FACE_GEOMETRY:face_geometry" + options: { + [mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions + .ext] { + geometry_pipeline_options { metadata_file { file_name: "$0" } } + } + } + } + )pb", + kGeometryPipelineMetadataPath)); std::vector output_packets; tool::AddVectorSink("face_geometry", &graph_config, &output_packets); @@ -116,19 +129,27 @@ TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) { TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) { CalculatorGraphConfig graph_config = ParseTextProtoOrDie< - CalculatorGraphConfig>(R"pb( - input_stream: "FACE_LANDMARKS:face_landmarks" - input_stream: "IMAGE_SIZE:image_size" - input_side_packet: "ENVIRONMENT:environment" - output_stream: "FACE_GEOMETRY:face_geometry" - node { - calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph" - input_stream: "FACE_LANDMARKS:face_landmarks" - input_stream: "IMAGE_SIZE:image_size" - input_side_packet: "ENVIRONMENT:environment" - output_stream: "FACE_GEOMETRY:face_geometry" - } - )pb"); + CalculatorGraphConfig>(absl::Substitute( + R"pb( + input_stream: "FACE_LANDMARKS:face_landmarks" + input_stream: "IMAGE_SIZE:image_size" + input_side_packet: "ENVIRONMENT:environment" + output_stream: "FACE_GEOMETRY:face_geometry" + node { + calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph" + input_stream: "FACE_LANDMARKS:face_landmarks" + input_stream: "IMAGE_SIZE:image_size" + input_side_packet: "ENVIRONMENT:environment" + output_stream: "FACE_GEOMETRY:face_geometry" + options: { + [mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions + .ext] { + geometry_pipeline_options { metadata_file { file_name: "$0" } } + } + } + } + )pb", + kGeometryPipelineMetadataPath)); std::vector output_packets; tool::AddVectorSink("face_geometry", &graph_config, &output_packets); diff --git a/mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.cc b/mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.cc index c7ac7c634..061f35f51 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.cc +++ b/mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.cc @@ -99,7 +99,7 @@ class ScreenToMetricSpaceConverter { // // (3) Use the canonical-to-runtime scale from (2) to unproject the screen // landmarks. The result is referenced as "intermediate landmarks" because - // they are the first estimation of the resuling metric landmarks, but are + // they are the first estimation of the resulting metric landmarks,but are // not quite there yet. // // (4) Estimate a canonical-to-runtime landmark set scale by running the @@ -347,7 +347,7 @@ class GeometryPipelineImpl : public GeometryPipeline { proto::Mesh3d* mutable_mesh = face_geometry.mutable_mesh(); // Copy the canonical face mesh as the face geometry mesh. mutable_mesh->CopyFrom(canonical_mesh_); - // Replace XYZ vertex mesh coodinates with the metric landmark positions. + // Replace XYZ vertex mesh coordinates with the metric landmark positions. for (int i = 0; i < canonical_mesh_num_vertices_; ++i) { uint32_t vertex_buffer_offset = canonical_mesh_vertex_size_ * i + canonical_mesh_vertex_position_offset_; diff --git a/mediapipe/tasks/cc/vision/face_geometry/proto/BUILD b/mediapipe/tasks/cc/vision/face_geometry/proto/BUILD index 9559448f3..e337a3452 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/proto/BUILD +++ b/mediapipe/tasks/cc/vision/face_geometry/proto/BUILD @@ -13,6 +13,7 @@ # limitations under the License. load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") +load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type") licenses(["notice"]) @@ -23,6 +24,16 @@ mediapipe_proto_library( srcs = ["environment.proto"], ) +mediapipe_register_type( + base_name = "face_geometry", + include_headers = ["mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"], + types = [ + "::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry", + "::std::vector<::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry>", + ], + deps = [":face_geometry_cc_proto"], +) + mediapipe_proto_library( name = "face_geometry_proto", srcs = ["face_geometry.proto"], @@ -44,3 +55,12 @@ mediapipe_proto_library( name = "mesh_3d_proto", srcs = ["mesh_3d.proto"], ) + +mediapipe_proto_library( + name = "face_geometry_graph_options_proto", + srcs = ["face_geometry_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/face_geometry/proto/environment.proto b/mediapipe/tasks/cc/vision/face_geometry/proto/environment.proto index e60f3c1e1..dd771dfbb 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/proto/environment.proto +++ b/mediapipe/tasks/cc/vision/face_geometry/proto/environment.proto @@ -16,7 +16,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.face_geometry.proto; -option java_package = "mediapipe.tasks.vision.facegeometry.proto"; +option java_package = "com.google.mediapipe.tasks.vision.facegeometry.proto"; option java_outer_classname = "EnvironmentProto"; // Defines the (0, 0) origin point location of the environment. diff --git a/mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.proto b/mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.proto index 1934828c3..149e10afd 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.proto +++ b/mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.proto @@ -19,7 +19,7 @@ package mediapipe.tasks.vision.face_geometry.proto; import "mediapipe/framework/formats/matrix_data.proto"; import "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto"; -option java_package = "mediapipe.tasks.vision.facegeometry.proto"; +option java_package = "com.google.mediapipe.tasks.vision.facegeometry.proto"; option java_outer_classname = "FaceGeometryProto"; // Defines the face geometry pipeline estimation result format. @@ -28,7 +28,7 @@ message FaceGeometry { // the face landmark IDs. // // XYZ coordinates exist in the right-handed Metric 3D space configured by an - // environment. UV coodinates are taken from the canonical face mesh model. + // environment. UV coordinates are taken from the canonical face mesh model. // // XY coordinates are guaranteed to match the screen positions of // the input face landmarks after (1) being multiplied by the face pose diff --git a/mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.proto b/mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.proto new file mode 100644 index 000000000..03831d1dc --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.proto @@ -0,0 +1,28 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto2"; + +package mediapipe.tasks.vision.face_geometry.proto; + +import "mediapipe/framework/calculator_options.proto"; +import "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.proto"; + +message FaceGeometryGraphOptions { + extend mediapipe.CalculatorOptions { + optional FaceGeometryGraphOptions ext = 515723506; + } + + optional FaceGeometryPipelineCalculatorOptions geometry_pipeline_options = 1; +} diff --git a/mediapipe/tasks/cc/vision/face_geometry/proto/geometry_pipeline_metadata.proto b/mediapipe/tasks/cc/vision/face_geometry/proto/geometry_pipeline_metadata.proto index 54fcaf23c..53d1a1392 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/proto/geometry_pipeline_metadata.proto +++ b/mediapipe/tasks/cc/vision/face_geometry/proto/geometry_pipeline_metadata.proto @@ -18,7 +18,7 @@ package mediapipe.tasks.vision.face_geometry.proto; import "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto"; -option java_package = "mediapipe.tasks.vision.facegeometry.proto"; +option java_package = "com.google.mediapipe.tasks.vision.facegeometry.proto"; option java_outer_classname = "GeometryPipelineMetadataProto"; enum InputSource { diff --git a/mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto b/mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto index 45913cf02..1131ae7ed 100644 --- a/mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto +++ b/mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto @@ -16,7 +16,7 @@ syntax = "proto2"; package mediapipe.tasks.vision.face_geometry.proto; -option java_package = "mediapipe.tasks.vision.facegeometry.proto"; +option java_package = "com.google.mediapipe.tasks.vision.facegeometry.proto"; option java_outer_classname = "Mesh3dProto"; message Mesh3d { diff --git a/mediapipe/tasks/cc/vision/face_landmarker/BUILD b/mediapipe/tasks/cc/vision/face_landmarker/BUILD index 7ecc93b21..ac78edda5 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/BUILD +++ b/mediapipe/tasks/cc/vision/face_landmarker/BUILD @@ -129,6 +129,37 @@ cc_library( ], ) +cc_library( + name = "face_landmarker", + srcs = ["face_landmarker.cc"], + hdrs = ["face_landmarker.h"], + deps = [ + ":face_landmarker_graph", + ":face_landmarker_result", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:classification_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:landmark_cc_proto", + "//mediapipe/framework/formats:matrix", + "//mediapipe/framework/formats:matrix_data_cc_proto", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/containers:classification_result", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/core:base_task_api", + "//mediapipe/tasks/cc/core:task_runner", + "//mediapipe/tasks/cc/core:utils", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], +) + cc_library( name = "face_landmarker_result_cc", srcs = ["face_landmarker_result.cc"], @@ -179,8 +210,10 @@ cc_library( "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry:face_geometry_from_landmarks_graph", + "//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.cc new file mode 100644 index 000000000..e006b4490 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.cc @@ -0,0 +1,250 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h" + +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/matrix_data.pb.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/core/base_task_api.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/core/utils.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_landmarker { + +namespace { + +using FaceLandmarkerGraphOptionsProto = ::mediapipe::tasks::vision:: + face_landmarker::proto::FaceLandmarkerGraphOptions; + +constexpr char kFaceLandmarkerGraphTypeName[] = + "mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageInStreamName[] = "image_in"; +constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectStreamName[] = "norm_rect_in"; +constexpr char kNormLandmarksTag[] = "NORM_LANDMARKS"; +constexpr char kNormLandmarksStreamName[] = "norm_landmarks"; +constexpr char kBlendshapesTag[] = "BLENDSHAPES"; +constexpr char kBlendshapesStreamName[] = "blendshapes"; +constexpr char kFaceGeometryTag[] = "FACE_GEOMETRY"; +constexpr char kFaceGeometryStreamName[] = "face_geometry"; +constexpr int kMicroSecondsPerMilliSecond = 1000; + +// Creates a MediaPipe graph config that contains a subgraph node of +// "mediapipe.tasks.vision.face_ladnamrker.FaceLandmarkerGraph". If the task is +// running in the live stream mode, a "FlowLimiterCalculator" will be added to +// limit the number of frames in flight. +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options, + bool output_face_blendshapes, bool output_facial_transformation_matrixes, + bool enable_flow_limiting) { + api2::builder::Graph graph; + auto& subgraph = graph.AddNode(kFaceLandmarkerGraphTypeName); + subgraph.GetOptions().Swap(options.get()); + graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kNormRectTag).SetName(kNormRectStreamName); + subgraph.Out(kNormLandmarksTag).SetName(kNormLandmarksStreamName) >> + graph.Out(kNormLandmarksTag); + subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> graph.Out(kImageTag); + if (output_face_blendshapes) { + subgraph.Out(kBlendshapesTag).SetName(kBlendshapesStreamName) >> + graph.Out(kBlendshapesTag); + } + if (output_facial_transformation_matrixes) { + subgraph.Out(kFaceGeometryTag).SetName(kFaceGeometryStreamName) >> + graph.Out(kFaceGeometryTag); + } + if (enable_flow_limiting) { + return tasks::core::AddFlowLimiterCalculator( + graph, subgraph, {kImageTag, kNormRectTag}, kNormLandmarksTag); + } + graph.In(kImageTag) >> subgraph.In(kImageTag); + graph.In(kNormRectTag) >> subgraph.In(kNormRectTag); + return graph.GetConfig(); +} + +// Converts the user-facing FaceLandmarkerOptions struct to the internal +// FaceLandmarkerGraphOptions proto. +std::unique_ptr +ConvertFaceLandmarkerGraphOptionsProto(FaceLandmarkerOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + options_proto->mutable_base_options()->set_use_stream_mode( + options->running_mode != core::RunningMode::IMAGE); + + // Configure face detector options. + auto* face_detector_graph_options = + options_proto->mutable_face_detector_graph_options(); + face_detector_graph_options->set_num_faces(options->num_faces); + face_detector_graph_options->set_min_detection_confidence( + options->min_face_detection_confidence); + + // Configure face landmark detector options. + options_proto->set_min_tracking_confidence(options->min_tracking_confidence); + auto* face_landmarks_detector_graph_options = + options_proto->mutable_face_landmarks_detector_graph_options(); + face_landmarks_detector_graph_options->set_min_detection_confidence( + options->min_face_presence_confidence); + + return options_proto; +} + +FaceLandmarkerResult GetFaceLandmarkerResultFromPacketMap( + const tasks::core::PacketMap& packet_map) { + const auto& face_landmarks = packet_map.at(kNormLandmarksStreamName) + .Get>(); + std::optional> face_blendshapes; + if (packet_map.find(kBlendshapesStreamName) != packet_map.end()) { + face_blendshapes = packet_map.at(kBlendshapesStreamName) + .Get>(); + } + std::optional> matrix_data_list; + if (packet_map.find(kFaceGeometryStreamName) != packet_map.end()) { + const auto& face_geometry_list = + packet_map.at(kFaceGeometryStreamName) + .Get>(); + matrix_data_list = std::vector(face_geometry_list.size()); + std::transform(face_geometry_list.begin(), face_geometry_list.end(), + matrix_data_list->begin(), + [](const face_geometry::proto::FaceGeometry& face_geometry) { + return face_geometry.pose_transform_matrix(); + }); + } + return ConvertToFaceLandmarkerResult( + /* face_landmarks_proto = */ face_landmarks, + /* face_blendshapes_proto= */ face_blendshapes, + /* facial_transformation_matrixes_proto= */ matrix_data_list); +} + +} // namespace + +absl::StatusOr> FaceLandmarker::Create( + std::unique_ptr options) { + auto options_proto = ConvertFaceLandmarkerGraphOptionsProto(options.get()); + tasks::core::PacketsCallback packets_callback = nullptr; + if (options->result_callback) { + auto result_callback = options->result_callback; + packets_callback = [=](absl::StatusOr packet_map) { + if (!packet_map.ok()) { + Image image; + result_callback(packet_map.status(), image, Timestamp::Unset().Value()); + return; + } + if (packet_map->at(kImageOutStreamName).IsEmpty()) { + return; + } + Packet image_packet = packet_map->at(kImageOutStreamName); + if (packet_map->at(kNormLandmarksStreamName).IsEmpty()) { + Packet empty_packet = packet_map->at(kNormLandmarksStreamName); + result_callback( + {FaceLandmarkerResult()}, image_packet.Get(), + empty_packet.Timestamp().Value() / kMicroSecondsPerMilliSecond); + return; + } + result_callback( + GetFaceLandmarkerResultFromPacketMap(*packet_map), + image_packet.Get(), + packet_map->at(kNormLandmarksStreamName).Timestamp().Value() / + kMicroSecondsPerMilliSecond); + }; + } + return core::VisionTaskApiFactory::Create( + CreateGraphConfig( + std::move(options_proto), options->output_face_blendshapes, + options->output_facial_transformation_matrixes, + options->running_mode == core::RunningMode::LIVE_STREAM), + std::move(options->base_options.op_resolver), options->running_mode, + std::move(packets_callback)); +} + +absl::StatusOr FaceLandmarker::Detect( + mediapipe::Image image, + std::optional image_processing_options) { + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, + /*roi_allowed=*/false)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData( + {{kImageInStreamName, MakePacket(std::move(image))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect))}})); + if (output_packets[kNormLandmarksStreamName].IsEmpty()) { + return {FaceLandmarkerResult()}; + } + return GetFaceLandmarkerResultFromPacketMap(output_packets); +} + +absl::StatusOr FaceLandmarker::DetectForVideo( + mediapipe::Image image, int64_t timestamp_ms, + std::optional image_processing_options) { + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, + /*roi_allowed=*/false)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessVideoData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}})); + if (output_packets[kNormLandmarksStreamName].IsEmpty()) { + return {FaceLandmarkerResult()}; + } + return GetFaceLandmarkerResultFromPacketMap(output_packets); +} + +absl::Status FaceLandmarker::DetectAsync( + mediapipe::Image image, int64_t timestamp_ms, + std::optional image_processing_options) { + ASSIGN_OR_RETURN(NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, + /*roi_allowed=*/false)); + return SendLiveStreamData( + {{kImageInStreamName, + MakePacket(std::move(image)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect)) + .At(Timestamp(timestamp_ms * kMicroSecondsPerMilliSecond))}}); +} + +} // namespace face_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h new file mode 100644 index 000000000..2c93fcba5 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h @@ -0,0 +1,198 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_FACE_LANDMARKER_FACE_LANDMARKER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_FACE_LANDMARKER_FACE_LANDMARKER_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_landmarker { + +struct FaceLandmarkerOptions { + // Base options for configuring MediaPipe Tasks library, such as specifying + // the TfLite model bundle file with metadata, accelerator options, op + // resolver, etc. + tasks::core::BaseOptions base_options; + + // The running mode of the task. Default to the image mode. + // FaceLandmarker has three running modes: + // 1) The image mode for detecting face landmarks on single image inputs. + // 2) The video mode for detecting face landmarks on the decoded frames of a + // video. + // 3) The live stream mode for detecting face landmarks on the live stream of + // input data, such as from camera. In this mode, the "result_callback" + // below must be specified to receive the detection results asynchronously. + core::RunningMode running_mode = core::RunningMode::IMAGE; + + // The maximum number of faces that can be detected by the FaceLandmarker. + int num_faces = 1; + + // The minimum confidence score for the face detection to be considered + // successful. + float min_face_detection_confidence = 0.5; + + // The minimum confidence score of face presence score in the face landmark + // detection. + float min_face_presence_confidence = 0.5; + + // The minimum confidence score for the face tracking to be considered + // successful. + float min_tracking_confidence = 0.5; + + // Whether FaceLandmarker outputs face blendshapes classification. Face + // blendshapes are used for rendering the 3D face model. + bool output_face_blendshapes = false; + + // Whether FaceLandmarker outputs facial transformation_matrix. Facial + // transformation matrix is used to transform the face landmarks in canonical + // face to the detected face, so that users can apply face effects on the + // detected landmarks. + bool output_facial_transformation_matrixes = false; + + // The user-defined result callback for processing live stream data. + // The result callback should only be specified when the running mode is set + // to RunningMode::LIVE_STREAM. + std::function, const Image&, + int64_t)> + result_callback = nullptr; +}; + +// Performs face landmarks detection on the given image. +// +// TODO add the link to DevSite. +// This API expects a pre-trained face landmarker model asset bundle. +// +// Inputs: +// Image +// - The image that face landmarks detection runs on. +// std::optional +// - If provided, can be used to specify the rotation to apply to the image +// before performing face landmarks detection, by setting its 'rotation' +// field in radians (e.g. 'M_PI / 2' for a 90° anti-clockwise rotation). +// Note that specifying a region-of-interest using the 'x_center', +// 'y_center', 'width' and 'height' fields is NOT supported and will +// result in an invalid argument error being returned. +// Outputs: +// FaceLandmarkerResult +// - The face landmarks detection results. +class FaceLandmarker : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates a FaceLandmarker from a FaceLandmarkerOptions to process image data + // or streaming data. Face landmarker can be created with one of the following + // three running modes: + // 1) Image mode for detecting face landmarks on single image inputs. Users + // provide mediapipe::Image to the `Detect` method, and will receive the + // detected face landmarks results as the return value. + // 2) Video mode for detecting face landmarks on the decoded frames of a + // video. Users call `DetectForVideo` method, and will receive the detected + // face landmarks results as the return value. + // 3) Live stream mode for detecting face landmarks on the live stream of the + // input data, such as from camera. Users call `DetectAsync` to push the + // image data into the FaceLandmarker, the detected results along with the + // input timestamp and the image that face landmarker runs on will be + // available in the result callback when the face landmarker finishes the + // work. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs face landmarks detection on the given image. + // Only use this method when the FaceLandmarker is created with the image + // running mode. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by setting + // its 'rotation_degrees' field. Note that specifying a region-of-interest + // using the 'region_of_interest' field is NOT supported and will result in an + // invalid argument error being returned. + // + // The image can be of any size with format RGB or RGBA. + // TODO: Describes how the input image will be preprocessed + // after the yuv support is implemented. + absl::StatusOr Detect( + Image image, + std::optional image_processing_options = + std::nullopt); + + // Performs face landmarks detection on the provided video frame. + // Only use this method when the FaceLandmarker is created with the video + // running mode. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by setting + // its 'rotation_degrees' field. Note that specifying a region-of-interest + // using the 'region_of_interest' field is NOT supported and will result in an + // invalid argument error being returned. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide the video frame's timestamp (in milliseconds). The input timestamps + // must be monotonically increasing. + absl::StatusOr DetectForVideo( + Image image, int64_t timestamp_ms, + std::optional image_processing_options = + std::nullopt); + + // Sends live image data to perform face landmarks detection, and the results + // will be available via the "result_callback" provided in the + // FaceLandmarkerOptions. Only use this method when the FaceLandmarker + // is created with the live stream running mode. + // + // The image can be of any size with format RGB or RGBA. It's required to + // provide a timestamp (in milliseconds) to indicate when the input image is + // sent to the face landmarker. The input timestamps must be monotonically + // increasing. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing detection, by setting + // its 'rotation_degrees' field. Note that specifying a region-of-interest + // using the 'region_of_interest' field is NOT supported and will result in an + // invalid argument error being returned. + // + // The "result_callback" provides + // - A vector of FaceLandmarkerResult, each is the detected results + // for a input frame. + // - The const reference to the corresponding input image that the face + // landmarker runs on. Note that the const reference to the image will no + // longer be valid when the callback returns. To access the image data + // outside of the callback, callers need to make a copy of the image. + // - The input timestamp in milliseconds. + absl::Status DetectAsync(Image image, int64_t timestamp_ms, + std::optional + image_processing_options = std::nullopt); + + // Shuts down the FaceLandmarker when all works are done. + absl::Status Close() { return runner_->Close(); } +}; + +} // namespace face_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_FACE_LANDMARKER_FACE_LANDMARKER_H_ diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc index 52c8b08a0..99d99466a 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_graph.cc @@ -40,8 +40,10 @@ limitations under the License. #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" +#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" @@ -93,6 +95,8 @@ constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite"; constexpr char kFaceLandmarksDetectorTFLiteName[] = "face_landmarks_detector.tflite"; constexpr char kFaceBlendshapeTFLiteName[] = "face_blendshapes.tflite"; +constexpr char kFaceGeometryPipelineMetadataName[] = + "geometry_pipeline_metadata_landmarks.binarypb"; struct FaceLandmarkerOutputs { Source> landmark_lists; @@ -112,7 +116,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, options->mutable_face_detector_graph_options(); if (!face_detector_graph_options->base_options().has_model_asset()) { ASSIGN_OR_RETURN(const auto face_detector_file, - resources.GetModelFile(kFaceDetectorTFLiteName)); + resources.GetFile(kFaceDetectorTFLiteName)); SetExternalFile(face_detector_file, face_detector_graph_options->mutable_base_options() ->mutable_model_asset(), @@ -128,7 +132,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, if (!face_landmarks_detector_graph_options->base_options() .has_model_asset()) { ASSIGN_OR_RETURN(const auto face_landmarks_detector_file, - resources.GetModelFile(kFaceLandmarksDetectorTFLiteName)); + resources.GetFile(kFaceLandmarksDetectorTFLiteName)); SetExternalFile( face_landmarks_detector_file, face_landmarks_detector_graph_options->mutable_base_options() @@ -142,7 +146,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, ->set_use_stream_mode(options->base_options().use_stream_mode()); absl::StatusOr face_blendshape_model = - resources.GetModelFile(kFaceBlendshapeTFLiteName); + resources.GetFile(kFaceBlendshapeTFLiteName); if (face_blendshape_model.ok()) { SetExternalFile(*face_blendshape_model, face_landmarks_detector_graph_options @@ -156,7 +160,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, ->mutable_acceleration() ->mutable_xnnpack(); LOG(WARNING) << "Face blendshape model contains CPU only ops. Sets " - << "FaceBlendshapesGraph acceleartion to Xnnpack."; + << "FaceBlendshapesGraph acceleration to Xnnpack."; } return absl::OkStatus(); @@ -176,7 +180,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, // would be triggered to detect faces. // // FaceGeometryFromLandmarksGraph finds the transformation from canonical face -// to the detected faces. This transformation is useful for renderring face +// to the detected faces. This transformation is useful for rendering face // effects on the detected faces. This subgraph is added if users request a // FaceGeometry Tag. // @@ -305,6 +309,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { absl::StatusOr GetConfig( SubgraphContext* sc) override { Graph graph; + bool output_geometry = HasOutput(sc->OriginalNode(), kFaceGeometryTag); if (sc->Options() .base_options() .has_model_asset()) { @@ -318,6 +323,18 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { sc->MutableOptions(), !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) .IsAvailable())); + if (output_geometry) { + // Set the face geometry metadata file for + // FaceGeometryFromLandmarksGraph. + ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file, + model_asset_bundle_resources->GetFile( + kFaceGeometryPipelineMetadataName)); + SetExternalFile(face_geometry_pipeline_metadata_file, + sc->MutableOptions() + ->mutable_face_geometry_graph_options() + ->mutable_geometry_pipeline_options() + ->mutable_metadata_file()); + } } std::optional> environment; if (HasSideInput(sc->OriginalNode(), kEnvironmentTag)) { @@ -338,7 +355,6 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { .face_landmarks_detector_graph_options() .has_face_blendshapes_graph_options())); } - bool output_geometry = HasOutput(sc->OriginalNode(), kFaceGeometryTag); ASSIGN_OR_RETURN( auto outs, BuildFaceLandmarkerGraph( @@ -481,6 +497,9 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { auto& face_geometry_from_landmarks = graph.AddNode( "mediapipe.tasks.vision.face_geometry." "FaceGeometryFromLandmarksGraph"); + face_geometry_from_landmarks + .GetOptions() + .Swap(tasks_options.mutable_face_geometry_graph_options()); if (environment.has_value()) { *environment >> face_geometry_from_landmarks.SideIn(kEnvironmentTag); } diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.cc index 3f369cc16..53a171ed5 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.cc @@ -34,7 +34,7 @@ FaceLandmarkerResult ConvertToFaceLandmarkerResult( std::optional> face_blendshapes_proto, std::optional> - facial_transformation_matrix_proto) { + facial_transformation_matrixes_proto) { FaceLandmarkerResult result; result.face_landmarks.resize(face_landmarks_proto.size()); std::transform(face_landmarks_proto.begin(), face_landmarks_proto.end(), @@ -52,12 +52,12 @@ FaceLandmarkerResult ConvertToFaceLandmarkerResult( classification_list); }); } - if (facial_transformation_matrix_proto.has_value()) { - result.facial_transformation_matrix = - std::vector(facial_transformation_matrix_proto->size()); - std::transform(facial_transformation_matrix_proto->begin(), - facial_transformation_matrix_proto->end(), - result.facial_transformation_matrix->begin(), + if (facial_transformation_matrixes_proto.has_value()) { + result.facial_transformation_matrixes = + std::vector(facial_transformation_matrixes_proto->size()); + std::transform(facial_transformation_matrixes_proto->begin(), + facial_transformation_matrixes_proto->end(), + result.facial_transformation_matrixes->begin(), [](const mediapipe::MatrixData& matrix_proto) { mediapipe::Matrix matrix; MatrixFromMatrixDataProto(matrix_proto, &matrix); diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h index 9774d80d9..bc097d6c3 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h @@ -34,13 +34,13 @@ namespace face_landmarker { // The face landmarks detection result from FaceLandmarker, where each vector // element represents a single face detected in the image. struct FaceLandmarkerResult { - // Detected hand landmarks in normalized image coordinates. + // Detected face landmarks in normalized image coordinates. std::vector face_landmarks; // Optional face blendshapes results. std::optional> face_blendshapes; // Optional facial transformation matrix. - std::optional> facial_transformation_matrix; + std::optional> facial_transformation_matrixes; }; // Convert face landmarks result from proto format to FaceLandmarkerResult. @@ -49,7 +49,7 @@ FaceLandmarkerResult ConvertToFaceLandmarkerResult( std::optional> face_blendshapes_proto = std::nullopt, std::optional> - facial_transformation_matrix_proto = std::nullopt); + facial_transformation_matrixes_proto = std::nullopt); } // namespace face_landmarker } // namespace vision diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result_test.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result_test.cc index c3ed2d371..4123a81f3 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result_test.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result_test.cc @@ -73,9 +73,10 @@ TEST(FaceLandmarkerResultTest, Succeeds) { std::nullopt)); Matrix expected_matrix{{0, 3, 6}, {1, 4, 7}, {2, 5, 8}}; - ASSERT_TRUE(face_landmarker_result.facial_transformation_matrix.has_value()); - EXPECT_EQ(face_landmarker_result.facial_transformation_matrix->size(), 1); - EXPECT_EQ(face_landmarker_result.facial_transformation_matrix->at(0), + ASSERT_TRUE( + face_landmarker_result.facial_transformation_matrixes.has_value()); + EXPECT_EQ(face_landmarker_result.facial_transformation_matrixes->size(), 1); + EXPECT_EQ(face_landmarker_result.facial_transformation_matrixes->at(0), expected_matrix); } diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_test.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_test.cc new file mode 100644 index 000000000..033f92cf1 --- /dev/null +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_test.cc @@ -0,0 +1,455 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/classification.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/landmark.pb.h" +#include "mediapipe/framework/formats/matrix.h" +#include "mediapipe/framework/formats/matrix_data.pb.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/common.h" +#include "mediapipe/tasks/cc/components/containers/category.h" +#include "mediapipe/tasks/cc/components/containers/classification_result.h" +#include "mediapipe/tasks/cc/components/containers/landmark.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/face_landmarker/face_landmarker_result.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace face_landmarker { +namespace { + +using ::file::Defaults; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kFaceLandmarkerModelBundleName[] = "face_landmarker.task"; +constexpr char kFaceLandmarkerWithBlendshapesModelBundleName[] = + "face_landmarker_with_blendshapes.task"; +constexpr char kPortraitImageName[] = "portrait.jpg"; +constexpr char kPortraitExpectedFaceLandamrksName[] = + "portrait_expected_face_landmarks.pbtxt"; +constexpr char kPortraitExpectedFaceLandmarksWithAttentionName[] = + "portrait_expected_face_landmarks_with_attention.pbtxt"; +constexpr char kPortraitExpectedBlendshapesName[] = + "portrait_expected_blendshapes_with_attention.pbtxt"; + +constexpr float kLandmarksDiffMargin = 0.03; +constexpr float kBlendshapesDiffMargin = 0.1; +constexpr float kFacialTransformationMatrixDiffMargin = 0.02; + +template +ProtoT GetExpectedProto(absl::string_view filename) { + ProtoT expected_proto; + MP_EXPECT_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, filename), + &expected_proto, Defaults())); + return expected_proto; +} + +// Struct holding the parameters for parameterized FaceLandmarkerGraphTest +// class. +struct FaceLandmarkerTestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of the model to test. + std::string input_model_name; + // The filename of the test image. + std::string test_image_name; + // The rotation to apply to the test image before processing, in degrees + // clockwise. + int rotation; + // The expected output face landmarker result. + FaceLandmarkerResult expected_result; +}; + +mediapipe::MatrixData MakePortraitExpectedFacialTransformationMatrix() { + const Matrix matrix{{0.9995292, -0.005092691, 0.030254554, -0.37340546}, + {0.0072318087, 0.99744856, -0.07102106, 22.212194}, + {-0.029815676, 0.07120642, 0.9970159, -64.76358}, + {0, 0, 0, 1}}; + mediapipe::MatrixData matrix_data; + MatrixDataProtoFromMatrix(matrix, &matrix_data); + return matrix_data; +} + +testing::Matcher LandmarkIs( + const components::containers::NormalizedLandmark& landmark) { + return testing::AllOf( + testing::Field(&components::containers::NormalizedLandmark::x, + testing::FloatNear(landmark.x, kLandmarksDiffMargin)), + testing::Field(&components::containers::NormalizedLandmark::y, + testing::FloatNear(landmark.y, kLandmarksDiffMargin))); +} + +void ExpectLandmarksCorrect( + const std::vector + actual_landmarks, + const std::vector + expected_landmarks) { + ASSERT_EQ(actual_landmarks.size(), expected_landmarks.size()); + for (int i = 0; i < actual_landmarks.size(); ++i) { + ASSERT_EQ(actual_landmarks[i].landmarks.size(), + expected_landmarks[i].landmarks.size()); + for (int j = 0; j < actual_landmarks[i].landmarks.size(); ++j) { + EXPECT_THAT(actual_landmarks[i].landmarks[j], + LandmarkIs(expected_landmarks[i].landmarks[j])); + } + } +} + +testing::Matcher CategoryIs( + const components::containers::Category& category) { + return testing::AllOf( + testing::Field(&components::containers::Category::index, + testing::Eq(category.index)), + testing::Field( + &components::containers::Category::score, + testing::FloatNear(category.score, kBlendshapesDiffMargin))); +} + +void ExpectBlendshapesCorrect( + const std::vector& + actual_blendshapes, + const std::vector& + expected_blendshapes) { + ASSERT_EQ(actual_blendshapes.size(), expected_blendshapes.size()); + for (int i = 0; i < actual_blendshapes.size(); ++i) { + ASSERT_EQ(actual_blendshapes[i].categories.size(), + expected_blendshapes[i].categories.size()); + for (int j = 0; j < actual_blendshapes[i].categories.size(); ++j) { + EXPECT_THAT(actual_blendshapes[i].categories[j], + CategoryIs(expected_blendshapes[i].categories[j])); + } + } +} + +void ExpectFacialTransformationMatrixCorrect( + const std::vector& actual_matrix_list, + const std::vector& expected_matrix_list) { + ASSERT_EQ(actual_matrix_list.size(), expected_matrix_list.size()); + for (int i = 0; i < actual_matrix_list.size(); ++i) { + const Matrix& actual_matrix = actual_matrix_list[i]; + const Matrix& expected_matrix = expected_matrix_list[i]; + ASSERT_EQ(actual_matrix.cols(), expected_matrix.cols()); + ASSERT_EQ(actual_matrix.rows(), expected_matrix.rows()); + for (int i = 0; i < actual_matrix.size(); ++i) { + EXPECT_NEAR(actual_matrix.data()[i], expected_matrix.data()[i], + kFacialTransformationMatrixDiffMargin); + } + } +} + +void ExpectFaceLandmarkerResultCorrect( + const FaceLandmarkerResult& actual_result, + const FaceLandmarkerResult& expected_result) { + ExpectLandmarksCorrect(actual_result.face_landmarks, + expected_result.face_landmarks); + + ASSERT_EQ(actual_result.face_blendshapes.has_value(), + expected_result.face_blendshapes.has_value()); + if (expected_result.face_blendshapes.has_value()) { + ASSERT_TRUE(actual_result.face_blendshapes.has_value()); + ExpectBlendshapesCorrect(*actual_result.face_blendshapes, + *expected_result.face_blendshapes); + } + + ASSERT_EQ(actual_result.facial_transformation_matrixes.has_value(), + expected_result.facial_transformation_matrixes.has_value()); + if (expected_result.facial_transformation_matrixes.has_value()) { + ASSERT_TRUE(actual_result.facial_transformation_matrixes.has_value()); + ExpectFacialTransformationMatrixCorrect( + *actual_result.facial_transformation_matrixes, + *expected_result.facial_transformation_matrixes); + } +} + +class ImageModeTest : public TestWithParam {}; + +TEST_P(ImageModeTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(file::JoinPath( + "./", kTestDataDirectory, GetParam().test_image_name))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + file::JoinPath("./", kTestDataDirectory, GetParam().input_model_name); + options->running_mode = core::RunningMode::IMAGE; + options->output_face_blendshapes = + GetParam().expected_result.face_blendshapes.has_value(); + options->output_facial_transformation_matrixes = + GetParam().expected_result.facial_transformation_matrixes.has_value(); + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr face_landmarker, + FaceLandmarker::Create(std::move(options))); + FaceLandmarkerResult actual_result; + if (GetParam().rotation != 0) { + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = GetParam().rotation; + MP_ASSERT_OK_AND_ASSIGN( + actual_result, + face_landmarker->Detect(image, image_processing_options)); + } else { + MP_ASSERT_OK_AND_ASSIGN(actual_result, face_landmarker->Detect(image)); + } + ExpectFaceLandmarkerResultCorrect(actual_result, GetParam().expected_result); + MP_ASSERT_OK(face_landmarker->Close()); +} + +INSTANTIATE_TEST_SUITE_P( + FaceLandmarkerTest, ImageModeTest, + Values(FaceLandmarkerTestParams{ + /* test_name= */ "Portrait", + /* input_model_name= */ kFaceLandmarkerModelBundleName, + /* test_image_name= */ kPortraitImageName, + /* rotation= */ 0, + /* expected_result= */ + ConvertToFaceLandmarkerResult( + {GetExpectedProto( + kPortraitExpectedFaceLandamrksName)})}, + FaceLandmarkerTestParams{ + /* test_name= */ "PortraitWithAttention", + /* input_model_name= */ + kFaceLandmarkerWithBlendshapesModelBundleName, + /* test_image_name= */ kPortraitImageName, + /* rotation= */ 0, + /* expected_result= */ + ConvertToFaceLandmarkerResult( + {GetExpectedProto( + kPortraitExpectedFaceLandmarksWithAttentionName)})}, + FaceLandmarkerTestParams{ + /* test_name= */ "PortraitWithBlendshapes", + /* input_model_name= */ + kFaceLandmarkerWithBlendshapesModelBundleName, + /* test_image_name= */ kPortraitImageName, + /* rotation= */ 0, + /* expected_result= */ + ConvertToFaceLandmarkerResult( + {GetExpectedProto( + kPortraitExpectedFaceLandmarksWithAttentionName)}, + {{GetExpectedProto( + kPortraitExpectedBlendshapesName)}})}, + FaceLandmarkerTestParams{ + /* test_name= */ "PortraitWithBlendshapesWithFacialTransformatio" + "nMatrix", + /* input_model_name= */ + kFaceLandmarkerWithBlendshapesModelBundleName, + /* test_image_name= */ kPortraitImageName, + /* rotation= */ 0, + /* expected_result= */ + ConvertToFaceLandmarkerResult( + {GetExpectedProto( + kPortraitExpectedFaceLandmarksWithAttentionName)}, + {{GetExpectedProto( + kPortraitExpectedBlendshapesName)}}, + {{MakePortraitExpectedFacialTransformationMatrix()}})}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +class VideoModeTest : public TestWithParam {}; + +TEST_P(VideoModeTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(file::JoinPath( + "./", kTestDataDirectory, GetParam().test_image_name))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + file::JoinPath("./", kTestDataDirectory, GetParam().input_model_name); + options->running_mode = core::RunningMode::VIDEO; + options->output_face_blendshapes = + GetParam().expected_result.face_blendshapes.has_value(); + options->output_facial_transformation_matrixes = + GetParam().expected_result.facial_transformation_matrixes.has_value(); + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr face_landmarker, + FaceLandmarker::Create(std::move(options))); + for (int i = 0; i < 3; ++i) { + FaceLandmarkerResult actual_result; + if (GetParam().rotation != 0) { + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = GetParam().rotation; + MP_ASSERT_OK_AND_ASSIGN( + actual_result, + face_landmarker->DetectForVideo(image, i, image_processing_options)); + } else { + MP_ASSERT_OK_AND_ASSIGN(actual_result, + face_landmarker->DetectForVideo(image, i)); + } + ExpectFaceLandmarkerResultCorrect(actual_result, + GetParam().expected_result); + } + MP_ASSERT_OK(face_landmarker->Close()); +} + +INSTANTIATE_TEST_SUITE_P( + FaceLandmarkerTest, VideoModeTest, + Values(FaceLandmarkerTestParams{ + /* test_name= */ "Portrait", + /* input_model_name= */ kFaceLandmarkerModelBundleName, + /* test_image_name= */ kPortraitImageName, + /* rotation= */ 0, + /* expected_result= */ + ConvertToFaceLandmarkerResult( + {GetExpectedProto( + kPortraitExpectedFaceLandamrksName)})}, + FaceLandmarkerTestParams{ + /* test_name= */ "PortraitWithAttention", + /* input_model_name= */ + kFaceLandmarkerWithBlendshapesModelBundleName, + /* test_image_name= */ kPortraitImageName, + /* rotation= */ 0, + /* expected_result= */ + ConvertToFaceLandmarkerResult( + {GetExpectedProto( + kPortraitExpectedFaceLandmarksWithAttentionName)})}, + FaceLandmarkerTestParams{ + /* test_name= */ "PortraitWithBlendshapes", + /* input_model_name= */ + kFaceLandmarkerWithBlendshapesModelBundleName, + /* test_image_name= */ kPortraitImageName, + /* rotation= */ 0, + /* expected_result= */ + ConvertToFaceLandmarkerResult( + {GetExpectedProto( + kPortraitExpectedFaceLandmarksWithAttentionName)}, + {{GetExpectedProto( + kPortraitExpectedBlendshapesName)}})}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +class LiveStreamModeTest : public TestWithParam {}; + +TEST_P(LiveStreamModeTest, Succeeds) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(file::JoinPath( + "./", kTestDataDirectory, GetParam().test_image_name))); + auto options = std::make_unique(); + options->base_options.model_asset_path = + file::JoinPath("./", kTestDataDirectory, GetParam().input_model_name); + options->running_mode = core::RunningMode::LIVE_STREAM; + options->output_face_blendshapes = + GetParam().expected_result.face_blendshapes.has_value(); + options->output_facial_transformation_matrixes = + GetParam().expected_result.facial_transformation_matrixes.has_value(); + + std::vector face_landmarker_results; + std::vector timestamps; + options->result_callback = [&face_landmarker_results, ×tamps]( + absl::StatusOr result, + const Image& image, int64_t timestamp_ms) { + MP_ASSERT_OK(result.status()); + face_landmarker_results.push_back(std::move(result.value())); + timestamps.push_back(timestamp_ms); + }; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr face_landmarker, + FaceLandmarker::Create(std::move(options))); + + const int iterations = 100; + for (int i = 0; i < iterations; ++i) { + FaceLandmarkerResult actual_result; + if (GetParam().rotation != 0) { + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = GetParam().rotation; + MP_ASSERT_OK( + face_landmarker->DetectAsync(image, i, image_processing_options)); + } else { + MP_ASSERT_OK(face_landmarker->DetectAsync(image, i)); + } + } + MP_ASSERT_OK(face_landmarker->Close()); + + // Due to the flow limiter, the total of outputs will be smaller than the + // number of iterations. + ASSERT_LE(face_landmarker_results.size(), iterations); + ASSERT_GT(face_landmarker_results.size(), 0); + + for (int i = 0; i < face_landmarker_results.size(); ++i) { + ExpectFaceLandmarkerResultCorrect(face_landmarker_results[i], + GetParam().expected_result); + } + int64_t timestamp_ms = -1; + for (const auto& timestamp : timestamps) { + EXPECT_GT(timestamp, timestamp_ms); + timestamp_ms = timestamp; + } +} + +INSTANTIATE_TEST_SUITE_P( + FaceLandmarkerTest, LiveStreamModeTest, + Values(FaceLandmarkerTestParams{ + /* test_name= */ "Portrait", + /* input_model_name= */ kFaceLandmarkerModelBundleName, + /* test_image_name= */ kPortraitImageName, + /* rotation= */ 0, + /* expected_result= */ + ConvertToFaceLandmarkerResult( + {GetExpectedProto( + kPortraitExpectedFaceLandamrksName)})}, + FaceLandmarkerTestParams{ + /* test_name= */ "PortraitWithAttention", + /* input_model_name= */ + kFaceLandmarkerWithBlendshapesModelBundleName, + /* test_image_name= */ kPortraitImageName, + /* rotation= */ 0, + /* expected_result= */ + ConvertToFaceLandmarkerResult( + {GetExpectedProto( + kPortraitExpectedFaceLandmarksWithAttentionName)})}, + FaceLandmarkerTestParams{ + /* test_name= */ "PortraitWithBlendshapes", + /* input_model_name= */ + kFaceLandmarkerWithBlendshapesModelBundleName, + /* test_image_name= */ kPortraitImageName, + /* rotation= */ 0, + /* expected_result= */ + ConvertToFaceLandmarkerResult( + {GetExpectedProto( + kPortraitExpectedFaceLandmarksWithAttentionName)}, + {{GetExpectedProto( + kPortraitExpectedBlendshapesName)}})}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace face_landmarker +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc index a898f2fe9..df9cab5b5 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/face_landmarker/face_landmarks_detector_graph.cc @@ -462,7 +462,7 @@ REGISTER_MEDIAPIPE_GRAPH( // - Accepts an input image and a vector of face rect RoIs to detect the // multiple face landmarks enclosed by the RoIs. Output vectors of // face landmarks related results, where each element in the vectors -// corrresponds to the result of the same face. +// corresponds to the result of the same face. // // Inputs: // IMAGE - Image diff --git a/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD b/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD index f943420c6..d3e236619 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD +++ b/mediapipe/tasks/cc/vision/face_landmarker/proto/BUILD @@ -60,5 +60,6 @@ mediapipe_proto_library( "//mediapipe/framework:calculator_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_proto", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_proto", ], ) diff --git a/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto b/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto index 67599295e..dc8654608 100644 --- a/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto +++ b/mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.proto @@ -21,6 +21,7 @@ import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto"; +import "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.proto"; import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto"; option java_package = "com.google.mediapipe.tasks.vision.facelandmarker.proto"; @@ -45,4 +46,8 @@ message FaceLandmarkerGraphOptions { // Minimum confidence for face landmarks tracking to be considered // successfully. optional float min_tracking_confidence = 4 [default = 0.5]; + + // Options for FaceGeometryGraph to get facial transformation matrix. + optional face_geometry.proto.FaceGeometryGraphOptions + face_geometry_graph_options = 5; } diff --git a/mediapipe/tasks/cc/vision/face_stylizer/BUILD b/mediapipe/tasks/cc/vision/face_stylizer/BUILD index 7da4e6e74..f62991d45 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/BUILD +++ b/mediapipe/tasks/cc/vision/face_stylizer/BUILD @@ -47,6 +47,7 @@ cc_library( "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto", "@com_google_absl//absl/status:statusor", ], + alwayslink = 1, ) cc_library( diff --git a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc index 03760c6b3..d9825b15f 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc +++ b/mediapipe/tasks/cc/vision/face_stylizer/calculators/tensors_to_image_calculator.cc @@ -294,7 +294,7 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) { threadsPerThreadgroup:threads_per_group]; [compute_encoder endEncoding]; [command_buffer commit]; - + [command_buffer waitUntilCompleted]; kOutputImage(cc).Send(Image(output)); return absl::OkStatus(); } diff --git a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h index 27e64c934..58501c47b 100644 --- a/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h +++ b/mediapipe/tasks/cc/vision/face_stylizer/face_stylizer.h @@ -81,7 +81,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { // running mode. // // The input image can be of any size with format RGB or RGBA. - // To ensure that the output image has reasonable quailty, the stylized output + // To ensure that the output image has reasonable quality, the stylized output // image size is the smaller of the model output size and the size of the // 'region_of_interest' specified in 'image_processing_options'. absl::StatusOr Stylize( @@ -106,7 +106,7 @@ class FaceStylizer : tasks::vision::core::BaseVisionTaskApi { // The image can be of any size with format RGB or RGBA. It's required to // provide the video frame's timestamp (in milliseconds). The input timestamps // must be monotonically increasing. - // To ensure that the output image has reasonable quailty, the stylized output + // To ensure that the output image has reasonable quality, the stylized output // image size is the smaller of the model output size and the size of the // 'region_of_interest' specified in 'image_processing_options'. absl::StatusOr StylizeForVideo( diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc index b6f6c88da..55db07cb8 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/gesture_recognizer_graph.cc @@ -92,7 +92,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, GestureRecognizerGraphOptions* options, bool is_copy) { ASSIGN_OR_RETURN(const auto hand_landmarker_file, - resources.GetModelFile(kHandLandmarkerBundleAssetName)); + resources.GetFile(kHandLandmarkerBundleAssetName)); auto* hand_landmarker_graph_options = options->mutable_hand_landmarker_graph_options(); SetExternalFile(hand_landmarker_file, @@ -105,9 +105,8 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, hand_landmarker_graph_options->mutable_base_options()->set_use_stream_mode( options->base_options().use_stream_mode()); - ASSIGN_OR_RETURN( - const auto hand_gesture_recognizer_file, - resources.GetModelFile(kHandGestureRecognizerBundleAssetName)); + ASSIGN_OR_RETURN(const auto hand_gesture_recognizer_file, + resources.GetFile(kHandGestureRecognizerBundleAssetName)); auto* hand_gesture_recognizer_graph_options = options->mutable_hand_gesture_recognizer_graph_options(); SetExternalFile(hand_gesture_recognizer_file, @@ -127,7 +126,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, ->mutable_acceleration() ->mutable_xnnpack(); LOG(WARNING) << "Hand Gesture Recognizer contains CPU only ops. Sets " - << "HandGestureRecognizerGraph acceleartion to Xnnpack."; + << "HandGestureRecognizerGraph acceleration to Xnnpack."; } hand_gesture_recognizer_graph_options->mutable_base_options() ->set_use_stream_mode(options->base_options().use_stream_mode()); diff --git a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc index 4db57e85b..3fe999937 100644 --- a/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc +++ b/mediapipe/tasks/cc/vision/gesture_recognizer/hand_gesture_recognizer_graph.cc @@ -207,7 +207,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { HandGestureRecognizerGraphOptions* options, bool is_copy) { ASSIGN_OR_RETURN(const auto gesture_embedder_file, - resources.GetModelFile(kGestureEmbedderTFLiteName)); + resources.GetFile(kGestureEmbedderTFLiteName)); auto* gesture_embedder_graph_options = options->mutable_gesture_embedder_graph_options(); SetExternalFile(gesture_embedder_file, @@ -218,9 +218,8 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { options->base_options(), gesture_embedder_graph_options->mutable_base_options()); - ASSIGN_OR_RETURN( - const auto canned_gesture_classifier_file, - resources.GetModelFile(kCannedGestureClassifierTFLiteName)); + ASSIGN_OR_RETURN(const auto canned_gesture_classifier_file, + resources.GetFile(kCannedGestureClassifierTFLiteName)); auto* canned_gesture_classifier_graph_options = options->mutable_canned_gesture_classifier_graph_options(); SetExternalFile( @@ -233,7 +232,7 @@ class SingleHandGestureRecognizerGraph : public core::ModelTaskGraph { canned_gesture_classifier_graph_options->mutable_base_options()); const auto custom_gesture_classifier_file = - resources.GetModelFile(kCustomGestureClassifierTFLiteName); + resources.GetFile(kCustomGestureClassifierTFLiteName); if (custom_gesture_classifier_file.ok()) { has_custom_gesture_classifier = true; auto* custom_gesture_classifier_graph_options = diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD index f45681fb3..73d3f38eb 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/BUILD @@ -36,6 +36,7 @@ cc_library( ":hand_association_calculator_cc_proto", "//mediapipe/calculators/util:association_calculator", "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:collection_item_id", "//mediapipe/framework/api2:node", "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/port:rectangle", diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc index dffdbdd38..011bce2b9 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/collection_item_id.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/port/rectangle.h" #include "mediapipe/framework/port/status.h" @@ -29,30 +30,55 @@ namespace mediapipe::api2 { using ::mediapipe::NormalizedRect; -// HandAssociationCalculator accepts multiple inputs of vectors of -// NormalizedRect. The output is a vector of NormalizedRect that contains -// rects from the input vectors that don't overlap with each other. When two -// rects overlap, the rect that comes in from an earlier input stream is -// kept in the output. If a rect has no ID (i.e. from detection stream), -// then a unique rect ID is assigned for it. - -// The rects in multiple input streams are effectively flattened to a single -// list. For example: -// Stream1 : rect 1, rect 2 -// Stream2: rect 3, rect 4 -// Stream3: rect 5, rect 6 -// (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6 -// In the flattened list, if a rect with a higher index overlaps with a rect a -// lower index, beyond a specified IOU threshold, the rect with the lower -// index will be in the output, and the rect with higher index will be -// discarded. +// Input: +// BASE_RECTS - Vector of NormalizedRect. +// RECTS - Vector of NormalizedRect. +// +// Output: +// No tag - Vector of NormalizedRect. +// +// Example use: +// node { +// calculator: "HandAssociationCalculator" +// input_stream: "BASE_RECTS:base_rects" +// input_stream: "RECTS:0:rects0" +// input_stream: "RECTS:1:rects1" +// input_stream: "RECTS:2:rects2" +// output_stream: "output_rects" +// options { +// [mediapipe.HandAssociationCalculatorOptions.ext] { +// min_similarity_threshold: 0.1 +// } +// } +// +// IMPORTANT Notes: +// - Rects from input streams tagged with "BASE_RECTS" are always preserved. +// - This calculator checks for overlap among rects from input streams tagged +// with "RECTS". Rects are prioritized based on their index in the vector and +// input streams to the calculator. When two rects overlap, the rect that +// comes from an input stream with lower tag-index is kept in the output. +// - Example of inputs for the node above: +// "base_rects": rect 0, rect 1 +// "rects0": rect 2, rect 3 +// "rects1": rect 4, rect 5 +// "rects2": rect 6, rect 7 +// (Conceptually) flattened list: 0, 1, 2, 3, 4, 5, 6, 7. +// Rects 0, 1 will be preserved. Rects 2, 3, 4, 5, 6, 7 will be checked for +// overlap. If a rect with a higher index overlaps with a rect with lower +// index, beyond a specified IOU threshold, the rect with the lower index +// will be in the output, and the rect with higher index will be discarded. // TODO: Upgrade this to latest API for calculators class HandAssociationCalculator : public CalculatorBase { public: static absl::Status GetContract(CalculatorContract* cc) { // Initialize input and output streams. - for (auto& input_stream : cc->Inputs()) { - input_stream.Set>(); + for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS"); + id != cc->Inputs().EndId("BASE_RECTS"); ++id) { + cc->Inputs().Get(id).Set>(); + } + for (CollectionItemId id = cc->Inputs().BeginId("RECTS"); + id != cc->Inputs().EndId("RECTS"); ++id) { + cc->Inputs().Get(id).Set>(); } cc->Outputs().Index(0).Set>(); @@ -89,7 +115,24 @@ class HandAssociationCalculator : public CalculatorBase { CalculatorContext* cc) { std::vector result; - for (const auto& input_stream : cc->Inputs()) { + for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS"); + id != cc->Inputs().EndId("BASE_RECTS"); ++id) { + const auto& input_stream = cc->Inputs().Get(id); + if (input_stream.IsEmpty()) { + continue; + } + + for (auto rect : input_stream.Get>()) { + if (!rect.has_rect_id()) { + rect.set_rect_id(GetNextRectId()); + } + result.push_back(rect); + } + } + + for (CollectionItemId id = cc->Inputs().BeginId("RECTS"); + id != cc->Inputs().EndId("RECTS"); ++id) { + const auto& input_stream = cc->Inputs().Get(id); if (input_stream.IsEmpty()) { continue; } diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc index 138164209..c22b1a7e6 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/calculators/hand_association_calculator_test.cc @@ -27,6 +27,8 @@ namespace mediapipe { namespace { using ::mediapipe::NormalizedRect; +using ::testing::ElementsAre; +using ::testing::EqualsProto; class HandAssociationCalculatorTest : public testing::Test { protected: @@ -87,9 +89,9 @@ class HandAssociationCalculatorTest : public testing::Test { TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "HandAssociationCalculator" - input_stream: "input_vec_0" - input_stream: "input_vec_1" - input_stream: "input_vec_2" + input_stream: "BASE_RECTS:input_vec_0" + input_stream: "RECTS:0:input_vec_1" + input_stream: "RECTS:1:input_vec_2" output_stream: "output_vec" options { [mediapipe.HandAssociationCalculatorOptions.ext] { @@ -103,20 +105,23 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) { input_vec_0->push_back(nr_0_); input_vec_0->push_back(nr_1_); input_vec_0->push_back(nr_2_); - runner.MutableInputs()->Index(0).packets.push_back( - Adopt(input_vec_0.release()).At(Timestamp(1))); + runner.MutableInputs() + ->Tag("BASE_RECTS") + .packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1))); // Input Stream 1: nr_3, nr_4. auto input_vec_1 = std::make_unique>(); input_vec_1->push_back(nr_3_); input_vec_1->push_back(nr_4_); - runner.MutableInputs()->Index(1).packets.push_back( + auto index_id = runner.MutableInputs()->GetId("RECTS", 0); + runner.MutableInputs()->Get(index_id).packets.push_back( Adopt(input_vec_1.release()).At(Timestamp(1))); // Input Stream 2: nr_5. auto input_vec_2 = std::make_unique>(); input_vec_2->push_back(nr_5_); - runner.MutableInputs()->Index(2).packets.push_back( + index_id = runner.MutableInputs()->GetId("RECTS", 1); + runner.MutableInputs()->Get(index_id).packets.push_back( Adopt(input_vec_2.release()).At(Timestamp(1))); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; @@ -134,25 +139,18 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) { EXPECT_EQ(3, assoc_rects.size()); // Check that IDs are filled in and contents match. - EXPECT_EQ(assoc_rects[0].rect_id(), 1); - assoc_rects[0].clear_rect_id(); - EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); - - EXPECT_EQ(assoc_rects[1].rect_id(), 2); - assoc_rects[1].clear_rect_id(); - EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_)); - - EXPECT_EQ(assoc_rects[2].rect_id(), 3); - assoc_rects[2].clear_rect_id(); - EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_)); + nr_0_.set_rect_id(1); + nr_1_.set_rect_id(2); + nr_2_.set_rect_id(3); + EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_), + EqualsProto(nr_2_))); } TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "HandAssociationCalculator" - input_stream: "input_vec_0" - input_stream: "input_vec_1" - input_stream: "input_vec_2" + input_stream: "BASE_RECTS:input_vec_0" + input_stream: "RECTS:0:input_vec_1" output_stream: "output_vec" options { [mediapipe.HandAssociationCalculatorOptions.ext] { @@ -169,14 +167,15 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) { input_vec_0->push_back(nr_0_); nr_1_.set_rect_id(-1); input_vec_0->push_back(nr_1_); - runner.MutableInputs()->Index(0).packets.push_back( - Adopt(input_vec_0.release()).At(Timestamp(1))); + runner.MutableInputs() + ->Tag("BASE_RECTS") + .packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1))); // Input Stream 1: nr_2, nr_3. Newly detected palms. auto input_vec_1 = std::make_unique>(); input_vec_1->push_back(nr_2_); input_vec_1->push_back(nr_3_); - runner.MutableInputs()->Index(1).packets.push_back( + runner.MutableInputs()->Tag("RECTS").packets.push_back( Adopt(input_vec_1.release()).At(Timestamp(1))); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; @@ -192,23 +191,17 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) { EXPECT_EQ(3, assoc_rects.size()); // Check that IDs are filled in and contents match. - EXPECT_EQ(assoc_rects[0].rect_id(), -2); - EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); - - EXPECT_EQ(assoc_rects[1].rect_id(), -1); - EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_)); - - EXPECT_EQ(assoc_rects[2].rect_id(), 1); - assoc_rects[2].clear_rect_id(); - EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_)); + nr_2_.set_rect_id(1); + EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_), + EqualsProto(nr_2_))); } TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "HandAssociationCalculator" - input_stream: "input_vec_0" - input_stream: "input_vec_1" - input_stream: "input_vec_2" + input_stream: "BASE_RECTS:input_vec_0" + input_stream: "RECTS:0:input_vec_1" + input_stream: "RECTS:1:input_vec_2" output_stream: "output_vec" options { [mediapipe.HandAssociationCalculatorOptions.ext] { @@ -220,14 +213,16 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { // Input Stream 0: nr_5. auto input_vec_0 = std::make_unique>(); input_vec_0->push_back(nr_5_); - runner.MutableInputs()->Index(0).packets.push_back( - Adopt(input_vec_0.release()).At(Timestamp(1))); + runner.MutableInputs() + ->Tag("BASE_RECTS") + .packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1))); // Input Stream 1: nr_4, nr_3 auto input_vec_1 = std::make_unique>(); input_vec_1->push_back(nr_4_); input_vec_1->push_back(nr_3_); - runner.MutableInputs()->Index(1).packets.push_back( + auto index_id = runner.MutableInputs()->GetId("RECTS", 0); + runner.MutableInputs()->Get(index_id).packets.push_back( Adopt(input_vec_1.release()).At(Timestamp(1))); // Input Stream 2: nr_2, nr_1, nr_0. @@ -235,7 +230,8 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { input_vec_2->push_back(nr_2_); input_vec_2->push_back(nr_1_); input_vec_2->push_back(nr_0_); - runner.MutableInputs()->Index(2).packets.push_back( + index_id = runner.MutableInputs()->GetId("RECTS", 1); + runner.MutableInputs()->Get(index_id).packets.push_back( Adopt(input_vec_2.release()).At(Timestamp(1))); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; @@ -253,23 +249,78 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { EXPECT_EQ(3, assoc_rects.size()); // Outputs are in same order as inputs, and IDs are filled in. - EXPECT_EQ(assoc_rects[0].rect_id(), 1); - assoc_rects[0].clear_rect_id(); - EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_)); + nr_5_.set_rect_id(1); + nr_4_.set_rect_id(2); + nr_0_.set_rect_id(3); + EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_4_), + EqualsProto(nr_0_))); +} - EXPECT_EQ(assoc_rects[1].rect_id(), 2); - assoc_rects[1].clear_rect_id(); - EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_)); +TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReservesBaseRects) { + CalculatorRunner runner(ParseTextProtoOrDie(R"pb( + calculator: "HandAssociationCalculator" + input_stream: "BASE_RECTS:input_vec_0" + input_stream: "RECTS:0:input_vec_1" + input_stream: "RECTS:1:input_vec_2" + output_stream: "output_vec" + options { + [mediapipe.HandAssociationCalculatorOptions.ext] { + min_similarity_threshold: 0.1 + } + } + )pb")); - EXPECT_EQ(assoc_rects[2].rect_id(), 3); - assoc_rects[2].clear_rect_id(); - EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_)); + // Input Stream 0: nr_5, nr_3, nr_1. + auto input_vec_0 = std::make_unique>(); + input_vec_0->push_back(nr_5_); + input_vec_0->push_back(nr_3_); + input_vec_0->push_back(nr_1_); + runner.MutableInputs() + ->Tag("BASE_RECTS") + .packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1))); + + // Input Stream 1: nr_4. + auto input_vec_1 = std::make_unique>(); + input_vec_1->push_back(nr_4_); + auto index_id = runner.MutableInputs()->GetId("RECTS", 0); + runner.MutableInputs()->Get(index_id).packets.push_back( + Adopt(input_vec_1.release()).At(Timestamp(1))); + + // Input Stream 2: nr_2, nr_0. + auto input_vec_2 = std::make_unique>(); + input_vec_2->push_back(nr_2_); + input_vec_2->push_back(nr_0_); + index_id = runner.MutableInputs()->GetId("RECTS", 1); + runner.MutableInputs()->Get(index_id).packets.push_back( + Adopt(input_vec_2.release()).At(Timestamp(1))); + + MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; + const std::vector& output = runner.Outputs().Index(0).packets; + EXPECT_EQ(1, output.size()); + auto assoc_rects = output[0].Get>(); + + // Rectangles are added in the following sequence: + // nr_5 is added because it is in BASE_RECTS input stream. + // nr_3 is added because it is in BASE_RECTS input stream. + // nr_1 is added because it is in BASE_RECTS input stream. + // nr_4 is added because it does not overlap with nr_5. + // nr_2 is NOT added because it overlaps with nr_4. + // nr_0 is NOT added because it overlaps with nr_3. + EXPECT_EQ(4, assoc_rects.size()); + + // Outputs are in same order as inputs, and IDs are filled in. + nr_5_.set_rect_id(1); + nr_3_.set_rect_id(2); + nr_1_.set_rect_id(3); + nr_4_.set_rect_id(4); + EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_3_), + EqualsProto(nr_1_), EqualsProto(nr_4_))); } TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) { CalculatorRunner runner(ParseTextProtoOrDie(R"pb( calculator: "HandAssociationCalculator" - input_stream: "input_vec" + input_stream: "BASE_RECTS:input_vec" output_stream: "output_vec" options { [mediapipe.HandAssociationCalculatorOptions.ext] { @@ -282,8 +333,9 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) { auto input_vec = std::make_unique>(); input_vec->push_back(nr_3_); input_vec->push_back(nr_5_); - runner.MutableInputs()->Index(0).packets.push_back( - Adopt(input_vec.release()).At(Timestamp(1))); + runner.MutableInputs() + ->Tag("BASE_RECTS") + .packets.push_back(Adopt(input_vec.release()).At(Timestamp(1))); MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; const std::vector& output = runner.Outputs().Index(0).packets; @@ -292,12 +344,12 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) { // Rectangles are added in the following sequence: // nr_3 is added 1st. - // nr_5 is NOT added because it overlaps with nr_3. - EXPECT_EQ(1, assoc_rects.size()); + // nr_5 is added 2nd. The calculator assumes it does not overlap with nr_3. + EXPECT_EQ(2, assoc_rects.size()); - EXPECT_EQ(assoc_rects[0].rect_id(), 1); - assoc_rects[0].clear_rect_id(); - EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_)); + nr_3_.set_rect_id(1); + nr_5_.set_rect_id(2); + EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_3_), EqualsProto(nr_5_))); } } // namespace diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h index 6f96fc68e..7a43d20d7 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker.h @@ -101,7 +101,7 @@ class HandLandmarker : tasks::vision::core::BaseVisionTaskApi { // three running modes: // 1) Image mode for detecting hand landmarks on single image inputs. Users // provide mediapipe::Image to the `Detect` method, and will receive the - // deteced hand landmarks results as the return value. + // detected hand landmarks results as the return value. // 2) Video mode for detecting hand landmarks on the decoded frames of a // video. Users call `DetectForVideo` method, and will receive the detected // hand landmarks results as the return value. diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc index 4a3db9f4d..b37141005 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarker_graph.cc @@ -97,7 +97,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, options->mutable_hand_detector_graph_options(); if (!hand_detector_graph_options->base_options().has_model_asset()) { ASSIGN_OR_RETURN(const auto hand_detector_file, - resources.GetModelFile(kHandDetectorTFLiteName)); + resources.GetFile(kHandDetectorTFLiteName)); SetExternalFile(hand_detector_file, hand_detector_graph_options->mutable_base_options() ->mutable_model_asset(), @@ -113,7 +113,7 @@ absl::Status SetSubTaskBaseOptions(const ModelAssetBundleResources& resources, if (!hand_landmarks_detector_graph_options->base_options() .has_model_asset()) { ASSIGN_OR_RETURN(const auto hand_landmarks_detector_file, - resources.GetModelFile(kHandLandmarksDetectorTFLiteName)); + resources.GetFile(kHandLandmarksDetectorTFLiteName)); SetExternalFile( hand_landmarks_detector_file, hand_landmarks_detector_graph_options->mutable_base_options() @@ -318,9 +318,9 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { .set_min_similarity_threshold( tasks_options.min_tracking_confidence()); prev_hand_rects_from_landmarks >> - hand_association[Input>::Multiple("")][0]; + hand_association[Input>("BASE_RECTS")]; hand_rects_from_hand_detector >> - hand_association[Input>::Multiple("")][1]; + hand_association[Input>("RECTS")]; auto hand_rects = hand_association.Out(""); hand_rects >> clip_hand_rects.In(""); } else { diff --git a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc index 6d232d3f1..f7fa83a11 100644 --- a/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc +++ b/mediapipe/tasks/cc/vision/hand_landmarker/hand_landmarks_detector_graph.cc @@ -409,7 +409,7 @@ REGISTER_MEDIAPIPE_GRAPH( // - Accepts CPU input image and a vector of hand rect RoIs to detect the // multiple hands landmarks enclosed by the RoIs. Output vectors of // hand landmarks related results, where each element in the vectors -// corrresponds to the result of the same hand. +// corresponds to the result of the same hand. // // Inputs: // IMAGE - Image diff --git a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc index dd602bef5..6aa0b85bc 100644 --- a/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc +++ b/mediapipe/tasks/cc/vision/image_embedder/image_embedder_test.cc @@ -52,7 +52,7 @@ constexpr char kMobileNetV3Embedder[] = constexpr double kSimilarityTolerancy = 1e-6; // Utility function to check the sizes, head_index and head_names of a result -// procuded by kMobileNetV3Embedder. +// produced by kMobileNetV3Embedder. void CheckMobileNetV3Result(const ImageEmbedderResult& result, bool quantized) { EXPECT_EQ(result.embeddings.size(), 1); EXPECT_EQ(result.embeddings[0].head_index, 0); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/BUILD index b084331c8..69833a5f6 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/BUILD @@ -25,6 +25,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":image_segmenter_graph", + "//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework/api2:builder", "//mediapipe/framework/formats:image", "//mediapipe/framework/formats:rect_cc_proto", @@ -34,10 +35,13 @@ cc_library( "//mediapipe/tasks/cc/vision/core:image_processing_options", "//mediapipe/tasks/cc/vision/core:running_mode", "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", + "//mediapipe/util:label_map_cc_proto", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", ], ) @@ -76,6 +80,7 @@ cc_library( "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", "//mediapipe/tasks/cc/vision/utils:image_tensor_specs", + "//mediapipe/tasks/metadata:image_segmenter_metadata_schema_cc", "//mediapipe/tasks/metadata:metadata_schema_cc", "//mediapipe/util:label_map_cc_proto", "//mediapipe/util:label_map_util", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD index b54e7352b..c621016dc 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD @@ -23,7 +23,6 @@ mediapipe_proto_library( srcs = ["tensors_to_segmentation_calculator.proto"], deps = [ "//mediapipe/framework:calculator_options_proto", - "//mediapipe/framework:calculator_proto", "//mediapipe/framework/formats:image_format_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_proto", "//mediapipe/util:label_map_proto", diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc index 091e4d6c9..b6c1fe6b0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -79,6 +80,133 @@ void Sigmoid(absl::Span values, [](float value) { return 1. / (1 + std::exp(-value)); }); } +std::vector ProcessForCategoryMaskCpu(const Shape& input_shape, + const Shape& output_shape, + const SegmenterOptions& options, + const float* tensors_buffer) { + cv::Mat resized_tensors_mat; + cv::Mat tensors_mat_view( + input_shape.height, input_shape.width, CV_32FC(input_shape.channels), + reinterpret_cast(const_cast(tensors_buffer))); + if (output_shape.height == input_shape.height && + output_shape.width == input_shape.width) { + resized_tensors_mat = tensors_mat_view; + } else { + // Resize input tensors to output size. + // TOOD(b/273633027) Use an efficient way to find values for category mask + // instead of resizing the whole tensor . + cv::resize(tensors_mat_view, resized_tensors_mat, + {output_shape.width, output_shape.height}, 0, 0, + cv::INTER_LINEAR); + } + + // Category mask Image. + ImageFrameSharedPtr image_frame_ptr = std::make_shared( + ImageFormat::GRAY8, output_shape.width, output_shape.height, 1); + Image category_mask(image_frame_ptr); + + // Fill in the maximum category in the category mask image. + cv::Mat category_mask_mat_view = + mediapipe::formats::MatView(image_frame_ptr.get()); + const int input_channels = input_shape.channels; + category_mask_mat_view.forEach( + [&resized_tensors_mat, &input_channels, &options](uint8_t& pixel, + const int position[]) { + float* tensors_buffer = + resized_tensors_mat.ptr(position[0], position[1]); + absl::Span confidence_scores(tensors_buffer, input_channels); + // Only process the activation function if it is SIGMOID. If NONE, + // we do nothing for activation, If SOFTMAX, it is required + // to have input_channels > 1, and for input_channels > 1, we don't need + // activation to find the maximum value. + if (options.activation() == SegmenterOptions::SIGMOID) { + Sigmoid(confidence_scores, confidence_scores); + } + if (input_channels == 1) { + // if the input tensor is a single mask, it is assumed to be a binary + // foreground segmentation mask. For such a mask, we make foreground + // category 1, and background category 0. + pixel = static_cast(*tensors_buffer > 0.5f); + } else { + const int maximum_category_idx = + std::max_element(confidence_scores.begin(), + confidence_scores.end()) - + confidence_scores.begin(); + pixel = maximum_category_idx; + } + }); + return {category_mask}; +} + +std::vector ProcessForConfidenceMaskCpu(const Shape& input_shape, + const Shape& output_shape, + const SegmenterOptions& options, + const float* tensors_buffer) { + std::function values, + absl::Span activated_values)> + activation_fn; + switch (options.activation()) { + case SegmenterOptions::SIGMOID: + activation_fn = &Sigmoid; + break; + case SegmenterOptions::SOFTMAX: + activation_fn = &StableSoftmax; + break; + case SegmenterOptions::NONE: + // Just copying for NONE activation. + activation_fn = [](absl::Span values, + absl::Span activated_values) { + std::copy(values.begin(), values.end(), activated_values.begin()); + }; + break; + } + + // TODO Use libyuv for resizing instead. + std::vector confidence_masks; + std::vector confidence_mask_mats; + confidence_masks.reserve(input_shape.channels); + confidence_mask_mats.reserve(input_shape.channels); + for (int i = 0; i < input_shape.channels; ++i) { + confidence_masks.push_back(Image(std::make_shared( + ImageFormat::VEC32F1, input_shape.width, input_shape.height, 1))); + confidence_mask_mats.push_back(mediapipe::formats::MatView( + confidence_masks.back().GetImageFrameSharedPtr().get())); + } + + // Applies activation function. + const int tensor_size = input_shape.height * input_shape.width; + std::vector activated_values(input_shape.channels); + absl::Span activated_values_span(activated_values); + for (int i = 0; i < tensor_size; ++i) { + activation_fn(absl::MakeConstSpan(&tensors_buffer[i * input_shape.channels], + input_shape.channels), + activated_values_span); + for (int j = 0; j < input_shape.channels; ++j) { + confidence_mask_mats[j].at( + i / input_shape.width, i % input_shape.width) = activated_values[j]; + } + } + if (output_shape.height == input_shape.height && + output_shape.width == input_shape.width) { + return confidence_masks; + } + std::vector resized_confidence_masks; + resized_confidence_masks.reserve(confidence_mask_mats.size()); + // Resizes segmented masks to required output size. + for (int i = 0; i < confidence_mask_mats.size(); i++) { + // Pre-allocates ImageFrame memory to avoid copying from cv::Mat + // afterward. + ImageFrameSharedPtr image_frame_ptr = std::make_shared( + ImageFormat::VEC32F1, output_shape.width, output_shape.height, 1); + cv::Mat resized_mask_mat_view = + mediapipe::formats::MatView(image_frame_ptr.get()); + cv::resize(confidence_mask_mats[i], resized_mask_mat_view, + resized_mask_mat_view.size(), 0, 0, cv::INTER_LINEAR); + resized_confidence_masks.push_back(Image(image_frame_ptr)); + } + return resized_confidence_masks; +} + } // namespace // Converts Tensors from a vector of Tensor to Segmentation. @@ -222,81 +350,16 @@ absl::Status TensorsToSegmentationCalculator::Process( std::vector TensorsToSegmentationCalculator::GetSegmentationResultCpu( const Shape& input_shape, const Shape& output_shape, const float* tensors_buffer) { - std::function values, - absl::Span activated_values)> - activation_fn; - switch (options_.segmenter_options().activation()) { - case SegmenterOptions::SIGMOID: - activation_fn = &Sigmoid; - break; - case SegmenterOptions::SOFTMAX: - activation_fn = &StableSoftmax; - break; - case SegmenterOptions::NONE: - // Just copying for NONE activation. - activation_fn = [](absl::Span values, - absl::Span activated_values) { - std::copy(values.begin(), values.end(), activated_values.begin()); - }; - break; - } - - const bool is_category_mask = options_.segmenter_options().output_type() == - SegmenterOptions::CATEGORY_MASK; - const int cv_mat_type = is_category_mask ? CV_8UC1 : CV_32FC1; - const int output_masks_num = output_shape.channels; - - // TODO Use libyuv for resizing instead. - std::vector segmented_mask_mats; - segmented_mask_mats.reserve(output_masks_num); - for (int i = 0; i < output_masks_num; ++i) { - segmented_mask_mats.push_back( - cv::Mat(input_shape.height, input_shape.width, cv_mat_type)); - } - - // Applies activation function. - const int tensor_size = input_shape.height * input_shape.width; - if (is_category_mask) { - for (int i = 0; i < tensor_size; ++i) { - absl::Span confidence_scores( - &tensors_buffer[i * input_shape.channels], input_shape.channels); - const int maximum_category_idx = - std::max_element(confidence_scores.begin(), confidence_scores.end()) - - confidence_scores.begin(); - segmented_mask_mats[0].at( - i / input_shape.width, i % input_shape.width) = maximum_category_idx; - } + if (options_.segmenter_options().output_type() == + SegmenterOptions::CATEGORY_MASK) { + return ProcessForCategoryMaskCpu(input_shape, output_shape, + options_.segmenter_options(), + tensors_buffer); } else { - std::vector activated_values(input_shape.channels); - absl::Span activated_values_span(activated_values); - for (int i = 0; i < tensor_size; ++i) { - activation_fn( - absl::MakeConstSpan(&tensors_buffer[i * input_shape.channels], - input_shape.channels), - activated_values_span); - for (int j = 0; j < input_shape.channels; ++j) { - segmented_mask_mats[j].at( - i / input_shape.width, i % input_shape.width) = activated_values[j]; - } - } + return ProcessForConfidenceMaskCpu(input_shape, output_shape, + options_.segmenter_options(), + tensors_buffer); } - - std::vector segmented_masks; - segmented_masks.reserve(output_masks_num); - // Resizes segmented masks to required output size. - for (int i = 0; i < segmented_mask_mats.size(); i++) { - // Pre-allocates ImageFrame memory to avoid copying from cv::Mat afterward. - ImageFrameSharedPtr image_frame_ptr = std::make_shared( - is_category_mask ? ImageFormat::GRAY8 : ImageFormat::VEC32F1, - output_shape.width, output_shape.height, 1); - cv::Mat resized_mask_mat_view = - mediapipe::formats::MatView(image_frame_ptr.get()); - cv::resize(segmented_mask_mats[i], resized_mask_mat_view, - resized_mask_mat_view.size(), 0, 0, - cv_mat_type == CV_8UC1 ? cv::INTER_NEAREST : cv::INTER_LINEAR); - segmented_masks.push_back(Image(image_frame_ptr)); - } - return segmented_masks; } MEDIAPIPE_REGISTER_NODE(::mediapipe::tasks::TensorsToSegmentationCalculator); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto index b0fdfdd32..dbaf34db0 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto +++ b/mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.proto @@ -18,10 +18,13 @@ syntax = "proto2"; // TODO: consolidate TensorToSegmentationCalculator. package mediapipe.tasks; -import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; import "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.proto"; import "mediapipe/util/label_map.proto"; +option java_package = "com.google.mediapipe.tasks"; +option java_outer_classname = "TensorsToSegmentationCalculatorOptionsProto"; + message TensorsToSegmentationCalculatorOptions { extend mediapipe.CalculatorOptions { optional TensorsToSegmentationCalculatorOptions ext = 458105876; diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc index 7130c72e2..c12fe7f7e 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.cc @@ -15,15 +15,21 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" +#include + +#include "absl/strings/str_format.h" #include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/vision/core/image_processing_options.h" #include "mediapipe/tasks/cc/vision/core/running_mode.h" #include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" +#include "mediapipe/util/label_map.pb.h" namespace mediapipe { namespace tasks { @@ -95,23 +101,42 @@ ConvertImageSegmenterOptionsToProto(ImageSegmenterOptions* options) { SegmenterOptions::CONFIDENCE_MASK); break; } - switch (options->activation) { - case ImageSegmenterOptions::Activation::NONE: - options_proto->mutable_segmenter_options()->set_activation( - SegmenterOptions::NONE); - break; - case ImageSegmenterOptions::Activation::SIGMOID: - options_proto->mutable_segmenter_options()->set_activation( - SegmenterOptions::SIGMOID); - break; - case ImageSegmenterOptions::Activation::SOFTMAX: - options_proto->mutable_segmenter_options()->set_activation( - SegmenterOptions::SOFTMAX); - break; - } return options_proto; } +absl::StatusOr> GetLabelsFromGraphConfig( + const CalculatorGraphConfig& graph_config) { + bool found_tensor_to_segmentation_calculator = false; + std::vector labels; + for (const auto& node : graph_config.node()) { + if (node.calculator() == + "mediapipe.tasks.TensorsToSegmentationCalculator") { + if (!found_tensor_to_segmentation_calculator) { + found_tensor_to_segmentation_calculator = true; + } else { + return absl::Status(CreateStatusWithPayload( + absl::StatusCode::kFailedPrecondition, + "The graph has more than one " + "mediapipe.tasks.TensorsToSegmentationCalculator.")); + } + TensorsToSegmentationCalculatorOptions options = + node.options().GetExtension( + TensorsToSegmentationCalculatorOptions::ext); + if (!options.label_items().empty()) { + for (int i = 0; i < options.label_items_size(); ++i) { + if (!options.label_items().contains(i)) { + return absl::Status(CreateStatusWithPayload( + absl::StatusCode::kFailedPrecondition, + absl::StrFormat("The lablemap have no expected key: %d.", i))); + } + labels.push_back(options.label_items().at(i).name()); + } + } + } + } + return labels; +} + } // namespace absl::StatusOr> ImageSegmenter::Create( @@ -140,13 +165,22 @@ absl::StatusOr> ImageSegmenter::Create( kMicroSecondsPerMilliSecond); }; } - return core::VisionTaskApiFactory::Create( - CreateGraphConfig( - std::move(options_proto), - options->running_mode == core::RunningMode::LIVE_STREAM), - std::move(options->base_options.op_resolver), options->running_mode, - std::move(packets_callback)); + + auto image_segmenter = + core::VisionTaskApiFactory::Create( + CreateGraphConfig( + std::move(options_proto), + options->running_mode == core::RunningMode::LIVE_STREAM), + std::move(options->base_options.op_resolver), options->running_mode, + std::move(packets_callback)); + if (!image_segmenter.ok()) { + return image_segmenter.status(); + } + ASSIGN_OR_RETURN( + (*image_segmenter)->labels_, + GetLabelsFromGraphConfig((*image_segmenter)->runner_->GetGraphConfig())); + return image_segmenter; } absl::StatusOr> ImageSegmenter::Segment( diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h index 511d3b9c1..076a5016c 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h @@ -64,15 +64,6 @@ struct ImageSegmenterOptions { OutputType output_type = OutputType::CATEGORY_MASK; - // The activation function used on the raw segmentation model output. - enum Activation { - NONE = 0, // No activation function is used. - SIGMOID = 1, // Assumes 1-channel input tensor. - SOFTMAX = 2, // Assumes multi-channel input tensor. - }; - - Activation activation = Activation::NONE; - // The user-defined result callback for processing live stream data. // The result callback should only be specified when the running mode is set // to RunningMode::LIVE_STREAM. @@ -189,6 +180,18 @@ class ImageSegmenter : tasks::vision::core::BaseVisionTaskApi { // Shuts down the ImageSegmenter when all works are done. absl::Status Close() { return runner_->Close(); } + + // Get the category label list of the ImageSegmenter can recognize. For + // CATEGORY_MASK type, the index in the category mask corresponds to the + // category in the label list. For CONFIDENCE_MASK type, the output mask list + // at index corresponds to the category in the label list. + // + // If there is no labelmap provided in the model file, empty label list is + // returned. + std::vector GetLabels() { return labels_; } + + private: + std::vector labels_; }; } // namespace image_segmenter diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc index c4a4065c6..fe6265b73 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_graph.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" #include "mediapipe/tasks/cc/vision/utils/image_tensor_specs.h" +#include "mediapipe/tasks/metadata/image_segmenter_metadata_schema_generated.h" #include "mediapipe/tasks/metadata/metadata_schema_generated.h" #include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map_util.h" @@ -74,6 +75,7 @@ constexpr char kImageGpuTag[] = "IMAGE_GPU"; constexpr char kNormRectTag[] = "NORM_RECT"; constexpr char kTensorsTag[] = "TENSORS"; constexpr char kOutputSizeTag[] = "OUTPUT_SIZE"; +constexpr char kSegmentationMetadataName[] = "SEGMENTER_METADATA"; // Struct holding the different output streams produced by the image segmenter // subgraph. @@ -130,7 +132,49 @@ absl::Status ConfigureTensorsToSegmentationCalculator( const ImageSegmenterGraphOptions& segmenter_option, const core::ModelResources& model_resources, TensorsToSegmentationCalculatorOptions* options) { - *options->mutable_segmenter_options() = segmenter_option.segmenter_options(); + // Set default activation function NONE + options->mutable_segmenter_options()->set_output_type( + segmenter_option.segmenter_options().output_type()); + options->mutable_segmenter_options()->set_activation(SegmenterOptions::NONE); + // Find the custom metadata of ImageSegmenterOptions type in model metadata. + const auto* metadata_extractor = model_resources.GetMetadataExtractor(); + bool found_activation_in_metadata = false; + if (metadata_extractor->GetCustomMetadataList() != nullptr && + metadata_extractor->GetCustomMetadataList()->size() > 0) { + for (const auto& custom_metadata : + *metadata_extractor->GetCustomMetadataList()) { + if (custom_metadata->name()->str() == kSegmentationMetadataName) { + found_activation_in_metadata = true; + auto activation_fb = + GetImageSegmenterOptions(custom_metadata->data()->data()) + ->activation(); + switch (activation_fb) { + case Activation_NONE: + options->mutable_segmenter_options()->set_activation( + SegmenterOptions::NONE); + break; + case Activation_SIGMOID: + options->mutable_segmenter_options()->set_activation( + SegmenterOptions::SIGMOID); + break; + case Activation_SOFTMAX: + options->mutable_segmenter_options()->set_activation( + SegmenterOptions::SOFTMAX); + break; + default: + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + "Invalid activation type found in CustomMetadata of " + "ImageSegmenterOptions type."); + } + } + } + } + if (!found_activation_in_metadata) { + LOG(WARNING) + << "No activation type is found in model metadata. Use NONE for " + "ImageSegmenterGraph."; + } const tflite::Model& model = *model_resources.GetTfLiteModel(); if (model.subgraphs()->size() != 1) { return CreateStatusWithPayload( @@ -146,8 +190,6 @@ absl::Status ConfigureTensorsToSegmentationCalculator( MediaPipeTasksStatus::kInvalidArgumentError); } - const ModelMetadataExtractor* metadata_extractor = - model_resources.GetMetadataExtractor(); ASSIGN_OR_RETURN( *options->mutable_label_items(), GetLabelItemsIfAny(*metadata_extractor, @@ -401,7 +443,7 @@ class ImageSegmenterGraph : public core::ModelTaskGraph { } else { ASSIGN_OR_RETURN(const tflite::Tensor* output_tensor, GetOutputTensor(model_resources)); - const int segmentation_streams_num = *output_tensor->shape()->rbegin(); + int segmentation_streams_num = *output_tensor->shape()->rbegin(); for (int i = 0; i < segmentation_streams_num; ++i) { segmented_masks.push_back(Source( tensor_to_images[Output::Multiple(kSegmentationTag)][i])); diff --git a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc index d1fe20182..1d75a3fb7 100644 --- a/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc +++ b/mediapipe/tasks/cc/vision/image_segmenter/image_segmenter_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mediapipe/tasks/cc/vision/image_segmenter/image_segmenter.h" +#include #include #include @@ -61,6 +62,11 @@ constexpr char kSelfie128x128WithMetadata[] = "selfie_segm_128_128_3.tflite"; constexpr char kSelfie144x256WithMetadata[] = "selfie_segm_144_256_3.tflite"; +constexpr char kSelfieSegmentation[] = "selfie_segmentation.tflite"; + +constexpr char kSelfieSegmentationLandscape[] = + "selfie_segmentation_landscape.tflite"; + constexpr char kHairSegmentationWithMetadata[] = "hair_segmentation.tflite"; constexpr float kGoldenMaskSimilarity = 0.98; @@ -71,6 +77,13 @@ constexpr float kGoldenMaskSimilarity = 0.98; // 20 means class index 2, etc. constexpr int kGoldenMaskMagnificationFactor = 10; +constexpr std::array kDeeplabLabelNames = { + "background", "aeroplane", "bicycle", "bird", "boat", + "bottle", "bus", "car", "cat", "chair", + "cow", "dining table", "dog", "horse", "motorbike", + "person", "potted plant", "sheep", "sofa", "train", + "tv"}; + // Intentionally converting output into CV_8UC1 and then again into CV_32FC1 // as expected outputs are stored in CV_8UC1, so this conversion allows to do // fair comparison. @@ -82,13 +95,8 @@ cv::Mat PostProcessResultMask(const cv::Mat& mask) { } Image GetSRGBImage(const std::string& image_path) { - // TODO: fix test so RGB really is used and not BGR/BGRA. - // mediapipe/app/aimatter/segmentation/segmenter_test_common.cc - // golden masks are generated with BGR image. To align with the unittest of - // aimatter segmenter, here reads image as BGR as well (opencv reads image as - // BGR). Once the correctness of mediapipe tasks segmenter is verified, change - // the golden masks to be generated by RGB image. cv::Mat image_mat = cv::imread(image_path); + cv::cvtColor(image_mat, image_mat, cv::COLOR_BGR2RGB); mediapipe::ImageFrame image_frame( mediapipe::ImageFormat::SRGB, image_mat.cols, image_mat.rows, image_mat.step, image_mat.data, [image_mat](uint8_t[]) {}); @@ -244,6 +252,22 @@ TEST_F(CreateFromOptionsTest, FailsWithInputChannelOneModel) { "channels = 3 or 4.")); } +TEST(GetLabelsTest, SucceedsWithLabelsInModel) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + const auto& labels = segmenter->GetLabels(); + ASSERT_FALSE(labels.empty()); + ASSERT_EQ(labels.size(), kDeeplabLabelNames.size()); + for (int i = 0; i < labels.size(); ++i) { + EXPECT_EQ(labels[i], kDeeplabLabelNames[i]); + } +} + class ImageModeTest : public tflite_shims::testing::Test {}; TEST_F(ImageModeTest, SucceedsWithCategoryMask) { @@ -280,7 +304,6 @@ TEST_F(ImageModeTest, SucceedsWithConfidenceMask) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -309,7 +332,6 @@ TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -340,7 +362,6 @@ TEST_F(ImageModeTest, FailsWithRegionOfInterest) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kDeeplabV3WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -364,7 +385,6 @@ TEST_F(ImageModeTest, SucceedsSelfie128x128Segmentation) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie128x128WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); @@ -392,7 +412,6 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kSelfie144x256WithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::NONE; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); @@ -411,6 +430,82 @@ TEST_F(ImageModeTest, SucceedsSelfie144x256Segmentations) { SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); } +TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationConfidenceMask) { + Image image = + GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kSelfieSegmentation); + options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); + EXPECT_EQ(confidence_masks.size(), 1); + MP_ASSERT_OK(segmenter->Close()); + + cv::Mat expected_mask = cv::imread( + JoinPath("./", kTestDataDirectory, + "portrait_selfie_segmentation_expected_confidence_mask.jpg"), + cv::IMREAD_GRAYSCALE); + cv::Mat expected_mask_float; + expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); + + cv::Mat selfie_mask = mediapipe::formats::MatView( + confidence_masks[0].GetImageFrameSharedPtr().get()); + EXPECT_THAT(selfie_mask, + SimilarToFloatMask(expected_mask_float, kGoldenMaskSimilarity)); +} + +TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationCategoryMask) { + Image image = + GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kSelfieSegmentation); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); + EXPECT_EQ(category_mask.size(), 1); + MP_ASSERT_OK(segmenter->Close()); + + cv::Mat selfie_mask = mediapipe::formats::MatView( + category_mask[0].GetImageFrameSharedPtr().get()); + cv::Mat expected_mask = cv::imread( + JoinPath("./", kTestDataDirectory, + "portrait_selfie_segmentation_expected_category_mask.jpg"), + cv::IMREAD_GRAYSCALE); + EXPECT_THAT(selfie_mask, + SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255)); +} + +TEST_F(ImageModeTest, SucceedsPortraitSelfieSegmentationLandscapeCategoryMask) { + Image image = + GetSRGBImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kSelfieSegmentationLandscape); + options->output_type = ImageSegmenterOptions::OutputType::CATEGORY_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + ImageSegmenter::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto category_mask, segmenter->Segment(image)); + EXPECT_EQ(category_mask.size(), 1); + MP_ASSERT_OK(segmenter->Close()); + + cv::Mat selfie_mask = mediapipe::formats::MatView( + category_mask[0].GetImageFrameSharedPtr().get()); + cv::Mat expected_mask = cv::imread( + JoinPath( + "./", kTestDataDirectory, + "portrait_selfie_segmentation_landscape_expected_category_mask.jpg"), + cv::IMREAD_GRAYSCALE); + EXPECT_THAT(selfie_mask, + SimilarToUint8Mask(expected_mask, kGoldenMaskSimilarity, 255)); +} + TEST_F(ImageModeTest, SucceedsHairSegmentation) { Image image = GetSRGBAImage(JoinPath("./", kTestDataDirectory, "portrait.jpg")); @@ -418,7 +513,6 @@ TEST_F(ImageModeTest, SucceedsHairSegmentation) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kHairSegmentationWithMetadata); options->output_type = ImageSegmenterOptions::OutputType::CONFIDENCE_MASK; - options->activation = ImageSegmenterOptions::Activation::SOFTMAX; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, ImageSegmenter::Create(std::move(options))); MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, segmenter->Segment(image)); diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD b/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD new file mode 100644 index 000000000..ea72d3d99 --- /dev/null +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/BUILD @@ -0,0 +1,76 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# Docs for Mediapipe Tasks Interactive Segmenter +# TODO: add doc link. +cc_library( + name = "interactive_segmenter", + srcs = ["interactive_segmenter.cc"], + hdrs = ["interactive_segmenter.h"], + deps = [ + ":interactive_segmenter_graph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/containers:keypoint", + "//mediapipe/tasks/cc/core:base_options", + "//mediapipe/tasks/cc/vision/core:base_vision_task_api", + "//mediapipe/tasks/cc/vision/core:image_processing_options", + "//mediapipe/tasks/cc/vision/core:running_mode", + "//mediapipe/tasks/cc/vision/core:vision_task_api_factory", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:render_data_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "interactive_segmenter_graph", + srcs = ["interactive_segmenter_graph.cc"], + deps = [ + "@com_google_absl//absl/strings", + "//mediapipe/calculators/image:set_alpha_calculator", + "//mediapipe/calculators/util:annotation_overlay_calculator", + "//mediapipe/calculators/util:flat_color_image_calculator", + "//mediapipe/calculators/util:flat_color_image_calculator_cc_proto", + "//mediapipe/calculators/util:from_image_calculator", + "//mediapipe/calculators/util:to_image_calculator", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/api2:port", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_cc_proto", + "//mediapipe/util:color_cc_proto", + "//mediapipe/util:label_map_cc_proto", + "//mediapipe/util:render_data_cc_proto", + ] + select({ + "//mediapipe/gpu:disable_gpu": [], + "//conditions:default": [ + "//mediapipe/gpu:gpu_buffer_to_image_frame_calculator", + "//mediapipe/gpu:image_frame_to_gpu_buffer_calculator", + ], + }), + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc new file mode 100644 index 000000000..4298d4a19 --- /dev/null +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.cc @@ -0,0 +1,163 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/core/running_mode.h" +#include "mediapipe/tasks/cc/vision/core/vision_task_api_factory.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/segmenter_options.pb.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace interactive_segmenter { +namespace { + +constexpr char kSegmentationStreamName[] = "segmented_mask_out"; +constexpr char kImageInStreamName[] = "image_in"; +constexpr char kImageOutStreamName[] = "image_out"; +constexpr char kRoiStreamName[] = "roi_in"; +constexpr char kNormRectStreamName[] = "norm_rect_in"; + +constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kRoiTag[] = "ROI"; +constexpr char kNormRectTag[] = "NORM_RECT"; + +constexpr char kSubgraphTypeName[] = + "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"; + +using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; +using ::mediapipe::tasks::vision::image_segmenter::proto::SegmenterOptions; +using ImageSegmenterGraphOptionsProto = ::mediapipe::tasks::vision:: + image_segmenter::proto::ImageSegmenterGraphOptions; + +// Creates a MediaPipe graph config that only contains a single subgraph node of +// "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph". +CalculatorGraphConfig CreateGraphConfig( + std::unique_ptr options) { + api2::builder::Graph graph; + auto& task_subgraph = graph.AddNode(kSubgraphTypeName); + task_subgraph.GetOptions().Swap( + options.get()); + graph.In(kImageTag).SetName(kImageInStreamName); + graph.In(kRoiTag).SetName(kRoiStreamName); + graph.In(kNormRectTag).SetName(kNormRectStreamName); + task_subgraph.Out(kGroupedSegmentationTag).SetName(kSegmentationStreamName) >> + graph.Out(kGroupedSegmentationTag); + task_subgraph.Out(kImageTag).SetName(kImageOutStreamName) >> + graph.Out(kImageTag); + graph.In(kImageTag) >> task_subgraph.In(kImageTag); + graph.In(kRoiTag) >> task_subgraph.In(kRoiTag); + graph.In(kNormRectTag) >> task_subgraph.In(kNormRectTag); + return graph.GetConfig(); +} + +// Converts the user-facing InteractiveSegmenterOptions struct to the internal +// ImageSegmenterOptions proto. +std::unique_ptr +ConvertImageSegmenterOptionsToProto(InteractiveSegmenterOptions* options) { + auto options_proto = std::make_unique(); + auto base_options_proto = std::make_unique( + tasks::core::ConvertBaseOptionsToProto(&(options->base_options))); + options_proto->mutable_base_options()->Swap(base_options_proto.get()); + switch (options->output_type) { + case InteractiveSegmenterOptions::OutputType::CATEGORY_MASK: + options_proto->mutable_segmenter_options()->set_output_type( + SegmenterOptions::CATEGORY_MASK); + break; + case InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK: + options_proto->mutable_segmenter_options()->set_output_type( + SegmenterOptions::CONFIDENCE_MASK); + break; + } + return options_proto; +} + +// Converts the user-facing RegionOfInterest struct to the RenderData proto that +// is used in subgraph. +absl::StatusOr ConvertRoiToRenderData(const RegionOfInterest& roi) { + RenderData result; + switch (roi.format) { + case RegionOfInterest::UNSPECIFIED: + return absl::InvalidArgumentError( + "RegionOfInterest format not specified"); + case RegionOfInterest::KEYPOINT: + RET_CHECK(roi.keypoint.has_value()); + auto* annotation = result.add_render_annotations(); + annotation->mutable_color()->set_r(255); + auto* point = annotation->mutable_point(); + point->set_normalized(true); + point->set_x(roi.keypoint->x); + point->set_y(roi.keypoint->y); + return result; + } + return absl::UnimplementedError("Unrecognized format"); +} + +} // namespace + +absl::StatusOr> +InteractiveSegmenter::Create( + std::unique_ptr options) { + auto options_proto = ConvertImageSegmenterOptionsToProto(options.get()); + return core::VisionTaskApiFactory::Create( + CreateGraphConfig(std::move(options_proto)), + std::move(options->base_options.op_resolver), core::RunningMode::IMAGE, + /*packets_callback=*/nullptr); +} + +absl::StatusOr> InteractiveSegmenter::Segment( + mediapipe::Image image, const RegionOfInterest& roi, + std::optional image_processing_options) { + if (image.UsesGpu()) { + return CreateStatusWithPayload( + absl::StatusCode::kInvalidArgument, + absl::StrCat("GPU input images are currently not supported."), + MediaPipeTasksStatus::kRunnerUnexpectedInputError); + } + ASSIGN_OR_RETURN( + NormalizedRect norm_rect, + ConvertToNormalizedRect(image_processing_options, /*roi_allowed=*/false)); + ASSIGN_OR_RETURN(RenderData roi_as_render_data, ConvertRoiToRenderData(roi)); + ASSIGN_OR_RETURN( + auto output_packets, + ProcessImageData( + {{kImageInStreamName, mediapipe::MakePacket(std::move(image))}, + {kRoiStreamName, + mediapipe::MakePacket(std::move(roi_as_render_data))}, + {kNormRectStreamName, + MakePacket(std::move(norm_rect))}})); + return output_packets[kSegmentationStreamName].Get>(); +} + +} // namespace interactive_segmenter +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h new file mode 100644 index 000000000..420b22462 --- /dev/null +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h @@ -0,0 +1,136 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_ +#define MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" +#include "mediapipe/tasks/cc/core/base_options.h" +#include "mediapipe/tasks/cc/vision/core/base_vision_task_api.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace interactive_segmenter { + +// The options for configuring a mediapipe interactive segmenter task. +struct InteractiveSegmenterOptions { + // Base options for configuring MediaPipe Tasks, such as specifying the model + // file with metadata, accelerator options, op resolver, etc. + tasks::core::BaseOptions base_options; + + // The output type of segmentation results. + enum OutputType { + // Gives a single output mask where each pixel represents the class which + // the pixel in the original image was predicted to belong to. + CATEGORY_MASK = 0, + // Gives a list of output masks where, for each mask, each pixel represents + // the prediction confidence, usually in the [0, 1] range. + CONFIDENCE_MASK = 1, + }; + + OutputType output_type = OutputType::CATEGORY_MASK; +}; + +// The Region-Of-Interest (ROI) to interact with. +struct RegionOfInterest { + enum Format { + UNSPECIFIED = 0, // Format not specified. + KEYPOINT = 1, // Using keypoint to represent ROI. + }; + + // Specifies the format used to specify the region-of-interest. Note that + // using `UNSPECIFIED` is invalid and will lead to an `InvalidArgument` status + // being returned. + Format format = Format::UNSPECIFIED; + + // Represents the ROI in keypoint format, this should be non-nullopt if + // `format` is `KEYPOINT`. + std::optional keypoint; +}; + +// Performs interactive segmentation on images. +// +// Users can represent user interaction through `RegionOfInterest`, which gives +// a hint to InteractiveSegmenter to perform segmentation focusing on the given +// region of interest. +// +// The API expects a TFLite model with mandatory TFLite Model Metadata. +// +// Input tensor: +// (kTfLiteUInt8/kTfLiteFloat32) +// - image input of size `[batch x height x width x channels]`. +// - batch inference is not supported (`batch` is required to be 1). +// - RGB inputs is supported (`channels` is required to be 3). +// - if type is kTfLiteFloat32, NormalizationOptions are required to be +// attached to the metadata for input normalization. +// Output tensors: +// (kTfLiteUInt8/kTfLiteFloat32) +// - list of segmented masks. +// - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. +// - if `output_type` is CONFIDENCE_MASK, float32 Image list of size +// `channels`. +// - batch is always 1 +class InteractiveSegmenter : tasks::vision::core::BaseVisionTaskApi { + public: + using BaseVisionTaskApi::BaseVisionTaskApi; + + // Creates an InteractiveSegmenter from the provided options. A non-default + // OpResolver can be specified in the BaseOptions of + // InteractiveSegmenterOptions, to support custom Ops of the segmentation + // model. + static absl::StatusOr> Create( + std::unique_ptr options); + + // Performs image segmentation on the provided single image. + // + // The image can be of any size with format RGB. + // + // The `roi` parameter is used to represent user's region of interest for + // segmentation. + // + // The optional 'image_processing_options' parameter can be used to specify + // the rotation to apply to the image before performing segmentation, by + // setting its 'rotation_degrees' field. Note that specifying a + // region-of-interest using the 'region_of_interest' field is NOT supported + // and will result in an invalid argument error being returned. + // + // If the output_type is CATEGORY_MASK, the returned vector of images is + // per-category segmented image mask. + // If the output_type is CONFIDENCE_MASK, the returned vector of images + // contains only one confidence image mask. + absl::StatusOr> Segment( + mediapipe::Image image, const RegionOfInterest& roi, + std::optional image_processing_options = + std::nullopt); + + // Shuts down the InteractiveSegmenter when all works are done. + absl::Status Close() { return runner_->Close(); } +}; + +} // namespace interactive_segmenter +} // namespace vision +} // namespace tasks +} // namespace mediapipe + +#endif // MEDIAPIPE_TASKS_CC_VISION_INTERACTIVE_SEGMENTER_INTERACTIVE_SEGMENTER_H_ diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc new file mode 100644 index 000000000..4c0cd2a88 --- /dev/null +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_graph.cc @@ -0,0 +1,198 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/util/color.pb.h" +#include "mediapipe/util/label_map.pb.h" +#include "mediapipe/util/render_data.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace interactive_segmenter { + +namespace { + +using image_segmenter::proto::ImageSegmenterGraphOptions; +using ::mediapipe::Image; +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; + +constexpr char kSegmentationTag[] = "SEGMENTATION"; +constexpr char kGroupedSegmentationTag[] = "GROUPED_SEGMENTATION"; +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageCpuTag[] = "IMAGE_CPU"; +constexpr char kImageGpuTag[] = "IMAGE_GPU"; +constexpr char kAlphaTag[] = "ALPHA"; +constexpr char kAlphaGpuTag[] = "ALPHA_GPU"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kRoiTag[] = "ROI"; +constexpr char kVideoTag[] = "VIDEO"; + +// Updates the graph to return `roi` stream which has same dimension as +// `image`, and rendered with `roi`. If `use_gpu` is true, returned `Source` is +// in GpuBuffer format, otherwise using ImageFrame. +Source<> RoiToAlpha(Source image, Source roi, bool use_gpu, + Graph& graph) { + // TODO: Replace with efficient implementation. + const absl::string_view image_tag_with_suffix = + use_gpu ? kImageGpuTag : kImageCpuTag; + + // Generates a blank canvas with same size as input image. + auto& flat_color = graph.AddNode("FlatColorImageCalculator"); + auto& flat_color_options = + flat_color.GetOptions(); + // SetAlphaCalculator only takes 1st channel. + flat_color_options.mutable_color()->set_r(0); + image >> flat_color.In(kImageTag)[0]; + auto blank_canvas = flat_color.Out(kImageTag)[0]; + + auto& from_mp_image = graph.AddNode("FromImageCalculator"); + blank_canvas >> from_mp_image.In(kImageTag); + auto blank_canvas_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix); + + auto& roi_to_alpha = graph.AddNode("AnnotationOverlayCalculator"); + blank_canvas_in_cpu_or_gpu >> + roi_to_alpha.In(use_gpu ? kImageGpuTag : kImageTag); + roi >> roi_to_alpha.In(0); + auto alpha = roi_to_alpha.Out(use_gpu ? kImageGpuTag : kImageTag); + + return alpha; +} + +} // namespace + +// An "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph" +// performs semantic segmentation given user's region-of-interest. Two kinds of +// outputs are provided: SEGMENTATION and GROUPED_SEGMENTATION. Users can +// retrieve segmented mask of only particular category/channel from +// SEGMENTATION, and users can also get all segmented masks from +// GROUPED_SEGMENTATION. +// - Accepts CPU input images and outputs segmented masks on CPU. +// +// Inputs: +// IMAGE - Image +// Image to perform segmentation on. +// ROI - RenderData proto +// Region of interest based on user interaction. Currently only support +// Point format, and Color has to be (255, 255, 255). +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform detection +// on. +// @Optional: rect covering the whole image is used if not specified. +// +// Outputs: +// SEGMENTATION - mediapipe::Image @Multiple +// Segmented masks for individual category. Segmented mask of single +// category can be accessed by index based output stream. +// GROUPED_SEGMENTATION - std::vector +// The output segmented masks grouped in a vector. +// IMAGE - mediapipe::Image +// The image that image segmenter runs on. +// +// Example: +// node { +// calculator: +// "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph" +// input_stream: "IMAGE:image" +// input_stream: "ROI:region_of_interest" +// output_stream: "SEGMENTATION:segmented_masks" +// options { +// [mediapipe.tasks.vision.image_segmenter.proto.ImageSegmenterGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "/path/to/model.tflite" +// } +// } +// segmenter_options { +// output_type: CONFIDENCE_MASK +// } +// } +// } +// } +class InteractiveSegmenterGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + mediapipe::SubgraphContext* sc) override { + Graph graph; + const auto& task_options = sc->Options(); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + task_options.base_options().acceleration()); + + Source image = graph[Input(kImageTag)]; + Source roi = graph[Input(kRoiTag)]; + Source norm_rect = + graph[Input(kNormRectTag)]; + const absl::string_view image_tag_with_suffix = + use_gpu ? kImageGpuTag : kImageCpuTag; + const absl::string_view alpha_tag_with_suffix = + use_gpu ? kAlphaGpuTag : kAlphaTag; + + auto& from_mp_image = graph.AddNode("FromImageCalculator"); + image >> from_mp_image.In(kImageTag); + auto image_in_cpu_or_gpu = from_mp_image.Out(image_tag_with_suffix); + + auto alpha_in_cpu_or_gpu = RoiToAlpha(image, roi, use_gpu, graph); + + auto& set_alpha = graph.AddNode("SetAlphaCalculator"); + image_in_cpu_or_gpu >> set_alpha.In(use_gpu ? kImageGpuTag : kImageTag); + alpha_in_cpu_or_gpu >> set_alpha.In(alpha_tag_with_suffix); + auto image_in_cpu_or_gpu_with_set_alpha = + set_alpha.Out(use_gpu ? kImageGpuTag : kImageTag); + + auto& to_mp_image = graph.AddNode("ToImageCalculator"); + image_in_cpu_or_gpu_with_set_alpha >> to_mp_image.In(image_tag_with_suffix); + auto image_with_set_alpha = to_mp_image.Out(kImageTag); + + auto& image_segmenter = graph.AddNode( + "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"); + image_segmenter.GetOptions() = task_options; + image_with_set_alpha >> image_segmenter.In(kImageTag); + norm_rect >> image_segmenter.In(kNormRectTag); + + image_segmenter.Out(kSegmentationTag) >> + graph[Output(kSegmentationTag)]; + image_segmenter.Out(kGroupedSegmentationTag) >> + graph[Output>(kGroupedSegmentationTag)]; + image_segmenter.Out(kImageTag) >> graph[Output(kImageTag)]; + + return graph.GetConfig(); + } +}; + +// REGISTER_MEDIAPIPE_GRAPH argument has to fit on one line to work properly. +// clang-format off +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::interactive_segmenter::InteractiveSegmenterGraph); +// clang-format on + +} // namespace interactive_segmenter +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc new file mode 100644 index 000000000..dbc3bbe4c --- /dev/null +++ b/mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter_test.cc @@ -0,0 +1,306 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mediapipe/tasks/cc/vision/interactive_segmenter/interactive_segmenter.h" + +#include +#include + +#include "absl/flags/flag.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/image_frame.h" +#include "mediapipe/framework/formats/image_frame_opencv.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/opencv_core_inc.h" +#include "mediapipe/framework/port/opencv_imgcodecs_inc.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/tasks/cc/components/containers/keypoint.h" +#include "mediapipe/tasks/cc/components/containers/rect.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/vision/core/image_processing_options.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator.pb.h" +#include "mediapipe/tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" +#include "tensorflow/lite/core/shims/cc/shims_test_util.h" +#include "tensorflow/lite/kernels/builtin_op_kernels.h" +#include "tensorflow/lite/mutable_op_resolver.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace interactive_segmenter { +namespace { + +using ::mediapipe::Image; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::components::containers::NormalizedKeypoint; +using ::mediapipe::tasks::components::containers::RectF; +using ::mediapipe::tasks::vision::core::ImageProcessingOptions; +using ::testing::HasSubstr; +using ::testing::Optional; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kPtmModel[] = "ptm_512_hdt_ptm_woid.tflite"; +constexpr char kCatsAndDogsJpg[] = "cats_and_dogs.jpg"; +// Golden mask for the dogs in cats_and_dogs.jpg. +constexpr char kCatsAndDogsMaskDog1[] = "cats_and_dogs_mask_dog1.png"; +constexpr char kCatsAndDogsMaskDog2[] = "cats_and_dogs_mask_dog2.png"; + +constexpr float kGoldenMaskSimilarity = 0.97; + +// Magnification factor used when creating the golden category masks to make +// them more human-friendly. Since interactive segmenter has only 2 categories, +// the golden mask uses 0 or 255 for each pixel. +constexpr int kGoldenMaskMagnificationFactor = 255; + +// Intentionally converting output into CV_8UC1 and then again into CV_32FC1 +// as expected outputs are stored in CV_8UC1, so this conversion allows to do +// fair comparison. +cv::Mat PostProcessResultMask(const cv::Mat& mask) { + cv::Mat mask_float; + mask.convertTo(mask_float, CV_8UC1, 255); + mask_float.convertTo(mask_float, CV_32FC1, 1 / 255.f); + return mask_float; +} + +double CalculateSum(const cv::Mat& m) { + double sum = 0.0; + cv::Scalar s = cv::sum(m); + for (int i = 0; i < m.channels(); ++i) { + sum += s.val[i]; + } + return sum; +} + +double CalculateSoftIOU(const cv::Mat& m1, const cv::Mat& m2) { + cv::Mat intersection; + cv::multiply(m1, m2, intersection); + double intersection_value = CalculateSum(intersection); + double union_value = + CalculateSum(m1.mul(m1)) + CalculateSum(m2.mul(m2)) - intersection_value; + return union_value > 0.0 ? intersection_value / union_value : 0.0; +} + +MATCHER_P2(SimilarToFloatMask, expected_mask, similarity_threshold, "") { + cv::Mat actual_mask = PostProcessResultMask(arg); + return arg.rows == expected_mask.rows && arg.cols == expected_mask.cols && + CalculateSoftIOU(arg, expected_mask) > similarity_threshold; +} + +MATCHER_P3(SimilarToUint8Mask, expected_mask, similarity_threshold, + magnification_factor, "") { + if (arg.rows != expected_mask.rows || arg.cols != expected_mask.cols) { + return false; + } + int consistent_pixels = 0; + const int num_pixels = expected_mask.rows * expected_mask.cols; + for (int i = 0; i < num_pixels; ++i) { + consistent_pixels += + (arg.data[i] * magnification_factor == expected_mask.data[i]); + } + return static_cast(consistent_pixels) / num_pixels >= + similarity_threshold; +} + +class CreateFromOptionsTest : public tflite_shims::testing::Test {}; + +class DeepLabOpResolverMissingOps : public ::tflite::MutableOpResolver { + public: + DeepLabOpResolverMissingOps() { + AddBuiltin(::tflite::BuiltinOperator_ADD, + ::tflite::ops::builtin::Register_ADD()); + } + + DeepLabOpResolverMissingOps(const DeepLabOpResolverMissingOps& r) = delete; +}; + +TEST_F(CreateFromOptionsTest, FailsWithSelectiveOpResolverMissingOps) { + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPtmModel); + options->base_options.op_resolver = + absl::make_unique(); + auto segmenter_or = InteractiveSegmenter::Create(std::move(options)); + // TODO: Make MediaPipe InferenceCalculator report the detailed + // interpreter errors (e.g., "Encountered unresolved custom op"). + EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInternal); + EXPECT_THAT( + segmenter_or.status().message(), + testing::HasSubstr("interpreter_builder(&interpreter) == kTfLiteOk")); +} + +TEST_F(CreateFromOptionsTest, FailsWithMissingModel) { + absl::StatusOr> segmenter_or = + InteractiveSegmenter::Create( + std::make_unique()); + + EXPECT_EQ(segmenter_or.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT( + segmenter_or.status().message(), + HasSubstr("ExternalFile must specify at least one of 'file_content', " + "'file_name', 'file_pointer_meta' or 'file_descriptor_meta'.")); + EXPECT_THAT(segmenter_or.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kRunnerInitializationError)))); +} + +struct InteractiveSegmenterTestParams { + std::string test_name; + RegionOfInterest::Format format; + NormalizedKeypoint roi; + std::string golden_mask_file; + float similarity_threshold; +}; + +using SucceedSegmentationWithRoi = + ::testing::TestWithParam; + +TEST_P(SucceedSegmentationWithRoi, SucceedsWithCategoryMask) { + const InteractiveSegmenterTestParams& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); + RegionOfInterest interaction_roi; + interaction_roi.format = params.format; + interaction_roi.keypoint = params.roi; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPtmModel); + options->output_type = InteractiveSegmenterOptions::OutputType::CATEGORY_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + InteractiveSegmenter::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto category_masks, + segmenter->Segment(image, interaction_roi)); + EXPECT_EQ(category_masks.size(), 1); + + cv::Mat actual_mask = mediapipe::formats::MatView( + category_masks[0].GetImageFrameSharedPtr().get()); + + cv::Mat expected_mask = + cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file), + cv::IMREAD_GRAYSCALE); + EXPECT_THAT(actual_mask, + SimilarToUint8Mask(expected_mask, params.similarity_threshold, + kGoldenMaskMagnificationFactor)); +} + +TEST_P(SucceedSegmentationWithRoi, SucceedsWithConfidenceMask) { + const auto& params = GetParam(); + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); + RegionOfInterest interaction_roi; + interaction_roi.format = params.format; + interaction_roi.keypoint = params.roi; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPtmModel); + options->output_type = + InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + InteractiveSegmenter::Create(std::move(options))); + MP_ASSERT_OK_AND_ASSIGN(auto confidence_masks, + segmenter->Segment(image, interaction_roi)); + EXPECT_EQ(confidence_masks.size(), 2); + + cv::Mat expected_mask = + cv::imread(JoinPath("./", kTestDataDirectory, params.golden_mask_file), + cv::IMREAD_GRAYSCALE); + cv::Mat expected_mask_float; + expected_mask.convertTo(expected_mask_float, CV_32FC1, 1 / 255.f); + + cv::Mat actual_mask = mediapipe::formats::MatView( + confidence_masks[1].GetImageFrameSharedPtr().get()); + EXPECT_THAT(actual_mask, SimilarToFloatMask(expected_mask_float, + params.similarity_threshold)); +} + +INSTANTIATE_TEST_SUITE_P( + SucceedSegmentationWithRoiTest, SucceedSegmentationWithRoi, + ::testing::ValuesIn( + {{"PointToDog1", RegionOfInterest::KEYPOINT, + NormalizedKeypoint{0.44, 0.70}, kCatsAndDogsMaskDog1, 0.84f}, + {"PointToDog2", RegionOfInterest::KEYPOINT, + NormalizedKeypoint{0.66, 0.66}, kCatsAndDogsMaskDog2, + kGoldenMaskSimilarity}}), + [](const ::testing::TestParamInfo& + info) { return info.param.test_name; }); + +class ImageModeTest : public tflite_shims::testing::Test {}; + +// TODO: fix this unit test after image segmenter handled post +// processing correctly with rotated image. +TEST_F(ImageModeTest, DISABLED_SucceedsWithRotation) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); + RegionOfInterest interaction_roi; + interaction_roi.format = RegionOfInterest::KEYPOINT; + interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66}; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPtmModel); + options->output_type = + InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + InteractiveSegmenter::Create(std::move(options))); + ImageProcessingOptions image_processing_options; + image_processing_options.rotation_degrees = -90; + MP_ASSERT_OK_AND_ASSIGN( + auto confidence_masks, + segmenter->Segment(image, interaction_roi, image_processing_options)); + EXPECT_EQ(confidence_masks.size(), 2); +} + +TEST_F(ImageModeTest, FailsWithRegionOfInterest) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, + DecodeImageFromFile(JoinPath("./", kTestDataDirectory, kCatsAndDogsJpg))); + RegionOfInterest interaction_roi; + interaction_roi.format = RegionOfInterest::KEYPOINT; + interaction_roi.keypoint = NormalizedKeypoint{0.66, 0.66}; + auto options = std::make_unique(); + options->base_options.model_asset_path = + JoinPath("./", kTestDataDirectory, kPtmModel); + options->output_type = + InteractiveSegmenterOptions::OutputType::CONFIDENCE_MASK; + + MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr segmenter, + InteractiveSegmenter::Create(std::move(options))); + RectF roi{/*left=*/0.1, /*top=*/0, /*right=*/0.9, /*bottom=*/1}; + ImageProcessingOptions image_processing_options{roi, /*rotation_degrees=*/0}; + + auto results = + segmenter->Segment(image, interaction_roi, image_processing_options); + EXPECT_EQ(results.status().code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(results.status().message(), + HasSubstr("This task doesn't support region-of-interest")); + EXPECT_THAT( + results.status().GetPayload(kMediaPipeTasksPayload), + Optional(absl::Cord(absl::StrCat( + MediaPipeTasksStatus::kImageProcessingInvalidArgumentError)))); +} + +} // namespace +} // namespace interactive_segmenter +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc index 1b0dcf2fb..cb37bbead 100644 --- a/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc +++ b/mediapipe/tasks/cc/vision/object_detector/object_detector_test.cc @@ -108,31 +108,31 @@ std::vector GenerateMobileSsdNoImageResizingFullExpectedResults() { return {ParseTextProtoOrDie(R"pb( label: "cat" - score: 0.6328125 + score: 0.6210937 location_data { format: BOUNDING_BOX - bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } + bounding_box { xmin: 15 ymin: 197 width: 98 height: 99 } })pb"), ParseTextProtoOrDie(R"pb( label: "cat" - score: 0.59765625 + score: 0.609375 location_data { format: BOUNDING_BOX - bounding_box { xmin: 151 ymin: 78 width: 104 height: 223 } + bounding_box { xmin: 150 ymin: 78 width: 104 height: 223 } })pb"), ParseTextProtoOrDie(R"pb( label: "cat" score: 0.5 location_data { format: BOUNDING_BOX - bounding_box { xmin: 65 ymin: 199 width: 41 height: 101 } + bounding_box { xmin: 64 ymin: 199 width: 42 height: 101 } })pb"), ParseTextProtoOrDie(R"pb( label: "dog" - score: 0.48828125 + score: 0.5 location_data { format: BOUNDING_BOX - bounding_box { xmin: 12 ymin: 110 width: 153 height: 193 } + bounding_box { xmin: 14 ymin: 110 width: 153 height: 193 } })pb")}; } @@ -268,7 +268,7 @@ TEST_F(CreateFromOptionsTest, FailsWithIllegalCallbackInImageOrVideoMode) { options->running_mode = running_mode; options->result_callback = [](absl::StatusOr detections, const Image& image, - int64 timestamp_ms) {}; + int64_t timestamp_ms) {}; absl::StatusOr> object_detector = ObjectDetector::Create(std::move(options)); EXPECT_EQ(object_detector.status().code(), @@ -381,28 +381,28 @@ TEST_F(ImageModeTest, Succeeds) { score: 0.69921875 location_data { format: BOUNDING_BOX - bounding_box { xmin: 608 ymin: 161 width: 381 height: 439 } + bounding_box { xmin: 608 ymin: 164 width: 381 height: 432 } })pb"), ParseTextProtoOrDie(R"pb( label: "cat" - score: 0.64453125 + score: 0.65625 location_data { format: BOUNDING_BOX - bounding_box { xmin: 60 ymin: 398 width: 386 height: 196 } + bounding_box { xmin: 57 ymin: 398 width: 386 height: 196 } })pb"), ParseTextProtoOrDie(R"pb( label: "cat" score: 0.51171875 location_data { format: BOUNDING_BOX - bounding_box { xmin: 256 ymin: 395 width: 173 height: 202 } + bounding_box { xmin: 256 ymin: 394 width: 173 height: 202 } })pb"), ParseTextProtoOrDie(R"pb( label: "cat" score: 0.48828125 location_data { format: BOUNDING_BOX - bounding_box { xmin: 362 ymin: 191 width: 325 height: 419 } + bounding_box { xmin: 360 ymin: 195 width: 330 height: 412 } })pb")})); } @@ -484,10 +484,10 @@ TEST_F(ImageModeTest, SucceedsWithScoreCalibration) { results, ConvertToDetectionResult({ParseTextProtoOrDie(R"pb( label: "cat" - score: 0.6531269142 + score: 0.650467276 location_data { format: BOUNDING_BOX - bounding_box { xmin: 14 ymin: 197 width: 98 height: 99 } + bounding_box { xmin: 15 ymin: 197 width: 98 height: 99 } })pb")})); } @@ -507,9 +507,9 @@ TEST_F(ImageModeTest, SucceedsWithScoreThresholdOption) { GenerateMobileSsdNoImageResizingFullExpectedResults(); ExpectApproximatelyEqual( - results, ConvertToDetectionResult({full_expected_results[0], - full_expected_results[1], - full_expected_results[2]})); + results, ConvertToDetectionResult( + {full_expected_results[0], full_expected_results[1], + full_expected_results[2], full_expected_results[3]})); } TEST_F(ImageModeTest, SucceedsWithMaxResultsOption) { @@ -685,7 +685,7 @@ TEST_F(LiveStreamModeTest, FailsWithCallingWrongMethod) { JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->running_mode = core::RunningMode::LIVE_STREAM; options->result_callback = [](absl::StatusOr detections, - const Image& image, int64 timestamp_ms) {}; + const Image& image, int64_t timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); @@ -716,7 +716,7 @@ TEST_F(LiveStreamModeTest, FailsWithOutOfOrderInputTimestamps) { options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->result_callback = [](absl::StatusOr detections, - const Image& image, int64 timestamp_ms) {}; + const Image& image, int64_t timestamp_ms) {}; MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr object_detector, ObjectDetector::Create(std::move(options))); MP_ASSERT_OK(object_detector->DetectAsync(image, 1)); @@ -742,13 +742,13 @@ TEST_F(LiveStreamModeTest, Succeeds) { options->running_mode = core::RunningMode::LIVE_STREAM; std::vector detection_results; std::vector> image_sizes; - std::vector timestamps; + std::vector timestamps; options->base_options.model_asset_path = JoinPath("./", kTestDataDirectory, kMobileSsdWithMetadata); options->result_callback = [&detection_results, &image_sizes, ×tamps]( absl::StatusOr detections, const Image& image, - int64 timestamp_ms) { + int64_t timestamp_ms) { MP_ASSERT_OK(detections.status()); detection_results.push_back(std::move(detections).value()); image_sizes.push_back({image.width(), image.height()}); @@ -775,7 +775,7 @@ TEST_F(LiveStreamModeTest, Succeeds) { EXPECT_EQ(image_size.first, image.width()); EXPECT_EQ(image_size.second, image.height()); } - int64 timestamp_ms = -1; + int64_t timestamp_ms = -1; for (const auto& timestamp : timestamps) { EXPECT_GT(timestamp, timestamp_ms); timestamp_ms = timestamp; diff --git a/mediapipe/tasks/cc/vision/pose_detector/BUILD b/mediapipe/tasks/cc/vision/pose_detector/BUILD new file mode 100644 index 000000000..1e361afbe --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_detector/BUILD @@ -0,0 +1,53 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +cc_library( + name = "pose_detector_graph", + srcs = ["pose_detector_graph.cc"], + deps = [ + "//mediapipe/calculators/core:clip_vector_size_calculator_cc_proto", + "//mediapipe/calculators/tensor:image_to_tensor_calculator_cc_proto", + "//mediapipe/calculators/tensor:inference_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator", + "//mediapipe/calculators/tensor:tensors_to_detections_calculator_cc_proto", + "//mediapipe/calculators/tflite:ssd_anchors_calculator", + "//mediapipe/calculators/tflite:ssd_anchors_calculator_cc_proto", + "//mediapipe/calculators/util:detection_projection_calculator", + "//mediapipe/calculators/util:detection_transformation_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator", + "//mediapipe/calculators/util:detections_to_rects_calculator_cc_proto", + "//mediapipe/calculators/util:non_max_suppression_calculator", + "//mediapipe/calculators/util:non_max_suppression_calculator_cc_proto", + "//mediapipe/calculators/util:rect_transformation_calculator", + "//mediapipe/calculators/util:rect_transformation_calculator_cc_proto", + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:subgraph", + "//mediapipe/framework/api2:builder", + "//mediapipe/framework/formats:detection_cc_proto", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/framework/formats:tensor", + "//mediapipe/tasks/cc/components/processors:image_preprocessing_graph", + "//mediapipe/tasks/cc/core:model_task_graph", + "//mediapipe/tasks/cc/vision/pose_detector/proto:pose_detector_graph_options_cc_proto", + "@com_google_absl//absl/status:statusor", + ], + alwayslink = 1, +) diff --git a/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc new file mode 100644 index 000000000..2a888ca83 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph.cc @@ -0,0 +1,354 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/statusor.h" +#include "mediapipe/calculators/core/clip_vector_size_calculator.pb.h" +#include "mediapipe/calculators/tensor/image_to_tensor_calculator.pb.h" +#include "mediapipe/calculators/tensor/tensors_to_detections_calculator.pb.h" +#include "mediapipe/calculators/tflite/ssd_anchors_calculator.pb.h" +#include "mediapipe/calculators/util/detections_to_rects_calculator.pb.h" +#include "mediapipe/calculators/util/non_max_suppression_calculator.pb.h" +#include "mediapipe/calculators/util/rect_transformation_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/formats/tensor.h" +#include "mediapipe/framework/subgraph.h" +#include "mediapipe/tasks/cc/components/processors/image_preprocessing_graph.h" +#include "mediapipe/tasks/cc/core/model_task_graph.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_detector { + +using ::mediapipe::NormalizedRect; +using ::mediapipe::Tensor; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::tasks::vision::pose_detector::proto:: + PoseDetectorGraphOptions; + +namespace { +constexpr char kImageTag[] = "IMAGE"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kTensorsTag[] = "TENSORS"; +constexpr char kImageSizeTag[] = "IMAGE_SIZE"; +constexpr char kAnchorsTag[] = "ANCHORS"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kNormRectsTag[] = "NORM_RECTS"; +constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS"; +constexpr char kPoseRectsTag[] = "POSE_RECTS"; +constexpr char kExpandedPoseRectsTag[] = "EXPANDED_POSE_RECTS"; +constexpr char kMatrixTag[] = "MATRIX"; +constexpr char kProjectionMatrixTag[] = "PROJECTION_MATRIX"; + +struct PoseDetectionOuts { + Source> pose_detections; + Source> pose_rects; + Source> expanded_pose_rects; + Source image; +}; + +// TODO: Configuration detection related calculators in pose +// detector with model metadata. +void ConfigureSsdAnchorsCalculator( + mediapipe::SsdAnchorsCalculatorOptions* options) { + // Derived from + // mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt + options->set_num_layers(5); + options->set_min_scale(0.1484375); + options->set_max_scale(0.75); + options->set_input_size_height(224); + options->set_input_size_width(224); + options->set_anchor_offset_x(0.5); + options->set_anchor_offset_y(0.5); + options->add_strides(8); + options->add_strides(16); + options->add_strides(32); + options->add_strides(32); + options->add_strides(32); + options->add_aspect_ratios(1.0); + options->set_fixed_anchor_size(true); +} + +// TODO: Configuration detection related calculators in pose +// detector with model metadata. +void ConfigureTensorsToDetectionsCalculator( + const PoseDetectorGraphOptions& tasks_options, + mediapipe::TensorsToDetectionsCalculatorOptions* options) { + // Derived from + // mediapipe/modules/pose_detection/pose_detection_gpu.pbtxt + options->set_num_classes(1); + options->set_num_boxes(2254); + options->set_num_coords(12); + options->set_box_coord_offset(0); + options->set_keypoint_coord_offset(4); + options->set_num_keypoints(4); + options->set_num_values_per_keypoint(2); + options->set_sigmoid_score(true); + options->set_score_clipping_thresh(100.0); + options->set_reverse_output_order(true); + options->set_min_score_thresh(tasks_options.min_detection_confidence()); + options->set_x_scale(224.0); + options->set_y_scale(224.0); + options->set_w_scale(224.0); + options->set_h_scale(224.0); +} + +void ConfigureNonMaxSuppressionCalculator( + const PoseDetectorGraphOptions& tasks_options, + mediapipe::NonMaxSuppressionCalculatorOptions* options) { + options->set_min_suppression_threshold( + tasks_options.min_suppression_threshold()); + options->set_overlap_type( + mediapipe::NonMaxSuppressionCalculatorOptions::INTERSECTION_OVER_UNION); + options->set_algorithm( + mediapipe::NonMaxSuppressionCalculatorOptions::WEIGHTED); +} + +// TODO: Configuration detection related calculators in pose +// detector with model metadata. +void ConfigureDetectionsToRectsCalculator( + mediapipe::DetectionsToRectsCalculatorOptions* options) { + options->set_rotation_vector_start_keypoint_index(0); + options->set_rotation_vector_end_keypoint_index(2); + options->set_rotation_vector_target_angle(90); + options->set_output_zero_rect_for_empty_detections(true); +} + +// TODO: Configuration detection related calculators in pose +// detector with model metadata. +void ConfigureRectTransformationCalculator( + mediapipe::RectTransformationCalculatorOptions* options) { + options->set_scale_x(2.6); + options->set_scale_y(2.6); + options->set_shift_y(-0.5); + options->set_square_long(true); +} + +} // namespace + +// A "mediapipe.tasks.vision.pose_detector.PoseDetectorGraph" performs pose +// detection. +// +// Inputs: +// IMAGE - Image +// Image to perform detection on. +// NORM_RECT - NormalizedRect @Optional +// Describes image rotation and region of image to perform detection on. If +// not provided, whole image is used for pose detection. +// +// Outputs: +// DETECTIONS - std::vector +// Detected pose with maximum `num_poses` specified in options. +// POSE_RECTS - std::vector +// Detected pose bounding boxes in normalized coordinates. +// EXPANDED_POSE_RECTS - std::vector +// Expanded pose bounding boxes in normalized coordinates so that bounding +// boxes likely contain the whole pose. This is usually used as RoI for pose +// landmarks detection to run on. +// IMAGE - Image +// The input image that the pose detector runs on and has the pixel data +// stored on the target storage (CPU vs GPU). +// All returned coordinates are in the unrotated and uncropped input image +// coordinates system. +// +// Example: +// node { +// calculator: "mediapipe.tasks.vision.pose_detector.PoseDetectorGraph" +// input_stream: "IMAGE:image" +// input_stream: "NORM_RECT:norm_rect" +// output_stream: "DETECTIONS:palm_detections" +// output_stream: "POSE_RECTS:pose_rects" +// output_stream: "EXPANDED_POSE_RECTS:expanded_pose_rects" +// output_stream: "IMAGE:image_out" +// options { +// [mediapipe.tasks.vision.pose_detector.proto.PoseDetectorGraphOptions.ext] +// { +// base_options { +// model_asset { +// file_name: "pose_detection.tflite" +// } +// } +// min_detection_confidence: 0.5 +// num_poses: 2 +// } +// } +// } +class PoseDetectorGraph : public core::ModelTaskGraph { + public: + absl::StatusOr GetConfig( + SubgraphContext* sc) override { + ASSIGN_OR_RETURN(const auto* model_resources, + CreateModelResources(sc)); + Graph graph; + ASSIGN_OR_RETURN(auto outs, + BuildPoseDetectionSubgraph( + sc->Options(), + *model_resources, graph[Input(kImageTag)], + graph[Input(kNormRectTag)], graph)); + + outs.pose_detections >> + graph.Out(kDetectionsTag).Cast>(); + outs.pose_rects >> + graph.Out(kPoseRectsTag).Cast>(); + outs.expanded_pose_rects >> + graph.Out(kExpandedPoseRectsTag).Cast>(); + outs.image >> graph.Out(kImageTag).Cast(); + + return graph.GetConfig(); + } + + private: + absl::StatusOr BuildPoseDetectionSubgraph( + const PoseDetectorGraphOptions& subgraph_options, + const core::ModelResources& model_resources, Source image_in, + Source norm_rect_in, Graph& graph) { + // Image preprocessing subgraph to convert image to tensor for the tflite + // model. + auto& preprocessing = graph.AddNode( + "mediapipe.tasks.components.processors.ImagePreprocessingGraph"); + bool use_gpu = + components::processors::DetermineImagePreprocessingGpuBackend( + subgraph_options.base_options().acceleration()); + MP_RETURN_IF_ERROR(components::processors::ConfigureImagePreprocessingGraph( + model_resources, use_gpu, + &preprocessing.GetOptions< + components::processors::proto::ImagePreprocessingGraphOptions>())); + auto& image_to_tensor_options = + *preprocessing + .GetOptions() + .mutable_image_to_tensor_options(); + image_to_tensor_options.set_keep_aspect_ratio(true); + image_to_tensor_options.set_border_mode( + mediapipe::ImageToTensorCalculatorOptions::BORDER_ZERO); + image_in >> preprocessing.In(kImageTag); + norm_rect_in >> preprocessing.In(kNormRectTag); + auto preprocessed_tensors = preprocessing.Out(kTensorsTag); + auto matrix = preprocessing.Out(kMatrixTag); + auto image_size = preprocessing.Out(kImageSizeTag); + + // Pose detection model inferece. + auto& inference = AddInference( + model_resources, subgraph_options.base_options().acceleration(), graph); + preprocessed_tensors >> inference.In(kTensorsTag); + auto model_output_tensors = + inference.Out(kTensorsTag).Cast>(); + + // Generates a single side packet containing a vector of SSD anchors. + auto& ssd_anchor = graph.AddNode("SsdAnchorsCalculator"); + ConfigureSsdAnchorsCalculator( + &ssd_anchor.GetOptions()); + auto anchors = ssd_anchor.SideOut(""); + + // Converts output tensors to Detections. + auto& tensors_to_detections = + graph.AddNode("TensorsToDetectionsCalculator"); + ConfigureTensorsToDetectionsCalculator( + subgraph_options, + &tensors_to_detections + .GetOptions()); + model_output_tensors >> tensors_to_detections.In(kTensorsTag); + anchors >> tensors_to_detections.SideIn(kAnchorsTag); + auto detections = tensors_to_detections.Out(kDetectionsTag); + + // Non maximum suppression removes redundant face detections. + auto& non_maximum_suppression = + graph.AddNode("NonMaxSuppressionCalculator"); + ConfigureNonMaxSuppressionCalculator( + subgraph_options, + &non_maximum_suppression + .GetOptions()); + detections >> non_maximum_suppression.In(""); + auto nms_detections = non_maximum_suppression.Out(""); + + // Projects detections back into the input image coordinates system. + auto& detection_projection = graph.AddNode("DetectionProjectionCalculator"); + nms_detections >> detection_projection.In(kDetectionsTag); + matrix >> detection_projection.In(kProjectionMatrixTag); + Source> pose_detections = + detection_projection.Out(kDetectionsTag).Cast>(); + + if (subgraph_options.has_num_poses()) { + // Clip face detections to maximum number of poses. + auto& clip_detection_vector_size = + graph.AddNode("ClipDetectionVectorSizeCalculator"); + clip_detection_vector_size + .GetOptions() + .set_max_vec_size(subgraph_options.num_poses()); + pose_detections >> clip_detection_vector_size.In(""); + pose_detections = + clip_detection_vector_size.Out("").Cast>(); + } + + // Converts results of pose detection into a rectangle (normalized by image + // size) that encloses the face and is rotated such that the line connecting + // left eye and right eye is aligned with the X-axis of the rectangle. + auto& detections_to_rects = graph.AddNode("DetectionsToRectsCalculator"); + ConfigureDetectionsToRectsCalculator( + &detections_to_rects + .GetOptions()); + image_size >> detections_to_rects.In(kImageSizeTag); + pose_detections >> detections_to_rects.In(kDetectionsTag); + auto pose_rects = detections_to_rects.Out(kNormRectsTag) + .Cast>(); + + // Expands and shifts the rectangle that contains the pose so that it's + // likely to cover the entire pose. + auto& rect_transformation = graph.AddNode("RectTransformationCalculator"); + ConfigureRectTransformationCalculator( + &rect_transformation + .GetOptions()); + pose_rects >> rect_transformation.In(kNormRectsTag); + image_size >> rect_transformation.In(kImageSizeTag); + auto expanded_pose_rects = + rect_transformation.Out("").Cast>(); + + // Calculator to convert relative detection bounding boxes to pixel + // detection bounding boxes. + auto& detection_transformation = + graph.AddNode("DetectionTransformationCalculator"); + detection_projection.Out(kDetectionsTag) >> + detection_transformation.In(kDetectionsTag); + preprocessing.Out(kImageSizeTag) >> + detection_transformation.In(kImageSizeTag); + auto pose_pixel_detections = + detection_transformation.Out(kPixelDetectionsTag) + .Cast>(); + + return PoseDetectionOuts{ + /* pose_detections= */ pose_pixel_detections, + /* pose_rects= */ pose_rects, + /* expanded_pose_rects= */ expanded_pose_rects, + /* image= */ preprocessing.Out(kImageTag).Cast()}; + } +}; + +REGISTER_MEDIAPIPE_GRAPH( + ::mediapipe::tasks::vision::pose_detector::PoseDetectorGraph); + +} // namespace pose_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph_test.cc b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph_test.cc new file mode 100644 index 000000000..5706bd9d7 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_detector/pose_detector_graph_test.cc @@ -0,0 +1,165 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/flags/flag.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mediapipe/calculators/tensor/inference_calculator.pb.h" +#include "mediapipe/framework/api2/builder.h" +#include "mediapipe/framework/api2/port.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/deps/file_path.h" +#include "mediapipe/framework/formats/detection.pb.h" +#include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" +#include "mediapipe/framework/port/file_helpers.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/tasks/cc/core/mediapipe_builtin_op_resolver.h" +#include "mediapipe/tasks/cc/core/proto/base_options.pb.h" +#include "mediapipe/tasks/cc/core/proto/external_file.pb.h" +#include "mediapipe/tasks/cc/core/task_runner.h" +#include "mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.pb.h" +#include "mediapipe/tasks/cc/vision/utils/image_utils.h" + +namespace mediapipe { +namespace tasks { +namespace vision { +namespace pose_detector { +namespace { + +using ::file::Defaults; +using ::file::GetTextProto; +using ::mediapipe::NormalizedRect; +using ::mediapipe::api2::Input; +using ::mediapipe::api2::Output; +using ::mediapipe::api2::builder::Graph; +using ::mediapipe::api2::builder::Source; +using ::mediapipe::file::JoinPath; +using ::mediapipe::tasks::core::TaskRunner; +using ::mediapipe::tasks::vision::pose_detector::proto:: + PoseDetectorGraphOptions; +using ::testing::EqualsProto; +using ::testing::Pointwise; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::Values; +using ::testing::proto::Approximately; +using ::testing::proto::Partially; + +constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; +constexpr char kPoseDetectionModel[] = "pose_detection.tflite"; +constexpr char kPortraitImage[] = "pose.jpg"; +constexpr char kPoseExpectedDetection[] = "pose_expected_detection.pbtxt"; + +constexpr char kImageTag[] = "IMAGE"; +constexpr char kImageName[] = "image"; +constexpr char kNormRectTag[] = "NORM_RECT"; +constexpr char kNormRectName[] = "norm_rect"; +constexpr char kDetectionsTag[] = "DETECTIONS"; +constexpr char kDetectionsName[] = "detections"; + +constexpr float kPoseDetectionMaxDiff = 0.01; + +// Helper function to create a TaskRunner. +absl::StatusOr> CreateTaskRunner( + absl::string_view model_name) { + Graph graph; + + auto& pose_detector_graph = + graph.AddNode("mediapipe.tasks.vision.pose_detector.PoseDetectorGraph"); + + auto options = std::make_unique(); + options->mutable_base_options()->mutable_model_asset()->set_file_name( + JoinPath("./", kTestDataDirectory, model_name)); + options->set_min_detection_confidence(0.6); + options->set_min_suppression_threshold(0.3); + pose_detector_graph.GetOptions().Swap( + options.get()); + + graph[Input(kImageTag)].SetName(kImageName) >> + pose_detector_graph.In(kImageTag); + graph[Input(kNormRectTag)].SetName(kNormRectName) >> + pose_detector_graph.In(kNormRectTag); + + pose_detector_graph.Out(kDetectionsTag).SetName(kDetectionsName) >> + graph[Output>(kDetectionsTag)]; + + return TaskRunner::Create( + graph.GetConfig(), std::make_unique()); +} + +Detection GetExpectedPoseDetectionResult(absl::string_view file_name) { + Detection detection; + CHECK_OK(GetTextProto(file::JoinPath("./", kTestDataDirectory, file_name), + &detection, Defaults())) + << "Expected pose detection result does not exist."; + return detection; +} + +struct TestParams { + // The name of this test, for convenience when displaying test results. + std::string test_name; + // The filename of pose landmark detection model. + std::string pose_detection_model_name; + // The filename of test image. + std::string test_image_name; + // Expected pose detection results. + std::vector expected_result; +}; + +class PoseDetectorGraphTest : public testing::TestWithParam {}; + +TEST_P(PoseDetectorGraphTest, Succeed) { + MP_ASSERT_OK_AND_ASSIGN( + Image image, DecodeImageFromFile(JoinPath("./", kTestDataDirectory, + GetParam().test_image_name))); + NormalizedRect input_norm_rect; + input_norm_rect.set_x_center(0.5); + input_norm_rect.set_y_center(0.5); + input_norm_rect.set_width(1.0); + input_norm_rect.set_height(1.0); + MP_ASSERT_OK_AND_ASSIGN( + auto task_runner, CreateTaskRunner(GetParam().pose_detection_model_name)); + auto output_packets = task_runner->Process( + {{kImageName, MakePacket(std::move(image))}, + {kNormRectName, + MakePacket(std::move(input_norm_rect))}}); + MP_ASSERT_OK(output_packets); + const std::vector& pose_detections = + (*output_packets)[kDetectionsName].Get>(); + EXPECT_THAT(pose_detections, Pointwise(Approximately(Partially(EqualsProto()), + kPoseDetectionMaxDiff), + GetParam().expected_result)); +} + +INSTANTIATE_TEST_SUITE_P( + PoseDetectorGraphTest, PoseDetectorGraphTest, + Values(TestParams{.test_name = "DetectPose", + .pose_detection_model_name = kPoseDetectionModel, + .test_image_name = kPortraitImage, + .expected_result = {GetExpectedPoseDetectionResult( + kPoseExpectedDetection)}}), + [](const TestParamInfo& info) { + return info.param.test_name; + }); + +} // namespace +} // namespace pose_detector +} // namespace vision +} // namespace tasks +} // namespace mediapipe diff --git a/mediapipe/tasks/cc/vision/pose_detector/proto/BUILD b/mediapipe/tasks/cc/vision/pose_detector/proto/BUILD new file mode 100644 index 000000000..287ed0183 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_detector/proto/BUILD @@ -0,0 +1,31 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") + +package(default_visibility = [ + "//mediapipe/tasks:internal", +]) + +licenses(["notice"]) + +mediapipe_proto_library( + name = "pose_detector_graph_options_proto", + srcs = ["pose_detector_graph_options.proto"], + deps = [ + "//mediapipe/framework:calculator_options_proto", + "//mediapipe/framework:calculator_proto", + "//mediapipe/tasks/cc/core/proto:base_options_proto", + ], +) diff --git a/mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto b/mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto new file mode 100644 index 000000000..693f95262 --- /dev/null +++ b/mediapipe/tasks/cc/vision/pose_detector/proto/pose_detector_graph_options.proto @@ -0,0 +1,45 @@ +/* Copyright 2023 The MediaPipe Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mediapipe.tasks.vision.pose_detector.proto; + +import "mediapipe/framework/calculator.proto"; +import "mediapipe/framework/calculator_options.proto"; +import "mediapipe/tasks/cc/core/proto/base_options.proto"; + +option java_package = "com.google.mediapipe.tasks.vision.posedetector.proto"; +option java_outer_classname = "PoseDetectorGraphOptionsProto"; + +message PoseDetectorGraphOptions { + extend mediapipe.CalculatorOptions { + optional PoseDetectorGraphOptions ext = 514774813; + } + // Base options for configuring Task library, such as specifying the TfLite + // model file with metadata, accelerator options, etc. + optional core.proto.BaseOptions base_options = 1; + + // Minimum confidence value ([0.0, 1.0]) for confidence score to be considered + // successfully detecting a pose in the image. + optional float min_detection_confidence = 2 [default = 0.5]; + + // IoU threshold ([0,0, 1.0]) for non-maximu-suppression to be considered + // duplicate detections. + optional float min_suppression_threshold = 3 [default = 0.5]; + + // Maximum number of poses to detect in the image. + optional int32 num_poses = 4; +} diff --git a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h index 69c28b916..36d90f223 100644 --- a/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h +++ b/mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h @@ -70,7 +70,7 @@ extern NSString *const MPPTasksErrorDomain; * @param error Pointer to the memory location where errors if any should be saved. If `nil`, no * error will be saved. * - * @return Pointer to the allocated block of memory on successfull allocation. `nil` in case as + * @return Pointer to the allocated block of memory on successful allocation. `nil` in case as * error is encountered because of invalid `memSize`. If failure is due to any other reason, method * terminates program execution. */ diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h index b94e704d1..15f2ee0c1 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskInfo.h @@ -22,7 +22,7 @@ NS_ASSUME_NONNULL_BEGIN /** - * Holds all needed informaton to initialize a MediaPipe Task. + * Holds all needed information to initialize a MediaPipe Task. */ @interface MPPTaskInfo : NSObject diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h index 704fc453f..41515571a 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.h @@ -17,6 +17,8 @@ #include "mediapipe/framework/calculator.pb.h" #include "mediapipe/tasks/cc/core/task_runner.h" +#include + NS_ASSUME_NONNULL_BEGIN /** @@ -30,7 +32,7 @@ NS_ASSUME_NONNULL_BEGIN * additional functionality. For eg:, vision tasks must create an `MPPVisionTaskRunner` and provide * additional functionality. An instance of `MPPVisionTaskRunner` can in turn be used by the each * vision task for creation and execution of the task. Please see the documentation for the C++ Task - * Runner for more details on how the taks runner operates. + * Runner for more details on how the tasks runner operates. */ @interface MPPTaskRunner : NSObject @@ -62,24 +64,57 @@ NS_ASSUME_NONNULL_BEGIN error:(NSError **)error NS_DESIGNATED_INITIALIZER; /** - * A synchronous method for processing batch data or offline streaming data. This method is designed - * for processing either batch data such as unrelated images and texts or offline streaming data - * such as the decoded frames from a video file or audio file. The call blocks the current - * thread until a failure status or a successful result is returned. If the input packets have no - * timestamp, an internal timestamp will be assigend per invocation. Otherwise, when the timestamp - * is set in the input packets, the caller must ensure that the input packet timestamps are greater - * than the timestamps of the previous invocation. This method is thread-unsafe and it is the - * caller's responsibility to synchronize access to this method across multiple threads and to - * ensure that the input packet timestamps are in order. + * A synchronous method for invoking the C++ task runner for processing batch data or offline + * streaming data. This method is designed for processing either batch data such as unrelated images + * and texts or offline streaming data such as the decoded frames from a video file or audio file. + * The call blocks the current thread until a failure status or a successful result is returned. If + * the input packets have no timestamp, an internal timestamp will be assigned per invocation. + * Otherwise, when the timestamp is set in the input packets, the caller must ensure that the input + * packet timestamps are greater than the timestamps of the previous invocation. This method is + * thread-unsafe and it is the caller's responsibility to synchronize access to this method across + * multiple threads and to ensure that the input packet timestamps are in order. + * + * @param packetMap A `PacketMap` containing pairs of input stream name and data packet which are to + * be sent to the C++ task runner for processing synchronously. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return An optional output `PacketMap` containing pairs of output stream name and data packet + * which holds the results of processing the input packet map, if there are no errors. */ -- (absl::StatusOr)process: - (const mediapipe::tasks::core::PacketMap &)packetMap; +- (std::optional) + processPacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap + error:(NSError **)error; + +/** + * An asynchronous method that is designed for handling live streaming data such as live camera. A + * user-defined PacketsCallback function must be provided in the constructor to receive the output + * packets. The caller must ensure that the input packet timestamps are monotonically increasing. + * This method is thread-unsafe and it is the caller's responsibility to synchronize access to this + * method across multiple threads and to ensure that the input packet timestamps are in order. + * + * @param packetMap A `PacketMap` containing pairs of input stream name and data packet that are to + * be sent to the C++ task runner for processing asynchronously. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return A `BOOL` indicating if the live stream data was sent to the C++ task runner successfully. + * Please note that any errors during processing of the live stream packet map will only be + * available in the user-defined `packetsCallback` that was provided during initialization of the + * `MPPVisionTaskRunner`. + */ +- (BOOL)sendPacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap error:(NSError **)error; /** * Shuts down the C++ task runner. After the runner is closed, any calls that send input data to the * runner are illegal and will receive errors. + * + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return A `BOOL` indicating if the C++ task runner was shutdown successfully. */ -- (absl::Status)close; +- (BOOL)closeWithError:(NSError **)error; - (instancetype)init NS_UNAVAILABLE; diff --git a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm index eb777679a..0813760c2 100644 --- a/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm +++ b/mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm @@ -50,12 +50,22 @@ using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner; return self; } -- (absl::StatusOr)process:(const PacketMap &)packetMap { - return _cppTaskRunner->Process(packetMap); +- (std::optional)processPacketMap:(const PacketMap &)packetMap error:(NSError **)error { + absl::StatusOr resultPacketMap = _cppTaskRunner->Process(packetMap); + if (![MPPCommonUtils checkCppError:resultPacketMap.status() toError:error]) { + return std::nullopt; + } + return resultPacketMap.value(); } -- (absl::Status)close { - return _cppTaskRunner->Close(); +- (BOOL)sendPacketMap:(const PacketMap &)packetMap error:(NSError **)error { + absl::Status sendStatus = _cppTaskRunner->Send(packetMap); + return [MPPCommonUtils checkCppError:sendStatus toError:error]; +} + +- (BOOL)closeWithError:(NSError **)error { + absl::Status closeStatus = _cppTaskRunner->Close(); + return [MPPCommonUtils checkCppError:closeStatus toError:error]; } @end diff --git a/mediapipe/tasks/ios/text/text_classifier/BUILD b/mediapipe/tasks/ios/text/text_classifier/BUILD index c56d51e5f..7913340ac 100644 --- a/mediapipe/tasks/ios/text/text_classifier/BUILD +++ b/mediapipe/tasks/ios/text/text_classifier/BUILD @@ -58,6 +58,5 @@ objc_library( "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierOptionsHelpers", "//mediapipe/tasks/ios/text/text_classifier/utils:MPPTextClassifierResultHelpers", - "@com_google_absl//absl/status:statusor", ], ) diff --git a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm index 52e4d92ac..f0e1e4152 100644 --- a/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm +++ b/mediapipe/tasks/ios/text/text_classifier/sources/MPPTextClassifier.mm @@ -22,7 +22,6 @@ #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierOptions+Helpers.h" #import "mediapipe/tasks/ios/text/text_classifier/utils/sources/MPPTextClassifierResult+Helpers.h" -#include "absl/status/statusor.h" #include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" namespace { @@ -83,15 +82,16 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_classifier.T Packet packet = [MPPTextPacketCreator createWithText:text]; std::map packetMap = {{kTextInStreamName.cppString, packet}}; - absl::StatusOr statusOrOutputPacketMap = [_textTaskRunner process:packetMap]; + std::optional outputPacketMap = [_textTaskRunner processPacketMap:packetMap + error:error]; - if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { + if (!outputPacketMap.has_value()) { return nil; } - return [MPPTextClassifierResult - textClassifierResultWithClassificationsPacket:statusOrOutputPacketMap.value() - [kClassificationsStreamName.cppString]]; + return + [MPPTextClassifierResult textClassifierResultWithClassificationsPacket: + outputPacketMap.value()[kClassificationsStreamName.cppString]]; } @end diff --git a/mediapipe/tasks/ios/text/text_embedder/BUILD b/mediapipe/tasks/ios/text/text_embedder/BUILD index 74aefdf77..a600d5366 100644 --- a/mediapipe/tasks/ios/text/text_embedder/BUILD +++ b/mediapipe/tasks/ios/text/text_embedder/BUILD @@ -58,6 +58,5 @@ objc_library( "//mediapipe/tasks/ios/text/core:MPPTextTaskRunner", "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderOptionsHelpers", "//mediapipe/tasks/ios/text/text_embedder/utils:MPPTextEmbedderResultHelpers", - "@com_google_absl//absl/status:statusor", ], ) diff --git a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm index 62eb882d3..e0f0d549d 100644 --- a/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm +++ b/mediapipe/tasks/ios/text/text_embedder/sources/MPPTextEmbedder.mm @@ -23,8 +23,6 @@ #import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderOptions+Helpers.h" #import "mediapipe/tasks/ios/text/text_embedder/utils/sources/MPPTextEmbedderResult+Helpers.h" -#include "absl/status/statusor.h" - namespace { using ::mediapipe::Packet; using ::mediapipe::tasks::core::PacketMap; @@ -83,14 +81,15 @@ static NSString *const kTaskGraphName = @"mediapipe.tasks.text.text_embedder.Tex Packet packet = [MPPTextPacketCreator createWithText:text]; std::map packetMap = {{kTextInStreamName.cppString, packet}}; - absl::StatusOr statusOrOutputPacketMap = [_textTaskRunner process:packetMap]; - if (![MPPCommonUtils checkCppError:statusOrOutputPacketMap.status() toError:error]) { + std::optional outputPacketMap = [_textTaskRunner processPacketMap:packetMap + error:error]; + + if (!outputPacketMap.has_value()) { return nil; } - return [MPPTextEmbedderResult - textEmbedderResultWithOutputPacket:statusOrOutputPacketMap + textEmbedderResultWithOutputPacket:outputPacketMap .value()[kEmbeddingsOutStreamName.cppString]]; } diff --git a/mediapipe/tasks/ios/vision/core/BUILD b/mediapipe/tasks/ios/vision/core/BUILD index 1961ca6b0..a8164d674 100644 --- a/mediapipe/tasks/ios/vision/core/BUILD +++ b/mediapipe/tasks/ios/vision/core/BUILD @@ -26,6 +26,24 @@ objc_library( module_name = "MPPRunningMode", ) +objc_library( + name = "MPPVisionPacketCreator", + srcs = ["sources/MPPVisionPacketCreator.mm"], + hdrs = ["sources/MPPVisionPacketCreator.h"], + copts = [ + "-ObjC++", + "-std=c++17", + ], + deps = [ + ":MPPImage", + "//mediapipe/framework:packet", + "//mediapipe/framework:timestamp", + "//mediapipe/framework/formats:image", + "//mediapipe/framework/formats:rect_cc_proto", + "//mediapipe/tasks/ios/vision/core/utils:MPPImageUtils", + ], +) + objc_library( name = "MPPVisionTaskRunner", srcs = ["sources/MPPVisionTaskRunner.mm"], @@ -36,8 +54,11 @@ objc_library( ], deps = [ ":MPPRunningMode", + "//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/tasks/ios/common:MPPCommon", "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", "//mediapipe/tasks/ios/core:MPPTaskRunner", + "//third_party/apple_frameworks:UIKit", + "@com_google_absl//absl/status:statusor", ], ) diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h b/mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h index 5cc57b88a..ab76546df 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h +++ b/mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h @@ -38,4 +38,17 @@ typedef NS_ENUM(NSUInteger, MPPRunningMode) { } NS_SWIFT_NAME(RunningMode); +NS_INLINE NSString *MPPRunningModeDisplayName(MPPRunningMode runningMode) { + if (runningMode > MPPRunningModeLiveStream) { + return nil; + } + + NSString *displayNameMap[MPPRunningModeLiveStream + 1] = { + [MPPRunningModeImage] = @"#MPPRunningModeImage", + [MPPRunningModeVideo] = @ "#MPPRunningModeVideo", + [MPPRunningModeLiveStream] = @ "#MPPRunningModeLiveStream"}; + + return displayNameMap[runningMode]; +} + NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h index cf597ec24..eaf059ad2 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h @@ -14,14 +14,63 @@ #import -#include "mediapipe/framework/packet.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h" +#include "mediapipe/framework/formats/rect.pb.h" +#include "mediapipe/framework/packet.h" + /** * This class helps create various kinds of packets for Mediapipe Vision Tasks. */ @interface MPPVisionPacketCreator : NSObject +/** + * Creates a MediapPipe Packet wrapping an `MPPImage` that can be send to a graph. + * + * @param image The image to send to the MediaPipe graph. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return The MediaPipe packet containing the image. An empty packet is returned if an error + * occurred during the conversion. + */ + (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image error:(NSError **)error; +/** + * Creates a MediapPipe Packet wrapping an `MPPImage` that can be send to a graph at the specified + * timestamp. + * + * @param image The image to send to the MediaPipe graph. + * @param timestampMs The timestamp (in milliseconds) to assign to the packet. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return The MediaPipe packet containing the image. An empty packet is returned if an error + * occurred during the conversion. + */ ++ (mediapipe::Packet)createPacketWithMPPImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error; + +/** + * Creates a MediapPipe Packet wrapping a `NormalizedRect` that can be send to a graph. + * + * @param image The `NormalizedRect` to send to the MediaPipe graph. + * + * @return The MediaPipe packet containing the normalized rect. + */ ++ (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect; + +/** + * Creates a MediapPipe Packet wrapping a `NormalizedRect` that can be send to a graph at the + * specified timestamp. + * + * @param image The `NormalizedRect` to send to the MediaPipe graph. + * @param timestampMs The timestamp (in milliseconds) to assign to the packet. + * + * @return The MediaPipe packet containing the normalized rect. + */ ++ (mediapipe::Packet)createPacketWithNormalizedRect:(mediapipe::NormalizedRect &)normalizedRect + timestampMs:(NSInteger)timestampMs; + @end diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm index 01e583e62..bf136a759 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.mm @@ -16,18 +16,19 @@ #import "mediapipe/tasks/ios/vision/core/utils/sources/MPPImage+Utils.h" #include "mediapipe/framework/formats/image.h" +#include "mediapipe/framework/timestamp.h" + +static const NSUInteger kMicroSecondsPerMilliSecond = 1000; namespace { using ::mediapipe::Image; using ::mediapipe::ImageFrame; using ::mediapipe::MakePacket; +using ::mediapipe::NormalizedRect; using ::mediapipe::Packet; +using ::mediapipe::Timestamp; } // namespace -struct freeDeleter { - void operator()(void *ptr) { free(ptr); } -}; - @implementation MPPVisionPacketCreator + (Packet)createPacketWithMPPImage:(MPPImage *)image error:(NSError **)error { @@ -40,4 +41,27 @@ struct freeDeleter { return MakePacket(std::move(imageFrame)); } ++ (Packet)createPacketWithMPPImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error { + std::unique_ptr imageFrame = [image imageFrameWithError:error]; + + if (!imageFrame) { + return Packet(); + } + + return MakePacket(std::move(imageFrame)) + .At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond))); +} + ++ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect { + return MakePacket(std::move(normalizedRect)); +} + ++ (Packet)createPacketWithNormalizedRect:(NormalizedRect &)normalizedRect + timestampMs:(NSInteger)timestampMs { + return MakePacket(std::move(normalizedRect)) + .At(Timestamp(int64(timestampMs * kMicroSecondsPerMilliSecond))); +} + @end diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h b/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h index 84b657305..f19e4ca75 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h @@ -13,10 +13,13 @@ // limitations under the License. #import +#import #import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h" #import "mediapipe/tasks/ios/vision/core/sources/MPPRunningMode.h" +#include "mediapipe/framework/formats/rect.pb.h" + NS_ASSUME_NONNULL_BEGIN /** @@ -54,6 +57,82 @@ NS_ASSUME_NONNULL_BEGIN (mediapipe::tasks::core::PacketsCallback)packetsCallback error:(NSError **)error NS_DESIGNATED_INITIALIZER; +/** + * Creates a `NormalizedRect` from a region of interest and an image orientation, performing + * sanity checks on-the-fly. + * If the input region of interest equals `CGRectZero`, returns a default `NormalizedRect` covering + * the whole image with rotation set according `imageOrientation`. If `ROIAllowed` is NO, an error + * will be returned if the input region of interest is not equal to `CGRectZero`. Mirrored + * orientations (`UIImageOrientationUpMirrored`,`UIImageOrientationDownMirrored`, + * `UIImageOrientationLeftMirrored`,`UIImageOrientationRightMirrored`) are not supported. An error + * will be returned if `imageOrientation` is equal to any one of them. + * + * @param roi A `CGRect` specifying the region of interest. If the input region of interest equals + * `CGRectZero`, the returned `NormalizedRect` covers the whole image. Make sure that `roi` equals + * `CGRectZero` if `ROIAllowed` is NO. Otherwise, an error will be returned. + * @param imageOrientation A `UIImageOrientation` indicating the rotation to be applied to the + * image. The resulting `NormalizedRect` will convert the `imageOrientation` to degrees clockwise. + * Mirrored orientations (`UIImageOrientationUpMirrored`, `UIImageOrientationDownMirrored`, + * `UIImageOrientationLeftMirrored`, `UIImageOrientationRightMirrored`) are not supported. An error + * will be returned if `imageOrientation` is equal to any one of them. + * @param ROIAllowed Indicates if the `roi` field is allowed to be a value other than `CGRectZero`. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return An optional `NormalizedRect` from the given region of interest and image orientation. + */ +- (std::optional) + normalizedRectFromRegionOfInterest:(CGRect)roi + imageOrientation:(UIImageOrientation)imageOrientation + ROIAllowed:(BOOL)ROIAllowed + error:(NSError **)error; + +/** + * A synchronous method to invoke the C++ task runner to process single image inputs. The call + * blocks the current thread until a failure status or a successful result is returned. + * + * @param packetMap A `PackeMap` containing pairs of input stream name and data packet. + * @param error Pointer to the memory location where errors if any should be + * saved. If @c NULL, no error will be saved. + * + * @return An optional `PacketMap` containing pairs of output stream name and data packet. + */ +- (std::optional) + processImagePacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap + error:(NSError **)error; + +/** + * A synchronous method to invoke the C++ task runner to process continuous video frames. The call + * blocks the current thread until a failure status or a successful result is returned. + * + * @param packetMap A `PackeMap` containing pairs of input stream name and data packet. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return An optional `PacketMap` containing pairs of output stream name and data packet. + */ +- (std::optional) + processVideoFramePacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap + error:(NSError **)error; + +/** + * An asynchronous method to send live stream data to the C++ task runner. The call blocks the + * current thread until a failure status or a successful result is returned. The results will be + * available in the user-defined `packetsCallback` that was provided during initialization of the + * `MPPVisionTaskRunner`. + * + * @param packetMap A `PackeMap` containing pairs of input stream name and data packet. + * @param error Pointer to the memory location where errors if any should be saved. If @c NULL, no + * error will be saved. + * + * @return A `BOOL` indicating if the live stream data was sent to the C++ task runner successfully. + * Please note that any errors during processing of the live stream packet map will only be + * available in the user-defined `packetsCallback` that was provided during initialization of the + * `MPPVisionTaskRunner`. + */ +- (BOOL)processLiveStreamPacketMap:(const mediapipe::tasks::core::PacketMap &)packetMap + error:(NSError **)error; + - (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig packetsCallback: (mediapipe::tasks::core::PacketsCallback)packetsCallback diff --git a/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm b/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm index bfa9e34e5..492d29a8b 100644 --- a/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm +++ b/mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.mm @@ -17,11 +17,26 @@ #import "mediapipe/tasks/ios/common/sources/MPPCommon.h" #import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#include "absl/status/statusor.h" + +#include + namespace { using ::mediapipe::CalculatorGraphConfig; +using ::mediapipe::NormalizedRect; +using ::mediapipe::tasks::core::PacketMap; using ::mediapipe::tasks::core::PacketsCallback; } // namespace +/** Rotation degress for a 90 degree rotation to the right. */ +static const NSInteger kMPPOrientationDegreesRight = -90; + +/** Rotation degress for a 180 degree rotation. */ +static const NSInteger kMPPOrientationDegreesDown = -180; + +/** Rotation degress for a 90 degree rotation to the left. */ +static const NSInteger kMPPOrientationDegreesLeft = -270; + @interface MPPVisionTaskRunner () { MPPRunningMode _runningMode; } @@ -70,4 +85,100 @@ using ::mediapipe::tasks::core::PacketsCallback; return self; } +- (std::optional)normalizedRectFromRegionOfInterest:(CGRect)roi + imageOrientation: + (UIImageOrientation)imageOrientation + ROIAllowed:(BOOL)ROIAllowed + error:(NSError **)error { + if (CGRectEqualToRect(roi, CGRectZero) && !ROIAllowed) { + [MPPCommonUtils createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:@"This task doesn't support region-of-interest."]; + return std::nullopt; + } + + CGRect calculatedRoi = CGRectEqualToRect(roi, CGRectZero) ? roi : CGRectMake(0.0, 0.0, 1.0, 1.0); + + NormalizedRect normalizedRect; + normalizedRect.set_x_center(CGRectGetMidX(calculatedRoi)); + normalizedRect.set_y_center(CGRectGetMidY(calculatedRoi)); + normalizedRect.set_width(CGRectGetWidth(calculatedRoi)); + normalizedRect.set_height(CGRectGetHeight(calculatedRoi)); + + int rotationDegrees = 0; + switch (imageOrientation) { + case UIImageOrientationUp: + break; + case UIImageOrientationRight: { + rotationDegrees = kMPPOrientationDegreesRight; + break; + } + case UIImageOrientationDown: { + rotationDegrees = kMPPOrientationDegreesDown; + break; + } + case UIImageOrientationLeft: { + rotationDegrees = kMPPOrientationDegreesLeft; + break; + } + default: + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description: + @"Unsupported UIImageOrientation. `imageOrientation` cannot be equal to " + @"any of the mirrored orientations " + @"(`UIImageOrientationUpMirrored`,`UIImageOrientationDownMirrored`,`" + @"UIImageOrientationLeftMirrored`,`UIImageOrientationRightMirrored`)"]; + } + + normalizedRect.set_rotation(rotationDegrees * M_PI / kMPPOrientationDegreesDown); + + return normalizedRect; +} + +- (std::optional)processImagePacketMap:(const PacketMap &)packetMap + error:(NSError **)error { + if (_runningMode != MPPRunningModeImage) { + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:[NSString stringWithFormat:@"The vision task is not initialized with " + @"image mode. Current Running Mode: %@", + MPPRunningModeDisplayName(_runningMode)]]; + return std::nullopt; + } + + return [self processPacketMap:packetMap error:error]; +} + +- (std::optional)processVideoFramePacketMap:(const PacketMap &)packetMap + error:(NSError **)error { + if (_runningMode != MPPRunningModeVideo) { + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:[NSString stringWithFormat:@"The vision task is not initialized with " + @"video mode. Current Running Mode: %@", + MPPRunningModeDisplayName(_runningMode)]]; + return std::nullopt; + } + + return [self processPacketMap:packetMap error:error]; +} + +- (BOOL)processLiveStreamPacketMap:(const PacketMap &)packetMap error:(NSError **)error { + if (_runningMode != MPPRunningModeLiveStream) { + [MPPCommonUtils + createCustomError:error + withCode:MPPTasksErrorCodeInvalidArgumentError + description:[NSString stringWithFormat:@"The vision task is not initialized with " + @"live stream mode. Current Running Mode: %@", + MPPRunningModeDisplayName(_runningMode)]]; + return NO; + } + + return [self sendPacketMap:packetMap error:error]; +} + @end diff --git a/mediapipe/tasks/ios/vision/image_classifier/BUILD b/mediapipe/tasks/ios/vision/image_classifier/BUILD index 45e6e2156..4ebcd2b29 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/BUILD +++ b/mediapipe/tasks/ios/vision/image_classifier/BUILD @@ -36,3 +36,30 @@ objc_library( "//mediapipe/tasks/ios/vision/core:MPPRunningMode", ], ) + +objc_library( + name = "MPPImageClassifier", + srcs = ["sources/MPPImageClassifier.mm"], + hdrs = ["sources/MPPImageClassifier.h"], + copts = [ + "-ObjC++", + "-std=c++17", + "-x objective-c++", + ], + module_name = "MPPImageClassifier", + deps = [ + ":MPPImageClassifierOptions", + ":MPPImageClassifierResult", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", + "//mediapipe/tasks/ios/common/utils:MPPCommonUtils", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskInfo", + "//mediapipe/tasks/ios/core:MPPTaskOptions", + "//mediapipe/tasks/ios/vision/core:MPPImage", + "//mediapipe/tasks/ios/vision/core:MPPVisionPacketCreator", + "//mediapipe/tasks/ios/vision/core:MPPVisionTaskRunner", + "//mediapipe/tasks/ios/vision/image_classifier/utils:MPPImageClassifierOptionsHelpers", + "//mediapipe/tasks/ios/vision/image_classifier/utils:MPPImageClassifierResultHelpers", + ], +) diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h new file mode 100644 index 000000000..581c8d95b --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h @@ -0,0 +1,219 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPImage.h" +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h" +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h" + +NS_ASSUME_NONNULL_BEGIN + +/** + * @brief Performs classification on images. + * + * The API expects a TFLite model with optional, but strongly recommended, + * [TFLite Model Metadata.](https://www.tensorflow.org/lite/convert/metadata"). + * + * The API supports models with one image input tensor and one or more output tensors. To be more + * specific, here are the requirements. + * + * Input tensor + * (kTfLiteUInt8/kTfLiteFloat32) + * - image input of size `[batch x height x width x channels]`. + * - batch inference is not supported (`batch` is required to be 1). + * - only RGB inputs are supported (`channels` is required to be 3). + * - if type is kTfLiteFloat32, NormalizationOptions are required to be attached to the metadata + * for input normalization. + * + * At least one output tensor with: + * (kTfLiteUInt8/kTfLiteFloat32) + * - `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or `[1 x 1 x 1 x N]` + * - optional (but recommended) label map(s) as AssociatedFiles with type TENSOR_AXIS_LABELS, + * containing one label per line. The first such AssociatedFile (if any) is used to fill the + * `class_name` field of the results. The `display_name` field is filled from the AssociatedFile + * (if any) whose locale matches the `display_names_locale` field of the `ImageClassifierOptions` + * used at creation time ("en" by default, i.e. English). If none of these are available, only + * the `index` field of the results will be filled. + * - optional score calibration can be attached using ScoreCalibrationOptions and an AssociatedFile + * with type TENSOR_AXIS_SCORE_CALIBRATION. See metadata_schema.fbs [1] for more details. + */ +NS_SWIFT_NAME(ImageClassifier) +@interface MPPImageClassifier : NSObject + +/** + * Creates a new instance of `MPPImageClassifier` from an absolute path to a TensorFlow Lite model + * file stored locally on the device and the default `MPPImageClassifierOptions`. + * + * @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device. + * @param error An optional error parameter populated when there is an error in initializing the + * image classifier. + * + * @return A new instance of `MPPImageClassifier` with the given model path. `nil` if there is an + * error in initializing the image classifier. + */ +- (nullable instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error; + +/** + * Creates a new instance of `MPPImageClassifier` from the given `MPPImageClassifierOptions`. + * + * @param options The options of type `MPPImageClassifierOptions` to use for configuring the + * `MPPImageClassifier`. + * @param error An optional error parameter populated when there is an error in initializing the + * image classifier. + * + * @return A new instance of `MPPImageClassifier` with the given options. `nil` if there is an error + * in initializing the image classifier. + */ +- (nullable instancetype)initWithOptions:(MPPImageClassifierOptions *)options + error:(NSError **)error NS_DESIGNATED_INITIALIZER; + +/** + * Performs image classification on the provided MPPImage using the whole image as region of + * interest. Rotation will be applied according to the `orientation` property of the provided + * `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeImage`. + * + * @param image The `MPPImage` on which image classification is to be performed. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input image. + * + * @return An `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image + error:(NSError **)error + NS_SWIFT_NAME(classify(image:)); + +/** + * Performs image classification on the provided `MPPImage` cropped to the specified region of + * interest. Rotation will be applied on the cropped image according to the `orientation` property + * of the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeImage`. + * + * @param image The `MPPImage` on which image classification is to be performed. + * @param roi A `CGRect` specifying the region of interest within the given `MPPImage`, on which + * image classification should be performed. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input image. + * + * @return An `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(classify(image:regionOfInterest:)); + +/** + * Performs image classification on the provided video frame of type `MPPImage` using the whole + * image as region of interest. Rotation will be applied according to the `orientation` property of + * the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeVideo`. + * + * @param image The `MPPImage` on which image classification is to be performed. + * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be + * monotonically increasing. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input video frame. + * + * @return An `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error + NS_SWIFT_NAME(classify(videoFrame:timestampMs:)); + +/** + * Performs image classification on the provided video frame of type `MPPImage` cropped to the + * specified region of interest. Rotation will be applied according to the `orientation` property of + * the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeVideo`. + * + * It's required to provide the video frame's timestamp (in milliseconds). The input timestamps must + * be monotonically increasing. + * + * @param image A live stream image data of type `MPPImage` on which image classification is to be + * performed. + * @param timestampMs The video frame's timestamp (in milliseconds). The input timestamps must be + * monotonically increasing. + * @param roi A `CGRect` specifying the region of interest within the video frame of type + * `MPPImage`, on which image classification should be performed. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input video frame. + * + * @return An `MPPImageClassifierResult` object that contains a list of image classifications. + */ +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(classify(videoFrame:timestampMs:regionOfInterest:)); + +/** + * Sends live stream image data of type `MPPImage` to perform image classification using the whole + * image as region of interest. Rotation will be applied according to the `orientation` property of + * the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeLiveStream`. Results are provided asynchronously via the `completion` callback + * provided in the `MPPImageClassifierOptions`. + * + * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * + * @param image A live stream image data of type `MPPImage` on which image classification is to be + * performed. + * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input live stream image data. + * + * @return `YES` if the image was sent to the task successfully, otherwise `NO`. + */ +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error NS_SWIFT_NAME(classifyAsync(image:timestampMs:)); + +/** + * Sends live stream image data of type `MPPImage` to perform image classification, cropped to the + * specified region of interest.. Rotation will be applied according to the `orientation` property + * of the provided `MPPImage`. Only use this method when the `MPPImageClassifier` is created with + * `MPPRunningModeLiveStream`. Results are provided asynchronously via the `completion` callback + * provided in the `MPPImageClassifierOptions`. + * + * It's required to provide a timestamp (in milliseconds) to indicate when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * + * @param image A live stream image data of type `MPPImage` on which image classification is to be + * performed. + * @param timestampMs The timestamp (in milliseconds) which indicates when the input image is sent + * to the image classifier. The input timestamps must be monotonically increasing. + * @param roi A `CGRect` specifying the region of interest within the given live stream image data + * of type `MPPImage`, on which image classification should be performed. + * @param error An optional error parameter populated when there is an error in performing image + * classification on the input live stream image data. + * + * @return `YES` if the image was sent to the task successfully, otherwise `NO`. + */ +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error + NS_SWIFT_NAME(classifyAsync(image:timestampMs:regionOfInterest:)); + +- (instancetype)init NS_UNAVAILABLE; + ++ (instancetype)new NS_UNAVAILABLE; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm new file mode 100644 index 000000000..0ad79003f --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.mm @@ -0,0 +1,232 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifier.h" + +#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h" +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskInfo.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionPacketCreator.h" +#import "mediapipe/tasks/ios/vision/core/sources/MPPVisionTaskRunner.h" +#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h" +#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +namespace { +using ::mediapipe::NormalizedRect; +using ::mediapipe::Packet; +using ::mediapipe::tasks::core::PacketMap; +using ::mediapipe::tasks::core::PacketsCallback; +} // namespace + +static NSString *const kClassificationsStreamName = @"classifications_out"; +static NSString *const kClassificationsTag = @"CLASSIFICATIONS"; +static NSString *const kImageInStreamName = @"image_in"; +static NSString *const kImageOutStreamName = @"image_out"; +static NSString *const kImageTag = @"IMAGE"; +static NSString *const kNormRectName = @"norm_rect_in"; +static NSString *const kNormRectTag = @"NORM_RECT"; + +static NSString *const kTaskGraphName = + @"mediapipe.tasks.vision.image_classifier.ImageClassifierGraph"; + +#define InputPacketMap(imagePacket, normalizedRectPacket) \ + { \ + {kImageInStreamName.cppString, imagePacket}, { kNormRectName.cppString, normalizedRectPacket } \ + } + +@interface MPPImageClassifier () { + /** iOS Vision Task Runner */ + MPPVisionTaskRunner *_visionTaskRunner; +} +@end + +@implementation MPPImageClassifier + +- (instancetype)initWithOptions:(MPPImageClassifierOptions *)options error:(NSError **)error { + self = [super init]; + if (self) { + MPPTaskInfo *taskInfo = [[MPPTaskInfo alloc] + initWithTaskGraphName:kTaskGraphName + inputStreams:@[ [NSString + stringWithFormat:@"%@:%@", kImageTag, kImageInStreamName] ] + outputStreams:@[ [NSString stringWithFormat:@"%@:%@", kClassificationsTag, + kClassificationsStreamName] ] + taskOptions:options + enableFlowLimiting:NO + error:error]; + + if (!taskInfo) { + return nil; + } + + PacketsCallback packetsCallback = nullptr; + + if (options.completion) { + packetsCallback = [=](absl::StatusOr status_or_packets) { + NSError *callbackError = nil; + MPPImageClassifierResult *result; + if ([MPPCommonUtils checkCppError:status_or_packets.status() toError:&callbackError]) { + result = [MPPImageClassifierResult + imageClassifierResultWithClassificationsPacket: + status_or_packets.value()[kClassificationsStreamName.cppString]]; + } + options.completion(result, callbackError); + }; + } + + _visionTaskRunner = + [[MPPVisionTaskRunner alloc] initWithCalculatorGraphConfig:[taskInfo generateGraphConfig] + runningMode:options.runningMode + packetsCallback:std::move(packetsCallback) + error:error]; + + if (!_visionTaskRunner) { + return nil; + } + } + return self; +} + +- (instancetype)initWithModelPath:(NSString *)modelPath error:(NSError **)error { + MPPImageClassifierOptions *options = [[MPPImageClassifierOptions alloc] init]; + + options.baseOptions.modelAssetPath = modelPath; + + return [self initWithOptions:options error:error]; +} + +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional rect = + [_visionTaskRunner normalizedRectFromRegionOfInterest:roi + imageOrientation:image.orientation + ROIAllowed:YES + error:error]; + if (!rect.has_value()) { + return nil; + } + + Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image error:error]; + if (imagePacket.IsEmpty()) { + return nil; + } + + Packet normalizedRectPacket = + [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value()]; + + PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); + + std::optional outputPacketMap = [_visionTaskRunner processPacketMap:inputPacketMap + error:error]; + if (!outputPacketMap.has_value()) { + return nil; + } + + return + [MPPImageClassifierResult imageClassifierResultWithClassificationsPacket: + outputPacketMap.value()[kClassificationsStreamName.cppString]]; +} + +- (std::optional)inputPacketMapWithMPPImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional rect = + [_visionTaskRunner normalizedRectFromRegionOfInterest:roi + imageOrientation:image.orientation + ROIAllowed:YES + error:error]; + if (!rect.has_value()) { + return std::nullopt; + } + + Packet imagePacket = [MPPVisionPacketCreator createPacketWithMPPImage:image + timestampMs:timestampMs + error:error]; + if (imagePacket.IsEmpty()) { + return std::nullopt; + } + + Packet normalizedRectPacket = [MPPVisionPacketCreator createPacketWithNormalizedRect:rect.value() + timestampMs:timestampMs]; + + PacketMap inputPacketMap = InputPacketMap(imagePacket, normalizedRectPacket); + return inputPacketMap; +} + +- (nullable MPPImageClassifierResult *)classifyImage:(MPPImage *)image error:(NSError **)error { + return [self classifyImage:image regionOfInterest:CGRectZero error:error]; +} + +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image + timestampMs:timestampMs + regionOfInterest:roi + error:error]; + if (!inputPacketMap.has_value()) { + return nil; + } + + std::optional outputPacketMap = + [_visionTaskRunner processVideoFramePacketMap:inputPacketMap.value() error:error]; + + if (!outputPacketMap.has_value()) { + return nil; + } + + return + [MPPImageClassifierResult imageClassifierResultWithClassificationsPacket: + outputPacketMap.value()[kClassificationsStreamName.cppString]]; +} + +- (nullable MPPImageClassifierResult *)classifyVideoFrame:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error { + return [self classifyVideoFrame:image + timestampMs:timestampMs + regionOfInterest:CGRectZero + error:error]; +} + +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + regionOfInterest:(CGRect)roi + error:(NSError **)error { + std::optional inputPacketMap = [self inputPacketMapWithMPPImage:image + timestampMs:timestampMs + regionOfInterest:roi + error:error]; + if (!inputPacketMap.has_value()) { + return NO; + } + + return [_visionTaskRunner processLiveStreamPacketMap:inputPacketMap.value() error:error]; +} + +- (BOOL)classifyAsyncImage:(MPPImage *)image + timestampMs:(NSInteger)timestampMs + error:(NSError **)error { + return [self classifyAsyncImage:image + timestampMs:timestampMs + regionOfInterest:CGRectZero + error:error]; +} + +@end diff --git a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h index f7e9a6297..2e6022041 100644 --- a/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h +++ b/mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h @@ -31,6 +31,7 @@ NS_SWIFT_NAME(ImageClassifierOptions) /** * The user-defined result callback for processing live stream data. The result callback should only * be specified when the running mode is set to the live stream mode. + * TODO: Add parameter `MPPImage` in the callback. */ @property(nonatomic, copy) void (^completion)(MPPImageClassifierResult *result, NSError *error); diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/BUILD b/mediapipe/tasks/ios/vision/image_classifier/utils/BUILD new file mode 100644 index 000000000..c1928b6ff --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/BUILD @@ -0,0 +1,44 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +objc_library( + name = "MPPImageClassifierOptionsHelpers", + srcs = ["sources/MPPImageClassifierOptions+Helpers.mm"], + hdrs = ["sources/MPPImageClassifierOptions+Helpers.h"], + deps = [ + "//mediapipe/framework:calculator_options_cc_proto", + "//mediapipe/tasks/cc/components/processors/proto:classifier_options_cc_proto", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_cc_proto", + "//mediapipe/tasks/ios/common/utils:NSStringHelpers", + "//mediapipe/tasks/ios/core:MPPTaskOptionsProtocol", + "//mediapipe/tasks/ios/core/utils:MPPBaseOptionsHelpers", + "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifierOptions", + ], +) + +objc_library( + name = "MPPImageClassifierResultHelpers", + srcs = ["sources/MPPImageClassifierResult+Helpers.mm"], + hdrs = ["sources/MPPImageClassifierResult+Helpers.h"], + deps = [ + "//mediapipe/framework:packet", + "//mediapipe/tasks/cc/components/containers/proto:classifications_cc_proto", + "//mediapipe/tasks/ios/components/containers/utils:MPPClassificationResultHelpers", + "//mediapipe/tasks/ios/vision/image_classifier:MPPImageClassifierResult", + ], +) diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h new file mode 100644 index 000000000..c3a3b2fec --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h @@ -0,0 +1,32 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/framework/calculator_options.pb.h" +#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h" +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierOptions.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPImageClassifierOptions (Helpers) + +/** + * Populates the provided `CalculatorOptions` proto container with the current settings. + * + * @param optionsProto The `CalculatorOptions` proto object to copy the settings to. + */ +- (void)copyToProto:(::mediapipe::CalculatorOptions *)optionsProto; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.mm b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.mm new file mode 100644 index 000000000..36ecf9093 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.mm @@ -0,0 +1,56 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierOptions+Helpers.h" + +#import "mediapipe/tasks/ios/common/utils/sources/NSString+Helpers.h" +#import "mediapipe/tasks/ios/core/utils/sources/MPPBaseOptions+Helpers.h" + +#include "mediapipe/tasks/cc/components/processors/proto/classifier_options.pb.h" +#include "mediapipe/tasks/cc/vision/image_classifier/proto/image_classifier_graph_options.pb.h" + +namespace { +using CalculatorOptionsProto = ::mediapipe::CalculatorOptions; +using ImageClassifierGraphOptionsProto = + ::mediapipe::tasks::vision::image_classifier::proto::ImageClassifierGraphOptions; +using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto::ClassifierOptions; +} // namespace + +@implementation MPPImageClassifierOptions (Helpers) + +- (void)copyToProto:(CalculatorOptionsProto *)optionsProto { + ImageClassifierGraphOptionsProto *graphOptions = + optionsProto->MutableExtension(ImageClassifierGraphOptionsProto::ext); + [self.baseOptions copyToProto:graphOptions->mutable_base_options()]; + + ClassifierOptionsProto *classifierOptionsProto = graphOptions->mutable_classifier_options(); + classifierOptionsProto->Clear(); + + if (self.displayNamesLocale) { + classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString); + } + + classifierOptionsProto->set_max_results((int)self.maxResults); + classifierOptionsProto->set_score_threshold(self.scoreThreshold); + + for (NSString *category in self.categoryAllowlist) { + classifierOptionsProto->add_category_allowlist(category.cppString); + } + + for (NSString *category in self.categoryDenylist) { + classifierOptionsProto->add_category_denylist(category.cppString); + } +} + +@end diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h new file mode 100644 index 000000000..0375ac2a5 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h @@ -0,0 +1,36 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/vision/image_classifier/sources/MPPImageClassifierResult.h" + +#include "mediapipe/framework/packet.h" + +NS_ASSUME_NONNULL_BEGIN + +@interface MPPImageClassifierResult (Helpers) + +/** + * Creates an `MPPImageClassifierResult` from a MediaPipe packet containing an + * `ClassificationResultProto`. + * + * @param packet a MediaPipe packet wrapping a ClassificationResultProto. + * + * @return An `MPPImageClassifierResult` object that contains a list of image classifications. + */ ++ (MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket: + (const mediapipe::Packet &)packet; + +@end + +NS_ASSUME_NONNULL_END diff --git a/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm new file mode 100644 index 000000000..09e21b278 --- /dev/null +++ b/mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.mm @@ -0,0 +1,41 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "mediapipe/tasks/ios/components/containers/utils/sources/MPPClassificationResult+Helpers.h" +#import "mediapipe/tasks/ios/vision/image_classifier/utils/sources/MPPImageClassifierResult+Helpers.h" + +#include "mediapipe/tasks/cc/components/containers/proto/classifications.pb.h" + +static const int kMicroSecondsPerMilliSecond = 1000; + +namespace { +using ClassificationResultProto = + ::mediapipe::tasks::components::containers::proto::ClassificationResult; +using ::mediapipe::Packet; +} // namespace + +@implementation MPPImageClassifierResult (Helpers) + ++ (MPPImageClassifierResult *)imageClassifierResultWithClassificationsPacket: + (const Packet &)packet { + MPPClassificationResult *classificationResult = [MPPClassificationResult + classificationResultWithProto:packet.Get()]; + + return [[MPPImageClassifierResult alloc] + initWithClassificationResult:classificationResult + timestampMs:(NSInteger)(packet.Timestamp().Value() / + kMicroSecondsPerMilliSecond)]; +} + +@end diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java index 7abde72d5..f577a361b 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/audio/core/BaseAudioTaskApi.java @@ -164,12 +164,14 @@ public class BaseAudioTaskApi implements AutoCloseable { * * @param numChannels the number of audio channels. * @param sampleRate the audio sample rate. + * @param requiredInputBufferSize the required input buffer size in number of float elements. * @return an {@link android.media.AudioRecord} instance in {@link * android.media.AudioRecord#STATE_INITIALIZED} * @throws IllegalArgumentException if the model required channel count is unsupported * @throws IllegalStateException if AudioRecord instance failed to initialize */ - public static AudioRecord createAudioRecord(int numChannels, int sampleRate) { + public AudioRecord createAudioRecord( + int numChannels, int sampleRate, int requiredInputBufferSize) { int channelConfig = 0; switch (numChannels) { case 1: @@ -190,6 +192,11 @@ public class BaseAudioTaskApi implements AutoCloseable { throw new IllegalStateException( String.format("AudioRecord.getMinBufferSize failed. Returned: %d", bufferSizeInBytes)); } + int bufferSizeMultiplier = 2; + int modelRequiredBufferSize = requiredInputBufferSize * Float.BYTES * bufferSizeMultiplier; + if (bufferSizeInBytes < modelRequiredBufferSize) { + bufferSizeInBytes = modelRequiredBufferSize; + } AudioRecord audioRecord = new AudioRecord( // including MIC, UNPROCESSED, and CAMCORDER. @@ -215,8 +222,8 @@ public class BaseAudioTaskApi implements AutoCloseable { * @throws IllegalArgumentException if the model required channel count is unsupported * @throws IllegalStateException if AudioRecord instance failed to initialize */ - public static AudioRecord createAudioRecord() { + public AudioRecord createAudioRecord() { // TODO: Support creating AudioRecord based on the model specifications. - return createAudioRecord(1, 16000); + return createAudioRecord(1, 16000, 16000); } } diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java index 1a128c538..11d385890 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/core/TaskRunner.java @@ -16,6 +16,7 @@ package com.google.mediapipe.tasks.core; import android.content.Context; import android.util.Log; +import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; import com.google.mediapipe.framework.AndroidAssetUtil; import com.google.mediapipe.framework.AndroidPacketCreator; import com.google.mediapipe.framework.Graph; @@ -201,6 +202,10 @@ public class TaskRunner implements AutoCloseable { } } + public CalculatorGraphConfig getCalculatorGraphConfig() { + return graph.getCalculatorGraphConfig(); + } + private synchronized void addPackets(Map inputs, long inputTimestamp) { if (!graphStarted.get()) { reportError( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl index 89fdb32dc..f9e8ed907 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/mediapipe_tasks_aar.bzl @@ -34,18 +34,19 @@ _AUDIO_TASKS_JAVA_PROTO_LITE_TARGETS = [ ] _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ - "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", - "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", ] _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ @@ -104,6 +105,11 @@ def mediapipe_tasks_core_aar(name, srcs, manifest): src_out = "com/google/mediapipe/calculator/proto/InferenceCalculatorProto.java", )) + mediapipe_tasks_java_proto_srcs.append(mediapipe_java_proto_src_extractor( + target = "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", + src_out = "com/google/mediapipe/tasks/TensorsToSegmentationCalculatorOptionsProto.java", + )) + android_library( name = name, srcs = srcs + [ @@ -136,6 +142,7 @@ def mediapipe_tasks_core_aar(name, srcs, manifest): "//mediapipe/framework/formats:rect_java_proto_lite", "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", "//third_party:androidx_annotation", "//third_party:autovalue", @@ -221,7 +228,9 @@ EOF name = name, srcs = srcs, manifest = "AndroidManifest.xml", - java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS, + java_proto_lite_targets = _CORE_TASKS_JAVA_PROTO_LITE_TARGETS + _VISION_TASKS_JAVA_PROTO_LITE_TARGETS + [ + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", + ], native_library = native_library, ) @@ -308,10 +317,14 @@ def _mediapipe_tasks_aar(name, srcs, manifest, java_proto_lite_targets, native_l "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embedding", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:embeddingresult", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:landmark", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalizedkeypoint", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalized_landmark", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:classifieroptions", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:cosinesimilarity", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/util:color_java_proto_lite", + "//mediapipe/util:label_map_java_proto_lite", + "//mediapipe/util:render_data_java_proto_lite", "//third_party:androidx_annotation", "//third_party:autovalue", "@maven//:com_google_guava_guava", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD index 32518725a..ddff069af 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/BUILD @@ -45,10 +45,12 @@ cc_binary( deps = [ "//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/java/com/google/mediapipe/framework/jni:mediapipe_framework_jni", + "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", "//mediapipe/tasks/cc/vision/gesture_recognizer:gesture_recognizer_graph", "//mediapipe/tasks/cc/vision/image_classifier:image_classifier_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", + "//mediapipe/tasks/cc/vision/interactive_segmenter:interactive_segmenter_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/java:version_script.lds", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core/jni:model_resources_cache_jni", @@ -196,6 +198,7 @@ android_library( "//mediapipe/java/com/google/mediapipe/framework:android_framework", "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", @@ -204,6 +207,35 @@ android_library( ], ) +android_library( + name = "interactivesegmenter", + srcs = [ + "imagesegmenter/ImageSegmenterResult.java", + "interactivesegmenter/InteractiveSegmenter.java", + ], + javacopts = [ + "-Xep:AndroidJdkLibsChecker:OFF", + ], + manifest = "interactivesegmenter/AndroidManifest.xml", + deps = [ + ":core", + "//mediapipe/framework:calculator_options_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", + "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:normalizedkeypoint", + "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", + "//mediapipe/util:color_java_proto_lite", + "//mediapipe/util:render_data_java_proto_lite", + "//third_party:autovalue", + "@maven//:androidx_annotation_annotation", + "@maven//:com_google_guava_guava", + ], +) + android_library( name = "imageembedder", srcs = [ @@ -235,6 +267,7 @@ android_library( android_library( name = "facedetector", srcs = [ + "facedetector/FaceDetector.java", "facedetector/FaceDetectorResult.java", ], javacopts = [ @@ -245,7 +278,10 @@ android_library( ":core", "//mediapipe/framework:calculator_options_java_proto_lite", "//mediapipe/framework/formats:detection_java_proto_lite", + "//mediapipe/java/com/google/mediapipe/framework:android_framework", + "//mediapipe/java/com/google/mediapipe/framework/image", "//mediapipe/tasks/cc/core/proto:base_options_java_proto_lite", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite", "//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:detection", "//mediapipe/tasks/java/com/google/mediapipe/tasks/core", "//third_party:autovalue", diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java new file mode 100644 index 000000000..c23432c1b --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/facedetector/FaceDetector.java @@ -0,0 +1,463 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.facedetector; + +import android.content.Context; +import android.os.ParcelFileDescriptor; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.facedetector.proto.FaceDetectorGraphOptionsProto; +import com.google.mediapipe.formats.proto.DetectionProto.Detection; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Performs face detection on images. + * + *

The API expects a TFLite model with TFLite Model Metadata.. + * + *

    + *
  • Input image {@link MPImage} + *
      + *
    • The image that the face detector runs on. + *
    + *
  • Output FaceDetectorResult {@link FaceDetectorResult} + *
      + *
    • A FaceDetectorResult containing detected faces. + *
    + *
+ */ +public final class FaceDetector extends BaseVisionTaskApi { + private static final String TAG = FaceDetector.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + + @SuppressWarnings("ConstantCaseForConstants") + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList("IMAGE:" + IMAGE_IN_STREAM_NAME, "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + + @SuppressWarnings("ConstantCaseForConstants") + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList(Arrays.asList("DETECTIONS:detections_out", "IMAGE:image_out")); + + private static final int DETECTIONS_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.face_detector.FaceDetectorGraph"; + + /** + * Creates a {@link FaceDetector} instance from a model file and the default {@link + * FaceDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param modelPath path to the detection model with metadata in the assets. + * @throws MediaPipeException if there is an error during {@link FaceDetector} creation. + */ + public static FaceDetector createFromFile(Context context, String modelPath) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetPath(modelPath).build(); + return createFromOptions( + context, FaceDetectorOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link FaceDetector} instance from a model file and the default {@link + * FaceDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param modelFile the detection model {@link File} instance. + * @throws IOException if an I/O error occurs when opening the tflite model file. + * @throws MediaPipeException if there is an error during {@link FaceDetector} creation. + */ + public static FaceDetector createFromFile(Context context, File modelFile) throws IOException { + try (ParcelFileDescriptor descriptor = + ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) { + BaseOptions baseOptions = + BaseOptions.builder().setModelAssetFileDescriptor(descriptor.getFd()).build(); + return createFromOptions( + context, FaceDetectorOptions.builder().setBaseOptions(baseOptions).build()); + } + } + + /** + * Creates a {@link FaceDetector} instance from a model buffer and the default {@link + * FaceDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param modelBuffer a direct {@link ByteBuffer} or a {@link MappedByteBuffer} of the detection + * model. + * @throws MediaPipeException if there is an error during {@link FaceDetector} creation. + */ + public static FaceDetector createFromBuffer(Context context, final ByteBuffer modelBuffer) { + BaseOptions baseOptions = BaseOptions.builder().setModelAssetBuffer(modelBuffer).build(); + return createFromOptions( + context, FaceDetectorOptions.builder().setBaseOptions(baseOptions).build()); + } + + /** + * Creates a {@link FaceDetector} instance from a {@link FaceDetectorOptions}. + * + * @param context an Android {@link Context}. + * @param detectorOptions a {@link FaceDetectorOptions} instance. + * @throws MediaPipeException if there is an error during {@link FaceDetector} creation. + */ + public static FaceDetector createFromOptions( + Context context, FaceDetectorOptions detectorOptions) { + // TODO: Consolidate OutputHandler and TaskRunner. + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public FaceDetectorResult convertToTaskResult(List packets) { + // If there is no faces detected in the image, just returns empty lists. + if (packets.get(DETECTIONS_OUT_STREAM_INDEX).isEmpty()) { + return FaceDetectorResult.create( + new ArrayList<>(), + BaseVisionTaskApi.generateResultTimestampMs( + detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX))); + } + return FaceDetectorResult.create( + PacketGetter.getProtoVector( + packets.get(DETECTIONS_OUT_STREAM_INDEX), Detection.parser()), + BaseVisionTaskApi.generateResultTimestampMs( + detectorOptions.runningMode(), packets.get(DETECTIONS_OUT_STREAM_INDEX))); + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + detectorOptions.resultListener().ifPresent(handler::setResultListener); + detectorOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(FaceDetector.class.getSimpleName()) + .setTaskRunningModeName(detectorOptions.runningMode().name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(detectorOptions) + .setEnableFlowLimiting(detectorOptions.runningMode() == RunningMode.LIVE_STREAM) + .build(), + handler); + return new FaceDetector(runner, detectorOptions.runningMode()); + } + + /** + * Constructor to initialize a {@link FaceDetector} from a {@link TaskRunner} and a {@link + * RunningMode}. + * + * @param taskRunner a {@link TaskRunner}. + * @param runningMode a mediapipe vision task {@link RunningMode}. + */ + private FaceDetector(TaskRunner taskRunner, RunningMode runningMode) { + super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + } + + /** + * Performs face detection on the provided single image with default image processing options, + * i.e. without any rotation applied. Only use this method when the {@link FaceDetector} is + * created with {@link RunningMode.IMAGE}. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @throws MediaPipeException if there is an internal error. + */ + public FaceDetectorResult detect(MPImage image) { + return detect(image, ImageProcessingOptions.builder().build()); + } + + /** + * Performs face detection on the provided single image. Only use this method when the {@link + * FaceDetector} is created with {@link RunningMode.IMAGE}. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public FaceDetectorResult detect(MPImage image, ImageProcessingOptions imageProcessingOptions) { + validateImageProcessingOptions(imageProcessingOptions); + return (FaceDetectorResult) processImageData(image, imageProcessingOptions); + } + + /** + * Performs face detection on the provided video frame with default image processing options, i.e. + * without any rotation applied. Only use this method when the {@link FaceDetector} is created + * with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public FaceDetectorResult detectForVideo(MPImage image, long timestampMs) { + return detectForVideo(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Performs face detection on the provided video frame. Only use this method when the {@link + * FaceDetector} is created with {@link RunningMode.VIDEO}. + * + *

It's required to provide the video frame's timestamp (in milliseconds). The input timestamps + * must be monotonically increasing. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public FaceDetectorResult detectForVideo( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + return (FaceDetectorResult) processVideoData(image, imageProcessingOptions, timestampMs); + } + + /** + * Sends live image data to perform face detection with default image processing options, i.e. + * without any rotation applied, and the results will be available via the {@link ResultListener} + * provided in the {@link FaceDetectorOptions}. Only use this method when the {@link FaceDetector} + * is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the face detector. The input timestamps must be monotonically increasing. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param timestampMs the input timestamp (in milliseconds). + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync(MPImage image, long timestampMs) { + detectAsync(image, ImageProcessingOptions.builder().build(), timestampMs); + } + + /** + * Sends live image data to perform face detection, and the results will be available via the + * {@link ResultListener} provided in the {@link FaceDetectorOptions}. Only use this method when + * the {@link FaceDetector} is created with {@link RunningMode.LIVE_STREAM}. + * + *

It's required to provide a timestamp (in milliseconds) to indicate when the input image is + * sent to the face detector. The input timestamps must be monotonically increasing. + * + *

{@link FaceDetector} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @param timestampMs the input timestamp (in milliseconds). + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. + */ + public void detectAsync( + MPImage image, ImageProcessingOptions imageProcessingOptions, long timestampMs) { + validateImageProcessingOptions(imageProcessingOptions); + sendLiveStreamData(image, imageProcessingOptions, timestampMs); + } + + /** Options for setting up a {@link FaceDetector}. */ + @AutoValue + public abstract static class FaceDetectorOptions extends TaskOptions { + + /** Builder for {@link FaceDetectorOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the {@link BaseOptions} for the face detector task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** + * Sets the {@link RunningMode} for the face detector task. Default to the image mode. face + * detector has three modes: + * + *
    + *
  • IMAGE: The mode for detecting faces on single image inputs. + *
  • VIDEO: The mode for detecting faces on the decoded frames of a video. + *
  • LIVE_STREAM: The mode for for detecting faces on a live stream of input data, such as + * from camera. In this mode, {@code setResultListener} must be called to set up a + * listener to receive the detection results asynchronously. + *
+ */ + public abstract Builder setRunningMode(RunningMode value); + + /** + * Sets the minimum confidence score for the face detection to be considered successful. The + * default minDetectionConfidence is 0.5. + */ + public abstract Builder setMinDetectionConfidence(Float value); + + /** + * Sets the minimum non-maximum-suppression threshold for face detection to be considered + * overlapped. The default minSuppressionThreshold is 0.3. + */ + public abstract Builder setMinSuppressionThreshold(Float value); + + /** + * Sets the {@link ResultListener} to receive the detection results asynchronously when the + * face detector is in the live stream mode. + */ + public abstract Builder setResultListener(ResultListener value); + + /** Sets an optional {@link ErrorListener}}. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract FaceDetectorOptions autoBuild(); + + /** + * Validates and builds the {@link FaceDetectorOptions} instance. + * + * @throws IllegalArgumentException if the result listener and the running mode are not + * properly configured. The result listener should only be set when the face detector is + * in the live stream mode. + */ + public final FaceDetectorOptions build() { + FaceDetectorOptions options = autoBuild(); + if (options.runningMode() == RunningMode.LIVE_STREAM) { + if (!options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The face detector is in the live stream mode, a user-defined result listener" + + " must be provided in FaceDetectorOptions."); + } + } else if (options.resultListener().isPresent()) { + throw new IllegalArgumentException( + "The face detector is in the image or the video mode, a user-defined result" + + " listener shouldn't be provided in FaceDetectorOptions."); + } + return options; + } + } + + abstract BaseOptions baseOptions(); + + abstract RunningMode runningMode(); + + abstract float minDetectionConfidence(); + + abstract float minSuppressionThreshold(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + public static Builder builder() { + return new AutoValue_FaceDetector_FaceDetectorOptions.Builder() + .setRunningMode(RunningMode.IMAGE) + .setMinDetectionConfidence(0.5f) + .setMinSuppressionThreshold(0.3f); + } + + /** Converts a {@link FaceDetectorOptions} to a {@link CalculatorOptions} protobuf message. */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + BaseOptionsProto.BaseOptions.Builder baseOptionsBuilder = + BaseOptionsProto.BaseOptions.newBuilder(); + baseOptionsBuilder.setUseStreamMode(runningMode() != RunningMode.IMAGE); + baseOptionsBuilder.mergeFrom(convertBaseOptionsToProto(baseOptions())); + FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions.Builder taskOptionsBuilder = + FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions.newBuilder() + .setBaseOptions(baseOptionsBuilder); + taskOptionsBuilder.setMinDetectionConfidence(minDetectionConfidence()); + taskOptionsBuilder.setMinSuppressionThreshold(minSuppressionThreshold()); + return CalculatorOptions.newBuilder() + .setExtension( + FaceDetectorGraphOptionsProto.FaceDetectorGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException("FaceDetector doesn't support region-of-interest."); + } + } +} diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java index 76b33fb97..931740c8e 100644 --- a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenter.java @@ -17,6 +17,7 @@ package com.google.mediapipe.tasks.vision.imagesegmenter; import android.content.Context; import com.google.auto.value.AutoValue; import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; import com.google.mediapipe.framework.AndroidPacketGetter; import com.google.mediapipe.framework.MediaPipeException; import com.google.mediapipe.framework.Packet; @@ -24,6 +25,7 @@ import com.google.mediapipe.framework.PacketGetter; import com.google.mediapipe.framework.image.BitmapImageBuilder; import com.google.mediapipe.framework.image.ByteBufferImageBuilder; import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.TensorsToSegmentationCalculatorOptionsProto; import com.google.mediapipe.tasks.core.BaseOptions; import com.google.mediapipe.tasks.core.ErrorListener; import com.google.mediapipe.tasks.core.OutputHandler; @@ -88,8 +90,10 @@ public final class ImageSegmenter extends BaseVisionTaskApi { private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; private static final String TASK_GRAPH_NAME = "mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph"; - + private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = + "mediapipe.tasks.TensorsToSegmentationCalculator"; private boolean hasResultListener = false; + private List labels = new ArrayList<>(); /** * Creates an {@link ImageSegmenter} instance from an {@link ImageSegmenterOptions}. @@ -190,6 +194,41 @@ public final class ImageSegmenter extends BaseVisionTaskApi { TaskRunner taskRunner, RunningMode runningMode, boolean hasResultListener) { super(taskRunner, runningMode, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); this.hasResultListener = hasResultListener; + populateLabels(); + } + /** + * Populate the labelmap in TensorsToSegmentationCalculator to labels field. + * + * @throws MediaPipeException if there is an error during finding TensorsToSegmentationCalculator. + */ + private void populateLabels() { + CalculatorGraphConfig graphConfig = this.runner.getCalculatorGraphConfig(); + + boolean foundTensorsToSegmentation = false; + for (CalculatorGraphConfig.Node node : graphConfig.getNodeList()) { + if (node.getName().contains(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)) { + if (foundTensorsToSegmentation) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "The graph has more than one mediapipe.tasks.TensorsToSegmentationCalculator."); + } + foundTensorsToSegmentation = true; + TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions options = + node.getOptions() + .getExtension( + TensorsToSegmentationCalculatorOptionsProto + .TensorsToSegmentationCalculatorOptions.ext); + for (int i = 0; i < options.getLabelItemsMap().size(); i++) { + Long labelKey = Long.valueOf(i); + if (!options.getLabelItemsMap().containsKey(labelKey)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "The lablemap have no expected key: " + labelKey); + } + labels.add(options.getLabelItemsMap().get(labelKey).getName()); + } + } + } } /** @@ -473,6 +512,17 @@ public final class ImageSegmenter extends BaseVisionTaskApi { sendLiveStreamData(image, imageProcessingOptions, timestampMs); } + /** + * Get the category label list of the ImageSegmenter can recognize. For CATEGORY_MASK type, the + * index in the category mask corresponds to the category in the label list. For CONFIDENCE_MASK + * type, the output mask list at index corresponds to the category in the label list. + * + *

If there is no labelmap provided in the model file, empty label list is returned. + */ + List getLabels() { + return labels; + } + /** Options for setting up an {@link ImageSegmenter}. */ @AutoValue public abstract static class ImageSegmenterOptions extends TaskOptions { @@ -591,9 +641,6 @@ public final class ImageSegmenter extends BaseVisionTaskApi { SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); } - // TODO: remove this once activation is handled in metadata and grpah level. - segmenterOptionsBuilder.setActivation( - SegmenterOptionsProto.SegmenterOptions.Activation.SOFTMAX); taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); return CalculatorOptions.newBuilder() .setExtension( diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/AndroidManifest.xml b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/AndroidManifest.xml new file mode 100644 index 000000000..1bde79182 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + diff --git a/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java new file mode 100644 index 000000000..8ee6951f8 --- /dev/null +++ b/mediapipe/tasks/java/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenter.java @@ -0,0 +1,556 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.interactivesegmenter; + +import android.content.Context; +import com.google.auto.value.AutoValue; +import com.google.mediapipe.proto.CalculatorOptionsProto.CalculatorOptions; +import com.google.mediapipe.proto.CalculatorProto.CalculatorGraphConfig; +import com.google.mediapipe.framework.AndroidPacketGetter; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.Packet; +import com.google.mediapipe.framework.PacketGetter; +import com.google.mediapipe.framework.ProtoUtil; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.ByteBufferImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.TensorsToSegmentationCalculatorOptionsProto; +import com.google.mediapipe.tasks.components.containers.NormalizedKeypoint; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.ErrorListener; +import com.google.mediapipe.tasks.core.OutputHandler; +import com.google.mediapipe.tasks.core.OutputHandler.ResultListener; +import com.google.mediapipe.tasks.core.TaskInfo; +import com.google.mediapipe.tasks.core.TaskOptions; +import com.google.mediapipe.tasks.core.TaskRunner; +import com.google.mediapipe.tasks.core.proto.BaseOptionsProto; +import com.google.mediapipe.tasks.vision.core.BaseVisionTaskApi; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult; +import com.google.mediapipe.tasks.vision.imagesegmenter.proto.ImageSegmenterGraphOptionsProto; +import com.google.mediapipe.tasks.vision.imagesegmenter.proto.SegmenterOptionsProto; +import com.google.mediapipe.util.proto.ColorProto.Color; +import com.google.mediapipe.util.proto.RenderDataProto.RenderAnnotation; +import com.google.mediapipe.util.proto.RenderDataProto.RenderData; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Performs interactive segmentation on images. + * + *

Note that, in addition to the standard segmentation API {@link segment} that takes an input + * image and returns the outputs, but involves deep copy of the returns, InteractiveSegmenter also + * supports the callback API, {@link segmentWithResultListener}, which allows you to access the + * outputs through zero copy. Set {@link ResultListener} in {@link InteractiveSegmenterOptions} + * properly to use the callback API. + * + *

The API expects a TFLite model with,TFLite Model Metadata.. The model + * expects input with 4 channels, where the first 3 channels represent RGB image, and the last + * channel represents the user's region of interest. + * + *

    + *
  • Input image {@link MPImage} + *
      + *
    • The image that image segmenter runs on. + *
    + *
  • Input roi {@link RegionOfInterest} + *
      + *
    • Region of interest based on user interaction. + *
    + *
  • Output ImageSegmenterResult {@link ImageSegmenterResult} + *
      + *
    • An ImageSegmenterResult containing segmented masks. + *
    + *
+ */ +public final class InteractiveSegmenter extends BaseVisionTaskApi { + private static final String TAG = InteractiveSegmenter.class.getSimpleName(); + private static final String IMAGE_IN_STREAM_NAME = "image_in"; + private static final String ROI_IN_STREAM_NAME = "roi_in"; + private static final String NORM_RECT_IN_STREAM_NAME = "norm_rect_in"; + private static final List INPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "IMAGE:" + IMAGE_IN_STREAM_NAME, + "ROI:" + ROI_IN_STREAM_NAME, + "NORM_RECT:" + NORM_RECT_IN_STREAM_NAME)); + private static final List OUTPUT_STREAMS = + Collections.unmodifiableList( + Arrays.asList( + "GROUPED_SEGMENTATION:segmented_mask_out", + "IMAGE:image_out", + "SEGMENTATION:0:segmentation")); + private static final int GROUPED_SEGMENTATION_OUT_STREAM_INDEX = 0; + private static final int IMAGE_OUT_STREAM_INDEX = 1; + private static final int SEGMENTATION_OUT_STREAM_INDEX = 2; + private static final String TASK_GRAPH_NAME = + "mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph"; + private static final String TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = + "mediapipe.tasks.TensorsToSegmentationCalculator"; + private boolean hasResultListener = false; + private List labels = new ArrayList<>(); + + static { + ProtoUtil.registerTypeName(RenderData.class, "mediapipe.RenderData"); + } + + /** + * Creates an {@link InteractiveSegmenter} instance from an {@link InteractiveSegmenterOptions}. + * + * @param context an Android {@link Context}. + * @param segmenterOptions an {@link InteractiveSegmenterOptions} instance. + * @throws MediaPipeException if there is an error during {@link InteractiveSegmenter} creation. + */ + public static InteractiveSegmenter createFromOptions( + Context context, InteractiveSegmenterOptions segmenterOptions) { + // TODO: Consolidate OutputHandler and TaskRunner. + OutputHandler handler = new OutputHandler<>(); + handler.setOutputPacketConverter( + new OutputHandler.OutputPacketConverter() { + @Override + public ImageSegmenterResult convertToTaskResult(List packets) + throws MediaPipeException { + if (packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).isEmpty()) { + return ImageSegmenterResult.create( + new ArrayList<>(), + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX).getTimestamp()); + } + List segmentedMasks = new ArrayList<>(); + int width = PacketGetter.getImageWidth(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); + int height = PacketGetter.getImageHeight(packets.get(SEGMENTATION_OUT_STREAM_INDEX)); + int imageFormat = + segmenterOptions.outputType() + == InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK + ? MPImage.IMAGE_FORMAT_VEC32F1 + : MPImage.IMAGE_FORMAT_ALPHA; + int imageListSize = + PacketGetter.getImageListSize(packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX)); + ByteBuffer[] buffersArray = new ByteBuffer[imageListSize]; + // If resultListener is not provided, the resulted MPImage is deep copied from mediapipe + // graph. If provided, the result MPImage is wrapping the mediapipe packet memory. + if (!segmenterOptions.resultListener().isPresent()) { + for (int i = 0; i < imageListSize; i++) { + buffersArray[i] = + ByteBuffer.allocateDirect( + width * height * (imageFormat == MPImage.IMAGE_FORMAT_VEC32F1 ? 4 : 1)); + } + } + if (!PacketGetter.getImageList( + packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX), + buffersArray, + !segmenterOptions.resultListener().isPresent())) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "There is an error getting segmented masks. It usually results from incorrect" + + " options of unsupported OutputType of given model."); + } + for (ByteBuffer buffer : buffersArray) { + ByteBufferImageBuilder builder = + new ByteBufferImageBuilder(buffer, width, height, imageFormat); + segmentedMasks.add(builder.build()); + } + + return ImageSegmenterResult.create( + segmentedMasks, + BaseVisionTaskApi.generateResultTimestampMs( + RunningMode.IMAGE, packets.get(GROUPED_SEGMENTATION_OUT_STREAM_INDEX))); + } + + @Override + public MPImage convertToTaskInput(List packets) { + return new BitmapImageBuilder( + AndroidPacketGetter.getBitmapFromRgb(packets.get(IMAGE_OUT_STREAM_INDEX))) + .build(); + } + }); + segmenterOptions.resultListener().ifPresent(handler::setResultListener); + segmenterOptions.errorListener().ifPresent(handler::setErrorListener); + TaskRunner runner = + TaskRunner.create( + context, + TaskInfo.builder() + .setTaskName(InteractiveSegmenter.class.getSimpleName()) + .setTaskRunningModeName(RunningMode.IMAGE.name()) + .setTaskGraphName(TASK_GRAPH_NAME) + .setInputStreams(INPUT_STREAMS) + .setOutputStreams(OUTPUT_STREAMS) + .setTaskOptions(segmenterOptions) + .setEnableFlowLimiting(false) + .build(), + handler); + return new InteractiveSegmenter(runner, segmenterOptions.resultListener().isPresent()); + } + + /** + * Constructor to initialize an {@link InteractiveSegmenter} from a {@link TaskRunner}. + * + * @param taskRunner a {@link TaskRunner}. + */ + private InteractiveSegmenter(TaskRunner taskRunner, boolean hasResultListener) { + super(taskRunner, RunningMode.IMAGE, IMAGE_IN_STREAM_NAME, NORM_RECT_IN_STREAM_NAME); + this.hasResultListener = hasResultListener; + populateLabels(); + } + + /** + * Populate the labelmap in TensorsToSegmentationCalculator to labels field. + * + * @throws MediaPipeException if there is an error during finding TensorsToSegmentationCalculator. + */ + private void populateLabels() { + CalculatorGraphConfig graphConfig = this.runner.getCalculatorGraphConfig(); + + boolean foundTensorsToSegmentation = false; + for (CalculatorGraphConfig.Node node : graphConfig.getNodeList()) { + if (node.getName().contains(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)) { + if (foundTensorsToSegmentation) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "The graph has more than one mediapipe.tasks.TensorsToSegmentationCalculator."); + } + foundTensorsToSegmentation = true; + TensorsToSegmentationCalculatorOptionsProto.TensorsToSegmentationCalculatorOptions options = + node.getOptions() + .getExtension( + TensorsToSegmentationCalculatorOptionsProto + .TensorsToSegmentationCalculatorOptions.ext); + for (int i = 0; i < options.getLabelItemsMap().size(); i++) { + Long labelKey = Long.valueOf(i); + if (!options.getLabelItemsMap().containsKey(labelKey)) { + throw new MediaPipeException( + MediaPipeException.StatusCode.INTERNAL.ordinal(), + "The lablemap have no expected key: " + labelKey); + } + labels.add(options.getLabelItemsMap().get(labelKey).getName()); + } + } + } + } + + /** + * Performs segmentation on the provided single image with default image processing options, given + * user's region-of-interest, i.e. without any rotation applied. TODO update java doc + * for input image format. + * + *

Users can represent user interaction through {@link RegionOfInterest}, which gives a hint to + * perform segmentation focusing on the given region of interest. + * + *

{@link InteractiveSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param roi a {@link RegionOfInterest} object to represent user interaction. + * @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is + * created with a {@link ResultListener}. + */ + public ImageSegmenterResult segment(MPImage image, RegionOfInterest roi) { + return segment(image, roi, ImageProcessingOptions.builder().build()); + } + + /** + * Performs segmentation on the provided single image, given user's region-of-interest. + * TODO update java doc for input image format. + * + *

Users can represent user interaction through {@link RegionOfInterest}, which gives a hint to + * perform segmentation focusing on the given region of interest. + * + *

{@link InteractiveSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param roi a {@link RegionOfInterest} object to represent user interaction. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is + * created with a {@link ResultListener}. + */ + public ImageSegmenterResult segment( + MPImage image, RegionOfInterest roi, ImageProcessingOptions imageProcessingOptions) { + if (hasResultListener) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "ResultListener is provided in the InteractiveSegmenterOptions, but this method will" + + " return an ImageSegmentationResult."); + } + validateImageProcessingOptions(imageProcessingOptions); + return processImageWithRoi(image, roi, imageProcessingOptions); + } + + /** + * Performs segmentation on the provided single image with default image processing options, given + * user's region-of-interest, i.e. without any rotation applied, and provides zero-copied results + * via {@link ResultListener} in {@link InteractiveSegmenterOptions}. + * + *

TODO update java doc for input image format. + * + *

Users can represent user interaction through {@link RegionOfInterest}, which gives a hint to + * perform segmentation focusing on the given region of interest. + * + *

{@link InteractiveSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param roi a {@link RegionOfInterest} object to represent user interaction. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is + * not created wtih {@link ResultListener} set in {@link InteractiveSegmenterOptions}. + */ + public void segmentWithResultListener(MPImage image, RegionOfInterest roi) { + segmentWithResultListener(image, roi, ImageProcessingOptions.builder().build()); + } + + /** + * Performs segmentation on the provided single image given user's region-of-interest, and + * provides zero-copied results via {@link ResultListener} in {@link InteractiveSegmenterOptions}. + * + *

TODO update java doc for input image format. + * + *

Users can represent user interaction through {@link RegionOfInterest}, which gives a hint to + * perform segmentation focusing on the given region of interest. + * + *

{@link InteractiveSegmenter} supports the following color space types: + * + *

    + *
  • {@link Bitmap.Config.ARGB_8888} + *
+ * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param roi a {@link RegionOfInterest} object to represent user interaction. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. Note that region-of-interest is not supported + * by this task: specifying {@link ImageProcessingOptions#regionOfInterest()} will result in + * this method throwing an IllegalArgumentException. + * @throws IllegalArgumentException if the {@link ImageProcessingOptions} specify a + * region-of-interest. + * @throws MediaPipeException if there is an internal error. Or if {@link InteractiveSegmenter} is + * not created wtih {@link ResultListener} set in {@link InteractiveSegmenterOptions}. + */ + public void segmentWithResultListener( + MPImage image, RegionOfInterest roi, ImageProcessingOptions imageProcessingOptions) { + if (!hasResultListener) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "ResultListener is not set in the InteractiveSegmenterOptions, but this method expects a" + + " ResultListener to process ImageSegmentationResult."); + } + validateImageProcessingOptions(imageProcessingOptions); + ImageSegmenterResult unused = processImageWithRoi(image, roi, imageProcessingOptions); + } + + /** + * Get the category label list of the ImageSegmenter can recognize. For CATEGORY_MASK type, the + * index in the category mask corresponds to the category in the label list. For CONFIDENCE_MASK + * type, the output mask list at index corresponds to the category in the label list. + * + *

If there is no labelmap provided in the model file, empty label list is returned. + */ + List getLabels() { + return labels; + } + + /** Options for setting up an {@link InteractiveSegmenter}. */ + @AutoValue + public abstract static class InteractiveSegmenterOptions extends TaskOptions { + + /** Builder for {@link InteractiveSegmenterOptions}. */ + @AutoValue.Builder + public abstract static class Builder { + /** Sets the base options for the image segmenter task. */ + public abstract Builder setBaseOptions(BaseOptions value); + + /** The output type from image segmenter. */ + public abstract Builder setOutputType(OutputType value); + + /** + * Sets an optional {@link ResultListener} to receive the segmentation results when the graph + * pipeline is done processing an image. + */ + public abstract Builder setResultListener( + ResultListener value); + + /** Sets an optional {@link ErrorListener}}. */ + public abstract Builder setErrorListener(ErrorListener value); + + abstract InteractiveSegmenterOptions autoBuild(); + + /** Builds the {@link InteractiveSegmenterOptions} instance. */ + public final InteractiveSegmenterOptions build() { + return autoBuild(); + } + } + + abstract BaseOptions baseOptions(); + + abstract OutputType outputType(); + + abstract Optional> resultListener(); + + abstract Optional errorListener(); + + /** The output type of segmentation results. */ + public enum OutputType { + // Gives a single output mask where each pixel represents the class which + // the pixel in the original image was predicted to belong to. + CATEGORY_MASK, + // Gives a list of output masks where, for each mask, each pixel represents + // the prediction confidence, usually in the [0, 1] range. + CONFIDENCE_MASK + } + + public static Builder builder() { + return new AutoValue_InteractiveSegmenter_InteractiveSegmenterOptions.Builder() + .setOutputType(OutputType.CATEGORY_MASK); + } + + /** + * Converts an {@link InteractiveSegmenterOptions} to a {@link CalculatorOptions} protobuf + * message. + */ + @Override + public CalculatorOptions convertToCalculatorOptionsProto() { + ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.Builder taskOptionsBuilder = + ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.newBuilder() + .setBaseOptions( + BaseOptionsProto.BaseOptions.newBuilder() + .setUseStreamMode(false) + .mergeFrom(convertBaseOptionsToProto(baseOptions())) + .build()); + + SegmenterOptionsProto.SegmenterOptions.Builder segmenterOptionsBuilder = + SegmenterOptionsProto.SegmenterOptions.newBuilder(); + if (outputType() == OutputType.CONFIDENCE_MASK) { + segmenterOptionsBuilder.setOutputType( + SegmenterOptionsProto.SegmenterOptions.OutputType.CONFIDENCE_MASK); + } else if (outputType() == OutputType.CATEGORY_MASK) { + segmenterOptionsBuilder.setOutputType( + SegmenterOptionsProto.SegmenterOptions.OutputType.CATEGORY_MASK); + } + + taskOptionsBuilder.setSegmenterOptions(segmenterOptionsBuilder); + return CalculatorOptions.newBuilder() + .setExtension( + ImageSegmenterGraphOptionsProto.ImageSegmenterGraphOptions.ext, + taskOptionsBuilder.build()) + .build(); + } + } + + /** + * Validates that the provided {@link ImageProcessingOptions} doesn't contain a + * region-of-interest. + */ + private static void validateImageProcessingOptions( + ImageProcessingOptions imageProcessingOptions) { + if (imageProcessingOptions.regionOfInterest().isPresent()) { + throw new IllegalArgumentException( + "InteractiveSegmenter doesn't support region-of-interest."); + } + } + + /** The Region-Of-Interest (ROI) to interact with. */ + public static class RegionOfInterest { + private NormalizedKeypoint keypoint; + + private RegionOfInterest() {} + + /** + * Creates a {@link RegionOfInterest} instance representing a single point pointing to the + * object that the user wants to segment. + */ + public static RegionOfInterest create(NormalizedKeypoint keypoint) { + RegionOfInterest roi = new RegionOfInterest(); + roi.keypoint = keypoint; + return roi; + } + } + + /** + * Converts a {@link RegionOfInterest} instance into a {@link RenderData} protobuf message + * + * @param roi a {@link RegionOfInterest} object to represent user interaction. + * @throws IllegalArgumentException if {@link RegionOfInterest} does not represent a valid user + * interaction. + */ + private static RenderData convertToRenderData(RegionOfInterest roi) { + RenderData.Builder builder = RenderData.newBuilder(); + if (roi.keypoint != null) { + return builder + .addRenderAnnotations( + RenderAnnotation.newBuilder() + .setColor(Color.newBuilder().setR(255)) + .setPoint( + RenderAnnotation.Point.newBuilder() + .setX(roi.keypoint.x()) + .setY(roi.keypoint.y()))) + .build(); + } + + throw new IllegalArgumentException( + "RegionOfInterest does not include a valid user interaction"); + } + + /** + * A synchronous method to process single image inputs. The call blocks the current thread until a + * failure status or a successful result is returned. + * + *

This is almost the same as {@link BaseVisionTaskApi.processImageData} except accepting an + * additional {@link RegionOfInterest}. + * + * @param image a MediaPipe {@link MPImage} object for processing. + * @param roi a {@link RegionOfInterest} object to represent user interaction. + * @param imageProcessingOptions the {@link ImageProcessingOptions} specifying how to process the + * input image before running inference. + * @throws MediaPipeException if the task is not in the image mode. + */ + private ImageSegmenterResult processImageWithRoi( + MPImage image, RegionOfInterest roi, ImageProcessingOptions imageProcessingOptions) { + if (runningMode != RunningMode.IMAGE) { + throw new MediaPipeException( + MediaPipeException.StatusCode.FAILED_PRECONDITION.ordinal(), + "Task is not initialized with the image mode. Current running mode:" + + runningMode.name()); + } + Map inputPackets = new HashMap<>(); + inputPackets.put(IMAGE_IN_STREAM_NAME, runner.getPacketCreator().createImage(image)); + RenderData renderData = convertToRenderData(roi); + inputPackets.put(ROI_IN_STREAM_NAME, runner.getPacketCreator().createProto(renderData)); + inputPackets.put( + NORM_RECT_IN_STREAM_NAME, + runner.getPacketCreator().createProto(convertToNormalizedRect(imageProcessingOptions))); + return (ImageSegmenterResult) runner.process(inputPackets); + } +} diff --git a/mediapipe/tasks/java/version_script.lds b/mediapipe/tasks/java/version_script.lds index 08577b101..13f36f21e 100644 --- a/mediapipe/tasks/java/version_script.lds +++ b/mediapipe/tasks/java/version_script.lds @@ -7,6 +7,7 @@ VERS_1.0 { Java_com_google_mediapipe_framework_Graph_nativeAddPacketToInputStream; Java_com_google_mediapipe_framework_Graph_nativeCloseAllPacketSources; Java_com_google_mediapipe_framework_Graph_nativeCreateGraph; + Java_com_google_mediapipe_framework_Graph_nativeGetCalculatorGraphConfig; Java_com_google_mediapipe_framework_Graph_nativeLoadBinaryGraph*; Java_com_google_mediapipe_framework_Graph_nativeMovePacketToInputStream; Java_com_google_mediapipe_framework_Graph_nativeReleaseGraph; diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/AndroidManifest.xml new file mode 100644 index 000000000..01cbc3a6f --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/BUILD new file mode 100644 index 000000000..c14486766 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/BUILD @@ -0,0 +1,19 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/FaceDetectorTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/FaceDetectorTest.java new file mode 100644 index 000000000..d995accd5 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/facedetector/FaceDetectorTest.java @@ -0,0 +1,455 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.facedetector; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import android.graphics.RectF; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.MediaPipeException; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.NormalizedKeypoint; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.core.TestUtils; +import com.google.mediapipe.tasks.vision.core.ImageProcessingOptions; +import com.google.mediapipe.tasks.vision.core.RunningMode; +import com.google.mediapipe.tasks.vision.facedetector.FaceDetector.FaceDetectorOptions; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link FaceDetector}. */ +@RunWith(Suite.class) +@SuiteClasses({FaceDetectorTest.General.class, FaceDetectorTest.RunningModeTest.class}) +public class FaceDetectorTest { + private static final String MODEL_FILE = "face_detection_short_range.tflite"; + private static final String CAT_IMAGE = "cat.jpg"; + private static final String PORTRAIT_IMAGE = "portrait.jpg"; + private static final String PORTRAIT_ROTATED_IMAGE = "portrait_rotated.jpg"; + private static final float KEYPOINTS_DIFF_TOLERANCE = 0.01f; + private static final float PIXEL_DIFF_TOLERANCE = 5.0f; + private static final RectF PORTRAIT_FACE_BOUNDING_BOX = new RectF(283, 115, 514, 349); + private static final List PORTRAIT_FACE_KEYPOINTS = + Collections.unmodifiableList( + Arrays.asList( + NormalizedKeypoint.create(0.44416f, 0.17643f), + NormalizedKeypoint.create(0.55514f, 0.17731f), + NormalizedKeypoint.create(0.50467f, 0.22657f), + NormalizedKeypoint.create(0.50227f, 0.27199f), + NormalizedKeypoint.create(0.36063f, 0.20143f), + NormalizedKeypoint.create(0.60841f, 0.20409f))); + private static final RectF PORTRAIT_ROTATED_FACE_BOUNDING_BOX = new RectF(674, 283, 910, 519); + private static final List PORTRAIT_ROTATED_FACE_KEYPOINTS = + Collections.unmodifiableList( + Arrays.asList( + NormalizedKeypoint.create(0.82075f, 0.44679f), + NormalizedKeypoint.create(0.81965f, 0.56261f), + NormalizedKeypoint.create(0.76194f, 0.51719f), + NormalizedKeypoint.create(0.71993f, 0.51360f), + NormalizedKeypoint.create(0.80700f, 0.36298f), + NormalizedKeypoint.create(0.80882f, 0.61204f))); + + @RunWith(AndroidJUnit4.class) + public static final class General extends FaceDetectorTest { + + @Test + public void detect_successWithValidModels() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + + @Test + public void detect_succeedsWithMinDetectionConfidence() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setMinDetectionConfidence(1.0f) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + // Set minDetectionConfidence to 1.0, so the detected face should be all filtered out. + assertThat(results.detections().isEmpty()).isTrue(); + } + + @Test + public void detect_succeedsWithEmptyFace() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setMinDetectionConfidence(1.0f) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(CAT_IMAGE)); + assertThat(results.detections().isEmpty()).isTrue(); + } + + @Test + public void detect_succeedsWithModelFileObject() throws Exception { + FaceDetector faceDetector = + FaceDetector.createFromFile( + ApplicationProvider.getApplicationContext(), + TestUtils.loadFile(ApplicationProvider.getApplicationContext(), MODEL_FILE)); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + + @Test + public void detect_succeedsWithModelBuffer() throws Exception { + FaceDetector faceDetector = + FaceDetector.createFromBuffer( + ApplicationProvider.getApplicationContext(), + TestUtils.loadToDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MODEL_FILE)); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + + @Test + public void detect_succeedsWithModelBufferAndOptions() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions( + BaseOptions.builder() + .setModelAssetBuffer( + TestUtils.loadToDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MODEL_FILE)) + .build()) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + + @Test + public void create_failsWithMissingModel() throws Exception { + String nonexistentFile = "/path/to/non/existent/file"; + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + FaceDetector.createFromFile( + ApplicationProvider.getApplicationContext(), nonexistentFile)); + assertThat(exception).hasMessageThat().contains(nonexistentFile); + } + + @Test + public void create_failsWithInvalidModelBuffer() throws Exception { + // Create a non-direct model ByteBuffer. + ByteBuffer modelBuffer = + TestUtils.loadToNonDirectByteBuffer( + ApplicationProvider.getApplicationContext(), MODEL_FILE); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + FaceDetector.createFromBuffer( + ApplicationProvider.getApplicationContext(), modelBuffer)); + + assertThat(exception) + .hasMessageThat() + .contains("The model buffer should be either a direct ByteBuffer or a MappedByteBuffer."); + } + + @Test + public void detect_succeedsWithRotation() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRotationDegrees(-90).build(); + FaceDetectorResult results = + faceDetector.detect(getImageFromAsset(PORTRAIT_ROTATED_IMAGE), imageProcessingOptions); + assertContainsSinglePortraitFace( + results, PORTRAIT_ROTATED_FACE_BOUNDING_BOX, PORTRAIT_ROTATED_FACE_KEYPOINTS); + } + + @Test + public void detect_failsWithRegionOfInterest() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + ImageProcessingOptions imageProcessingOptions = + ImageProcessingOptions.builder().setRegionOfInterest(new RectF(0, 0, 1, 1)).build(); + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE), imageProcessingOptions)); + assertThat(exception) + .hasMessageThat() + .contains("FaceDetector doesn't support region-of-interest"); + } + } + + @RunWith(AndroidJUnit4.class) + public static final class RunningModeTest extends FaceDetectorTest { + + @Test + public void create_failsWithIllegalResultListenerInNonLiveStreamMode() throws Exception { + for (RunningMode mode : new RunningMode[] {RunningMode.IMAGE, RunningMode.VIDEO}) { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(mode) + .setResultListener((faceDetectorResult, inputImage) -> {}) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener shouldn't be provided"); + } + } + + @Test + public void create_failsWithMissingResultListenerInLiveSteamMode() throws Exception { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .build()); + assertThat(exception) + .hasMessageThat() + .contains("a user-defined result listener must be provided"); + } + + @Test + public void detect_failsWithCallingWrongApiInImageMode() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> + faceDetector.detectForVideo( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + faceDetector.detectAsync( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInVideoMode() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + faceDetector.detectAsync( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the live stream mode"); + } + + @Test + public void detect_failsWithCallingWrongApiInLiveSteamMode() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener((faceDetectorResult, inputImage) -> {}) + .build(); + + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE))); + assertThat(exception).hasMessageThat().contains("not initialized with the image mode"); + exception = + assertThrows( + MediaPipeException.class, + () -> + faceDetector.detectForVideo( + getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ 0)); + assertThat(exception).hasMessageThat().contains("not initialized with the video mode"); + } + + @Test + public void detect_successWithImageMode() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.IMAGE) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + FaceDetectorResult results = faceDetector.detect(getImageFromAsset(PORTRAIT_IMAGE)); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + + @Test + public void detect_successWithVideoMode() throws Exception { + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.VIDEO) + .build(); + FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options); + for (int i = 0; i < 3; i++) { + FaceDetectorResult results = + faceDetector.detectForVideo(getImageFromAsset(PORTRAIT_IMAGE), /* timestampsMs= */ i); + assertContainsSinglePortraitFace( + results, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + } + } + + @Test + public void detect_failsWithOutOfOrderInputTimestamps() throws Exception { + MPImage image = getImageFromAsset(PORTRAIT_IMAGE); + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (faceDetectorResult, inputImage) -> { + assertContainsSinglePortraitFace( + faceDetectorResult, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + }) + .build(); + try (FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + faceDetector.detectAsync(image, /* timestampsMs= */ 1); + MediaPipeException exception = + assertThrows( + MediaPipeException.class, + () -> faceDetector.detectAsync(image, /* timestampsMs= */ 0)); + assertThat(exception) + .hasMessageThat() + .contains("having a smaller timestamp than the processed timestamp"); + } + } + + @Test + public void detect_successWithLiveSteamMode() throws Exception { + MPImage image = getImageFromAsset(PORTRAIT_IMAGE); + FaceDetectorOptions options = + FaceDetectorOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(MODEL_FILE).build()) + .setRunningMode(RunningMode.LIVE_STREAM) + .setResultListener( + (faceDetectorResult, inputImage) -> { + assertContainsSinglePortraitFace( + faceDetectorResult, PORTRAIT_FACE_BOUNDING_BOX, PORTRAIT_FACE_KEYPOINTS); + }) + .build(); + try (FaceDetector faceDetector = + FaceDetector.createFromOptions(ApplicationProvider.getApplicationContext(), options)) { + for (int i = 0; i < 3; i++) { + faceDetector.detectAsync(image, /* timestampsMs= */ i); + } + } + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } + + private static void assertContainsSinglePortraitFace( + FaceDetectorResult results, + RectF expectedboundingBox, + List expectedKeypoints) { + assertThat(results.detections()).hasSize(1); + assertApproximatelyEqualBoundingBoxes( + results.detections().get(0).boundingBox(), expectedboundingBox); + assertThat(results.detections().get(0).keypoints().isPresent()).isTrue(); + assertApproximatelyEqualKeypoints( + results.detections().get(0).keypoints().get(), expectedKeypoints); + } + + private static void assertApproximatelyEqualBoundingBoxes( + RectF boundingBox1, RectF boundingBox2) { + assertThat(boundingBox1.left).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.left); + assertThat(boundingBox1.top).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.top); + assertThat(boundingBox1.right).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.right); + assertThat(boundingBox1.bottom).isWithin(PIXEL_DIFF_TOLERANCE).of(boundingBox2.bottom); + } + + private static void assertApproximatelyEqualKeypoints( + List keypoints1, List keypoints2) { + assertThat(keypoints1.size()).isEqualTo(keypoints2.size()); + for (int i = 0; i < keypoints1.size(); i++) { + assertThat(keypoints1.get(i).x()) + .isWithin(KEYPOINTS_DIFF_TOLERANCE) + .of(keypoints2.get(i).x()); + assertThat(keypoints1.get(i).y()) + .isWithin(KEYPOINTS_DIFF_TOLERANCE) + .of(keypoints2.get(i).y()); + } + } +} diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java index 16f591c40..3b35c21bc 100644 --- a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/imagesegmenter/ImageSegmenterTest.java @@ -34,6 +34,7 @@ import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenter.ImageSegm import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.FloatBuffer; +import java.util.Arrays; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -135,6 +136,45 @@ public class ImageSegmenterTest { // MPImage expectedMaskBuffer = getImageFromAsset(goldenImageName); // verifyConfidenceMask(actualMaskBuffer, expectedMaskBuffer, GOLDEN_MASK_SIMILARITY); // } + + @Test + public void getLabels_success() throws Exception { + final List expectedLabels = + Arrays.asList( + "background", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "dining table", + "dog", + "horse", + "motorbike", + "person", + "potted plant", + "sheep", + "sofa", + "train", + "tv"); + ImageSegmenterOptions options = + ImageSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(ImageSegmenterOptions.OutputType.CONFIDENCE_MASK) + .build(); + ImageSegmenter imageSegmenter = + ImageSegmenter.createFromOptions(ApplicationProvider.getApplicationContext(), options); + List actualLabels = imageSegmenter.getLabels(); + assertThat(actualLabels.size()).isEqualTo(expectedLabels.size()); + for (int i = 0; i < actualLabels.size(); i++) { + assertThat(actualLabels.get(i)).isEqualTo(expectedLabels.get(i)); + } + } } @RunWith(AndroidJUnit4.class) diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/AndroidManifest.xml b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/AndroidManifest.xml new file mode 100644 index 000000000..97280f5e4 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/AndroidManifest.xml @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/BUILD b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/BUILD new file mode 100644 index 000000000..c14486766 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/BUILD @@ -0,0 +1,19 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +# TODO: Enable this in OSS diff --git a/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java new file mode 100644 index 000000000..0d9581437 --- /dev/null +++ b/mediapipe/tasks/javatests/com/google/mediapipe/tasks/vision/interactivesegmenter/InteractiveSegmenterTest.java @@ -0,0 +1,92 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.mediapipe.tasks.vision.interactivesegmenter; + +import static com.google.common.truth.Truth.assertThat; + +import android.content.res.AssetManager; +import android.graphics.BitmapFactory; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.mediapipe.framework.image.BitmapImageBuilder; +import com.google.mediapipe.framework.image.MPImage; +import com.google.mediapipe.tasks.components.containers.NormalizedKeypoint; +import com.google.mediapipe.tasks.core.BaseOptions; +import com.google.mediapipe.tasks.vision.imagesegmenter.ImageSegmenterResult; +import com.google.mediapipe.tasks.vision.interactivesegmenter.InteractiveSegmenter.InteractiveSegmenterOptions; +import java.io.InputStream; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; +import org.junit.runners.Suite.SuiteClasses; + +/** Test for {@link InteractiveSegmenter}. */ +@RunWith(Suite.class) +@SuiteClasses({ + InteractiveSegmenterTest.General.class, +}) +public class InteractiveSegmenterTest { + private static final String DEEPLAB_MODEL_FILE = "ptm_512_hdt_ptm_woid.tflite"; + private static final String CATS_AND_DOGS_IMAGE = "cats_and_dogs.jpg"; + private static final int MAGNIFICATION_FACTOR = 10; + + @RunWith(AndroidJUnit4.class) + public static final class General extends InteractiveSegmenterTest { + @Test + public void segment_successWithCategoryMask() throws Exception { + final String inputImageName = CATS_AND_DOGS_IMAGE; + final InteractiveSegmenter.RegionOfInterest roi = + InteractiveSegmenter.RegionOfInterest.create(NormalizedKeypoint.create(0.25f, 0.9f)); + InteractiveSegmenterOptions options = + InteractiveSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(InteractiveSegmenterOptions.OutputType.CATEGORY_MASK) + .build(); + InteractiveSegmenter imageSegmenter = + InteractiveSegmenter.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + MPImage image = getImageFromAsset(inputImageName); + ImageSegmenterResult actualResult = imageSegmenter.segment(image, roi); + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(1); + } + + @Test + public void segment_successWithConfidenceMask() throws Exception { + final String inputImageName = CATS_AND_DOGS_IMAGE; + final InteractiveSegmenter.RegionOfInterest roi = + InteractiveSegmenter.RegionOfInterest.create(NormalizedKeypoint.create(0.25f, 0.9f)); + InteractiveSegmenterOptions options = + InteractiveSegmenterOptions.builder() + .setBaseOptions(BaseOptions.builder().setModelAssetPath(DEEPLAB_MODEL_FILE).build()) + .setOutputType(InteractiveSegmenterOptions.OutputType.CONFIDENCE_MASK) + .build(); + InteractiveSegmenter imageSegmenter = + InteractiveSegmenter.createFromOptions( + ApplicationProvider.getApplicationContext(), options); + ImageSegmenterResult actualResult = + imageSegmenter.segment(getImageFromAsset(inputImageName), roi); + List segmentations = actualResult.segmentations(); + assertThat(segmentations.size()).isEqualTo(2); + } + } + + private static MPImage getImageFromAsset(String filePath) throws Exception { + AssetManager assetManager = ApplicationProvider.getApplicationContext().getAssets(); + InputStream istr = assetManager.open(filePath); + return new BitmapImageBuilder(BitmapFactory.decodeStream(istr)).build(); + } +} diff --git a/mediapipe/tasks/metadata/BUILD b/mediapipe/tasks/metadata/BUILD index abd948809..de6350685 100644 --- a/mediapipe/tasks/metadata/BUILD +++ b/mediapipe/tasks/metadata/BUILD @@ -7,7 +7,9 @@ package( licenses = ["notice"], # Apache 2.0 ) -exports_files(["metadata_schema.fbs"]) +exports_files(glob([ + "*.fbs", +])) # Generic schema for model metadata. flatbuffer_cc_library( @@ -24,3 +26,13 @@ flatbuffer_py_library( name = "metadata_schema_py", srcs = ["metadata_schema.fbs"], ) + +flatbuffer_cc_library( + name = "image_segmenter_metadata_schema_cc", + srcs = ["image_segmenter_metadata_schema.fbs"], +) + +flatbuffer_py_library( + name = "image_segmenter_metadata_schema_py", + srcs = ["image_segmenter_metadata_schema.fbs"], +) diff --git a/mediapipe/tasks/metadata/image_segmenter_metadata_schema.fbs b/mediapipe/tasks/metadata/image_segmenter_metadata_schema.fbs new file mode 100644 index 000000000..e120c64aa --- /dev/null +++ b/mediapipe/tasks/metadata/image_segmenter_metadata_schema.fbs @@ -0,0 +1,59 @@ +// Copyright 2023 The MediaPipe Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace mediapipe.tasks; + +// Image segmenter metadata contains information specific for the image +// segmentation task. The metadata can be added in +// SubGraphMetadata.custom_metadata [1] in model metadata. +// [1]: https://github.com/google/mediapipe/blob/46b5c4012d2ef76c9d92bb0d88a6b107aee83814/mediapipe/tasks/metadata/metadata_schema.fbs#L685 + +// ImageSegmenterOptions.min_parser_version indicates the minimum necessary +// image segmenter metadata parser version to fully understand all fields in a +// given metadata flatbuffer. This min_parser_version is specific for the +// image segmenter metadata defined in this schema file. +// +// New fields and types will have associated comments with the schema version +// for which they were added. +// +// Schema Semantic version: 1.0.0 + +// This indicates the flatbuffer compatibility. The number will bump up when a +// break change is applied to the schema, such as removing fields or adding new +// fields to the middle of a table. +file_identifier "V001"; + +// History: +// 1.0.0 - Initial version. + +// Supported activation functions. +enum Activation: byte { + NONE = 0, + SIGMOID = 1, + SOFTMAX = 2 +} + +table ImageSegmenterOptions { + // The activation function of the output layer in the image segmenter. + activation: Activation; + + // The minimum necessary image segmenter metadata parser version to fully + // understand all fields in a given metadata flatbuffer. This field is + // automaticaly populated by the MetadataPopulator when the metadata is + // populated into a TFLite model. This min_parser_version is specific for the + // image segmenter metadata defined in this schema file. + min_parser_version:string; +} + +root_type ImageSegmenterOptions; diff --git a/mediapipe/tasks/metadata/metadata_schema.fbs b/mediapipe/tasks/metadata/metadata_schema.fbs index 3253e1ea8..8fe7a08fa 100644 --- a/mediapipe/tasks/metadata/metadata_schema.fbs +++ b/mediapipe/tasks/metadata/metadata_schema.fbs @@ -233,7 +233,7 @@ table ImageProperties { // // : // Input image tensors: NA. -// Output image tensors: parses the values into a data stucture that represents +// Output image tensors: parses the values into a data structure that represents // bounding boxes. For example, in the generated wrapper for Android, it returns // the output as android.graphics.Rect objects. enum BoundingBoxType : byte { @@ -389,7 +389,7 @@ table NormalizationOptions{ // mean and std are normalization parameters. Tensor values are normalized // on a per-channel basis, by the formula // (x - mean) / std. - // If there is only one value in mean or std, we'll propogate the value to + // If there is only one value in mean or std, we'll propagate the value to // all channels. // // Quantized models share the same normalization parameters as their @@ -526,7 +526,7 @@ table Stats { // Max and min are not currently used in tflite.support codegen. They mainly // serve as references for users to better understand the model. They can also // be used to validate model pre/post processing results. - // If there is only one value in max or min, we'll propogate the value to + // If there is only one value in max or min, we'll propagate the value to // all channels. // Per-channel maximum value of the tensor. @@ -542,7 +542,7 @@ table Stats { // has four outputs: classes, scores, bounding boxes, and number of detections. // If the four outputs are bundled together using TensorGroup (for example, // named as "detection result"), the codegen tool will generate the class, -// `DetectionResult`, which contains the class, score, and bouding box. And the +// `DetectionResult`, which contains the class, score, and bounding box. And the // outputs of the model will be converted to a list of `DetectionResults` and // the number of detection. Note that the number of detection is a single // number, therefore is inappropriate for the list of `DetectionResult`. @@ -624,7 +624,7 @@ table SubGraphMetadata { // A description explains details about what the subgraph does. description:string; - // Metadata of all input tensors used in this subgraph. It matches extactly + // Metadata of all input tensors used in this subgraph. It matches exactly // the input tensors specified by `SubGraph.inputs` in the TFLite // schema.fbs file[2]. The number of `TensorMetadata` in the array should // equal to the number of indices in `SubGraph.inputs`. @@ -634,7 +634,7 @@ table SubGraphMetadata { // Determines how to process the inputs. input_tensor_metadata:[TensorMetadata]; - // Metadata of all output tensors used in this subgraph. It matches extactly + // Metadata of all output tensors used in this subgraph. It matches exactly // the output tensors specified by `SubGraph.outputs` in the TFLite // schema.fbs file[2]. The number of `TensorMetadata` in the array should // equal to the number of indices in `SubGraph.outputs`. @@ -724,7 +724,7 @@ table ModelMetadata { // number among the versions of all the fields populated and the smallest // compatible version indicated by the file identifier. // - // This field is automaticaly populated by the MetadataPopulator when + // This field is automatically populated by the MetadataPopulator when // the metadata is populated into a TFLite model. min_parser_version:string; } diff --git a/mediapipe/tasks/python/components/containers/BUILD b/mediapipe/tasks/python/components/containers/BUILD index 7108617ff..b84ab744d 100644 --- a/mediapipe/tasks/python/components/containers/BUILD +++ b/mediapipe/tasks/python/components/containers/BUILD @@ -73,12 +73,22 @@ py_library( ], ) +py_library( + name = "keypoint", + srcs = ["keypoint.py"], + deps = [ + "//mediapipe/framework/formats:location_data_py_pb2", + "//mediapipe/tasks/python/core:optional_dependencies", + ], +) + py_library( name = "detections", srcs = ["detections.py"], deps = [ ":bounding_box", ":category", + ":keypoint", "//mediapipe/framework/formats:detection_py_pb2", "//mediapipe/framework/formats:location_data_py_pb2", "//mediapipe/tasks/python/core:optional_dependencies", diff --git a/mediapipe/tasks/python/components/containers/detections.py b/mediapipe/tasks/python/components/containers/detections.py index b4d550633..935d294a6 100644 --- a/mediapipe/tasks/python/components/containers/detections.py +++ b/mediapipe/tasks/python/components/containers/detections.py @@ -14,12 +14,13 @@ """Detections data class.""" import dataclasses -from typing import Any, List +from typing import Any, List, Optional from mediapipe.framework.formats import detection_pb2 from mediapipe.framework.formats import location_data_pb2 from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import keypoint as keypoint_module from mediapipe.tasks.python.core.optional_dependencies import doc_controls _DetectionListProto = detection_pb2.DetectionList @@ -34,10 +35,12 @@ class Detection: Attributes: bounding_box: A BoundingBox object. categories: A list of Category objects. + keypoints: A list of NormalizedKeypoint objects. """ bounding_box: bounding_box_module.BoundingBox categories: List[category_module.Category] + keypoints: Optional[List[keypoint_module.NormalizedKeypoint]] = None @doc_controls.do_not_generate_docs def to_pb2(self) -> _DetectionProto: @@ -46,6 +49,8 @@ class Detection: label_ids = [] scores = [] display_names = [] + relative_keypoints = [] + for category in self.categories: scores.append(category.score) if category.index: @@ -54,6 +59,20 @@ class Detection: labels.append(category.category_name) if category.display_name: display_names.append(category.display_name) + + if self.keypoints: + for keypoint in self.keypoints: + relative_keypoint_proto = _LocationDataProto.RelativeKeypoint() + if keypoint.x: + relative_keypoint_proto.x = keypoint.x + if keypoint.y: + relative_keypoint_proto.y = keypoint.y + if keypoint.label: + relative_keypoint_proto.keypoint_label = keypoint.label + if keypoint.score: + relative_keypoint_proto.score = keypoint.score + relative_keypoints.append(relative_keypoint_proto) + return _DetectionProto( label=labels, label_id=label_ids, @@ -61,28 +80,52 @@ class Detection: display_name=display_names, location_data=_LocationDataProto( format=_LocationDataProto.Format.BOUNDING_BOX, - bounding_box=self.bounding_box.to_pb2())) + bounding_box=self.bounding_box.to_pb2(), + relative_keypoints=relative_keypoints, + ), + ) @classmethod @doc_controls.do_not_generate_docs def create_from_pb2(cls, pb2_obj: _DetectionProto) -> 'Detection': """Creates a `Detection` object from the given protobuf object.""" categories = [] + keypoints = [] + for idx, score in enumerate(pb2_obj.score): categories.append( category_module.Category( score=score, index=pb2_obj.label_id[idx] - if idx < len(pb2_obj.label_id) else None, + if idx < len(pb2_obj.label_id) + else None, category_name=pb2_obj.label[idx] - if idx < len(pb2_obj.label) else None, + if idx < len(pb2_obj.label) + else None, display_name=pb2_obj.display_name[idx] - if idx < len(pb2_obj.display_name) else None)) + if idx < len(pb2_obj.display_name) + else None, + ) + ) + + if pb2_obj.location_data.relative_keypoints: + for idx, elem in enumerate(pb2_obj.location_data.relative_keypoints): + keypoints.append( + keypoint_module.NormalizedKeypoint( + x=elem.x, + y=elem.y, + label=elem.keypoint_label, + score=elem.score, + ) + ) return Detection( bounding_box=bounding_box_module.BoundingBox.create_from_pb2( - pb2_obj.location_data.bounding_box), - categories=categories) + pb2_obj.location_data.bounding_box + ), + categories=categories, + keypoints=keypoints, + ) def __eq__(self, other: Any) -> bool: """Checks if this object is equal to the given object. diff --git a/mediapipe/tasks/python/components/containers/keypoint.py b/mediapipe/tasks/python/components/containers/keypoint.py new file mode 100644 index 000000000..3d957f3d8 --- /dev/null +++ b/mediapipe/tasks/python/components/containers/keypoint.py @@ -0,0 +1,77 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Keypoint data class.""" + +import dataclasses +from typing import Any, Optional + +from mediapipe.framework.formats import location_data_pb2 +from mediapipe.tasks.python.core.optional_dependencies import doc_controls + +_RelativeKeypointProto = location_data_pb2.LocationData.RelativeKeypoint + + +@dataclasses.dataclass +class NormalizedKeypoint: + """A normalized keypoint. + + Normalized keypoint represents a point in 2D space with x, y coordinates. + x and y are normalized to [0.0, 1.0] by the image width and height + respectively. + + Attributes: + x: The x coordinates of the normalized keypoint. + y: The y coordinates of the normalized keypoint. + label: The optional label of the keypoint. + score: The score of the keypoint. + """ + + x: Optional[float] = None + y: Optional[float] = None + label: Optional[str] = None + score: Optional[float] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _RelativeKeypointProto: + """Generates a RelativeKeypoint protobuf object.""" + return _RelativeKeypointProto( + x=self.x, y=self.y, keypoint_label=self.label, score=self.score + ) + + @classmethod + @doc_controls.do_not_generate_docs + def create_from_pb2( + cls, pb2_obj: _RelativeKeypointProto + ) -> 'NormalizedKeypoint': + """Creates a `NormalizedKeypoint` object from the given protobuf object.""" + return NormalizedKeypoint( + x=pb2_obj.x, + y=pb2_obj.y, + label=pb2_obj.keypoint_label, + score=pb2_obj.score, + ) + + def __eq__(self, other: Any) -> bool: + """Checks if this object is equal to the given object. + + Args: + other: The object to be compared with. + + Returns: + True if the objects are equal. + """ + if not isinstance(other, NormalizedKeypoint): + return False + + return self.to_pb2().__eq__(other.to_pb2()) diff --git a/mediapipe/tasks/python/components/utils/cosine_similarity.py b/mediapipe/tasks/python/components/utils/cosine_similarity.py index ff8979458..a6245a579 100644 --- a/mediapipe/tasks/python/components/utils/cosine_similarity.py +++ b/mediapipe/tasks/python/components/utils/cosine_similarity.py @@ -20,20 +20,20 @@ from mediapipe.tasks.python.components.containers import embedding_result _Embedding = embedding_result.Embedding -def _compute_cosine_similarity(u, v): +def _compute_cosine_similarity(u: np.ndarray, v: np.ndarray): """Computes cosine similarity between two embeddings.""" - if len(u.embedding) <= 0: + if len(u) <= 0: raise ValueError("Cannot compute cosing similarity on empty embeddings.") - norm_u = np.linalg.norm(u.embedding) - norm_v = np.linalg.norm(v.embedding) + norm_u = np.linalg.norm(u) + norm_v = np.linalg.norm(v) if norm_u <= 0 or norm_v <= 0: raise ValueError( "Cannot compute cosine similarity on embedding with 0 norm.") - return np.dot(u.embedding, v.embedding.T) / (norm_u * norm_v) + return u.dot(v) / (norm_u * norm_v) def cosine_similarity(u: _Embedding, v: _Embedding) -> float: @@ -56,10 +56,13 @@ def cosine_similarity(u: _Embedding, v: _Embedding) -> float: f"({len(u.embedding)} vs. {len(v.embedding)}).") if u.embedding.dtype == float and v.embedding.dtype == float: - return _compute_cosine_similarity(u, v) + return _compute_cosine_similarity(u.embedding, v.embedding) if u.embedding.dtype == np.uint8 and v.embedding.dtype == np.uint8: - return _compute_cosine_similarity(u, v) + return _compute_cosine_similarity( + u.embedding.view("int8").astype("float"), + v.embedding.view("int8").astype("float"), + ) raise ValueError("Cannot compute cosine similarity between quantized and " "float embeddings.") diff --git a/mediapipe/tasks/python/core/pybind/task_runner.cc b/mediapipe/tasks/python/core/pybind/task_runner.cc index cb13787c3..aa48a1a9a 100644 --- a/mediapipe/tasks/python/core/pybind/task_runner.cc +++ b/mediapipe/tasks/python/core/pybind/task_runner.cc @@ -96,7 +96,7 @@ Args: Raises: RuntimeError: Any of the following: a) The graph config proto is invalid. - b) The underlying medipaipe graph fails to initilize and start. + b) The underlying medipaipe graph fails to initialize and start. )doc", py::arg("graph_config"), py::arg("packets_callback") = py::none()); @@ -120,7 +120,7 @@ This method is designed for processing either batch data such as unrelated images and texts or offline streaming data such as the decoded frames from a video file and an audio file. The call blocks the current thread until a failure status or a successful result is returned. -If the input packets have no timestamp, an internal timestamp will be assigend +If the input packets have no timestamp, an internal timestamp will be assigned per invocation. Otherwise, when the timestamp is set in the input packets, the caller must ensure that the input packet timestamps are greater than the timestamps of the previous invocation. This method is thread-unsafe and it is diff --git a/mediapipe/tasks/python/metadata/metadata.py b/mediapipe/tasks/python/metadata/metadata.py index 6afb5a3fa..25d83cae8 100644 --- a/mediapipe/tasks/python/metadata/metadata.py +++ b/mediapipe/tasks/python/metadata/metadata.py @@ -17,10 +17,13 @@ import copy import inspect import io +import json +import logging import os import shutil import sys import tempfile +from typing import Dict, Optional import warnings import zipfile @@ -109,10 +112,10 @@ class MetadataPopulator(object): mediapipe/tasks/metadata/metadata_schema.fbs Example usage: - Populate matadata and label file into an image classifier model. + Populate metadata and label file into an image classifier model. First, based on metadata_schema.fbs, generate the metadata for this image - classifer model using Flatbuffers API. Attach the label file onto the ouput + classifier model using Flatbuffers API. Attach the label file onto the output tensor (the tensor of probabilities) in the metadata. Then, pack the metadata and label file into the model as follows. @@ -170,7 +173,7 @@ class MetadataPopulator(object): Raises: IOError: File not found. - ValueError: the model does not have the expected flatbuffer identifer. + ValueError: the model does not have the expected flatbuffer identifier. """ _assert_model_file_identifier(model_file) self._model_file = model_file @@ -190,7 +193,7 @@ class MetadataPopulator(object): Raises: IOError: File not found. - ValueError: the model does not have the expected flatbuffer identifer. + ValueError: the model does not have the expected flatbuffer identifier. """ return cls(model_file) @@ -207,7 +210,7 @@ class MetadataPopulator(object): A MetadataPopulator(_MetadataPopulatorWithBuffer) object. Raises: - ValueError: the model does not have the expected flatbuffer identifer. + ValueError: the model does not have the expected flatbuffer identifier. """ return _MetadataPopulatorWithBuffer(model_buf) @@ -290,7 +293,7 @@ class MetadataPopulator(object): Raises: ValueError: The metadata to be populated is empty. - ValueError: The metadata does not have the expected flatbuffer identifer. + ValueError: The metadata does not have the expected flatbuffer identifier. ValueError: Cannot get minimum metadata parser version. ValueError: The number of SubgraphMetadata is not 1. ValueError: The number of input/output tensors does not match the number @@ -643,7 +646,7 @@ class MetadataPopulator(object): class _MetadataPopulatorWithBuffer(MetadataPopulator): - """Subclass of MetadtaPopulator that populates metadata to a model buffer. + """Subclass of MetadataPopulator that populates metadata to a model buffer. This class is used to populate metadata into a in-memory model buffer. As we use Zip API to concatenate associated files after tflite model file, the @@ -661,7 +664,7 @@ class _MetadataPopulatorWithBuffer(MetadataPopulator): Raises: ValueError: model_buf is empty. - ValueError: model_buf does not have the expected flatbuffer identifer. + ValueError: model_buf does not have the expected flatbuffer identifier. """ if not model_buf: raise ValueError("model_buf cannot be empty.") @@ -789,21 +792,50 @@ class MetadataDisplayer(object): return [] +def _get_custom_metadata(metadata_buffer: bytes, name: str): + """Gets the custom metadata in metadata_buffer based on the name. + + Args: + metadata_buffer: valid metadata buffer in bytes. + name: custom metadata name. + + Returns: + Index of custom metadata, custom metadata flatbuffer. Returns (None, None) + if the custom metadata is not found. + """ + model_metadata = _metadata_fb.ModelMetadata.GetRootAs(metadata_buffer) + subgraph = model_metadata.SubgraphMetadata(0) + if subgraph is None or subgraph.CustomMetadataIsNone(): + return None, None + + for i in range(subgraph.CustomMetadataLength()): + custom_metadata = subgraph.CustomMetadata(i) + if custom_metadata.Name().decode("utf-8") == name: + return i, custom_metadata.DataAsNumpy().tobytes() + return None, None + + # Create an individual method for getting the metadata json file, so that it can # be used as a standalone util. -def convert_to_json(metadata_buffer): +def convert_to_json( + metadata_buffer, custom_metadata_schema: Optional[Dict[str, str]] = None +) -> str: """Converts the metadata into a json string. Args: metadata_buffer: valid metadata buffer in bytes. + custom_metadata_schema: A dict of custom metadata schema, in which key is + custom metadata name [1], value is the filepath that defines custom + metadata schema. For instance, custom_metadata_schema = + {"SEGMENTER_METADATA": "metadata/vision_tasks_metadata_schema.fbs"}. [1]: + https://github.com/google/mediapipe/blob/46b5c4012d2ef76c9d92bb0d88a6b107aee83814/mediapipe/tasks/metadata/metadata_schema.fbs#L612 Returns: Metadata in JSON format. Raises: - ValueError: error occured when parsing the metadata schema file. + ValueError: error occurred when parsing the metadata schema file. """ - opt = _pywrap_flatbuffers.IDLOptions() opt.strict_json = True parser = _pywrap_flatbuffers.Parser(opt) @@ -811,7 +843,35 @@ def convert_to_json(metadata_buffer): metadata_schema_content = f.read() if not parser.parse(metadata_schema_content): raise ValueError("Cannot parse metadata schema. Reason: " + parser.error) - return _pywrap_flatbuffers.generate_text(parser, metadata_buffer) + # Json content which may contain binary custom metadata. + raw_json_content = _pywrap_flatbuffers.generate_text(parser, metadata_buffer) + if not custom_metadata_schema: + return raw_json_content + + json_data = json.loads(raw_json_content) + # Gets the custom metadata by name and parse the binary custom metadata into + # human readable json content. + for name, schema_file in custom_metadata_schema.items(): + idx, custom_metadata = _get_custom_metadata(metadata_buffer, name) + if not custom_metadata: + logging.info( + "No custom metadata with name %s in metadata flatbuffer.", name + ) + continue + _assert_file_exist(schema_file) + with _open_file(schema_file, "rb") as f: + custom_metadata_schema_content = f.read() + if not parser.parse(custom_metadata_schema_content): + raise ValueError( + "Cannot parse custom metadata schema. Reason: " + parser.error + ) + custom_metadata_json = _pywrap_flatbuffers.generate_text( + parser, custom_metadata + ) + json_meta = json_data["subgraph_metadata"][0]["custom_metadata"][idx] + json_meta["name"] = name + json_meta["data"] = json.loads(custom_metadata_json) + return json.dumps(json_data, indent=2) def _assert_file_exist(filename): diff --git a/mediapipe/tasks/python/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/metadata/metadata_writers/BUILD index ce572283f..1f126c30b 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/metadata/metadata_writers/BUILD @@ -50,6 +50,20 @@ py_library( deps = [":metadata_writer"], ) +py_library( + name = "image_segmenter", + srcs = ["image_segmenter.py"], + data = ["//mediapipe/tasks/metadata:image_segmenter_metadata_schema.fbs"], + deps = [ + ":metadata_info", + ":metadata_writer", + "//mediapipe/tasks/metadata:image_segmenter_metadata_schema_py", + "//mediapipe/tasks/metadata:metadata_schema_py", + "//mediapipe/tasks/python/metadata", + "@flatbuffers//:runtime_py", + ], +) + py_library( name = "object_detector", srcs = ["object_detector.py"], diff --git a/mediapipe/tasks/python/metadata/metadata_writers/image_segmenter.py b/mediapipe/tasks/python/metadata/metadata_writers/image_segmenter.py new file mode 100644 index 000000000..3268f3b1f --- /dev/null +++ b/mediapipe/tasks/python/metadata/metadata_writers/image_segmenter.py @@ -0,0 +1,161 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Writes metadata and label file to the image segmenter models.""" +import enum +from typing import List, Optional + +import flatbuffers +from mediapipe.tasks.metadata import image_segmenter_metadata_schema_py_generated as _segmenter_metadata_fb +from mediapipe.tasks.metadata import metadata_schema_py_generated as _metadata_fb +from mediapipe.tasks.python.metadata import metadata +from mediapipe.tasks.python.metadata.metadata_writers import metadata_info +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer + + +_MODEL_NAME = "ImageSegmenter" +_MODEL_DESCRIPTION = ( + "Semantic image segmentation predicts whether each pixel " + "of an image is associated with a certain class." +) + +# Metadata Schema file for image segmenter. +_FLATC_METADATA_SCHEMA_FILE = metadata.get_path_to_datafile( + "../../../metadata/image_segmenter_metadata_schema.fbs", +) + +# Metadata name in custom metadata field. The metadata name is used to get +# image segmenter metadata from SubGraphMetadata.custom_metadata and +# shouldn't be changed. +_METADATA_NAME = "SEGMENTER_METADATA" + + +class Activation(enum.Enum): + NONE = 0 + SIGMOID = 1 + SOFTMAX = 2 + + +# Create an individual method for getting the metadata json file, so that it can +# be used as a standalone util. +def convert_to_json(metadata_buffer: bytearray) -> str: + """Converts the metadata into a json string. + + Args: + metadata_buffer: valid metadata buffer in bytes. + + Returns: + Metadata in JSON format. + + Raises: + ValueError: error occurred when parsing the metadata schema file. + """ + return metadata.convert_to_json( + metadata_buffer, + custom_metadata_schema={_METADATA_NAME: _FLATC_METADATA_SCHEMA_FILE}, + ) + + +class ImageSegmenterOptionsMd(metadata_info.CustomMetadataMd): + """Image segmenter options metadata.""" + + _METADATA_FILE_IDENTIFIER = b"V001" + + def __init__(self, activation: Activation) -> None: + """Creates an ImageSegmenterOptionsMd object. + + Args: + activation: activation function of the output layer in the image + segmenter. + """ + self.activation = activation + super().__init__(name=_METADATA_NAME) + + def create_metadata(self) -> _metadata_fb.CustomMetadataT: + """Creates the image segmenter options metadata. + + Returns: + A Flatbuffers Python object of the custom metadata including image + segmenter options metadata. + """ + segmenter_options = _segmenter_metadata_fb.ImageSegmenterOptionsT() + segmenter_options.activation = self.activation.value + + # Get the image segmenter options flatbuffer. + b = flatbuffers.Builder(0) + b.Finish(segmenter_options.Pack(b), self._METADATA_FILE_IDENTIFIER) + segmenter_options_buf = b.Output() + + # Add the image segmenter options flatbuffer in custom metadata. + custom_metadata = _metadata_fb.CustomMetadataT() + custom_metadata.name = self.name + custom_metadata.data = segmenter_options_buf + return custom_metadata + + +class MetadataWriter(metadata_writer.MetadataWriterBase): + """MetadataWriter to write the metadata for image segmenter.""" + + @classmethod + def create( + cls, + model_buffer: bytearray, + input_norm_mean: List[float], + input_norm_std: List[float], + labels: Optional[metadata_writer.Labels] = None, + activation: Optional[Activation] = None, + ) -> "MetadataWriter": + """Creates MetadataWriter to write the metadata for image segmenter. + + The parameters required in this method are mandatory when using MediaPipe + Tasks. + + Example usage: + metadata_writer = image_segmenter.Metadatawriter.create(model_buffer, ...) + tflite_content, json_content = metadata_writer.populate() + + When calling `populate` function in this class, it returns TfLite content + and JSON content. Note that only the output TFLite is used for deployment. + The output JSON content is used to interpret the metadata content. + + Args: + model_buffer: A valid flatbuffer loaded from the TFLite model file. + input_norm_mean: the mean value used in the input tensor normalization + [1]. + input_norm_std: the std value used in the input tensor normalizarion [1]. + labels: an instance of Labels helper class used in the output category + tensor [2]. + activation: activation function for the output layer. + [1]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L389 + [2]: + https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L116 + + Returns: + A MetadataWriter object. + """ + writer = metadata_writer.MetadataWriter(model_buffer) + writer.add_general_info(_MODEL_NAME, _MODEL_DESCRIPTION) + writer.add_image_input(input_norm_mean, input_norm_std) + writer.add_segmentation_output(labels=labels) + if activation is not None: + option_md = ImageSegmenterOptionsMd(activation) + writer.add_custom_metadata(option_md) + return cls(writer) + + def populate(self) -> tuple[bytearray, str]: + model_buf, _ = super().populate() + metadata_buf = metadata.get_metadata_buffer(model_buf) + json_content = convert_to_json(metadata_buf) + return model_buf, json_content diff --git a/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py index 4794a12fc..f201ab7e0 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/metadata_info.py @@ -1030,6 +1030,52 @@ class TensorGroupMd: return group +class SegmentationMaskMd(TensorMd): + """A container for the segmentation mask metadata information.""" + + # The output tensor is in the shape of [1, ImageHeight, ImageWidth, N], where + # N is the number of objects that the segmentation model can recognize. The + # output tensor is essentially a list of grayscale bitmaps, where each value + # is the probability of the corresponding pixel belonging to a certain object + # type. Therefore, the content dimension range of the output tensor is [1, 2]. + _CONTENT_DIM_MIN = 1 + _CONTENT_DIM_MAX = 2 + + def __init__( + self, + name: Optional[str] = None, + description: Optional[str] = None, + label_files: Optional[List[LabelFileMd]] = None, + ): + self.name = name + self.description = description + associated_files = label_files or [] + super().__init__( + name=name, description=description, associated_files=associated_files + ) + + def create_metadata(self) -> _metadata_fb.TensorMetadataT: + """Creates the metadata for the segmentation masks tensor.""" + masks_metadata = super().create_metadata() + + # Create tensor content information. + content = _metadata_fb.ContentT() + content.contentProperties = _metadata_fb.ImagePropertiesT() + content.contentProperties.colorSpace = _metadata_fb.ColorSpaceType.GRAYSCALE + content.contentPropertiesType = ( + _metadata_fb.ContentProperties.ImageProperties + ) + # Add the content range. See + # https://github.com/google/mediapipe/blob/f8af41b1eb49ff4bdad756ff19d1d36f486be614/mediapipe/tasks/metadata/metadata_schema.fbs#L323-L385 + dim_range = _metadata_fb.ValueRangeT() + dim_range.min = self._CONTENT_DIM_MIN + dim_range.max = self._CONTENT_DIM_MAX + content.range = dim_range + masks_metadata.content = content + + return masks_metadata + + class CustomMetadataMd(abc.ABC): """An abstract class of a container for the custom metadata information.""" diff --git a/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py b/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py index fda6a64d3..e0be9beea 100644 --- a/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py +++ b/mediapipe/tasks/python/metadata/metadata_writers/metadata_writer.py @@ -34,6 +34,10 @@ _INPUT_REGEX_TEXT_DESCRIPTION = ('Embedding vectors representing the input ' 'text to be processed.') _OUTPUT_CLASSIFICATION_NAME = 'score' _OUTPUT_CLASSIFICATION_DESCRIPTION = 'Score of the labels respectively.' +_OUTPUT_SEGMENTATION_MASKS_NAME = 'segmentation_masks' +_OUTPUT_SEGMENTATION_MASKS_DESCRIPTION = ( + 'Masks over the target objects with high accuracy.' +) # Detection tensor result to be grouped together. _DETECTION_GROUP_NAME = 'detection_result' # File name to export score calibration parameters. @@ -657,6 +661,32 @@ class MetadataWriter(object): self._output_group_mds.append(group_md) return self + def add_segmentation_output( + self, + labels: Optional[Labels] = None, + name: str = _OUTPUT_SEGMENTATION_MASKS_NAME, + description: str = _OUTPUT_SEGMENTATION_MASKS_DESCRIPTION, + ) -> 'MetadataWriter': + """Adds a segmentation head metadata for segmentation output tensor. + + Args: + labels: an instance of Labels helper class. + name: Metadata name of the tensor. Note that this is different from tensor + name in the flatbuffer. + description: human readable description of what the output is. + + Returns: + The current Writer instance to allow chained operation. + """ + label_files = self._create_label_file_md(labels) + output_md = metadata_info.SegmentationMaskMd( + name=name, + description=description, + label_files=label_files, + ) + self._output_mds.append(output_md) + return self + def add_feature_output(self, name: Optional[str] = None, description: Optional[str] = None) -> 'MetadataWriter': diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD index 417e3e10c..976ddc9d2 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/BUILD @@ -91,3 +91,18 @@ py_test( "//mediapipe/tasks/python/test:test_utils", ], ) + +py_test( + name = "image_segmenter_test", + srcs = ["image_segmenter_test.py"], + data = [ + "//mediapipe/tasks/testdata/metadata:data_files", + "//mediapipe/tasks/testdata/metadata:model_files", + ], + deps = [ + "//mediapipe/tasks/python/metadata", + "//mediapipe/tasks/python/metadata/metadata_writers:image_segmenter", + "//mediapipe/tasks/python/metadata/metadata_writers:metadata_writer", + "//mediapipe/tasks/python/test:test_utils", + ], +) diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/image_segmenter_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/image_segmenter_test.py new file mode 100644 index 000000000..a12f009cd --- /dev/null +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/image_segmenter_test.py @@ -0,0 +1,98 @@ +# Copyright 2022 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for metadata_writer.image_segmenter.""" + +import os + +from absl.testing import absltest + +from mediapipe.tasks.python.metadata import metadata +from mediapipe.tasks.python.metadata.metadata_writers import image_segmenter +from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer +from mediapipe.tasks.python.test import test_utils + +_TEST_DATA_DIR = "mediapipe/tasks/testdata/metadata" +_MODEL_FILE = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, "deeplabv3_without_metadata.tflite") +) +_LABEL_FILE_NAME = "labels.txt" +_LABEL_FILE = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, "segmenter_labelmap.txt") +) +_NORM_MEAN = 127.5 +_NORM_STD = 127.5 +_JSON_FILE = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, "deeplabv3.json") +) +_JSON_FILE_WITHOUT_LABELS = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, "deeplabv3_without_labels.json") +) +_JSON_FILE_WITH_ACTIVATION = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, "deeplabv3_with_activation.json") +) + + +class ImageSegmenterTest(absltest.TestCase): + + def test_write_metadata(self): + with open(_MODEL_FILE, "rb") as f: + model_buffer = f.read() + writer = image_segmenter.MetadataWriter.create( + bytearray(model_buffer), + [_NORM_MEAN], + [_NORM_STD], + labels=metadata_writer.Labels().add_from_file(_LABEL_FILE), + ) + tflite_content, metadata_json = writer.populate() + with open(_JSON_FILE, "r") as f: + expected_json = f.read().strip() + self.assertEqual(metadata_json, expected_json) + + displayer = metadata.MetadataDisplayer.with_model_buffer(tflite_content) + label_file_buffer = displayer.get_associated_file_buffer(_LABEL_FILE_NAME) + with open(_LABEL_FILE, "rb") as f: + expected_labelfile_buffer = f.read() + self.assertEqual(label_file_buffer, expected_labelfile_buffer) + + def test_write_metadata_without_labels(self): + with open(_MODEL_FILE, "rb") as f: + model_buffer = f.read() + writer = image_segmenter.MetadataWriter.create( + bytearray(model_buffer), + [_NORM_MEAN], + [_NORM_STD], + ) + _, metadata_json = writer.populate() + with open(_JSON_FILE_WITHOUT_LABELS, "r") as f: + expected_json = f.read().strip() + self.assertEqual(metadata_json, expected_json) + + def test_write_metadata_with_activation(self): + with open(_MODEL_FILE, "rb") as f: + model_buffer = f.read() + writer = image_segmenter.MetadataWriter.create( + bytearray(model_buffer), + [_NORM_MEAN], + [_NORM_STD], + activation=image_segmenter.Activation.SIGMOID, + ) + _, metadata_json = writer.populate() + with open(_JSON_FILE_WITH_ACTIVATION, "r") as f: + expected_json = f.read().strip() + self.assertEqual(metadata_json, expected_json) + + +if __name__ == "__main__": + absltest.main() diff --git a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py index bcb384a34..fd4462631 100644 --- a/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py +++ b/mediapipe/tasks/python/test/metadata/metadata_writers/metadata_info_test.py @@ -455,6 +455,27 @@ class TensorGroupMdMdTest(absltest.TestCase): self.assertEqual(metadata_json, expected_json) +class SegmentationMaskMdTest(absltest.TestCase): + _NAME = "segmentation_masks" + _DESCRIPTION = "Masks over the target objects." + _EXPECTED_JSON = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, "segmentation_mask_meta.json") + ) + + def test_create_metadata_should_succeed(self): + segmentation_mask_md = metadata_info.SegmentationMaskMd( + name=self._NAME, description=self._DESCRIPTION + ) + metadata = segmentation_mask_md.create_metadata() + + metadata_json = _metadata.convert_to_json( + _create_dummy_model_metadata_with_tensor(metadata) + ) + with open(self._EXPECTED_JSON, "r") as f: + expected_json = f.read() + self.assertEqual(metadata_json, expected_json) + + def _create_dummy_model_metadata_with_tensor( tensor_metadata: _metadata_fb.TensorMetadataT) -> bytes: # Create a dummy model using the tensor metadata. diff --git a/mediapipe/tasks/python/test/vision/BUILD b/mediapipe/tasks/python/test/vision/BUILD index 19d592895..db0bd66b2 100644 --- a/mediapipe/tasks/python/test/vision/BUILD +++ b/mediapipe/tasks/python/test/vision/BUILD @@ -92,6 +92,53 @@ py_test( ], ) +py_test( + name = "face_detector_test", + srcs = ["face_detector_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + "//mediapipe/tasks/testdata/vision:test_protos", + ], + deps = [ + "//mediapipe/framework/formats:detection_py_pb2", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/containers:bounding_box", + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:detections", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:face_detector", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "@com_google_protobuf//:protobuf_python", + ], +) + +py_test( + name = "face_landmarker_test", + srcs = ["face_landmarker_test.py"], + data = [ + "//mediapipe/tasks/testdata/vision:test_images", + "//mediapipe/tasks/testdata/vision:test_models", + "//mediapipe/tasks/testdata/vision:test_protos", + ], + deps = [ + "//mediapipe/framework/formats:classification_py_pb2", + "//mediapipe/framework/formats:landmark_py_pb2", + "//mediapipe/python:_framework_bindings", + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:landmark", + "//mediapipe/tasks/python/components/containers:rect", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/test:test_utils", + "//mediapipe/tasks/python/vision:face_landmarker", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + "@com_google_protobuf//:protobuf_python", + ], +) + py_test( name = "hand_landmarker_test", srcs = ["hand_landmarker_test.py"], diff --git a/mediapipe/tasks/python/test/vision/face_detector_test.py b/mediapipe/tasks/python/test/vision/face_detector_test.py new file mode 100644 index 000000000..4ae8101b7 --- /dev/null +++ b/mediapipe/tasks/python/test/vision/face_detector_test.py @@ -0,0 +1,523 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for face detector.""" + +import enum +import os +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized + +from google.protobuf import text_format +from mediapipe.framework.formats import detection_pb2 +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.components.containers import bounding_box as bounding_box_module +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import detections as detections_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.vision import face_detector +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +FaceDetectorResult = detections_module.DetectionResult +_BaseOptions = base_options_module.BaseOptions +_Category = category_module.Category +_BoundingBox = bounding_box_module.BoundingBox +_Detection = detections_module.Detection +_Image = image_module.Image +_FaceDetector = face_detector.FaceDetector +_FaceDetectorOptions = face_detector.FaceDetectorOptions +_RUNNING_MODE = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions + +_SHORT_RANGE_BLAZE_FACE_MODEL = 'face_detection_short_range.tflite' +_PORTRAIT_IMAGE = 'portrait.jpg' +_PORTRAIT_EXPECTED_DETECTION = 'portrait_expected_detection.pbtxt' +_PORTRAIT_ROTATED_IMAGE = 'portrait_rotated.jpg' +_PORTRAIT_ROTATED_EXPECTED_DETECTION = ( + 'portrait_rotated_expected_detection.pbtxt' +) +_CAT_IMAGE = 'cat.jpg' +_KEYPOINT_ERROR_THRESHOLD = 1e-2 +_TEST_DATA_DIR = 'mediapipe/tasks/testdata/vision' + + +def _get_expected_face_detector_result(file_name: str) -> FaceDetectorResult: + face_detection_result_file_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, file_name) + ) + with open(face_detection_result_file_path, 'rb') as f: + face_detection_proto = detection_pb2.Detection() + text_format.Parse(f.read(), face_detection_proto) + face_detection = detections_module.Detection.create_from_pb2( + face_detection_proto + ) + return FaceDetectorResult(detections=[face_detection]) + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class FaceDetectorTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _PORTRAIT_IMAGE) + ) + ) + self.model_path = test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _SHORT_RANGE_BLAZE_FACE_MODEL) + ) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _FaceDetector.create_from_model_path(self.model_path) as detector: + self.assertIsInstance(detector, _FaceDetector) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceDetectorOptions(base_options=base_options) + with _FaceDetector.create_from_options(options) as detector: + self.assertIsInstance(detector, _FaceDetector) + + def test_create_from_options_fails_with_invalid_model_path(self): + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite' + ): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite' + ) + options = _FaceDetectorOptions(base_options=base_options) + _FaceDetector.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _FaceDetectorOptions(base_options=base_options) + detector = _FaceDetector.create_from_options(options) + self.assertIsInstance(detector, _FaceDetector) + + def _expect_keypoints_correct(self, actual_keypoints, expected_keypoints): + self.assertLen(actual_keypoints, len(expected_keypoints)) + for i in range(len(actual_keypoints)): + self.assertAlmostEqual( + actual_keypoints[i].x, + expected_keypoints[i].x, + delta=_KEYPOINT_ERROR_THRESHOLD, + ) + self.assertAlmostEqual( + actual_keypoints[i].y, + expected_keypoints[i].y, + delta=_KEYPOINT_ERROR_THRESHOLD, + ) + + def _expect_face_detector_results_correct( + self, actual_results, expected_results + ): + self.assertLen(actual_results.detections, len(expected_results.detections)) + for i in range(len(actual_results.detections)): + actual_bbox = actual_results.detections[i].bounding_box + expected_bbox = expected_results.detections[i].bounding_box + self.assertEqual(actual_bbox, expected_bbox) + self.assertNotEmpty(actual_results.detections[i].keypoints) + self._expect_keypoints_correct( + actual_results.detections[i].keypoints, + expected_results.detections[i].keypoints, + ) + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _PORTRAIT_EXPECTED_DETECTION), + (ModelFileType.FILE_CONTENT, _PORTRAIT_EXPECTED_DETECTION), + ) + def test_detect(self, model_file_type, expected_detection_result_file): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceDetectorOptions(base_options=base_options) + detector = _FaceDetector.create_from_options(options) + + # Performs face detection on the input. + detection_result = detector.detect(self.test_image) + # Comparing results. + expected_detection_result = _get_expected_face_detector_result( + expected_detection_result_file + ) + self._expect_face_detector_results_correct( + detection_result, expected_detection_result + ) + # Closes the detector explicitly when the detector is not used in + # a context. + detector.close() + + @parameterized.parameters( + (ModelFileType.FILE_NAME, _PORTRAIT_EXPECTED_DETECTION), + (ModelFileType.FILE_CONTENT, _PORTRAIT_EXPECTED_DETECTION), + ) + def test_detect_in_context( + self, model_file_type, expected_detection_result_file + ): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceDetectorOptions(base_options=base_options) + + with _FaceDetector.create_from_options(options) as detector: + # Performs face detection on the input. + detection_result = detector.detect(self.test_image) + # Comparing results. + expected_detection_result = _get_expected_face_detector_result( + expected_detection_result_file + ) + self._expect_face_detector_results_correct( + detection_result, expected_detection_result + ) + + def test_detect_succeeds_with_rotated_image(self): + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceDetectorOptions(base_options=base_options) + with _FaceDetector.create_from_options(options) as detector: + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, _PORTRAIT_ROTATED_IMAGE) + ) + ) + # Rotated input image. + image_processing_options = _ImageProcessingOptions(rotation_degrees=-90) + # Performs face detection on the input. + detection_result = detector.detect(test_image, image_processing_options) + # Comparing results. + expected_detection_result = _get_expected_face_detector_result( + _PORTRAIT_ROTATED_EXPECTED_DETECTION + ) + self._expect_face_detector_results_correct( + detection_result, expected_detection_result + ) + + def test_empty_detection_outputs(self): + # Load a test image with no faces. + test_image = _Image.create_from_file( + test_utils.get_test_data_path(os.path.join(_TEST_DATA_DIR, _CAT_IMAGE)) + ) + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path) + ) + with _FaceDetector.create_from_options(options) as detector: + # Performs face detection on the input. + detection_result = detector.detect(test_image) + self.assertEmpty(detection_result.detections) + + def test_missing_result_callback(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + ) + with self.assertRaisesRegex( + ValueError, r'result callback must be provided' + ): + with _FaceDetector.create_from_options(options) as unused_detector: + pass + + @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) + def test_illegal_result_callback(self, running_mode): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=running_mode, + result_callback=mock.MagicMock(), + ) + with self.assertRaisesRegex( + ValueError, r'result callback should not be provided' + ): + with _FaceDetector.create_from_options(options) as unused_detector: + pass + + def test_calling_detect_for_video_in_image_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE, + ) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): + detector.detect_for_video(self.test_image, 0) + + def test_calling_detect_async_in_image_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE, + ) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): + detector.detect_async(self.test_image, 0) + + def test_calling_detect_in_video_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + ) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): + detector.detect(self.test_image) + + def test_calling_detect_async_in_video_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + ) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): + detector.detect_async(self.test_image, 0) + + def test_detect_for_video_with_out_of_order_timestamp(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + ) + with _FaceDetector.create_from_options(options) as detector: + unused_result = detector.detect_for_video(self.test_image, 1) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing' + ): + detector.detect_for_video(self.test_image, 0) + + @parameterized.parameters( + ( + ModelFileType.FILE_NAME, + _PORTRAIT_IMAGE, + 0, + _get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION), + ), + ( + ModelFileType.FILE_CONTENT, + _PORTRAIT_IMAGE, + 0, + _get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION), + ), + ( + ModelFileType.FILE_NAME, + _PORTRAIT_ROTATED_IMAGE, + -90, + _get_expected_face_detector_result( + _PORTRAIT_ROTATED_EXPECTED_DETECTION + ), + ), + ( + ModelFileType.FILE_CONTENT, + _PORTRAIT_ROTATED_IMAGE, + -90, + _get_expected_face_detector_result( + _PORTRAIT_ROTATED_EXPECTED_DETECTION + ), + ), + (ModelFileType.FILE_NAME, _CAT_IMAGE, 0, FaceDetectorResult([])), + (ModelFileType.FILE_CONTENT, _CAT_IMAGE, 0, FaceDetectorResult([])), + ) + def test_detect_for_video( + self, + model_file_type, + test_image_file_name, + rotation_degrees, + expected_detection_result, + ): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceDetectorOptions( + base_options=base_options, running_mode=_RUNNING_MODE.VIDEO + ) + + with _FaceDetector.create_from_options(options) as detector: + for timestamp in range(0, 300, 30): + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, test_image_file_name) + ) + ) + # Set the image processing options. + image_processing_options = _ImageProcessingOptions( + rotation_degrees=rotation_degrees + ) + # Performs face detection on the input. + detection_result = detector.detect_for_video( + test_image, timestamp, image_processing_options + ) + # Comparing results. + self._expect_face_detector_results_correct( + detection_result, expected_detection_result + ) + + def test_calling_detect_in_live_stream_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock(), + ) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): + detector.detect(self.test_image) + + def test_calling_detect_for_video_in_live_stream_mode(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock(), + ) + with _FaceDetector.create_from_options(options) as detector: + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): + detector.detect_for_video(self.test_image, 0) + + def test_detect_async_calls_with_illegal_timestamp(self): + options = _FaceDetectorOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock(), + ) + with _FaceDetector.create_from_options(options) as detector: + detector.detect_async(self.test_image, 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing' + ): + detector.detect_async(self.test_image, 0) + + @parameterized.parameters( + ( + ModelFileType.FILE_NAME, + _PORTRAIT_IMAGE, + 0, + _get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION), + ), + ( + ModelFileType.FILE_CONTENT, + _PORTRAIT_IMAGE, + 0, + _get_expected_face_detector_result(_PORTRAIT_EXPECTED_DETECTION), + ), + ( + ModelFileType.FILE_NAME, + _PORTRAIT_ROTATED_IMAGE, + -90, + _get_expected_face_detector_result( + _PORTRAIT_ROTATED_EXPECTED_DETECTION + ), + ), + ( + ModelFileType.FILE_CONTENT, + _PORTRAIT_ROTATED_IMAGE, + -90, + _get_expected_face_detector_result( + _PORTRAIT_ROTATED_EXPECTED_DETECTION + ), + ), + (ModelFileType.FILE_NAME, _CAT_IMAGE, 0, FaceDetectorResult([])), + (ModelFileType.FILE_CONTENT, _CAT_IMAGE, 0, FaceDetectorResult([])), + ) + def test_detect_async_calls( + self, + model_file_type, + test_image_file_name, + rotation_degrees, + expected_detection_result, + ): + # Creates detector. + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=self.model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(self.model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + observed_timestamp_ms = -1 + + def check_result( + result: FaceDetectorResult, + unused_output_image: _Image, + timestamp_ms: int, + ): + self._expect_face_detector_results_correct( + result, expected_detection_result + ) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + options = _FaceDetectorOptions( + base_options=base_options, + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=check_result, + ) + + # Load the test image. + test_image = _Image.create_from_file( + test_utils.get_test_data_path( + os.path.join(_TEST_DATA_DIR, test_image_file_name) + ) + ) + + with _FaceDetector.create_from_options(options) as detector: + for timestamp in range(0, 300, 30): + # Set the image processing options. + image_processing_options = _ImageProcessingOptions( + rotation_degrees=rotation_degrees + ) + detector.detect_async(test_image, timestamp, image_processing_options) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/test/vision/face_landmarker_test.py b/mediapipe/tasks/python/test/vision/face_landmarker_test.py new file mode 100644 index 000000000..8e070064d --- /dev/null +++ b/mediapipe/tasks/python/test/vision/face_landmarker_test.py @@ -0,0 +1,564 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for face landmarker.""" + +import enum +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np + +from google.protobuf import text_format +from mediapipe.framework.formats import classification_pb2 +from mediapipe.framework.formats import landmark_pb2 +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import landmark as landmark_module +from mediapipe.tasks.python.components.containers import rect as rect_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.test import test_utils +from mediapipe.tasks.python.vision import face_landmarker +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + + +FaceLandmarkerResult = face_landmarker.FaceLandmarkerResult +_BaseOptions = base_options_module.BaseOptions +_Category = category_module.Category +_Rect = rect_module.Rect +_Landmark = landmark_module.Landmark +_NormalizedLandmark = landmark_module.NormalizedLandmark +_Image = image_module.Image +_FaceLandmarker = face_landmarker.FaceLandmarker +_FaceLandmarkerOptions = face_landmarker.FaceLandmarkerOptions +_RUNNING_MODE = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions + +_FACE_LANDMARKER_BUNDLE_ASSET_FILE = 'face_landmarker.task' +_PORTRAIT_IMAGE = 'portrait.jpg' +_CAT_IMAGE = 'cat.jpg' +_PORTRAIT_EXPECTED_FACE_LANDMARKS = 'portrait_expected_face_landmarks.pbtxt' +_PORTRAIT_EXPECTED_FACE_LANDMARKS_WITH_ATTENTION = ( + 'portrait_expected_face_landmarks_with_attention.pbtxt' +) +_PORTRAIT_EXPECTED_BLENDSHAPES = ( + 'portrait_expected_blendshapes_with_attention.pbtxt' +) +_LANDMARKS_DIFF_MARGIN = 0.03 +_BLENDSHAPES_DIFF_MARGIN = 0.12 +_FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN = 0.02 + + +def _get_expected_face_landmarks(file_path: str): + proto_file_path = test_utils.get_test_data_path(file_path) + with open(proto_file_path, 'rb') as f: + proto = landmark_pb2.NormalizedLandmarkList() + text_format.Parse(f.read(), proto) + face_landmarks = [] + for landmark in proto.landmark: + face_landmarks.append(_NormalizedLandmark.create_from_pb2(landmark)) + return face_landmarks + + +def _get_expected_face_blendshapes(file_path: str): + proto_file_path = test_utils.get_test_data_path(file_path) + with open(proto_file_path, 'rb') as f: + proto = classification_pb2.ClassificationList() + text_format.Parse(f.read(), proto) + face_blendshapes_categories = [] + face_blendshapes_classifications = classification_pb2.ClassificationList() + face_blendshapes_classifications.MergeFrom(proto) + for face_blendshapes in face_blendshapes_classifications.classification: + face_blendshapes_categories.append( + category_module.Category( + index=face_blendshapes.index, + score=face_blendshapes.score, + display_name=face_blendshapes.display_name, + category_name=face_blendshapes.label, + ) + ) + return face_blendshapes_categories + + +def _get_expected_facial_transformation_matrixes(): + matrix = np.array([ + [0.9995292, -0.005092691, 0.030254554, -0.37340546], + [0.0072318087, 0.99744856, -0.07102106, 22.212194], + [-0.029815676, 0.07120642, 0.9970159, -64.76358], + [0, 0, 0, 1], + ]) + facial_transformation_matrixes_results = [] + facial_transformation_matrixes_results.append(matrix) + return facial_transformation_matrixes_results + + +class ModelFileType(enum.Enum): + FILE_CONTENT = 1 + FILE_NAME = 2 + + +class FaceLandmarkerTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.test_image = _Image.create_from_file( + test_utils.get_test_data_path(_PORTRAIT_IMAGE) + ) + self.model_path = test_utils.get_test_data_path( + _FACE_LANDMARKER_BUNDLE_ASSET_FILE + ) + + def _expect_landmarks_correct(self, actual_landmarks, expected_landmarks): + # Expects to have the same number of faces detected. + self.assertLen(actual_landmarks, len(expected_landmarks)) + + for i, elem in enumerate(actual_landmarks): + self.assertAlmostEqual( + elem.x, expected_landmarks[i].x, delta=_LANDMARKS_DIFF_MARGIN + ) + self.assertAlmostEqual( + elem.y, expected_landmarks[i].y, delta=_LANDMARKS_DIFF_MARGIN + ) + + def _expect_blendshapes_correct( + self, actual_blendshapes, expected_blendshapes + ): + # Expects to have the same number of blendshapes. + self.assertLen(actual_blendshapes, len(expected_blendshapes)) + + for i, elem in enumerate(actual_blendshapes): + self.assertEqual(elem.index, expected_blendshapes[i].index) + self.assertAlmostEqual( + elem.score, + expected_blendshapes[i].score, + delta=_BLENDSHAPES_DIFF_MARGIN, + ) + + def _expect_facial_transformation_matrixes_correct( + self, actual_matrix_list, expected_matrix_list + ): + self.assertLen(actual_matrix_list, len(expected_matrix_list)) + + for i, elem in enumerate(actual_matrix_list): + self.assertEqual(elem.shape[0], expected_matrix_list[i].shape[0]) + self.assertEqual(elem.shape[1], expected_matrix_list[i].shape[1]) + self.assertSequenceAlmostEqual( + elem.flatten(), + expected_matrix_list[i].flatten(), + delta=_FACIAL_TRANSFORMATION_MATRIX_DIFF_MARGIN, + ) + + def test_create_from_file_succeeds_with_valid_model_path(self): + # Creates with default option and valid model file successfully. + with _FaceLandmarker.create_from_model_path(self.model_path) as landmarker: + self.assertIsInstance(landmarker, _FaceLandmarker) + + def test_create_from_options_succeeds_with_valid_model_path(self): + # Creates with options containing model file successfully. + base_options = _BaseOptions(model_asset_path=self.model_path) + options = _FaceLandmarkerOptions(base_options=base_options) + with _FaceLandmarker.create_from_options(options) as landmarker: + self.assertIsInstance(landmarker, _FaceLandmarker) + + def test_create_from_options_fails_with_invalid_model_path(self): + # Invalid empty model path. + with self.assertRaisesRegex( + RuntimeError, 'Unable to open file at /path/to/invalid/model.tflite' + ): + base_options = _BaseOptions( + model_asset_path='/path/to/invalid/model.tflite' + ) + options = _FaceLandmarkerOptions(base_options=base_options) + _FaceLandmarker.create_from_options(options) + + def test_create_from_options_succeeds_with_valid_model_content(self): + # Creates with options containing model content successfully. + with open(self.model_path, 'rb') as f: + base_options = _BaseOptions(model_asset_buffer=f.read()) + options = _FaceLandmarkerOptions(base_options=base_options) + landmarker = _FaceLandmarker.create_from_options(options) + self.assertIsInstance(landmarker, _FaceLandmarker) + + @parameterized.parameters( + ( + ModelFileType.FILE_NAME, + _FACE_LANDMARKER_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS), + None, + None, + ), + ( + ModelFileType.FILE_CONTENT, + _FACE_LANDMARKER_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS), + None, + None, + ), + ) + def test_detect( + self, + model_file_type, + model_name, + expected_face_landmarks, + expected_face_blendshapes, + expected_facial_transformation_matrixes, + ): + # Creates face landmarker. + model_path = test_utils.get_test_data_path(model_name) + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceLandmarkerOptions( + base_options=base_options, + output_face_blendshapes=True if expected_face_blendshapes else False, + output_facial_transformation_matrixes=True + if expected_facial_transformation_matrixes + else False, + ) + landmarker = _FaceLandmarker.create_from_options(options) + + # Performs face landmarks detection on the input. + detection_result = landmarker.detect(self.test_image) + # Comparing results. + if expected_face_landmarks is not None: + self._expect_landmarks_correct( + detection_result.face_landmarks[0], expected_face_landmarks + ) + if expected_face_blendshapes is not None: + self._expect_blendshapes_correct( + detection_result.face_blendshapes[0], expected_face_blendshapes + ) + if expected_facial_transformation_matrixes is not None: + self._expect_facial_transformation_matrixes_correct( + detection_result.facial_transformation_matrixes, + expected_facial_transformation_matrixes, + ) + + # Closes the face landmarker explicitly when the face landmarker is not used + # in a context. + landmarker.close() + + @parameterized.parameters( + ( + ModelFileType.FILE_NAME, + _FACE_LANDMARKER_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS), + None, + None, + ), + ( + ModelFileType.FILE_CONTENT, + _FACE_LANDMARKER_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS), + None, + None, + ), + ) + def test_detect_in_context( + self, + model_file_type, + model_name, + expected_face_landmarks, + expected_face_blendshapes, + expected_facial_transformation_matrixes, + ): + # Creates face landmarker. + model_path = test_utils.get_test_data_path(model_name) + if model_file_type is ModelFileType.FILE_NAME: + base_options = _BaseOptions(model_asset_path=model_path) + elif model_file_type is ModelFileType.FILE_CONTENT: + with open(model_path, 'rb') as f: + model_content = f.read() + base_options = _BaseOptions(model_asset_buffer=model_content) + else: + # Should never happen + raise ValueError('model_file_type is invalid.') + + options = _FaceLandmarkerOptions( + base_options=base_options, + output_face_blendshapes=True if expected_face_blendshapes else False, + output_facial_transformation_matrixes=True + if expected_facial_transformation_matrixes + else False, + ) + + with _FaceLandmarker.create_from_options(options) as landmarker: + # Performs face landmarks detection on the input. + detection_result = landmarker.detect(self.test_image) + # Comparing results. + if expected_face_landmarks is not None: + self._expect_landmarks_correct( + detection_result.face_landmarks[0], expected_face_landmarks + ) + if expected_face_blendshapes is not None: + self._expect_blendshapes_correct( + detection_result.face_blendshapes[0], expected_face_blendshapes + ) + if expected_facial_transformation_matrixes is not None: + self._expect_facial_transformation_matrixes_correct( + detection_result.facial_transformation_matrixes, + expected_facial_transformation_matrixes, + ) + + def test_empty_detection_outputs(self): + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path) + ) + with _FaceLandmarker.create_from_options(options) as landmarker: + # Load the image with no faces. + no_faces_test_image = _Image.create_from_file( + test_utils.get_test_data_path(_CAT_IMAGE) + ) + # Performs face landmarks detection on the input. + detection_result = landmarker.detect(no_faces_test_image) + self.assertEmpty(detection_result.face_landmarks) + self.assertEmpty(detection_result.face_blendshapes) + self.assertEmpty(detection_result.facial_transformation_matrixes) + + def test_missing_result_callback(self): + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + ) + with self.assertRaisesRegex( + ValueError, r'result callback must be provided' + ): + with _FaceLandmarker.create_from_options(options) as unused_landmarker: + pass + + @parameterized.parameters((_RUNNING_MODE.IMAGE), (_RUNNING_MODE.VIDEO)) + def test_illegal_result_callback(self, running_mode): + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=running_mode, + result_callback=mock.MagicMock(), + ) + with self.assertRaisesRegex( + ValueError, r'result callback should not be provided' + ): + with _FaceLandmarker.create_from_options(options) as unused_landmarker: + pass + + def test_calling_detect_for_video_in_image_mode(self): + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE, + ) + with _FaceLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): + landmarker.detect_for_video(self.test_image, 0) + + def test_calling_detect_async_in_image_mode(self): + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.IMAGE, + ) + with _FaceLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): + landmarker.detect_async(self.test_image, 0) + + def test_calling_detect_in_video_mode(self): + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + ) + with _FaceLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): + landmarker.detect(self.test_image) + + def test_calling_detect_async_in_video_mode(self): + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + ) + with _FaceLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the live stream mode' + ): + landmarker.detect_async(self.test_image, 0) + + def test_detect_for_video_with_out_of_order_timestamp(self): + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.VIDEO, + ) + with _FaceLandmarker.create_from_options(options) as landmarker: + unused_result = landmarker.detect_for_video(self.test_image, 1) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing' + ): + landmarker.detect_for_video(self.test_image, 0) + + @parameterized.parameters( + ( + _FACE_LANDMARKER_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS), + None, + None, + ), + ) + def test_detect_for_video( + self, + model_name, + expected_face_landmarks, + expected_face_blendshapes, + expected_facial_transformation_matrixes, + ): + # Creates face landmarker. + model_path = test_utils.get_test_data_path(model_name) + base_options = _BaseOptions(model_asset_path=model_path) + + options = _FaceLandmarkerOptions( + base_options=base_options, + running_mode=_RUNNING_MODE.VIDEO, + output_face_blendshapes=True if expected_face_blendshapes else False, + output_facial_transformation_matrixes=True + if expected_facial_transformation_matrixes + else False, + ) + + with _FaceLandmarker.create_from_options(options) as landmarker: + for timestamp in range(0, 300, 30): + # Performs face landmarks detection on the input. + detection_result = landmarker.detect_for_video( + self.test_image, timestamp + ) + # Comparing results. + if expected_face_landmarks is not None: + self._expect_landmarks_correct( + detection_result.face_landmarks[0], expected_face_landmarks + ) + if expected_face_blendshapes is not None: + self._expect_blendshapes_correct( + detection_result.face_blendshapes[0], expected_face_blendshapes + ) + if expected_facial_transformation_matrixes is not None: + self._expect_facial_transformation_matrixes_correct( + detection_result.facial_transformation_matrixes, + expected_facial_transformation_matrixes, + ) + + def test_calling_detect_in_live_stream_mode(self): + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock(), + ) + with _FaceLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the image mode' + ): + landmarker.detect(self.test_image) + + def test_calling_detect_for_video_in_live_stream_mode(self): + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock(), + ) + with _FaceLandmarker.create_from_options(options) as landmarker: + with self.assertRaisesRegex( + ValueError, r'not initialized with the video mode' + ): + landmarker.detect_for_video(self.test_image, 0) + + def test_detect_async_calls_with_illegal_timestamp(self): + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=self.model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + result_callback=mock.MagicMock(), + ) + with _FaceLandmarker.create_from_options(options) as landmarker: + landmarker.detect_async(self.test_image, 100) + with self.assertRaisesRegex( + ValueError, r'Input timestamp must be monotonically increasing' + ): + landmarker.detect_async(self.test_image, 0) + + @parameterized.parameters( + ( + _PORTRAIT_IMAGE, + _FACE_LANDMARKER_BUNDLE_ASSET_FILE, + _get_expected_face_landmarks(_PORTRAIT_EXPECTED_FACE_LANDMARKS), + None, + None, + ), + ) + def test_detect_async_calls( + self, + image_path, + model_name, + expected_face_landmarks, + expected_face_blendshapes, + expected_facial_transformation_matrixes, + ): + test_image = _Image.create_from_file( + test_utils.get_test_data_path(image_path) + ) + observed_timestamp_ms = -1 + + def check_result( + result: FaceLandmarkerResult, output_image: _Image, timestamp_ms: int + ): + # Comparing results. + if expected_face_landmarks is not None: + self._expect_landmarks_correct( + result.face_landmarks[0], expected_face_landmarks + ) + if expected_face_blendshapes is not None: + self._expect_blendshapes_correct( + result.face_blendshapes[0], expected_face_blendshapes + ) + if expected_facial_transformation_matrixes is not None: + self._expect_facial_transformation_matrixes_correct( + result.facial_transformation_matrixes, + expected_facial_transformation_matrixes, + ) + self.assertTrue( + np.array_equal(output_image.numpy_view(), test_image.numpy_view()) + ) + self.assertLess(observed_timestamp_ms, timestamp_ms) + self.observed_timestamp_ms = timestamp_ms + + model_path = test_utils.get_test_data_path(model_name) + options = _FaceLandmarkerOptions( + base_options=_BaseOptions(model_asset_path=model_path), + running_mode=_RUNNING_MODE.LIVE_STREAM, + output_face_blendshapes=True if expected_face_blendshapes else False, + output_facial_transformation_matrixes=True + if expected_facial_transformation_matrixes + else False, + result_callback=check_result, + ) + with _FaceLandmarker.create_from_options(options) as landmarker: + for timestamp in range(0, 300, 30): + landmarker.detect_async(test_image, timestamp) + + +if __name__ == '__main__': + absltest.main() diff --git a/mediapipe/tasks/python/test/vision/image_embedder_test.py b/mediapipe/tasks/python/test/vision/image_embedder_test.py index 11c0cf002..8c7fb59a2 100644 --- a/mediapipe/tasks/python/test/vision/image_embedder_test.py +++ b/mediapipe/tasks/python/test/vision/image_embedder_test.py @@ -119,14 +119,43 @@ class ImageEmbedderTest(parameterized.TestCase): similarity, expected_similarity, delta=_SIMILARITY_TOLERANCE) @parameterized.parameters( - (False, False, False, ModelFileType.FILE_NAME, 0.925519, 1024, - (-0.2101883, -0.193027)), - (True, False, False, ModelFileType.FILE_NAME, 0.925519, 1024, - (-0.0142344, -0.0131606)), - # (False, True, False, ModelFileType.FILE_NAME, - # 0.926791, 1024, (229, 231)), - (False, False, True, ModelFileType.FILE_CONTENT, 0.999931, 1024, - (-0.195062, -0.193027))) + ( + False, + False, + False, + ModelFileType.FILE_NAME, + 0.925519, + 1024, + (-0.2101883, -0.193027), + ), + ( + True, + False, + False, + ModelFileType.FILE_NAME, + 0.925519, + 1024, + (-0.0142344, -0.0131606), + ), + ( + False, + True, + False, + ModelFileType.FILE_NAME, + 0.926791, + 1024, + (229, 231), + ), + ( + False, + False, + True, + ModelFileType.FILE_CONTENT, + 0.999931, + 1024, + (-0.195062, -0.193027), + ), + ) def test_embed(self, l2_normalize, quantize, with_roi, model_file_type, expected_similarity, expected_size, expected_first_values): # Creates embedder. diff --git a/mediapipe/tasks/python/test/vision/object_detector_test.py b/mediapipe/tasks/python/test/vision/object_detector_test.py index 5afa31459..2bb9b0214 100644 --- a/mediapipe/tasks/python/test/vision/object_detector_test.py +++ b/mediapipe/tasks/python/test/vision/object_detector_test.py @@ -42,48 +42,62 @@ _RUNNING_MODE = running_mode_module.VisionTaskRunningMode _MODEL_FILE = 'coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.tflite' _IMAGE_FILE = 'cats_and_dogs.jpg' -_EXPECTED_DETECTION_RESULT = _DetectionResult(detections=[ - _Detection( - bounding_box=_BoundingBox( - origin_x=608, origin_y=161, width=381, height=439), - categories=[ - _Category( - index=None, - score=0.69921875, - display_name=None, - category_name='cat') - ]), - _Detection( - bounding_box=_BoundingBox( - origin_x=60, origin_y=398, width=386, height=196), - categories=[ - _Category( - index=None, - score=0.64453125, - display_name=None, - category_name='cat') - ]), - _Detection( - bounding_box=_BoundingBox( - origin_x=256, origin_y=395, width=173, height=202), - categories=[ - _Category( - index=None, - score=0.51171875, - display_name=None, - category_name='cat') - ]), - _Detection( - bounding_box=_BoundingBox( - origin_x=362, origin_y=191, width=325, height=419), - categories=[ - _Category( - index=None, - score=0.48828125, - display_name=None, - category_name='cat') - ]) -]) +_EXPECTED_DETECTION_RESULT = _DetectionResult( + detections=[ + _Detection( + bounding_box=_BoundingBox( + origin_x=608, origin_y=161, width=381, height=439 + ), + categories=[ + _Category( + index=None, + score=0.69921875, + display_name=None, + category_name='cat', + ) + ], + ), + _Detection( + bounding_box=_BoundingBox( + origin_x=60, origin_y=398, width=386, height=196 + ), + categories=[ + _Category( + index=None, + score=0.64453125, + display_name=None, + category_name='cat', + ) + ], + ), + _Detection( + bounding_box=_BoundingBox( + origin_x=256, origin_y=395, width=173, height=202 + ), + categories=[ + _Category( + index=None, + score=0.51171875, + display_name=None, + category_name='cat', + ) + ], + ), + _Detection( + bounding_box=_BoundingBox( + origin_x=362, origin_y=191, width=325, height=419 + ), + categories=[ + _Category( + index=None, + score=0.48828125, + display_name=None, + category_name='cat', + ) + ], + ), + ] +) _ALLOW_LIST = ['cat', 'dog'] _DENY_LIST = ['cat'] _SCORE_THRESHOLD = 0.3 diff --git a/mediapipe/tasks/python/vision/BUILD b/mediapipe/tasks/python/vision/BUILD index 8ce0ef96e..e21171fc2 100644 --- a/mediapipe/tasks/python/vision/BUILD +++ b/mediapipe/tasks/python/vision/BUILD @@ -171,3 +171,48 @@ py_library( "//mediapipe/tasks/python/vision/core:vision_task_running_mode", ], ) + +py_library( + name = "face_detector", + srcs = [ + "face_detector.py", + ], + deps = [ + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_py_pb2", + "//mediapipe/tasks/python/components/containers:detections", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) + +py_library( + name = "face_landmarker", + srcs = [ + "face_landmarker.py", + ], + deps = [ + "//mediapipe/framework/formats:classification_py_pb2", + "//mediapipe/framework/formats:landmark_py_pb2", + "//mediapipe/framework/formats:matrix_data_py_pb2", + "//mediapipe/python:_framework_bindings", + "//mediapipe/python:packet_creator", + "//mediapipe/python:packet_getter", + "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_py_pb2", + "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_py_pb2", + "//mediapipe/tasks/python/components/containers:category", + "//mediapipe/tasks/python/components/containers:landmark", + "//mediapipe/tasks/python/core:base_options", + "//mediapipe/tasks/python/core:optional_dependencies", + "//mediapipe/tasks/python/core:task_info", + "//mediapipe/tasks/python/vision/core:base_vision_task_api", + "//mediapipe/tasks/python/vision/core:image_processing_options", + "//mediapipe/tasks/python/vision/core:vision_task_running_mode", + ], +) diff --git a/mediapipe/tasks/python/vision/face_detector.py b/mediapipe/tasks/python/vision/face_detector.py new file mode 100644 index 000000000..f0ce4d1f1 --- /dev/null +++ b/mediapipe/tasks/python/vision/face_detector.py @@ -0,0 +1,332 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe face detector task.""" + +import dataclasses +from typing import Callable, Mapping, Optional + +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import packet as packet_module +from mediapipe.tasks.cc.vision.face_detector.proto import face_detector_graph_options_pb2 +from mediapipe.tasks.python.components.containers import detections as detections_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +FaceDetectorResult = detections_module.DetectionResult +_BaseOptions = base_options_module.BaseOptions +_FaceDetectorGraphOptionsProto = ( + face_detector_graph_options_pb2.FaceDetectorGraphOptions +) +_RunningMode = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions +_TaskInfo = task_info_module.TaskInfo + +_DETECTIONS_OUT_STREAM_NAME = 'detections' +_DETECTIONS_TAG = 'DETECTIONS' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' +_IMAGE_TAG = 'IMAGE' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_detector.FaceDetectorGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +@dataclasses.dataclass +class FaceDetectorOptions: + """Options for the face detector task. + + Attributes: + base_options: Base options for the face detector task. + running_mode: The running mode of the task. Default to the image mode. Face + detector task has three running modes: 1) The image mode for detecting + faces on single image inputs. 2) The video mode for detecting faces on the + decoded frames of a video. 3) The live stream mode for detecting faces on + a live stream of input data, such as from camera. + min_detection_confidence: The minimum confidence score for the face + detection to be considered successful. + min_suppression_threshold: The minimum non-maximum-suppression threshold for + face detection to be considered overlapped. + result_callback: The user-defined result callback for processing live stream + data. The result callback should only be specified when the running mode + is set to the live stream mode. + """ + + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE + min_detection_confidence: Optional[float] = None + min_suppression_threshold: Optional[float] = None + result_callback: Optional[ + Callable[ + [detections_module.DetectionResult, image_module.Image, int], None + ] + ] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _FaceDetectorGraphOptionsProto: + """Generates an FaceDetectorOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) + return _FaceDetectorGraphOptionsProto( + base_options=base_options_proto, + min_detection_confidence=self.min_detection_confidence, + min_suppression_threshold=self.min_suppression_threshold, + ) + + +class FaceDetector(base_vision_task_api.BaseVisionTaskApi): + """Class that performs face detection on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'FaceDetector': + """Creates an `FaceDetector` object from a TensorFlow Lite model and the default `FaceDetectorOptions`. + + Note that the created `FaceDetector` instance is in image mode, for + detecting faces on single image inputs. + + Args: + model_path: Path to the model. + + Returns: + `FaceDetector` object that's created from the model file and the default + `FaceDetectorOptions`. + + Raises: + ValueError: If failed to create `FaceDetector` object from the provided + file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = FaceDetectorOptions( + base_options=base_options, running_mode=_RunningMode.IMAGE + ) + return cls.create_from_options(options) + + @classmethod + def create_from_options(cls, options: FaceDetectorOptions) -> 'FaceDetector': + """Creates the `FaceDetector` object from face detector options. + + Args: + options: Options for the face detector task. + + Returns: + `FaceDetector` object that's created from `options`. + + Raises: + ValueError: If failed to create `FaceDetector` object from + `FaceDetectorOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty(): + empty_packet = output_packets[_DETECTIONS_OUT_STREAM_NAME] + options.result_callback( + FaceDetectorResult([]), + image, + empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) + return + detection_proto_list = packet_getter.get_proto_list( + output_packets[_DETECTIONS_OUT_STREAM_NAME] + ) + detection_result = detections_module.DetectionResult( + [ + detections_module.Detection.create_from_pb2(result) + for result in detection_proto_list + ] + ) + + timestamp = output_packets[_IMAGE_OUT_STREAM_NAME].timestamp + options.result_callback( + detection_result, + image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), + ], + output_streams=[ + ':'.join([_DETECTIONS_TAG, _DETECTIONS_OUT_STREAM_NAME]), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ], + task_options=options, + ) + return cls( + task_info.generate_graph_config( + enable_flow_limiting=options.running_mode + == _RunningMode.LIVE_STREAM + ), + options.running_mode, + packets_callback if options.result_callback else None, + ) + + def detect( + self, + image: image_module.Image, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> FaceDetectorResult: + """Performs face detection on the provided MediaPipe Image. + + Only use this method when the FaceDetector is created with the image + running mode. + + Args: + image: MediaPipe Image. + image_processing_options: Options for image processing. + + Returns: + A face detection result object that contains a list of face detections, + each detection has a bounding box that is expressed in the unrotated input + frame of reference coordinates system, i.e. in `[0,image_width) x [0, + image_height)`, which are the dimensions of the underlying image data. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face detection failed to run. + """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, roi_allowed=False + ) + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), + }) + if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty(): + return FaceDetectorResult([]) + detection_proto_list = packet_getter.get_proto_list( + output_packets[_DETECTIONS_OUT_STREAM_NAME] + ) + return detections_module.DetectionResult( + [ + detections_module.Detection.create_from_pb2(result) + for result in detection_proto_list + ] + ) + + def detect_for_video( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> detections_module.DetectionResult: + """Performs face detection on the provided video frames. + + Only use this method when the FaceDetector is created with the video + running mode. It's required to provide the video frame's timestamp (in + milliseconds) along with the video frame. The input timestamps should be + monotonically increasing for adjacent calls of this method. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input video frame in milliseconds. + image_processing_options: Options for image processing. + + Returns: + A face detection result object that contains a list of face detections, + each detection has a bounding box that is expressed in the unrotated input + frame of reference coordinates system, i.e. in `[0,image_width) x [0, + image_height)`, which are the dimensions of the underlying image data. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face detection failed to run. + """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, roi_allowed=False + ) + output_packets = self._process_video_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + }) + if output_packets[_DETECTIONS_OUT_STREAM_NAME].is_empty(): + return FaceDetectorResult([]) + detection_proto_list = packet_getter.get_proto_list( + output_packets[_DETECTIONS_OUT_STREAM_NAME] + ) + return detections_module.DetectionResult( + [ + detections_module.Detection.create_from_pb2(result) + for result in detection_proto_list + ] + ) + + def detect_async( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> None: + """Sends live image data (an Image with a unique timestamp) to perform face detection. + + Only use this method when the FaceDetector is created with the live stream + running mode. The input timestamps should be monotonically increasing for + adjacent calls of this method. This method will return immediately after the + input image is accepted. The results will be available via the + `result_callback` provided in the `FaceDetectorOptions`. The + `detect_async` method is designed to process live stream data such as camera + input. To lower the overall latency, face detector may drop the input + images if needed. In other words, it's not guaranteed to have output per + input image. + + The `result_callback` provides: + - A face detection result object that contains a list of face detections, + each detection has a bounding box that is expressed in the unrotated + input frame of reference coordinates system, + i.e. in `[0,image_width) x [0,image_height)`, which are the dimensions + of the underlying image data. + - The input image that the face detector runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + image_processing_options: Options for image processing. + + Raises: + ValueError: If the current input timestamp is smaller than what the face + detector has already processed. + """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, roi_allowed=False + ) + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + }) diff --git a/mediapipe/tasks/python/vision/face_landmarker.py b/mediapipe/tasks/python/vision/face_landmarker.py new file mode 100644 index 000000000..41faf6d91 --- /dev/null +++ b/mediapipe/tasks/python/vision/face_landmarker.py @@ -0,0 +1,509 @@ +# Copyright 2023 The MediaPipe Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MediaPipe face landmarker task.""" + +import dataclasses +import enum +from typing import Callable, Mapping, Optional, List + +import numpy as np + +from mediapipe.framework.formats import classification_pb2 +from mediapipe.framework.formats import landmark_pb2 +from mediapipe.framework.formats import matrix_data_pb2 +from mediapipe.python import packet_creator +from mediapipe.python import packet_getter +from mediapipe.python._framework_bindings import image as image_module +from mediapipe.python._framework_bindings import packet as packet_module +# pylint: disable=unused-import +from mediapipe.tasks.cc.vision.face_geometry.proto import face_geometry_pb2 +# pylint: enable=unused-import +from mediapipe.tasks.cc.vision.face_landmarker.proto import face_landmarker_graph_options_pb2 +from mediapipe.tasks.python.components.containers import category as category_module +from mediapipe.tasks.python.components.containers import landmark as landmark_module +from mediapipe.tasks.python.core import base_options as base_options_module +from mediapipe.tasks.python.core import task_info as task_info_module +from mediapipe.tasks.python.core.optional_dependencies import doc_controls +from mediapipe.tasks.python.vision.core import base_vision_task_api +from mediapipe.tasks.python.vision.core import image_processing_options as image_processing_options_module +from mediapipe.tasks.python.vision.core import vision_task_running_mode as running_mode_module + +_BaseOptions = base_options_module.BaseOptions +_FaceLandmarkerGraphOptionsProto = ( + face_landmarker_graph_options_pb2.FaceLandmarkerGraphOptions +) +_LayoutEnum = matrix_data_pb2.MatrixData.Layout +_RunningMode = running_mode_module.VisionTaskRunningMode +_ImageProcessingOptions = image_processing_options_module.ImageProcessingOptions +_TaskInfo = task_info_module.TaskInfo + +_IMAGE_IN_STREAM_NAME = 'image_in' +_IMAGE_OUT_STREAM_NAME = 'image_out' +_IMAGE_TAG = 'IMAGE' +_NORM_RECT_STREAM_NAME = 'norm_rect_in' +_NORM_RECT_TAG = 'NORM_RECT' +_NORM_LANDMARKS_STREAM_NAME = 'norm_landmarks' +_NORM_LANDMARKS_TAG = 'NORM_LANDMARKS' +_BLENDSHAPES_STREAM_NAME = 'blendshapes' +_BLENDSHAPES_TAG = 'BLENDSHAPES' +_FACE_GEOMETRY_STREAM_NAME = 'face_geometry' +_FACE_GEOMETRY_TAG = 'FACE_GEOMETRY' +_TASK_GRAPH_NAME = 'mediapipe.tasks.vision.face_landmarker.FaceLandmarkerGraph' +_MICRO_SECONDS_PER_MILLISECOND = 1000 + + +class Blendshapes(enum.IntEnum): + """The 52 blendshape coefficients.""" + + NEUTRAL = 0 + BROW_DOWN_LEFT = 1 + BROW_DOWN_RIGHT = 2 + BROW_INNER_UP = 3 + BROW_OUTER_UP_LEFT = 4 + BROW_OUTER_UP_RIGHT = 5 + CHEEK_PUFF = 6 + CHEEK_SQUINT_LEFT = 7 + CHEEK_SQUINT_RIGHT = 8 + EYE_BLINK_LEFT = 9 + EYE_BLINK_RIGHT = 10 + EYE_LOOK_DOWN_LEFT = 11 + EYE_LOOK_DOWN_RIGHT = 12 + EYE_LOOK_IN_LEFT = 13 + EYE_LOOK_IN_RIGHT = 14 + EYE_LOOK_OUT_LEFT = 15 + EYE_LOOK_OUT_RIGHT = 16 + EYE_LOOK_UP_LEFT = 17 + EYE_LOOK_UP_RIGHT = 18 + EYE_SQUINT_LEFT = 19 + EYE_SQUINT_RIGHT = 20 + EYE_WIDE_LEFT = 21 + EYE_WIDE_RIGHT = 22 + JAW_FORWARD = 23 + JAW_LEFT = 24 + JAW_OPEN = 25 + JAW_RIGHT = 26 + MOUTH_CLOSE = 27 + MOUTH_DIMPLE_LEFT = 28 + MOUTH_DIMPLE_RIGHT = 29 + MOUTH_FROWN_LEFT = 30 + MOUTH_FROWN_RIGHT = 31 + MOUTH_FUNNEL = 32 + MOUTH_LEFT = 33 + MOUTH_LOWER_DOWN_LEFT = 34 + MOUTH_LOWER_DOWN_RIGHT = 35 + MOUTH_PRESS_LEFT = 36 + MOUTH_PRESS_RIGHT = 37 + MOUTH_PUCKER = 38 + MOUTH_RIGHT = 39 + MOUTH_ROLL_LOWER = 40 + MOUTH_ROLL_UPPER = 41 + MOUTH_SHRUG_LOWER = 42 + MOUTH_SHRUG_UPPER = 43 + MOUTH_SMILE_LEFT = 44 + MOUTH_SMILE_RIGHT = 45 + MOUTH_STRETCH_LEFT = 46 + MOUTH_STRETCH_RIGHT = 47 + MOUTH_UPPER_UP_LEFT = 48 + MOUTH_UPPER_UP_RIGHT = 49 + NOSE_SNEER_LEFT = 50 + NOSE_SNEER_RIGHT = 51 + + +@dataclasses.dataclass +class FaceLandmarkerResult: + """The face landmarks detection result from FaceLandmarker, where each vector element represents a single face detected in the image. + + Attributes: + face_landmarks: Detected face landmarks in normalized image coordinates. + face_blendshapes: Optional face blendshapes results. + facial_transformation_matrixes: Optional facial transformation matrix. + """ + + face_landmarks: List[List[landmark_module.NormalizedLandmark]] + face_blendshapes: List[List[category_module.Category]] + facial_transformation_matrixes: List[np.ndarray] + + +def _build_landmarker_result( + output_packets: Mapping[str, packet_module.Packet] +) -> FaceLandmarkerResult: + """Constructs a `FaceLandmarkerResult` from output packets.""" + face_landmarks_proto_list = packet_getter.get_proto_list( + output_packets[_NORM_LANDMARKS_STREAM_NAME] + ) + + face_landmarks_results = [] + for proto in face_landmarks_proto_list: + face_landmarks = landmark_pb2.NormalizedLandmarkList() + face_landmarks.MergeFrom(proto) + face_landmarks_list = [] + for face_landmark in face_landmarks.landmark: + face_landmarks_list.append( + landmark_module.NormalizedLandmark.create_from_pb2(face_landmark) + ) + face_landmarks_results.append(face_landmarks_list) + + face_blendshapes_results = [] + if _BLENDSHAPES_STREAM_NAME in output_packets: + face_blendshapes_proto_list = packet_getter.get_proto_list( + output_packets[_BLENDSHAPES_STREAM_NAME] + ) + for proto in face_blendshapes_proto_list: + face_blendshapes_categories = [] + face_blendshapes_classifications = classification_pb2.ClassificationList() + face_blendshapes_classifications.MergeFrom(proto) + for face_blendshapes in face_blendshapes_classifications.classification: + face_blendshapes_categories.append( + category_module.Category( + index=face_blendshapes.index, + score=face_blendshapes.score, + display_name=face_blendshapes.display_name, + category_name=face_blendshapes.label, + ) + ) + face_blendshapes_results.append(face_blendshapes_categories) + + facial_transformation_matrixes_results = [] + if _FACE_GEOMETRY_STREAM_NAME in output_packets: + facial_transformation_matrixes_proto_list = packet_getter.get_proto_list( + output_packets[_FACE_GEOMETRY_STREAM_NAME] + ) + for proto in facial_transformation_matrixes_proto_list: + if hasattr(proto, 'pose_transform_matrix'): + matrix_data = matrix_data_pb2.MatrixData() + matrix_data.MergeFrom(proto.pose_transform_matrix) + matrix = np.array(matrix_data.packed_data) + matrix = matrix.reshape((matrix_data.rows, matrix_data.cols)) + matrix = ( + matrix if matrix_data.layout == _LayoutEnum.ROW_MAJOR else matrix.T + ) + facial_transformation_matrixes_results.append(matrix) + + return FaceLandmarkerResult( + face_landmarks_results, + face_blendshapes_results, + facial_transformation_matrixes_results, + ) + + +@dataclasses.dataclass +class FaceLandmarkerOptions: + """Options for the face landmarker task. + + Attributes: + base_options: Base options for the face landmarker task. + running_mode: The running mode of the task. Default to the image mode. + HandLandmarker has three running modes: 1) The image mode for detecting + face landmarks on single image inputs. 2) The video mode for detecting + face landmarks on the decoded frames of a video. 3) The live stream mode + for detecting face landmarks on the live stream of input data, such as + from camera. In this mode, the "result_callback" below must be specified + to receive the detection results asynchronously. + num_faces: The maximum number of faces that can be detected by the + FaceLandmarker. + min_face_detection_confidence: The minimum confidence score for the face + detection to be considered successful. + min_face_presence_confidence: The minimum confidence score of face presence + score in the face landmark detection. + min_tracking_confidence: The minimum confidence score for the face tracking + to be considered successful. + output_face_blendshapes: Whether FaceLandmarker outputs face blendshapes + classification. Face blendshapes are used for rendering the 3D face model. + output_facial_transformation_matrixes: Whether FaceLandmarker outputs facial + transformation_matrix. Facial transformation matrix is used to transform + the face landmarks in canonical face to the detected face, so that users + can apply face effects on the detected landmarks. + result_callback: The user-defined result callback for processing live stream + data. The result callback should only be specified when the running mode + is set to the live stream mode. + """ + + base_options: _BaseOptions + running_mode: _RunningMode = _RunningMode.IMAGE + num_faces: Optional[int] = 1 + min_face_detection_confidence: Optional[float] = 0.5 + min_face_presence_confidence: Optional[float] = 0.5 + min_tracking_confidence: Optional[float] = 0.5 + output_face_blendshapes: Optional[bool] = False + output_facial_transformation_matrixes: Optional[bool] = False + result_callback: Optional[ + Callable[[FaceLandmarkerResult, image_module.Image, int], None] + ] = None + + @doc_controls.do_not_generate_docs + def to_pb2(self) -> _FaceLandmarkerGraphOptionsProto: + """Generates an FaceLandmarkerGraphOptions protobuf object.""" + base_options_proto = self.base_options.to_pb2() + base_options_proto.use_stream_mode = ( + False if self.running_mode == _RunningMode.IMAGE else True + ) + + # Initialize the face landmarker options from base options. + face_landmarker_options_proto = _FaceLandmarkerGraphOptionsProto( + base_options=base_options_proto + ) + + # Configure face detector options. + face_landmarker_options_proto.face_detector_graph_options.num_faces = ( + self.num_faces + ) + face_landmarker_options_proto.face_detector_graph_options.min_detection_confidence = ( + self.min_face_detection_confidence + ) + + # Configure face landmark detector options. + face_landmarker_options_proto.min_tracking_confidence = ( + self.min_tracking_confidence + ) + face_landmarker_options_proto.face_landmarks_detector_graph_options.min_detection_confidence = ( + self.min_face_detection_confidence + ) + return face_landmarker_options_proto + + +class FaceLandmarker(base_vision_task_api.BaseVisionTaskApi): + """Class that performs face landmarks detection on images.""" + + @classmethod + def create_from_model_path(cls, model_path: str) -> 'FaceLandmarker': + """Creates an `FaceLandmarker` object from a TensorFlow Lite model and the default `FaceLandmarkerOptions`. + + Note that the created `FaceLandmarker` instance is in image mode, for + detecting face landmarks on single image inputs. + + Args: + model_path: Path to the model. + + Returns: + `FaceLandmarker` object that's created from the model file and the + default `FaceLandmarkerOptions`. + + Raises: + ValueError: If failed to create `FaceLandmarker` object from the + provided file such as invalid file path. + RuntimeError: If other types of error occurred. + """ + base_options = _BaseOptions(model_asset_path=model_path) + options = FaceLandmarkerOptions( + base_options=base_options, running_mode=_RunningMode.IMAGE + ) + return cls.create_from_options(options) + + @classmethod + def create_from_options( + cls, options: FaceLandmarkerOptions + ) -> 'FaceLandmarker': + """Creates the `FaceLandmarker` object from face landmarker options. + + Args: + options: Options for the face landmarker task. + + Returns: + `FaceLandmarker` object that's created from `options`. + + Raises: + ValueError: If failed to create `FaceLandmarker` object from + `FaceLandmarkerOptions` such as missing the model. + RuntimeError: If other types of error occurred. + """ + + def packets_callback(output_packets: Mapping[str, packet_module.Packet]): + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return + + image = packet_getter.get_image(output_packets[_IMAGE_OUT_STREAM_NAME]) + if output_packets[_IMAGE_OUT_STREAM_NAME].is_empty(): + return + + if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty(): + empty_packet = output_packets[_NORM_LANDMARKS_STREAM_NAME] + options.result_callback( + FaceLandmarkerResult([], [], []), + image, + empty_packet.timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) + return + + face_landmarks_result = _build_landmarker_result(output_packets) + timestamp = output_packets[_NORM_LANDMARKS_STREAM_NAME].timestamp + options.result_callback( + face_landmarks_result, + image, + timestamp.value // _MICRO_SECONDS_PER_MILLISECOND, + ) + + output_streams = [ + ':'.join([_NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME]), + ':'.join([_IMAGE_TAG, _IMAGE_OUT_STREAM_NAME]), + ] + + if options.output_face_blendshapes: + output_streams.append( + ':'.join([_BLENDSHAPES_TAG, _BLENDSHAPES_STREAM_NAME]) + ) + if options.output_facial_transformation_matrixes: + output_streams.append( + ':'.join([_FACE_GEOMETRY_TAG, _FACE_GEOMETRY_STREAM_NAME]) + ) + + task_info = _TaskInfo( + task_graph=_TASK_GRAPH_NAME, + input_streams=[ + ':'.join([_IMAGE_TAG, _IMAGE_IN_STREAM_NAME]), + ':'.join([_NORM_RECT_TAG, _NORM_RECT_STREAM_NAME]), + ], + output_streams=output_streams, + task_options=options, + ) + return cls( + task_info.generate_graph_config( + enable_flow_limiting=options.running_mode + == _RunningMode.LIVE_STREAM + ), + options.running_mode, + packets_callback if options.result_callback else None, + ) + + def detect( + self, + image: image_module.Image, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> FaceLandmarkerResult: + """Performs face landmarks detection on the given image. + + Only use this method when the FaceLandmarker is created with the image + running mode. + + The image can be of any size with format RGB or RGBA. + TODO: Describes how the input image will be preprocessed after the yuv + support is implemented. + + Args: + image: MediaPipe Image. + image_processing_options: Options for image processing. + + Returns: + The face landmarks detection results. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face landmarker detection failed to run. + """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, roi_allowed=False + ) + output_packets = self._process_image_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ), + }) + + if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty(): + return FaceLandmarkerResult([], [], []) + + return _build_landmarker_result(output_packets) + + def detect_for_video( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> FaceLandmarkerResult: + """Performs face landmarks detection on the provided video frame. + + Only use this method when the FaceLandmarker is created with the video + running mode. + + Only use this method when the FaceLandmarker is created with the video + running mode. It's required to provide the video frame's timestamp (in + milliseconds) along with the video frame. The input timestamps should be + monotonically increasing for adjacent calls of this method. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input video frame in milliseconds. + image_processing_options: Options for image processing. + + Returns: + The face landmarks detection results. + + Raises: + ValueError: If any of the input arguments is invalid. + RuntimeError: If face landmarker detection failed to run. + """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, roi_allowed=False + ) + output_packets = self._process_video_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + }) + + if output_packets[_NORM_LANDMARKS_STREAM_NAME].is_empty(): + return FaceLandmarkerResult([], [], []) + + return _build_landmarker_result(output_packets) + + def detect_async( + self, + image: image_module.Image, + timestamp_ms: int, + image_processing_options: Optional[_ImageProcessingOptions] = None, + ) -> None: + """Sends live image data to perform face landmarks detection. + + The results will be available via the "result_callback" provided in the + FaceLandmarkerOptions. Only use this method when the FaceLandmarker is + created with the live stream running mode. + + Only use this method when the FaceLandmarker is created with the live + stream running mode. The input timestamps should be monotonically increasing + for adjacent calls of this method. This method will return immediately after + the input image is accepted. The results will be available via the + `result_callback` provided in the `FaceLandmarkerOptions`. The + `detect_async` method is designed to process live stream data such as + camera input. To lower the overall latency, face landmarker may drop the + input images if needed. In other words, it's not guaranteed to have output + per input image. + + The `result_callback` provides: + - The face landmarks detection results. + - The input image that the face landmarker runs on. + - The input timestamp in milliseconds. + + Args: + image: MediaPipe Image. + timestamp_ms: The timestamp of the input image in milliseconds. + image_processing_options: Options for image processing. + + Raises: + ValueError: If the current input timestamp is smaller than what the + face landmarker has already processed. + """ + normalized_rect = self.convert_to_normalized_rect( + image_processing_options, roi_allowed=False + ) + self._send_live_stream_data({ + _IMAGE_IN_STREAM_NAME: packet_creator.create_image(image).at( + timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND + ), + _NORM_RECT_STREAM_NAME: packet_creator.create_proto( + normalized_rect.to_pb2() + ).at(timestamp_ms * _MICRO_SECONDS_PER_MILLISECOND), + }) diff --git a/mediapipe/tasks/testdata/metadata/BUILD b/mediapipe/tasks/testdata/metadata/BUILD index 0ac06caac..e335831aa 100644 --- a/mediapipe/tasks/testdata/metadata/BUILD +++ b/mediapipe/tasks/testdata/metadata/BUILD @@ -28,6 +28,10 @@ mediapipe_files(srcs = [ "category_tensor_float_meta.json", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite", "coco_ssd_mobilenet_v1_score_calibration.json", + "deeplabv3.json", + "deeplabv3_with_activation.json", + "deeplabv3_without_labels.json", + "deeplabv3_without_metadata.tflite", "efficientdet_lite0_v1.json", "efficientdet_lite0_v1.tflite", "labelmap.txt", @@ -44,6 +48,8 @@ mediapipe_files(srcs = [ "mobilenet_v2_1.0_224_without_metadata.tflite", "movie_review.tflite", "score_calibration.csv", + "segmentation_mask_meta.json", + "segmenter_labelmap.txt", "ssd_mobilenet_v1_no_metadata.json", "ssd_mobilenet_v1_no_metadata.tflite", "tensor_group_meta.json", @@ -87,6 +93,7 @@ filegroup( "30k-clean.model", "bert_text_classifier_no_metadata.tflite", "coco_ssd_mobilenet_v1_1.0_quant_2018_06_29_no_metadata.tflite", + "deeplabv3_without_metadata.tflite", "efficientdet_lite0_v1.tflite", "mobile_ica_8bit-with-custom-metadata.tflite", "mobile_ica_8bit-with-large-min-parser-version.tflite", @@ -116,6 +123,9 @@ filegroup( "classification_tensor_uint8_meta.json", "classification_tensor_unsupported_meta.json", "coco_ssd_mobilenet_v1_score_calibration.json", + "deeplabv3.json", + "deeplabv3_with_activation.json", + "deeplabv3_without_labels.json", "efficientdet_lite0_v1.json", "external_file", "feature_tensor_meta.json", @@ -140,6 +150,8 @@ filegroup( "score_calibration_file_meta.json", "score_calibration_tensor_meta.json", "score_thresholding_meta.json", + "segmentation_mask_meta.json", + "segmenter_labelmap.txt", "sentence_piece_tokenizer_meta.json", "ssd_mobilenet_v1_no_metadata.json", "tensor_group_meta.json", diff --git a/mediapipe/tasks/testdata/metadata/deeplabv3.json b/mediapipe/tasks/testdata/metadata/deeplabv3.json new file mode 100644 index 000000000..1ae982200 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/deeplabv3.json @@ -0,0 +1,66 @@ +{ + "name": "ImageSegmenter", + "description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be processed.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + -1.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "segmentation_masks", + "description": "Masks over the target objects with high accuracy.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "GRAYSCALE" + }, + "range": { + "min": 1, + "max": 2 + } + }, + "stats": {}, + "associated_files": [ + { + "name": "labels.txt", + "description": "Labels for categories that the model can recognize.", + "type": "TENSOR_AXIS_LABELS" + } + ] + } + ] + } + ], + "min_parser_version": "1.0.0" +} diff --git a/mediapipe/tasks/testdata/metadata/deeplabv3_with_activation.json b/mediapipe/tasks/testdata/metadata/deeplabv3_with_activation.json new file mode 100644 index 000000000..4fb32bab3 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/deeplabv3_with_activation.json @@ -0,0 +1,67 @@ +{ + "name": "ImageSegmenter", + "description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be processed.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + -1.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "segmentation_masks", + "description": "Masks over the target objects with high accuracy.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "GRAYSCALE" + }, + "range": { + "min": 1, + "max": 2 + } + }, + "stats": {} + } + ], + "custom_metadata": [ + { + "name": "SEGMENTER_METADATA", + "data": { + "activation": "SIGMOID" + } + } + ] + } + ], + "min_parser_version": "1.5.0" +} diff --git a/mediapipe/tasks/testdata/metadata/deeplabv3_without_labels.json b/mediapipe/tasks/testdata/metadata/deeplabv3_without_labels.json new file mode 100644 index 000000000..d7a1a1c25 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/deeplabv3_without_labels.json @@ -0,0 +1,59 @@ +{ + "name": "ImageSegmenter", + "description": "Semantic image segmentation predicts whether each pixel of an image is associated with a certain class.", + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "image", + "description": "Input image to be processed.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "RGB" + } + }, + "process_units": [ + { + "options_type": "NormalizationOptions", + "options": { + "mean": [ + 127.5 + ], + "std": [ + 127.5 + ] + } + } + ], + "stats": { + "max": [ + 1.0 + ], + "min": [ + -1.0 + ] + } + } + ], + "output_tensor_metadata": [ + { + "name": "segmentation_masks", + "description": "Masks over the target objects with high accuracy.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "GRAYSCALE" + }, + "range": { + "min": 1, + "max": 2 + } + }, + "stats": {} + } + ] + } + ], + "min_parser_version": "1.0.0" +} diff --git a/mediapipe/tasks/testdata/metadata/segmentation_mask_meta.json b/mediapipe/tasks/testdata/metadata/segmentation_mask_meta.json new file mode 100644 index 000000000..9252d9573 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/segmentation_mask_meta.json @@ -0,0 +1,24 @@ +{ + "subgraph_metadata": [ + { + "input_tensor_metadata": [ + { + "name": "segmentation_masks", + "description": "Masks over the target objects.", + "content": { + "content_properties_type": "ImageProperties", + "content_properties": { + "color_space": "GRAYSCALE" + }, + "range": { + "min": 1, + "max": 2 + } + }, + "stats": { + } + } + ] + } + ] +} diff --git a/mediapipe/tasks/testdata/metadata/segmenter_labelmap.txt b/mediapipe/tasks/testdata/metadata/segmenter_labelmap.txt new file mode 100644 index 000000000..204608ce5 --- /dev/null +++ b/mediapipe/tasks/testdata/metadata/segmenter_labelmap.txt @@ -0,0 +1,21 @@ +background +aeroplane +bicycle +bird +boat +bottle +bus +car +cat +chair +cow +dining table +dog +horse +motorbike +person +potted plant +sheep +sofa +train +tv diff --git a/mediapipe/tasks/testdata/vision/BUILD b/mediapipe/tasks/testdata/vision/BUILD index 63e3613e6..097acad43 100644 --- a/mediapipe/tasks/testdata/vision/BUILD +++ b/mediapipe/tasks/testdata/vision/BUILD @@ -31,6 +31,8 @@ mediapipe_files(srcs = [ "cat_rotated.jpg", "cat_rotated_mask.jpg", "cats_and_dogs.jpg", + "cats_and_dogs_mask_dog1.png", + "cats_and_dogs_mask_dog2.png", "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", "coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", @@ -70,6 +72,11 @@ mediapipe_files(srcs = [ "portrait.jpg", "portrait_hair_expected_mask.jpg", "portrait_rotated.jpg", + "portrait_selfie_segmentation_expected_category_mask.jpg", + "portrait_selfie_segmentation_expected_confidence_mask.jpg", + "portrait_selfie_segmentation_landscape_expected_category_mask.jpg", + "pose.jpg", + "pose_detection.tflite", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -78,6 +85,8 @@ mediapipe_files(srcs = [ "selfie_segm_128_128_3_expected_mask.jpg", "selfie_segm_144_256_3.tflite", "selfie_segm_144_256_3_expected_mask.jpg", + "selfie_segmentation.tflite", + "selfie_segmentation_landscape.tflite", "thumb_up.jpg", "victory.jpg", ]) @@ -109,6 +118,8 @@ filegroup( "cat_rotated.jpg", "cat_rotated_mask.jpg", "cats_and_dogs.jpg", + "cats_and_dogs_mask_dog1.png", + "cats_and_dogs_mask_dog2.png", "cats_and_dogs_no_resizing.jpg", "cats_and_dogs_rotated.jpg", "fist.jpg", @@ -125,6 +136,10 @@ filegroup( "portrait.jpg", "portrait_hair_expected_mask.jpg", "portrait_rotated.jpg", + "portrait_selfie_segmentation_expected_category_mask.jpg", + "portrait_selfie_segmentation_expected_confidence_mask.jpg", + "portrait_selfie_segmentation_landscape_expected_category_mask.jpg", + "pose.jpg", "right_hands.jpg", "right_hands_rotated.jpg", "segmentation_golden_rotation0.png", @@ -169,8 +184,11 @@ filegroup( "mobilenet_v2_1.0_224.tflite", "mobilenet_v3_small_100_224_embedder.tflite", "palm_detection_full.tflite", + "pose_detection.tflite", "selfie_segm_128_128_3.tflite", "selfie_segm_144_256_3.tflite", + "selfie_segmentation.tflite", + "selfie_segmentation_landscape.tflite", ], ) @@ -195,6 +213,7 @@ filegroup( "portrait_expected_face_landmarks.pbtxt", "portrait_expected_face_landmarks_with_attention.pbtxt", "portrait_rotated_expected_detection.pbtxt", + "pose_expected_detection.pbtxt", "thumb_up_landmarks.pbtxt", "thumb_up_rotated_landmarks.pbtxt", "victory_landmarks.pbtxt", diff --git a/mediapipe/tasks/testdata/vision/face_landmarker_with_blendshapes.task b/mediapipe/tasks/testdata/vision/face_landmarker_with_blendshapes.task index d20846326..04adf1841 100644 Binary files a/mediapipe/tasks/testdata/vision/face_landmarker_with_blendshapes.task and b/mediapipe/tasks/testdata/vision/face_landmarker_with_blendshapes.task differ diff --git a/mediapipe/tasks/testdata/vision/pose_expected_detection.pbtxt b/mediapipe/tasks/testdata/vision/pose_expected_detection.pbtxt new file mode 100644 index 000000000..411a374d4 --- /dev/null +++ b/mediapipe/tasks/testdata/vision/pose_expected_detection.pbtxt @@ -0,0 +1,27 @@ +# proto-file: mediapipe/framework/formats/detection.proto +# proto-message: Detection +location_data { + format: BOUNDING_BOX + bounding_box { + xmin: 397 + ymin: 198 + width: 199 + height: 199 + } + relative_keypoints { + x: 0.4879558 + y: 0.7013345 + } + relative_keypoints { + x: 0.48453212 + y: 0.32265592 + } + relative_keypoints { + x: 0.4992165 + y: 0.4854874 + } + relative_keypoints { + x: 0.50227845 + y: 0.159788 + } +} diff --git a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts index a8a2b232b..9e37e5987 100644 --- a/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts +++ b/mediapipe/tasks/web/audio/audio_embedder/audio_embedder_test.ts @@ -161,7 +161,7 @@ describe('AudioEmbedder', () => { {floatEmbedding: [0.1, 0.9], headIndex: 1, headName: 'headName'}); } - it('from embeddings strem', async () => { + it('from embeddings stream', async () => { audioEmbedder.fakeWasmModule._waitUntilIdle.and.callFake(() => { verifyListenersRegistered(audioEmbedder); // Pass the test data to our listener diff --git a/mediapipe/tasks/web/audio/core/audio_task_runner.ts b/mediapipe/tasks/web/audio/core/audio_task_runner.ts index ff39185f2..2c327f1ab 100644 --- a/mediapipe/tasks/web/audio/core/audio_task_runner.ts +++ b/mediapipe/tasks/web/audio/core/audio_task_runner.ts @@ -36,11 +36,9 @@ export abstract class AudioTaskRunner extends TaskRunner { /** Sends a single audio clip to the graph and awaits results. */ protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { - // Increment the timestamp by 1 millisecond to guarantee that we send - // monotonically increasing timestamps to the graph. - const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; return this.process( - audioData, sampleRate ?? this.defaultSampleRate, syntheticTimestamp); + audioData, sampleRate ?? this.defaultSampleRate, + this.getSynctheticTimestamp()); } } diff --git a/mediapipe/tasks/web/components/containers/BUILD b/mediapipe/tasks/web/components/containers/BUILD index a0db59d0b..0126e83c9 100644 --- a/mediapipe/tasks/web/components/containers/BUILD +++ b/mediapipe/tasks/web/components/containers/BUILD @@ -15,6 +15,11 @@ mediapipe_ts_declaration( deps = [":category"], ) +mediapipe_ts_declaration( + name = "keypoint", + srcs = ["keypoint.d.ts"], +) + mediapipe_ts_declaration( name = "landmark", srcs = ["landmark.d.ts"], diff --git a/mediapipe/tasks/web/components/containers/keypoint.d.ts b/mediapipe/tasks/web/components/containers/keypoint.d.ts new file mode 100644 index 000000000..3aaf9eb06 --- /dev/null +++ b/mediapipe/tasks/web/components/containers/keypoint.d.ts @@ -0,0 +1,33 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * A keypoint, defined by the coordinates (x, y), normalized by the image + * dimensions. + */ +export declare interface NormalizedKeypoint { + /** X in normalized image coordinates. */ + x: number; + + /** Y in normalized image coordinates. */ + y: number; + + /** Optional label of the keypoint. */ + label?: string; + + /** Optional score of the keypoint. */ + score?: number; +} diff --git a/mediapipe/tasks/web/core/BUILD b/mediapipe/tasks/web/core/BUILD index ec65548d4..a417d4d72 100644 --- a/mediapipe/tasks/web/core/BUILD +++ b/mediapipe/tasks/web/core/BUILD @@ -19,6 +19,7 @@ mediapipe_ts_library( deps = [ ":core", "//mediapipe/calculators/tensor:inference_calculator_jspb_proto", + "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/tasks/cc/core/proto:acceleration_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:external_file_jspb_proto", diff --git a/mediapipe/tasks/web/core/task_runner.ts b/mediapipe/tasks/web/core/task_runner.ts index 79b2ca173..68208c970 100644 --- a/mediapipe/tasks/web/core/task_runner.ts +++ b/mediapipe/tasks/web/core/task_runner.ts @@ -15,6 +15,7 @@ */ import {InferenceCalculatorOptions} from '../../../calculators/tensor/inference_calculator_pb'; +import {CalculatorGraphConfig} from '../../../framework/calculator_pb'; import {Acceleration} from '../../../tasks/cc/core/proto/acceleration_pb'; import {BaseOptions as BaseOptionsProto} from '../../../tasks/cc/core/proto/base_options_pb'; import {ExternalFile} from '../../../tasks/cc/core/proto/external_file_pb'; @@ -120,11 +121,13 @@ export abstract class TaskRunner { .then(buffer => { this.setExternalFile(new Uint8Array(buffer)); this.refreshGraph(); + this.onGraphRefreshed(); }); } else { // Apply the setting synchronously. this.setExternalFile(baseOptions.modelAssetBuffer); this.refreshGraph(); + this.onGraphRefreshed(); return Promise.resolve(); } } @@ -132,6 +135,24 @@ export abstract class TaskRunner { /** Appliest the current options to the MediaPipe graph. */ protected abstract refreshGraph(): void; + /** + * Callback that gets invoked once a new graph configuration has been + * applied. + */ + protected onGraphRefreshed(): void {} + + /** Returns the current CalculatorGraphConfig. */ + protected getCalculatorGraphConfig(): CalculatorGraphConfig { + let config: CalculatorGraphConfig|undefined; + this.graphRunner.getCalculatorGraphConfig(binaryData => { + config = CalculatorGraphConfig.deserializeBinary(binaryData); + }); + if (!config) { + throw new Error('Failed to retrieve CalculatorGraphConfig'); + } + return config; + } + /** * Takes the raw data from a MediaPipe graph, and passes it to C++ to be run * over the video stream. Will replace the previously running MediaPipe graph, @@ -175,9 +196,13 @@ export abstract class TaskRunner { Math.max(this.latestOutputTimestamp, timestamp); } - /** Returns the latest output timestamp. */ - protected getLatestOutputTimestamp() { - return this.latestOutputTimestamp; + /** + * Gets a syncthethic timestamp in ms that can be used to send data to the + * next packet. The timestamp is one millisecond past the last timestamp + * received from the graph. + */ + protected getSynctheticTimestamp(): number { + return this.latestOutputTimestamp + 1; } /** Throws the error from the error listener if an error was raised. */ diff --git a/mediapipe/tasks/web/core/task_runner_test_utils.ts b/mediapipe/tasks/web/core/task_runner_test_utils.ts index 62dd0463a..b0aa34095 100644 --- a/mediapipe/tasks/web/core/task_runner_test_utils.ts +++ b/mediapipe/tasks/web/core/task_runner_test_utils.ts @@ -16,7 +16,7 @@ import 'jasmine'; import {CalculatorGraphConfig} from '../../../framework/calculator_pb'; -import {WasmModule} from '../../../web/graph_runner/graph_runner'; +import {CALCULATOR_GRAPH_CONFIG_LISTENER_NAME, SimpleListener, WasmModule} from '../../../web/graph_runner/graph_runner'; import {WasmModuleRegisterModelResources} from '../../../web/graph_runner/register_model_resources_graph_service'; type SpyWasmModuleInternal = WasmModule&WasmModuleRegisterModelResources; @@ -36,8 +36,13 @@ export function createSpyWasmModule(): SpyWasmModule { '_setAutoRenderToScreen', 'stringToNewUTF8', '_attachProtoListener', '_attachProtoVectorListener', '_free', '_waitUntilIdle', '_addStringToInputStream', '_registerModelResourcesGraphService', - '_configureAudio', '_malloc', '_addProtoToInputStream' + '_configureAudio', '_malloc', '_addProtoToInputStream', '_getGraphConfig' ]); + spyWasmModule._getGraphConfig.and.callFake(() => { + (spyWasmModule.simpleListeners![CALCULATOR_GRAPH_CONFIG_LISTENER_NAME] as + SimpleListener)( + new CalculatorGraphConfig().serializeBinary(), 0); + }); spyWasmModule.HEAPU8 = jasmine.createSpyObj(['set']); return spyWasmModule; } diff --git a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts index b28817613..2495bf5a9 100644 --- a/mediapipe/tasks/web/text/text_classifier/text_classifier.ts +++ b/mediapipe/tasks/web/text/text_classifier/text_classifier.ts @@ -131,11 +131,9 @@ export class TextClassifier extends TaskRunner { * @return The classification result of the text */ classify(text: string): TextClassifierResult { - // Increment the timestamp by 1 millisecond to guarantee that we send - // monotonically increasing timestamps to the graph. - const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; this.classificationResult = {classifications: []}; - this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); + this.graphRunner.addStringToStream( + text, INPUT_STREAM, this.getSynctheticTimestamp()); this.finishProcessing(); return this.classificationResult; } diff --git a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts index 1034de033..3b7f4f7e4 100644 --- a/mediapipe/tasks/web/text/text_embedder/text_embedder.ts +++ b/mediapipe/tasks/web/text/text_embedder/text_embedder.ts @@ -135,10 +135,8 @@ export class TextEmbedder extends TaskRunner { * @return The embedding resuls of the text */ embed(text: string): TextEmbedderResult { - // Increment the timestamp by 1 millisecond to guarantee that we send - // monotonically increasing timestamps to the graph. - const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; - this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); + this.graphRunner.addStringToStream( + text, INPUT_STREAM, this.getSynctheticTimestamp()); this.finishProcessing(); return this.embeddingResult; } diff --git a/mediapipe/tasks/web/vision/BUILD b/mediapipe/tasks/web/vision/BUILD index a229cbd2a..67db27ddb 100644 --- a/mediapipe/tasks/web/vision/BUILD +++ b/mediapipe/tasks/web/vision/BUILD @@ -19,11 +19,13 @@ mediapipe_files(srcs = [ VISION_LIBS = [ "//mediapipe/tasks/web/core:fileset_resolver", + "//mediapipe/tasks/web/vision/face_stylizer", "//mediapipe/tasks/web/vision/gesture_recognizer", "//mediapipe/tasks/web/vision/hand_landmarker", "//mediapipe/tasks/web/vision/image_classifier", "//mediapipe/tasks/web/vision/image_embedder", "//mediapipe/tasks/web/vision/image_segmenter", + "//mediapipe/tasks/web/vision/interactive_segmenter", "//mediapipe/tasks/web/vision/object_detector", ] diff --git a/mediapipe/tasks/web/vision/README.md b/mediapipe/tasks/web/vision/README.md index 9e86eafd3..a1444e10b 100644 --- a/mediapipe/tasks/web/vision/README.md +++ b/mediapipe/tasks/web/vision/README.md @@ -2,23 +2,57 @@ This package contains the vision tasks for MediaPipe. -## Object Detection +## Face Stylizer -The MediaPipe Object Detector task lets you detect the presence and location of -multiple classes of objects within images or videos. +The MediaPipe Face Stylizer lets you perform face stylization on images. ``` const vision = await FilesetResolver.forVisionTasks( "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" ); -const objectDetector = await ObjectDetector.createFromModelPath(vision, - "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" +const faceStylizer = await FaceStylizer.createFromModelPath(vision, + "model.tflite" ); const image = document.getElementById("image") as HTMLImageElement; -const detections = objectDetector.detect(image); +const stylizedImage = faceStylizer.stylize(image); ``` -For more information, refer to the [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector/web_js) documentation. +## Gesture Recognition + +The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real +time, and provides the recognized hand gesture results along with the landmarks +of the detected hands. You can use this task to recognize specific hand gestures +from a user, and invoke application features that correspond to those gestures. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/gesture_recognizer/gesture_recognizer.task" +); +const image = document.getElementById("image") as HTMLImageElement; +const recognitions = gestureRecognizer.recognize(image); +``` + +## Hand Landmark Detection + +The MediaPipe Hand Landmarker task lets you detect the landmarks of the hands in +an image. You can use this Task to localize key points of the hands and render +visual effects over the hands. + +``` +const vision = await FilesetResolver.forVisionTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" +); +const handLandmarker = await HandLandmarker.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/hand_landmarker/hand_landmarker.task" +); +const image = document.getElementById("image") as HTMLImageElement; +const landmarks = handLandmarker.detect(image); +``` + +For more information, refer to the [Handlandmark Detection](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker/web_js) documentation. ## Image Classification @@ -56,40 +90,39 @@ imageSegmenter.segment(image, (masks, width, height) => { }); ``` -## Gesture Recognition +## Interactive Segmentation -The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real -time, and provides the recognized hand gesture results along with the landmarks -of the detected hands. You can use this task to recognize specific hand gestures -from a user, and invoke application features that correspond to those gestures. +The MediaPipe Interactive Segmenter lets you select a region of interest to +segment an image by. ``` const vision = await FilesetResolver.forVisionTasks( "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" ); -const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision, - "https://storage.googleapis.com/mediapipe-tasks/gesture_recognizer/gesture_recognizer.task" +const interactiveSegmenter = await InteractiveSegmenter.createFromModelPath( + vision, "model.tflite" ); const image = document.getElementById("image") as HTMLImageElement; -const recognitions = gestureRecognizer.recognize(image); +interactiveSegmenter.segment(image, { keypoint: { x: 0.1, y: 0.2 } }, + (masks, width, height) => { ... } +); ``` -## Handlandmark Detection +## Object Detection -The MediaPipe Hand Landmarker task lets you detect the landmarks of the hands in -an image. You can use this Task to localize key points of the hands and render -visual effects over the hands. +The MediaPipe Object Detector task lets you detect the presence and location of +multiple classes of objects within images or videos. ``` const vision = await FilesetResolver.forVisionTasks( "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" ); -const handLandmarker = await HandLandmarker.createFromModelPath(vision, - "https://storage.googleapis.com/mediapipe-tasks/hand_landmarker/hand_landmarker.task" +const objectDetector = await ObjectDetector.createFromModelPath(vision, + "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" ); const image = document.getElementById("image") as HTMLImageElement; -const landmarks = handLandmarker.detect(image); +const detections = objectDetector.detect(image); ``` -For more information, refer to the [Handlandmark Detection](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker/web_js) documentation. +For more information, refer to the [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector/web_js) documentation. diff --git a/mediapipe/tasks/web/vision/core/BUILD b/mediapipe/tasks/web/vision/core/BUILD index a0a008122..cd2954eb3 100644 --- a/mediapipe/tasks/web/vision/core/BUILD +++ b/mediapipe/tasks/web/vision/core/BUILD @@ -21,6 +21,14 @@ mediapipe_ts_declaration( ], ) +mediapipe_ts_declaration( + name = "types", + srcs = ["types.d.ts"], + deps = [ + "//mediapipe/tasks/web/components/containers:keypoint", + ], +) + mediapipe_ts_library( name = "vision_task_runner", srcs = ["vision_task_runner.ts"], @@ -51,6 +59,11 @@ mediapipe_ts_library( ], ) +mediapipe_ts_library( + name = "render_utils", + srcs = ["render_utils.ts"], +) + jasmine_node_test( name = "vision_task_runner_test", deps = [":vision_task_runner_test_lib"], diff --git a/mediapipe/tasks/web/vision/core/render_utils.ts b/mediapipe/tasks/web/vision/core/render_utils.ts new file mode 100644 index 000000000..879e23010 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/render_utils.ts @@ -0,0 +1,77 @@ +/** @fileoverview Utility functions used in the vision demos. */ + +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Pre-baked color table for a maximum of 12 classes. +const CM_ALPHA = 128; +const COLOR_MAP = [ + [0, 0, 0, CM_ALPHA], // class 0 is BG = transparent + [255, 0, 0, CM_ALPHA], // class 1 is red + [0, 255, 0, CM_ALPHA], // class 2 is light green + [0, 0, 255, CM_ALPHA], // class 3 is blue + [255, 255, 0, CM_ALPHA], // class 4 is yellow + [255, 0, 255, CM_ALPHA], // class 5 is light purple / magenta + [0, 255, 255, CM_ALPHA], // class 6 is light blue / aqua + [128, 128, 128, CM_ALPHA], // class 7 is gray + [255, 128, 0, CM_ALPHA], // class 8 is orange + [128, 0, 255, CM_ALPHA], // class 9 is dark purple + [0, 128, 0, CM_ALPHA], // class 10 is dark green + [255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead? +]; + + +/** Helper function to draw a confidence mask */ +export function drawConfidenceMask( + ctx: CanvasRenderingContext2D, image: Float32Array, width: number, + height: number): void { + const uint8ClampedArray = new Uint8ClampedArray(width * height * 4); + for (let i = 0; i < image.length; i++) { + uint8ClampedArray[4 * i] = 128; + uint8ClampedArray[4 * i + 1] = 0; + uint8ClampedArray[4 * i + 2] = 0; + uint8ClampedArray[4 * i + 3] = image[i] * 255; + } + ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0); +} + +/** + * Helper function to draw a category mask. For GPU, we only have F32Arrays + * for now. + */ +export function drawCategoryMask( + ctx: CanvasRenderingContext2D, image: Uint8ClampedArray|Float32Array, + width: number, height: number): void { + const rgbaArray = new Uint8ClampedArray(width * height * 4); + const isFloatArray = image instanceof Float32Array; + for (let i = 0; i < image.length; i++) { + const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i]; + const color = COLOR_MAP[colorIndex]; + + // When we're given a confidence mask by accident, we just log and return. + // TODO: We should fix this. + if (!color) { + console.warn('No color for ', colorIndex); + return; + } + + rgbaArray[4 * i] = color[0]; + rgbaArray[4 * i + 1] = color[1]; + rgbaArray[4 * i + 2] = color[2]; + rgbaArray[4 * i + 3] = color[3]; + } + ctx.putImageData(new ImageData(rgbaArray, width, height), 0, 0); +} diff --git a/mediapipe/tasks/web/vision/core/types.d.ts b/mediapipe/tasks/web/vision/core/types.d.ts new file mode 100644 index 000000000..f0ac08627 --- /dev/null +++ b/mediapipe/tasks/web/vision/core/types.d.ts @@ -0,0 +1,53 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint'; + +/** + * The segmentation tasks return the segmentation result as a Uint8ClampedArray + * (when the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for + * output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved + * for future usage. + */ +export type SegmentationMask = Uint8ClampedArray|Float32Array|WebGLTexture; + +/** + * A callback that receives the computed masks from the segmentation tasks. The + * callback either receives a single element array with a category mask (as a + * `[Uint8ClampedArray]`) or multiple confidence masks (as a `Float32Array[]`). + * The returned data is only valid for the duration of the callback. If + * asynchronous processing is needed, all data needs to be copied before the + * callback returns. + */ +export type SegmentationMaskCallback = + (masks: SegmentationMask[], width: number, height: number) => void; + +/** + * A callback that receives an `ImageData` object from a Vision task. The + * lifetime of the underlying data is limited to the duration of the callback. + * If asynchronous processing is needed, all data needs to be copied before the + * callback returns. + * + * The `WebGLTexture` output type is reserved for future usage. + */ +export type ImageCallback = + (image: ImageData|WebGLTexture, width: number, height: number) => void; + +/** A Region-Of-Interest (ROI) to represent a region within an image. */ +export declare interface RegionOfInterest { + /** The ROI in keypoint format. */ + keypoint: NormalizedKeypoint; +} diff --git a/mediapipe/tasks/web/vision/core/vision_task_runner.ts b/mediapipe/tasks/web/vision/core/vision_task_runner.ts index b3e8ed4db..f19b9f2df 100644 --- a/mediapipe/tasks/web/vision/core/vision_task_runner.ts +++ b/mediapipe/tasks/web/vision/core/vision_task_runner.ts @@ -18,7 +18,7 @@ import {NormalizedRect} from '../../../../framework/formats/rect_pb'; import {TaskRunner} from '../../../../tasks/web/core/task_runner'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {GraphRunner, ImageSource} from '../../../../web/graph_runner/graph_runner'; -import {SupportImage} from '../../../../web/graph_runner/graph_runner_image_lib'; +import {SupportImage, WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; import {SupportModelResourcesGraphService} from '../../../../web/graph_runner/register_model_resources_graph_service'; import {VisionTaskOptions} from './vision_task_options'; @@ -74,11 +74,7 @@ export abstract class VisionTaskRunner extends TaskRunner { 'Task is not initialized with image mode. ' + '\'runningMode\' must be set to \'IMAGE\'.'); } - - // Increment the timestamp by 1 millisecond to guarantee that we send - // monotonically increasing timestamps to the graph. - const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; - this.process(image, imageProcessingOptions, syntheticTimestamp); + this.process(image, imageProcessingOptions, this.getSynctheticTimestamp()); } /** Sends a single video frame to the graph and awaits results. */ @@ -152,6 +148,31 @@ export abstract class VisionTaskRunner extends TaskRunner { imageSource, this.imageStreamName, timestamp ?? performance.now()); this.finishProcessing(); } + + /** Converts the RGB or RGBA Uint8Array of a WasmImage to ImageData. */ + protected convertToImageData(wasmImage: WasmImage): ImageData { + const {data, width, height} = wasmImage; + if (!(data instanceof Uint8ClampedArray)) { + throw new Error( + 'Only Uint8ClampedArray-based images can be converted to ImageData'); + } + + if (data.length === width * height * 4) { + return new ImageData(data, width, height); + } else if (data.length === width * height * 3) { + const rgba = new Uint8ClampedArray(width * height * 4); + for (let i = 0; i < width * height; ++i) { + rgba[4 * i] = data[3 * i]; + rgba[4 * i + 1] = data[3 * i + 1]; + rgba[4 * i + 2] = data[3 * i + 2]; + rgba[4 * i + 3] = 255; + } + return new ImageData(rgba, width, height); + } else { + throw new Error( + `Unsupported channel count: ${data.length / width / height}`); + } + } } diff --git a/mediapipe/tasks/web/vision/face_stylizer/BUILD b/mediapipe/tasks/web/vision/face_stylizer/BUILD new file mode 100644 index 000000000..7716d617f --- /dev/null +++ b/mediapipe/tasks/web/vision/face_stylizer/BUILD @@ -0,0 +1,57 @@ +# This contains the MediaPipe Face Stylizer Task. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "face_stylizer", + srcs = ["face_stylizer.ts"], + deps = [ + ":face_stylizer_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:types", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "face_stylizer_types", + srcs = ["face_stylizer_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "face_stylizer_test_lib", + testonly = True, + srcs = [ + "face_stylizer_test.ts", + ], + deps = [ + ":face_stylizer", + ":face_stylizer_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + ], +) + +jasmine_node_test( + name = "face_stylizer_test", + tags = ["nomsan"], + deps = [":face_stylizer_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts new file mode 100644 index 000000000..47a4ffdfd --- /dev/null +++ b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer.ts @@ -0,0 +1,298 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {FaceStylizerGraphOptions as FaceStylizerGraphOptionsProto} from '../../../../tasks/cc/vision/face_stylizer/proto/face_stylizer_graph_options_pb'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {ImageCallback} from '../../../../tasks/web/vision/core/types'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {FaceStylizerOptions} from './face_stylizer_options'; + +export * from './face_stylizer_options'; +export {ImageSource}; // Used in the public API + +const IMAGE_STREAM = 'image_in'; +const NORM_RECT_STREAM = 'norm_rect'; +const STYLIZED_IMAGE_STREAM = 'stylized_image'; +const FACE_STYLIZER_GRAPH = + 'mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +export {ImageCallback}; + +/** Performs face stylization on images. */ +export class FaceStylizer extends VisionTaskRunner { + private userCallback: ImageCallback = () => {}; + private readonly options: FaceStylizerGraphOptionsProto; + + /** + * Initializes the Wasm runtime and creates a new Face Stylizer from the + * provided options. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param faceStylizerOptions The options for the Face Stylizer. Note + * that either a path to the model asset or a model buffer needs to be + * provided (via `baseOptions`). + */ + static createFromOptions( + wasmFileset: WasmFileset, + faceStylizerOptions: FaceStylizerOptions): Promise { + return VisionTaskRunner.createInstance( + FaceStylizer, /* initializeCanvas= */ true, wasmFileset, + faceStylizerOptions); + } + + /** + * Initializes the Wasm runtime and creates a new Face Stylizer based on + * the provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createInstance( + FaceStylizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new Face Stylizer based on + * the path to the model asset. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createInstance( + FaceStylizer, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_STREAM, + NORM_RECT_STREAM, /* roiAllowed= */ true); + this.options = new FaceStylizerGraphOptionsProto(); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the Face Stylizer. + * + * Calling `setOptions()` with a subset of options only affects those + * options. You can reset an option back to its default value by + * explicitly setting it to `undefined`. + * + * @param options The options for the Face Stylizer. + */ + override setOptions(options: FaceStylizerOptions): Promise { + return super.applyOptions(options); + } + + + /** + * Performs face stylization on the provided single image. The method returns + * synchronously once the callback returns. Only use this method when the + * FaceStylizer is created with the image running mode. + * + * The input image can be of any size. To ensure that the output image has + * reasonable quailty, the stylized output image size is determined by the + * model output size. + * + * @param image An image to process. + * @param callback The callback that is invoked with the stylized image. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + stylize(image: ImageSource, callback: ImageCallback): void; + /** + * Performs face stylization on the provided single image. The method returns + * synchronously once the callback returns. Only use this method when the + * FaceStylizer is created with the image running mode. + * + * The 'imageProcessingOptions' parameter can be used to specify one or all + * of: + * - the rotation to apply to the image before performing stylization, by + * setting its 'rotationDegrees' property. + * - the region-of-interest on which to perform stylization, by setting its + * 'regionOfInterest' property. If not specified, the full image is used. + * If both are specified, the crop around the region-of-interest is extracted + * first, then the specified rotation is applied to the crop. + * + * The input image can be of any size. To ensure that the output image has + * reasonable quailty, the stylized output image size is the smaller of the + * model output size and the size of the 'regionOfInterest' specified in + * 'imageProcessingOptions'. + * + * @param image An image to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param callback The callback that is invoked with the stylized image. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + stylize( + image: ImageSource, imageProcessingOptions: ImageProcessingOptions, + callback: ImageCallback): void; + stylize( + image: ImageSource, + imageProcessingOptionsOrCallback: ImageProcessingOptions|ImageCallback, + callback?: ImageCallback): void { + const imageProcessingOptions = + typeof imageProcessingOptionsOrCallback !== 'function' ? + imageProcessingOptionsOrCallback : + {}; + + this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + imageProcessingOptionsOrCallback : + callback!; + this.processImageData(image, imageProcessingOptions ?? {}); + this.userCallback = () => {}; + } + + /** + * Performs face stylization on the provided video frame. Only use this method + * when the FaceStylizer is created with the video running mode. + * + * The input frame can be of any size. It's required to provide the video + * frame's timestamp (in milliseconds). The input timestamps must be + * monotonically increasing. + * + * To ensure that the output image has reasonable quality, the stylized + * output image size is determined by the model output size. + * + * @param videoFrame A video frame to process. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the stylized image. The + * lifetime of the returned data is only guaranteed for the duration of + * the callback. + */ + stylizeForVideo( + videoFrame: ImageSource, timestamp: number, + callback: ImageCallback): void; + /** + * Performs face stylization on the provided video frame. Only use this + * method when the FaceStylizer is created with the video running mode. + * + * The 'imageProcessingOptions' parameter can be used to specify one or all + * of: + * - the rotation to apply to the image before performing stylization, by + * setting its 'rotationDegrees' property. + * - the region-of-interest on which to perform stylization, by setting its + * 'regionOfInterest' property. If not specified, the full image is used. + * If both are specified, the crop around the region-of-interest is + * extracted first, then the specified rotation is applied to the crop. + * + * The input frame can be of any size. It's required to provide the video + * frame's timestamp (in milliseconds). The input timestamps must be + * monotonically increasing. + * + * To ensure that the output image has reasonable quailty, the stylized + * output image size is the smaller of the model output size and the size of + * the 'regionOfInterest' specified in 'imageProcessingOptions'. + * + * @param videoFrame A video frame to process. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param timestamp The timestamp of the current frame, in ms. + * @param callback The callback that is invoked with the stylized image. The + * lifetime of the returned data is only guaranteed for the duration of + * the callback. + */ + stylizeForVideo( + videoFrame: ImageSource, imageProcessingOptions: ImageProcessingOptions, + timestamp: number, callback: ImageCallback): void; + stylizeForVideo( + videoFrame: ImageSource, + timestampOrImageProcessingOptions: number|ImageProcessingOptions, + timestampOrCallback: number|ImageCallback, + callback?: ImageCallback): void { + const imageProcessingOptions = + typeof timestampOrImageProcessingOptions !== 'number' ? + timestampOrImageProcessingOptions : + {}; + const timestamp = typeof timestampOrImageProcessingOptions === 'number' ? + timestampOrImageProcessingOptions : + timestampOrCallback as number; + + this.userCallback = typeof timestampOrCallback === 'function' ? + timestampOrCallback : + callback!; + this.processVideoData(videoFrame, imageProcessingOptions, timestamp); + this.userCallback = () => {}; + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_STREAM); + graphConfig.addInputStream(NORM_RECT_STREAM); + graphConfig.addOutputStream(STYLIZED_IMAGE_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + FaceStylizerGraphOptionsProto.ext, this.options); + + const segmenterNode = new CalculatorGraphConfig.Node(); + segmenterNode.setCalculator(FACE_STYLIZER_GRAPH); + segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM); + segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); + segmenterNode.addOutputStream('STYLIZED_IMAGE:' + STYLIZED_IMAGE_STREAM); + segmenterNode.setOptions(calculatorOptions); + + graphConfig.addNode(segmenterNode); + + this.graphRunner.attachImageListener( + STYLIZED_IMAGE_STREAM, (image, timestamp) => { + const imageData = this.convertToImageData(image); + this.userCallback(imageData, image.width, image.height); + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener( + STYLIZED_IMAGE_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } +} + + diff --git a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_options.d.ts b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_options.d.ts new file mode 100644 index 000000000..38f5028c0 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_options.d.ts @@ -0,0 +1,20 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {VisionTaskOptions} from '../../../../tasks/web/vision/core/vision_task_options'; + +/** Options to configure the MediaPipe Face Stylizer Task */ +export interface FaceStylizerOptions extends VisionTaskOptions {} diff --git a/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_test.ts b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_test.ts new file mode 100644 index 000000000..72d540797 --- /dev/null +++ b/mediapipe/tasks/web/vision/face_stylizer/face_stylizer_test.ts @@ -0,0 +1,114 @@ +/** + * Copyright 2022 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; + +import {FaceStylizer} from './face_stylizer'; + +class FaceStylizerFake extends FaceStylizer implements MediapipeTasksFake { + calculatorName = 'mediapipe.tasks.vision.face_stylizer.FaceStylizerGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + imageListener: ((images: WasmImage, timestamp: number) => void)|undefined; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachImageListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('stylized_image'); + this.imageListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + } +} + +describe('FaceStylizer', () => { + let faceStylizer: FaceStylizerFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + faceStylizer = new FaceStylizerFake(); + await faceStylizer.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(faceStylizer); + verifyListenersRegistered(faceStylizer); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await faceStylizer.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + faceStylizer, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + it('invokes callback', (done) => { + if (typeof ImageData === 'undefined') { + console.log('ImageData tests are not supported on Node'); + done(); + return; + } + + // Pass the test data to our listener + faceStylizer.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(faceStylizer); + faceStylizer.imageListener! + ({data: new Uint8ClampedArray([1, 1, 1, 1]), width: 1, height: 1}, + /* timestamp= */ 1337); + }); + + // Invoke the face stylizeer + faceStylizer.stylize({} as HTMLImageElement, (image, width, height) => { + expect(faceStylizer.fakeWasmModule._waitUntilIdle).toHaveBeenCalled(); + expect(image).toBeInstanceOf(ImageData); + expect(width).toEqual(1); + expect(height).toEqual(1); + done(); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts index dd8fc9548..9e9728af8 100644 --- a/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts +++ b/mediapipe/tasks/web/vision/gesture_recognizer/gesture_recognizer_options.d.ts @@ -44,7 +44,7 @@ export declare interface GestureRecognizerOptions extends VisionTaskOptions { minTrackingConfidence?: number|undefined; /** - * Sets the optional `ClassifierOptions` controling the canned gestures + * Sets the optional `ClassifierOptions` controlling the canned gestures * classifier, such as score threshold, allow list and deny list of gestures. * The categories for canned gesture * classifiers are: ["None", "Closed_Fist", "Open_Palm", "Pointing_Up", diff --git a/mediapipe/tasks/web/vision/image_segmenter/BUILD b/mediapipe/tasks/web/vision/image_segmenter/BUILD index d15fe63f1..a4b9008dd 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/BUILD +++ b/mediapipe/tasks/web/vision/image_segmenter/BUILD @@ -15,12 +15,14 @@ mediapipe_ts_library( "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/framework:calculator_options_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/calculators:tensors_to_segmentation_calculator_jspb_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto", "//mediapipe/tasks/web/core", "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:types", "//mediapipe/tasks/web/vision/core:vision_task_runner", - "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + "//mediapipe/util:label_map_jspb_proto", "//mediapipe/web/graph_runner:graph_runner_ts", ], ) diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts index 3a2e9f2af..cb192b0ce 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter.ts @@ -17,43 +17,30 @@ import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {TensorsToSegmentationCalculatorOptions} from '../../../../tasks/cc/vision/image_segmenter/calculators/tensors_to_segmentation_calculator_pb'; import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb'; import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {LabelMapItem} from '../../../../util/label_map_pb'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; // Placeholder for internal dependency on trusted resource url import {ImageSegmenterOptions} from './image_segmenter_options'; export * from './image_segmenter_options'; +export {SegmentationMask, SegmentationMaskCallback}; export {ImageSource}; // Used in the public API -/** - * The ImageSegmenter returns the segmentation result as a Uint8Array (when - * the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for - * output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved - * for future usage. - */ -export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture; - -/** - * A callback that receives the computed masks from the image segmenter. The - * callback either receives a single element array with a category mask (as a - * `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`). - * The returned data is only valid for the duration of the callback. If - * asynchronous processing is needed, all data needs to be copied before the - * callback returns. - */ -export type SegmentationMaskCallback = - (masks: SegmentationMask[], width: number, height: number) => void; - const IMAGE_STREAM = 'image_in'; const NORM_RECT_STREAM = 'norm_rect'; const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; -const IMAGEA_SEGMENTER_GRAPH = +const IMAGE_SEGMENTER_GRAPH = 'mediapipe.tasks.vision.image_segmenter.ImageSegmenterGraph'; +const TENSORS_TO_SEGMENTATION_CALCULATOR_NAME = + 'mediapipe.tasks.TensorsToSegmentationCalculator'; // The OSS JS API does not support the builder pattern. // tslint:disable:jspb-use-builder-pattern @@ -61,6 +48,7 @@ const IMAGEA_SEGMENTER_GRAPH = /** Performs image segmentation on images. */ export class ImageSegmenter extends VisionTaskRunner { private userCallback: SegmentationMaskCallback = () => {}; + private labels: string[] = []; private readonly options: ImageSegmenterGraphOptionsProto; private readonly segmenterOptions: SegmenterOptionsProto; @@ -163,6 +151,39 @@ export class ImageSegmenter extends VisionTaskRunner { return super.applyOptions(options); } + protected override onGraphRefreshed(): void { + this.populateLabels(); + } + + /** + * Populate the labelMap in TensorsToSegmentationCalculator to labels field. + * @throws Exception if there is an error during finding + * TensorsToSegmentationCalculator. + */ + private populateLabels(): void { + const graphConfig = this.getCalculatorGraphConfig(); + const tensorsToSegmentationCalculators = graphConfig.getNodeList().filter( + (n: CalculatorGraphConfig.Node) => + n.getName().includes(TENSORS_TO_SEGMENTATION_CALCULATOR_NAME)); + + this.labels = []; + if (tensorsToSegmentationCalculators.length > 1) { + throw new Error(`The graph has more than one ${ + TENSORS_TO_SEGMENTATION_CALCULATOR_NAME}.`); + } else if (tensorsToSegmentationCalculators.length === 1) { + const labelItems = + tensorsToSegmentationCalculators[0] + .getOptions() + ?.getExtension(TensorsToSegmentationCalculatorOptions.ext) + ?.getLabelItemsMap() ?? + new Map(); + labelItems.forEach((value, index) => { + // tslint:disable-next-line:no-unnecessary-type-assertion + this.labels[Number(index)] = value.getName()!; + }); + } + } + /** * Performs image segmentation on the provided single image and invokes the * callback with the response. The method returns synchronously once the @@ -208,6 +229,21 @@ export class ImageSegmenter extends VisionTaskRunner { this.userCallback = () => {}; } + /** + * Get the category label list of the ImageSegmenter can recognize. For + * `CATEGORY_MASK` type, the index in the category mask corresponds to the + * category in the label list. For `CONFIDENCE_MASK` type, the output mask + * list at index corresponds to the category in the label list. + * + * If there is no labelmap provided in the model file, empty label array is + * returned. + * + * @return The labels used by the current model. + */ + getLabels(): string[] { + return this.labels; + } + /** * Performs image segmentation on the provided video frame and invokes the * callback with the response. The method returns synchronously once the @@ -272,7 +308,7 @@ export class ImageSegmenter extends VisionTaskRunner { ImageSegmenterGraphOptionsProto.ext, this.options); const segmenterNode = new CalculatorGraphConfig.Node(); - segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH); + segmenterNode.setCalculator(IMAGE_SEGMENTER_GRAPH); segmenterNode.addInputStream('IMAGE:' + IMAGE_STREAM); segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_STREAM); segmenterNode.addOutputStream( diff --git a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts index aa81be025..4cf27b9a5 100644 --- a/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts +++ b/mediapipe/tasks/web/vision/image_segmenter/image_segmenter_test.ts @@ -159,7 +159,7 @@ describe('ImageSegmenter', () => { }); it('supports category masks', (done) => { - const mask = new Uint8Array([1, 2, 3, 4]); + const mask = new Uint8ClampedArray([1, 2, 3, 4]); // Pass the test data to our listener imageSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { diff --git a/mediapipe/tasks/web/vision/index.ts b/mediapipe/tasks/web/vision/index.ts index 5a87c7a82..7fca725ec 100644 --- a/mediapipe/tasks/web/vision/index.ts +++ b/mediapipe/tasks/web/vision/index.ts @@ -15,29 +15,35 @@ */ import {FilesetResolver as FilesetResolverImpl} from '../../../tasks/web/core/fileset_resolver'; +import {FaceStylizer as FaceStylizerImpl} from '../../../tasks/web/vision/face_stylizer/face_stylizer'; import {GestureRecognizer as GestureRecognizerImpl} from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; import {HandLandmarker as HandLandmarkerImpl} from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; import {ImageClassifier as ImageClassifierImpl} from '../../../tasks/web/vision/image_classifier/image_classifier'; import {ImageEmbedder as ImageEmbedderImpl} from '../../../tasks/web/vision/image_embedder/image_embedder'; import {ImageSegmenter as ImageSegementerImpl} from '../../../tasks/web/vision/image_segmenter/image_segmenter'; +import {InteractiveSegmenter as InteractiveSegmenterImpl} from '../../../tasks/web/vision/interactive_segmenter/interactive_segmenter'; import {ObjectDetector as ObjectDetectorImpl} from '../../../tasks/web/vision/object_detector/object_detector'; // Declare the variables locally so that Rollup in OSS includes them explicitly // as exports. const FilesetResolver = FilesetResolverImpl; +const FaceStylizer = FaceStylizerImpl; const GestureRecognizer = GestureRecognizerImpl; const HandLandmarker = HandLandmarkerImpl; const ImageClassifier = ImageClassifierImpl; const ImageEmbedder = ImageEmbedderImpl; const ImageSegmenter = ImageSegementerImpl; +const InteractiveSegmenter = InteractiveSegmenterImpl; const ObjectDetector = ObjectDetectorImpl; export { FilesetResolver, + FaceStylizer, GestureRecognizer, HandLandmarker, ImageClassifier, ImageEmbedder, ImageSegmenter, + InteractiveSegmenter, ObjectDetector }; diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/BUILD b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD new file mode 100644 index 000000000..a4a3f27c9 --- /dev/null +++ b/mediapipe/tasks/web/vision/interactive_segmenter/BUILD @@ -0,0 +1,62 @@ +# This contains the MediaPipe Interactive Segmenter Task. + +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_ts_declaration", "mediapipe_ts_library") +load("@npm//@bazel/jasmine:index.bzl", "jasmine_node_test") + +package(default_visibility = ["//mediapipe/tasks:internal"]) + +licenses(["notice"]) + +mediapipe_ts_library( + name = "interactive_segmenter", + srcs = ["interactive_segmenter.ts"], + deps = [ + ":interactive_segmenter_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/framework:calculator_options_jspb_proto", + "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_jspb_proto", + "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto", + "//mediapipe/tasks/web/components/containers:keypoint", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/vision/core:image_processing_options", + "//mediapipe/tasks/web/vision/core:types", + "//mediapipe/tasks/web/vision/core:vision_task_runner", + "//mediapipe/util:color_jspb_proto", + "//mediapipe/util:render_data_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_ts", + ], +) + +mediapipe_ts_declaration( + name = "interactive_segmenter_types", + srcs = ["interactive_segmenter_options.d.ts"], + deps = [ + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:classifier_options", + "//mediapipe/tasks/web/vision/core:vision_task_options", + ], +) + +mediapipe_ts_library( + name = "interactive_segmenter_test_lib", + testonly = True, + srcs = [ + "interactive_segmenter_test.ts", + ], + deps = [ + ":interactive_segmenter", + ":interactive_segmenter_types", + "//mediapipe/framework:calculator_jspb_proto", + "//mediapipe/tasks/web/core", + "//mediapipe/tasks/web/core:task_runner_test_utils", + "//mediapipe/util:render_data_jspb_proto", + "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", + ], +) + +jasmine_node_test( + name = "interactive_segmenter_test", + tags = ["nomsan"], + deps = [":interactive_segmenter_test_lib"], +) diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts new file mode 100644 index 000000000..1499a4c0c --- /dev/null +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter.ts @@ -0,0 +1,306 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {CalculatorOptions} from '../../../../framework/calculator_options_pb'; +import {BaseOptions as BaseOptionsProto} from '../../../../tasks/cc/core/proto/base_options_pb'; +import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/image_segmenter_graph_options_pb'; +import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; +import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; +import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; +import {RegionOfInterest, SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; +import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; +import {Color as ColorProto} from '../../../../util/color_pb'; +import {RenderAnnotation as RenderAnnotationProto, RenderData as RenderDataProto} from '../../../../util/render_data_pb'; +import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; +// Placeholder for internal dependency on trusted resource url + +import {InteractiveSegmenterOptions} from './interactive_segmenter_options'; + +export * from './interactive_segmenter_options'; +export {SegmentationMask, SegmentationMaskCallback, RegionOfInterest}; +export {ImageSource}; + +const IMAGE_IN_STREAM = 'image_in'; +const NORM_RECT_IN_STREAM = 'norm_rect_in'; +const ROI_IN_STREAM = 'roi_in'; +const IMAGE_OUT_STREAM = 'image_out'; +const IMAGEA_SEGMENTER_GRAPH = + 'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph'; + +// The OSS JS API does not support the builder pattern. +// tslint:disable:jspb-use-builder-pattern + +/** + * Performs interactive segmentation on images. + * + * Users can represent user interaction through `RegionOfInterest`, which gives + * a hint to InteractiveSegmenter to perform segmentation focusing on the given + * region of interest. + * + * The API expects a TFLite model with mandatory TFLite Model Metadata. + * + * Input tensor: + * (kTfLiteUInt8/kTfLiteFloat32) + * - image input of size `[batch x height x width x channels]`. + * - batch inference is not supported (`batch` is required to be 1). + * - RGB inputs is supported (`channels` is required to be 3). + * - if type is kTfLiteFloat32, NormalizationOptions are required to be + * attached to the metadata for input normalization. + * Output tensors: + * (kTfLiteUInt8/kTfLiteFloat32) + * - list of segmented masks. + * - if `output_type` is CATEGORY_MASK, uint8 Image, Image vector of size 1. + * - if `output_type` is CONFIDENCE_MASK, float32 Image list of size + * `channels`. + * - batch is always 1 + */ +export class InteractiveSegmenter extends VisionTaskRunner { + private userCallback: SegmentationMaskCallback = () => {}; + private readonly options: ImageSegmenterGraphOptionsProto; + private readonly segmenterOptions: SegmenterOptionsProto; + + /** + * Initializes the Wasm runtime and creates a new interactive segmenter from + * the provided options. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param interactiveSegmenterOptions The options for the Interactive + * Segmenter. Note that either a path to the model asset or a model buffer + * needs to be provided (via `baseOptions`). + * @return A new `InteractiveSegmenter`. + */ + static createFromOptions( + wasmFileset: WasmFileset, + interactiveSegmenterOptions: InteractiveSegmenterOptions): + Promise { + return VisionTaskRunner.createInstance( + InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset, + interactiveSegmenterOptions); + } + + /** + * Initializes the Wasm runtime and creates a new interactive segmenter based + * on the provided model asset buffer. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetBuffer A binary representation of the model. + * @return A new `InteractiveSegmenter`. + */ + static createFromModelBuffer( + wasmFileset: WasmFileset, + modelAssetBuffer: Uint8Array): Promise { + return VisionTaskRunner.createInstance( + InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetBuffer}}); + } + + /** + * Initializes the Wasm runtime and creates a new interactive segmenter based + * on the path to the model asset. + * @param wasmFileset A configuration object that provides the location of + * the Wasm binary and its loader. + * @param modelAssetPath The path to the model asset. + * @return A new `InteractiveSegmenter`. + */ + static createFromModelPath( + wasmFileset: WasmFileset, + modelAssetPath: string): Promise { + return VisionTaskRunner.createInstance( + InteractiveSegmenter, /* initializeCanvas= */ true, wasmFileset, + {baseOptions: {modelAssetPath}}); + } + + /** @hideconstructor */ + constructor( + wasmModule: WasmModule, + glCanvas?: HTMLCanvasElement|OffscreenCanvas|null) { + super( + new VisionGraphRunner(wasmModule, glCanvas), IMAGE_IN_STREAM, + NORM_RECT_IN_STREAM, /* roiAllowed= */ false); + this.options = new ImageSegmenterGraphOptionsProto(); + this.segmenterOptions = new SegmenterOptionsProto(); + this.options.setSegmenterOptions(this.segmenterOptions); + this.options.setBaseOptions(new BaseOptionsProto()); + } + + + protected override get baseOptions(): BaseOptionsProto { + return this.options.getBaseOptions()!; + } + + protected override set baseOptions(proto: BaseOptionsProto) { + this.options.setBaseOptions(proto); + } + + /** + * Sets new options for the interactive segmenter. + * + * Calling `setOptions()` with a subset of options only affects those + * options. You can reset an option back to its default value by + * explicitly setting it to `undefined`. + * + * @param options The options for the interactive segmenter. + * @return A Promise that resolves when the settings have been applied. + */ + override setOptions(options: InteractiveSegmenterOptions): Promise { + if (options.outputType === 'CONFIDENCE_MASK') { + this.segmenterOptions.setOutputType( + SegmenterOptionsProto.OutputType.CONFIDENCE_MASK); + } else { + this.segmenterOptions.setOutputType( + SegmenterOptionsProto.OutputType.CATEGORY_MASK); + } + + return super.applyOptions(options); + } + + /** + * Performs interactive segmentation on the provided single image and invokes + * the callback with the response. The `roi` parameter is used to represent a + * user's region of interest for segmentation. + * + * If the output_type is `CATEGORY_MASK`, the callback is invoked with vector + * of images that represent per-category segmented image mask. If the + * output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of + * images that contains only one confidence image mask. The method returns + * synchronously once the callback returns. + * + * @param image An image to process. + * @param roi The region of interest for segmentation. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segment( + image: ImageSource, roi: RegionOfInterest, + callback: SegmentationMaskCallback): void; + /** + * Performs interactive segmentation on the provided single image and invokes + * the callback with the response. The `roi` parameter is used to represent a + * user's region of interest for segmentation. + * + * The 'image_processing_options' parameter can be used to specify the + * rotation to apply to the image before performing segmentation, by setting + * its 'rotationDegrees' field. Note that specifying a region-of-interest + * using the 'regionOfInterest' field is NOT supported and will result in an + * error. + * + * If the output_type is `CATEGORY_MASK`, the callback is invoked with vector + * of images that represent per-category segmented image mask. If the + * output_type is `CONFIDENCE_MASK`, the callback is invoked with a vector of + * images that contains only one confidence image mask. The method returns + * synchronously once the callback returns. + * + * @param image An image to process. + * @param roi The region of interest for segmentation. + * @param imageProcessingOptions the `ImageProcessingOptions` specifying how + * to process the input image before running inference. + * @param callback The callback that is invoked with the segmented masks. The + * lifetime of the returned data is only guaranteed for the duration of the + * callback. + */ + segment( + image: ImageSource, roi: RegionOfInterest, + imageProcessingOptions: ImageProcessingOptions, + callback: SegmentationMaskCallback): void; + segment( + image: ImageSource, roi: RegionOfInterest, + imageProcessingOptionsOrCallback: ImageProcessingOptions| + SegmentationMaskCallback, + callback?: SegmentationMaskCallback): void { + const imageProcessingOptions = + typeof imageProcessingOptionsOrCallback !== 'function' ? + imageProcessingOptionsOrCallback : + {}; + + this.userCallback = typeof imageProcessingOptionsOrCallback === 'function' ? + imageProcessingOptionsOrCallback : + callback!; + + this.processRenderData(roi, this.getSynctheticTimestamp()); + this.processImageData(image, imageProcessingOptions); + this.userCallback = () => {}; + } + + /** Updates the MediaPipe graph configuration. */ + protected override refreshGraph(): void { + const graphConfig = new CalculatorGraphConfig(); + graphConfig.addInputStream(IMAGE_IN_STREAM); + graphConfig.addInputStream(ROI_IN_STREAM); + graphConfig.addInputStream(NORM_RECT_IN_STREAM); + graphConfig.addOutputStream(IMAGE_OUT_STREAM); + + const calculatorOptions = new CalculatorOptions(); + calculatorOptions.setExtension( + ImageSegmenterGraphOptionsProto.ext, this.options); + + const segmenterNode = new CalculatorGraphConfig.Node(); + segmenterNode.setCalculator(IMAGEA_SEGMENTER_GRAPH); + segmenterNode.addInputStream('IMAGE:' + IMAGE_IN_STREAM); + segmenterNode.addInputStream('ROI:' + ROI_IN_STREAM); + segmenterNode.addInputStream('NORM_RECT:' + NORM_RECT_IN_STREAM); + segmenterNode.addOutputStream('GROUPED_SEGMENTATION:' + IMAGE_OUT_STREAM); + segmenterNode.setOptions(calculatorOptions); + + graphConfig.addNode(segmenterNode); + + this.graphRunner.attachImageVectorListener( + IMAGE_OUT_STREAM, (masks, timestamp) => { + if (masks.length === 0) { + this.userCallback([], 0, 0); + } else { + this.userCallback( + masks.map(m => m.data), masks[0].width, masks[0].height); + } + this.setLatestOutputTimestamp(timestamp); + }); + this.graphRunner.attachEmptyPacketListener(IMAGE_OUT_STREAM, timestamp => { + this.setLatestOutputTimestamp(timestamp); + }); + + const binaryGraph = graphConfig.serializeBinary(); + this.setGraph(new Uint8Array(binaryGraph), /* isBinary= */ true); + } + + /** + * Converts the user-facing RegionOfInterest message to the RenderData proto + * and sends it to the graph + */ + private processRenderData(roi: RegionOfInterest, timestamp: number): void { + const renderData = new RenderDataProto(); + + const renderAnnotation = new RenderAnnotationProto(); + + const color = new ColorProto(); + color.setR(255); + renderAnnotation.setColor(color); + + const point = new RenderAnnotationProto.Point(); + point.setNormalized(true); + point.setX(roi.keypoint.x); + point.setY(roi.keypoint.y); + renderAnnotation.setPoint(point); + + renderData.addRenderAnnotations(renderAnnotation); + + this.graphRunner.addProtoToStream( + renderData.serializeBinary(), 'mediapipe.RenderData', ROI_IN_STREAM, + timestamp); + } +} + + diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts new file mode 100644 index 000000000..beb43cd81 --- /dev/null +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_options.d.ts @@ -0,0 +1,36 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +import {TaskRunnerOptions} from '../../../../tasks/web/core/task_runner_options'; + +/** Options to configure the MediaPipe Interactive Segmenter Task */ +export interface InteractiveSegmenterOptions extends TaskRunnerOptions { + /** + * The output type of segmentation results. + * + * The two supported modes are: + * - Category Mask: Gives a single output mask where each pixel represents + * the class which the pixel in the original image was + * predicted to belong to. + * - Confidence Mask: Gives a list of output masks (one for each class). For + * each mask, the pixel represents the prediction + * confidence, usually in the [0.0, 0.1] range. + * + * Defaults to `CATEGORY_MASK`. + */ + outputType?: 'CATEGORY_MASK'|'CONFIDENCE_MASK'|undefined; +} diff --git a/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts new file mode 100644 index 000000000..d6e3a97a5 --- /dev/null +++ b/mediapipe/tasks/web/vision/interactive_segmenter/interactive_segmenter_test.ts @@ -0,0 +1,214 @@ +/** + * Copyright 2023 The MediaPipe Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +// Placeholder for internal dependency on encodeByteArray +import {CalculatorGraphConfig} from '../../../../framework/calculator_pb'; +import {addJasmineCustomFloatEqualityTester, createSpyWasmModule, MediapipeTasksFake, SpyWasmModule, verifyGraph, verifyListenersRegistered} from '../../../../tasks/web/core/task_runner_test_utils'; +import {RenderData as RenderDataProto} from '../../../../util/render_data_pb'; +import {WasmImage} from '../../../../web/graph_runner/graph_runner_image_lib'; + +import {InteractiveSegmenter, RegionOfInterest} from './interactive_segmenter'; + + +const ROI: RegionOfInterest = { + keypoint: {x: 0.1, y: 0.2} +}; + +class InteractiveSegmenterFake extends InteractiveSegmenter implements + MediapipeTasksFake { + calculatorName = + 'mediapipe.tasks.vision.interactive_segmenter.InteractiveSegmenterGraph'; + attachListenerSpies: jasmine.Spy[] = []; + graph: CalculatorGraphConfig|undefined; + + fakeWasmModule: SpyWasmModule; + imageVectorListener: + ((images: WasmImage[], timestamp: number) => void)|undefined; + lastRoi?: RenderDataProto; + + constructor() { + super(createSpyWasmModule(), /* glCanvas= */ null); + this.fakeWasmModule = + this.graphRunner.wasmModule as unknown as SpyWasmModule; + + this.attachListenerSpies[0] = + spyOn(this.graphRunner, 'attachImageVectorListener') + .and.callFake((stream, listener) => { + expect(stream).toEqual('image_out'); + this.imageVectorListener = listener; + }); + spyOn(this.graphRunner, 'setGraph').and.callFake(binaryGraph => { + this.graph = CalculatorGraphConfig.deserializeBinary(binaryGraph); + }); + spyOn(this.graphRunner, 'addGpuBufferAsImageToStream'); + + spyOn(this.graphRunner, 'addProtoToStream') + .and.callFake((data, protoName, stream) => { + if (stream === 'roi_in') { + expect(protoName).toEqual('mediapipe.RenderData'); + this.lastRoi = RenderDataProto.deserializeBinary(data); + } + }); + } +} + +describe('InteractiveSegmenter', () => { + let interactiveSegmenter: InteractiveSegmenterFake; + + beforeEach(async () => { + addJasmineCustomFloatEqualityTester(); + interactiveSegmenter = new InteractiveSegmenterFake(); + await interactiveSegmenter.setOptions( + {baseOptions: {modelAssetBuffer: new Uint8Array([])}}); + }); + + it('initializes graph', async () => { + verifyGraph(interactiveSegmenter); + verifyListenersRegistered(interactiveSegmenter); + }); + + it('reloads graph when settings are changed', async () => { + await interactiveSegmenter.setOptions({outputType: 'CATEGORY_MASK'}); + verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 1]); + verifyListenersRegistered(interactiveSegmenter); + + await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + verifyGraph(interactiveSegmenter, [['segmenterOptions', 'outputType'], 2]); + verifyListenersRegistered(interactiveSegmenter); + }); + + it('can use custom models', async () => { + const newModel = new Uint8Array([0, 1, 2, 3, 4]); + const newModelBase64 = Buffer.from(newModel).toString('base64'); + await interactiveSegmenter.setOptions({ + baseOptions: { + modelAssetBuffer: newModel, + } + }); + + verifyGraph( + interactiveSegmenter, + /* expectedCalculatorOptions= */ undefined, + /* expectedBaseOptions= */ + [ + 'modelAsset', { + fileContent: newModelBase64, + fileName: undefined, + fileDescriptorMeta: undefined, + filePointerMeta: undefined + } + ]); + }); + + + describe('setOptions()', () => { + const fieldPath = ['segmenterOptions', 'outputType']; + + it(`can set outputType`, async () => { + await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + verifyGraph(interactiveSegmenter, [fieldPath, 2]); + }); + + it(`can clear outputType`, async () => { + await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + verifyGraph(interactiveSegmenter, [fieldPath, 2]); + await interactiveSegmenter.setOptions({outputType: undefined}); + verifyGraph(interactiveSegmenter, [fieldPath, 1]); + }); + }); + + it('doesn\'t support region of interest', () => { + expect(() => { + interactiveSegmenter.segment( + {} as HTMLImageElement, ROI, + {regionOfInterest: {left: 0, right: 0, top: 0, bottom: 0}}, () => {}); + }).toThrowError('This task doesn\'t support region-of-interest.'); + }); + + it('sends region-of-interest', (done) => { + interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + expect(interactiveSegmenter.lastRoi).toBeDefined(); + expect(interactiveSegmenter.lastRoi!.toObject().renderAnnotationsList![0]) + .toEqual(jasmine.objectContaining({ + color: {r: 255, b: undefined, g: undefined}, + })); + done(); + }); + + interactiveSegmenter.segment({} as HTMLImageElement, ROI, () => {}); + }); + + it('supports category masks', (done) => { + const mask = new Uint8ClampedArray([1, 2, 3, 4]); + + // Pass the test data to our listener + interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(interactiveSegmenter); + interactiveSegmenter.imageVectorListener!( + [ + {data: mask, width: 2, height: 2}, + ], + /* timestamp= */ 1337); + }); + + // Invoke the image segmenter + interactiveSegmenter.segment( + {} as HTMLImageElement, ROI, (masks, width, height) => { + expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) + .toHaveBeenCalled(); + expect(masks).toHaveSize(1); + expect(masks[0]).toEqual(mask); + expect(width).toEqual(2); + expect(height).toEqual(2); + done(); + }); + }); + + it('supports confidence masks', async () => { + const mask1 = new Float32Array([0.1, 0.2, 0.3, 0.4]); + const mask2 = new Float32Array([0.5, 0.6, 0.7, 0.8]); + + await interactiveSegmenter.setOptions({outputType: 'CONFIDENCE_MASK'}); + + // Pass the test data to our listener + interactiveSegmenter.fakeWasmModule._waitUntilIdle.and.callFake(() => { + verifyListenersRegistered(interactiveSegmenter); + interactiveSegmenter.imageVectorListener!( + [ + {data: mask1, width: 2, height: 2}, + {data: mask2, width: 2, height: 2}, + ], + 1337); + }); + + return new Promise(resolve => { + // Invoke the image segmenter + interactiveSegmenter.segment( + {} as HTMLImageElement, ROI, (masks, width, height) => { + expect(interactiveSegmenter.fakeWasmModule._waitUntilIdle) + .toHaveBeenCalled(); + expect(masks).toHaveSize(2); + expect(masks[0]).toEqual(mask1); + expect(masks[1]).toEqual(mask2); + expect(width).toEqual(2); + expect(height).toEqual(2); + resolve(); + }); + }); + }); +}); diff --git a/mediapipe/tasks/web/vision/types.ts b/mediapipe/tasks/web/vision/types.ts index b9d951f60..7836192a0 100644 --- a/mediapipe/tasks/web/vision/types.ts +++ b/mediapipe/tasks/web/vision/types.ts @@ -15,9 +15,11 @@ */ export * from '../../../tasks/web/core/fileset_resolver'; +export * from '../../../tasks/web/vision/face_stylizer/face_stylizer'; export * from '../../../tasks/web/vision/gesture_recognizer/gesture_recognizer'; export * from '../../../tasks/web/vision/hand_landmarker/hand_landmarker'; export * from '../../../tasks/web/vision/image_classifier/image_classifier'; export * from '../../../tasks/web/vision/image_embedder/image_embedder'; export * from '../../../tasks/web/vision/image_segmenter/image_segmenter'; +export * from '../../../tasks/web/vision/interactive_segmenter/interactive_segmenter'; export * from '../../../tasks/web/vision/object_detector/object_detector'; diff --git a/mediapipe/util/BUILD b/mediapipe/util/BUILD index 59d03361e..9c9c19b2b 100644 --- a/mediapipe/util/BUILD +++ b/mediapipe/util/BUILD @@ -198,15 +198,15 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":resource_util_custom", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "//mediapipe/framework/deps:file_path", + "//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:logging", "//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:singleton", "//mediapipe/framework/port:status", "//mediapipe/framework/port:statusor", - "//mediapipe/framework/port:file_helpers", - "@com_google_absl//absl/strings", ] + select({ "//conditions:default": [ "@com_google_absl//absl/flags:flag", @@ -253,13 +253,13 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:variant", "//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location_opencv", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:variant", - "//mediapipe/framework/port:status", "//mediapipe/framework/port:map_util", + "//mediapipe/framework/port:status", ] + select({ "//conditions:default": [ "@org_tensorflow//tensorflow/core:framework", diff --git a/mediapipe/util/frame_buffer/BUILD b/mediapipe/util/frame_buffer/BUILD new file mode 100644 index 000000000..27343d6df --- /dev/null +++ b/mediapipe/util/frame_buffer/BUILD @@ -0,0 +1,106 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//mediapipe/util/frame_buffer:__subpackages__"]) + +cc_library( + name = "frame_buffer_util", + srcs = ["frame_buffer_util.cc"], + hdrs = ["frame_buffer_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":buffer", + "//mediapipe/framework/formats:frame_buffer", + "//mediapipe/framework/port:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "frame_buffer_util_test", + srcs = [ + "frame_buffer_util_test.cc", + ], + deps = [ + ":frame_buffer_util", + "//mediapipe/framework/formats:frame_buffer", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/port:status", + ], +) + +cc_library( + name = "buffer", + srcs = [ + "buffer_common.cc", + "gray_buffer.cc", + "rgb_buffer.cc", + "yuv_buffer.cc", + ], + hdrs = [ + "buffer_common.h", + "gray_buffer.h", + "rgb_buffer.h", + "yuv_buffer.h", + ], + deps = [ + "//mediapipe/util/frame_buffer/halide:gray_flip_halide", + "//mediapipe/util/frame_buffer/halide:gray_resize_halide", + "//mediapipe/util/frame_buffer/halide:gray_rotate_halide", + "//mediapipe/util/frame_buffer/halide:rgb_flip_halide", + "//mediapipe/util/frame_buffer/halide:rgb_gray_halide", + "//mediapipe/util/frame_buffer/halide:rgb_resize_halide", + "//mediapipe/util/frame_buffer/halide:rgb_rgb_halide", + "//mediapipe/util/frame_buffer/halide:rgb_rotate_halide", + "//mediapipe/util/frame_buffer/halide:rgb_yuv_halide", + "//mediapipe/util/frame_buffer/halide:yuv_flip_halide", + "//mediapipe/util/frame_buffer/halide:yuv_resize_halide", + "//mediapipe/util/frame_buffer/halide:yuv_rgb_halide", + "//mediapipe/util/frame_buffer/halide:yuv_rotate_halide", + "@halide//:runtime", + ], +) + +# Tests: +cc_test( + name = "rgb_buffer_test", + srcs = ["rgb_buffer_test.cc"], + deps = [ + ":buffer", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/log", + ], +) + +cc_test( + name = "yuv_buffer_test", + srcs = ["yuv_buffer_test.cc"], + deps = [ + ":buffer", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/log", + ], +) + +cc_test( + name = "gray_buffer_test", + srcs = ["gray_buffer_test.cc"], + deps = [ + ":buffer", + "//mediapipe/framework/port:gtest_main", + "@com_google_absl//absl/log", + ], +) diff --git a/mediapipe/util/frame_buffer/buffer_common.cc b/mediapipe/util/frame_buffer/buffer_common.cc new file mode 100644 index 000000000..180f52c03 --- /dev/null +++ b/mediapipe/util/frame_buffer/buffer_common.cc @@ -0,0 +1,40 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/frame_buffer/buffer_common.h" + +namespace mediapipe { +namespace frame_buffer { +namespace common { + +bool crop_buffer(int x0, int y0, int x1, int y1, halide_buffer_t* buffer) { + if (x0 < 0 || x1 >= buffer->dim[0].extent) { + return false; + } + if (y0 < 0 || y1 >= buffer->dim[1].extent) { + return false; + } + + // Move the start pointer so that it points at (x0, y0) and set the new + // extents. Leave the strides unchanged; we simply skip over the cropped + // image data. + buffer->host += y0 * buffer->dim[1].stride + x0 * buffer->dim[0].stride; + buffer->dim[0].extent = x1 - x0 + 1; + buffer->dim[1].extent = y1 - y0 + 1; + return true; +} + +} // namespace common +} // namespace frame_buffer +} // namespace mediapipe diff --git a/mediapipe/util/frame_buffer/buffer_common.h b/mediapipe/util/frame_buffer/buffer_common.h new file mode 100644 index 000000000..9e0e891d3 --- /dev/null +++ b/mediapipe/util/frame_buffer/buffer_common.h @@ -0,0 +1,32 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_FRAME_BUFFER_BUFFER_COMMON_H_ +#define MEDIAPIPE_UTIL_FRAME_BUFFER_BUFFER_COMMON_H_ + +#include "HalideRuntime.h" + +namespace mediapipe { +namespace frame_buffer { +namespace common { + +// Performs in-place cropping on the given buffer; the provided rectangle +// becomes the full extent of the buffer upon success. Returns false on error. +bool crop_buffer(int x0, int y0, int x1, int y1, halide_buffer_t* buffer); + +} // namespace common +} // namespace frame_buffer +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_FRAME_BUFFER_BUFFER_COMMON_H_ diff --git a/mediapipe/util/frame_buffer/frame_buffer_util.cc b/mediapipe/util/frame_buffer/frame_buffer_util.cc new file mode 100644 index 000000000..b18e8cb13 --- /dev/null +++ b/mediapipe/util/frame_buffer/frame_buffer_util.cc @@ -0,0 +1,794 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/frame_buffer/frame_buffer_util.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "mediapipe/framework/formats/frame_buffer.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/util/frame_buffer/gray_buffer.h" +#include "mediapipe/util/frame_buffer/rgb_buffer.h" +#include "mediapipe/util/frame_buffer/yuv_buffer.h" + +namespace mediapipe { +namespace frame_buffer { + +namespace { + +constexpr int kRgbaChannels = 4; +constexpr int kRgbaPixelBytes = 4; +constexpr int kRgbChannels = 3; +constexpr int kRgbPixelBytes = 3; +constexpr int kGrayChannel = 1; +constexpr int kGrayPixelBytes = 1; + +// YUV helpers. +//------------------------------------------------------------------------------ + +// Returns whether the buffer is part of the supported Yuv format. +bool IsSupportedYuvBuffer(const FrameBuffer& buffer) { + return buffer.format() == FrameBuffer::Format::kNV21 || + buffer.format() == FrameBuffer::Format::kNV12 || + buffer.format() == FrameBuffer::Format::kYV12 || + buffer.format() == FrameBuffer::Format::kYV21; +} + +// Shared validation functions. +//------------------------------------------------------------------------------ + +// Indicates whether the given buffers have the same dimensions. +bool AreBufferDimsEqual(const FrameBuffer& buffer1, + const FrameBuffer& buffer2) { + return buffer1.dimension() == buffer2.dimension(); +} + +// Indicates whether the given buffers formats are compatible. Same formats are +// compatible and all YUV family formats (e.g. NV21, NV12, YV12, YV21, etc) are +// compatible. +bool AreBufferFormatsCompatible(const FrameBuffer& buffer1, + const FrameBuffer& buffer2) { + switch (buffer1.format()) { + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kRGB: + return (buffer2.format() == FrameBuffer::Format::kRGBA || + buffer2.format() == FrameBuffer::Format::kRGB); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return (buffer2.format() == FrameBuffer::Format::kNV12 || + buffer2.format() == FrameBuffer::Format::kNV21 || + buffer2.format() == FrameBuffer::Format::kYV12 || + buffer2.format() == FrameBuffer::Format::kYV21); + case FrameBuffer::Format::kGRAY: + default: + return buffer1.format() == buffer2.format(); + } +} + +absl::Status ValidateBufferFormat(const FrameBuffer& buffer) { + switch (buffer.format()) { + case FrameBuffer::Format::kGRAY: + case FrameBuffer::Format::kRGB: + case FrameBuffer::Format::kRGBA: + if (buffer.plane_count() == 1) return absl::OkStatus(); + return absl::InvalidArgumentError( + "Plane count must be 1 for grayscale and RGB[a] buffers."); + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kYV21: + case FrameBuffer::Format::kYV12: + return absl::OkStatus(); + default: + return absl::InternalError( + absl::StrFormat("Unsupported buffer format: %i.", buffer.format())); + } +} + +absl::Status ValidateBufferFormats(const FrameBuffer& buffer1, + const FrameBuffer& buffer2) { + MP_RETURN_IF_ERROR(ValidateBufferFormat(buffer1)); + MP_RETURN_IF_ERROR(ValidateBufferFormat(buffer2)); + return absl::OkStatus(); +} + +absl::Status ValidateResizeBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer) { + bool valid_format = false; + switch (buffer.format()) { + case FrameBuffer::Format::kGRAY: + case FrameBuffer::Format::kRGB: + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + valid_format = (buffer.format() == output_buffer.format()); + break; + case FrameBuffer::Format::kRGBA: + valid_format = (output_buffer.format() == FrameBuffer::Format::kRGBA || + output_buffer.format() == FrameBuffer::Format::kRGB); + break; + default: + return absl::InternalError( + absl::StrFormat("Unsupported buffer format: %i.", buffer.format())); + } + if (!valid_format) { + return absl::InvalidArgumentError( + "Input and output buffer formats must match."); + } + return ValidateBufferFormats(buffer, output_buffer); +} + +absl::Status ValidateRotateBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer, + int angle_deg) { + if (!AreBufferFormatsCompatible(buffer, output_buffer)) { + return absl::InvalidArgumentError( + "Input and output buffer formats must match."); + } + + const bool is_dimension_change = (angle_deg / 90) % 2 == 1; + const bool are_dimensions_rotated = + (buffer.dimension().width == output_buffer.dimension().height) && + (buffer.dimension().height == output_buffer.dimension().width); + const bool are_dimensions_equal = + buffer.dimension() == output_buffer.dimension(); + + if (angle_deg >= 360 || angle_deg <= 0 || angle_deg % 90 != 0) { + return absl::InvalidArgumentError( + "Rotation angle must be between 0 and 360, in multiples of 90 " + "degrees."); + } else if ((is_dimension_change && !are_dimensions_rotated) || + (!is_dimension_change && !are_dimensions_equal)) { + return absl::InvalidArgumentError( + "Output buffer has invalid dimensions for rotation."); + } + return absl::OkStatus(); +} + +absl::Status ValidateCropBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer, int x0, + int y0, int x1, int y1) { + if (!AreBufferFormatsCompatible(buffer, output_buffer)) { + return absl::InvalidArgumentError( + "Input and output buffer formats must match."); + } + + bool is_buffer_size_valid = + ((x1 < buffer.dimension().width) && y1 < buffer.dimension().height); + bool are_points_valid = (x0 >= 0) && (y0 >= 0) && (x1 >= x0) && (y1 >= y0); + + if (!is_buffer_size_valid || !are_points_valid) { + return absl::InvalidArgumentError("Invalid crop coordinates."); + } + return absl::OkStatus(); +} + +absl::Status ValidateFlipBufferInputs(const FrameBuffer& buffer, + const FrameBuffer& output_buffer) { + if (!AreBufferFormatsCompatible(buffer, output_buffer)) { + return absl::InvalidArgumentError( + "Input and output buffer formats must match."); + } + return AreBufferDimsEqual(buffer, output_buffer) + ? absl::OkStatus() + : absl::InvalidArgumentError( + "Input and output buffers must have the same dimensions."); +} + +absl::Status ValidateConvertFormats(FrameBuffer::Format from_format, + FrameBuffer::Format to_format) { + if (from_format == to_format) { + return absl::InvalidArgumentError("Formats must be different."); + } + + switch (from_format) { + case FrameBuffer::Format::kGRAY: + return absl::InvalidArgumentError( + "Grayscale format does not convert to other formats."); + case FrameBuffer::Format::kRGB: + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return absl::OkStatus(); + default: + return absl::InternalError( + absl::StrFormat("Unsupported buffer format: %i.", from_format)); + } +} + +// Construct buffer helper functions. +//------------------------------------------------------------------------------ + +// Creates NV12 / NV21 / YV12 / YV21 YuvBuffer from the input `buffer`. The +// output YuvBuffer is agnostic to the YUV format since the YUV buffers are +// managed individually. +absl::StatusOr CreateYuvBuffer(const FrameBuffer& buffer) { + ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + return YuvBuffer(const_cast(yuv_data.y_buffer), + const_cast(yuv_data.u_buffer), + const_cast(yuv_data.v_buffer), + buffer.dimension().width, buffer.dimension().height, + yuv_data.y_row_stride, yuv_data.uv_row_stride, + yuv_data.uv_pixel_stride); +} + +absl::StatusOr CreateGrayBuffer(const FrameBuffer& buffer) { + if (buffer.plane_count() != 1) { + return absl::InternalError("Unsupported grayscale planar format."); + } + return GrayBuffer(const_cast(buffer.plane(0).buffer()), + buffer.dimension().width, buffer.dimension().height); +} + +absl::StatusOr CreateRgbBuffer(const FrameBuffer& buffer) { + if (buffer.plane_count() != 1) { + return absl::InternalError("Unsupported rgb[a] planar format."); + } + bool alpha = buffer.format() == FrameBuffer::Format::kRGBA ? true : false; + return RgbBuffer(const_cast(buffer.plane(0).buffer()), + buffer.dimension().width, buffer.dimension().height, + buffer.plane(0).stride().row_stride_bytes, alpha); +} + +// Grayscale transformation functions. +//------------------------------------------------------------------------------ + +absl::Status CropGrayscale(const FrameBuffer& buffer, int x0, int y0, int x1, + int y1, FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateGrayBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateGrayBuffer(*output_buffer)); + bool success_crop = input.Crop(x0, y0, x1, y1); + if (!success_crop) { + return absl::UnknownError("Halide grayscale crop operation failed."); + } + bool success_resize = input.Resize(&output); + if (!success_resize) { + return absl::UnknownError("Halide grayscale resize operation failed."); + } + return absl::OkStatus(); +} + +absl::Status ResizeGrayscale(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateGrayBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateGrayBuffer(*output_buffer)); + return input.Resize(&output) + ? absl::OkStatus() + : absl::UnknownError("Halide grayscale resize operation failed."); +} + +absl::Status RotateGrayscale(const FrameBuffer& buffer, int angle_deg, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateGrayBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateGrayBuffer(*output_buffer)); + return input.Rotate(angle_deg % 360, &output) + ? absl::OkStatus() + : absl::UnknownError("Halide grayscale rotate operation failed."); +} + +absl::Status FlipHorizontallyGrayscale(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateGrayBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateGrayBuffer(*output_buffer)); + return input.FlipHorizontally(&output) + ? absl::OkStatus() + : absl::UnknownError( + "Halide grayscale horizontal flip operation failed."); +} + +absl::Status FlipVerticallyGrayscale(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateGrayBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateGrayBuffer(*output_buffer)); + return input.FlipVertically(&output) + ? absl::OkStatus() + : absl::UnknownError( + "Halide grayscale vertical flip operation failed."); +} + +// Rgb transformation functions. +//------------------------------------------------------------------------------ + +absl::Status ResizeRgb(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateRgbBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateRgbBuffer(*output_buffer)); + return input.Resize(&output) + ? absl::OkStatus() + : absl::UnknownError("Halide rgb[a] resize operation failed."); +} + +absl::Status ConvertRgb(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateRgbBuffer(buffer)); + bool result = false; + if (output_buffer->format() == FrameBuffer::Format::kGRAY) { + ASSIGN_OR_RETURN(auto output, CreateGrayBuffer(*output_buffer)); + result = input.Convert(&output); + } else if (IsSupportedYuvBuffer(*output_buffer)) { + ASSIGN_OR_RETURN(auto output, CreateYuvBuffer(*output_buffer)); + result = input.Convert(&output); + } else if (output_buffer->format() == FrameBuffer::Format::kRGBA || + output_buffer->format() == FrameBuffer::Format::kRGB) { + ASSIGN_OR_RETURN(auto output, CreateRgbBuffer(*output_buffer)); + result = input.Convert(&output); + } + return result ? absl::OkStatus() + : absl::UnknownError("Halide rgb[a] convert operation failed."); +} + +absl::Status CropRgb(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateRgbBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateRgbBuffer(*output_buffer)); + bool success_crop = input.Crop(x0, y0, x1, y1); + if (!success_crop) { + return absl::UnknownError("Halide rgb[a] crop operation failed."); + } + bool success_resize = input.Resize(&output); + if (!success_resize) { + return absl::UnknownError("Halide rgb resize operation failed."); + } + return absl::OkStatus(); +} + +absl::Status FlipHorizontallyRgb(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateRgbBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateRgbBuffer(*output_buffer)); + return input.FlipHorizontally(&output) + ? absl::OkStatus() + : absl::UnknownError( + "Halide rgb[a] horizontal flip operation failed."); +} + +absl::Status FlipVerticallyRgb(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateRgbBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateRgbBuffer(*output_buffer)); + return input.FlipVertically(&output) + ? absl::OkStatus() + : absl::UnknownError( + "Halide rgb[a] vertical flip operation failed."); +} + +absl::Status RotateRgb(const FrameBuffer& buffer, int angle, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateRgbBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateRgbBuffer(*output_buffer)); + return input.Rotate(angle % 360, &output) + ? absl::OkStatus() + : absl::UnknownError("Halide rgb[a] rotate operation failed."); +} + +// Yuv transformation functions. +//------------------------------------------------------------------------------ + +absl::Status CropYuv(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateYuvBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateYuvBuffer(*output_buffer)); + bool success_crop = input.Crop(x0, y0, x1, y1); + if (!success_crop) { + return absl::UnknownError("Halide YUV crop operation failed."); + } + bool success_resize = input.Resize(&output); + if (!success_resize) { + return absl::UnknownError("Halide YUV resize operation failed."); + } + return absl::OkStatus(); +} + +absl::Status ResizeYuv(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateYuvBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateYuvBuffer(*output_buffer)); + return input.Resize(&output) + ? absl::OkStatus() + : absl::UnknownError("Halide YUV resize operation failed."); +} + +absl::Status RotateYuv(const FrameBuffer& buffer, int angle_deg, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateYuvBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateYuvBuffer(*output_buffer)); + return input.Rotate(angle_deg % 360, &output) + ? absl::OkStatus() + : absl::UnknownError("Halide YUV rotate operation failed."); +} + +absl::Status FlipHorizontallyYuv(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateYuvBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateYuvBuffer(*output_buffer)); + return input.FlipHorizontally(&output) + ? absl::OkStatus() + : absl::UnknownError( + "Halide YUV horizontal flip operation failed."); +} + +absl::Status FlipVerticallyYuv(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + ASSIGN_OR_RETURN(auto input, CreateYuvBuffer(buffer)); + ASSIGN_OR_RETURN(auto output, CreateYuvBuffer(*output_buffer)); + return input.FlipVertically(&output) + ? absl::OkStatus() + : absl::UnknownError("Halide YUV vertical flip operation failed."); +} + +// Converts input YUV `buffer` into the `output_buffer` in RGB, RGBA or gray +// scale format. +absl::Status ConvertYuv(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + bool success_convert = false; + ASSIGN_OR_RETURN(auto input, CreateYuvBuffer(buffer)); + if (output_buffer->format() == FrameBuffer::Format::kRGBA || + output_buffer->format() == FrameBuffer::Format::kRGB) { + ASSIGN_OR_RETURN(auto output, CreateRgbBuffer(*output_buffer)); + bool half_sampling = false; + if (buffer.dimension().width / 2 == output_buffer->dimension().width && + buffer.dimension().height / 2 == output_buffer->dimension().height) { + half_sampling = true; + } + success_convert = input.Convert(half_sampling, &output); + } else if (output_buffer->format() == FrameBuffer::Format::kGRAY) { + if (buffer.plane(0).stride().row_stride_bytes == buffer.dimension().width) { + std::copy(input.y_buffer()->host, + input.y_buffer()->host + buffer.dimension().Size(), + const_cast(output_buffer->plane(0).buffer())); + } else { + // The y_buffer is padded. The conversion removes the padding. + uint8_t* gray_buffer = + const_cast(output_buffer->plane(0).buffer()); + for (int i = 0; i < buffer.dimension().height; i++) { + int src_address = i * buffer.plane(0).stride().row_stride_bytes; + int dest_address = i * buffer.dimension().width; + std::memcpy(&gray_buffer[dest_address], + &buffer.plane(0).buffer()[src_address], + buffer.dimension().width); + } + } + success_convert = true; + } else if (IsSupportedYuvBuffer(*output_buffer)) { + ASSIGN_OR_RETURN(auto output, CreateYuvBuffer(*output_buffer)); + success_convert = input.Resize(&output); + } + return success_convert + ? absl::OkStatus() + : absl::UnknownError("Halide YUV convert operation failed."); +} + +} // namespace + +// Public methods. +//------------------------------------------------------------------------------ + +std::shared_ptr CreateFromRgbaRawBuffer( + uint8_t* input, FrameBuffer::Dimension dimension, + FrameBuffer::Stride stride) { + if (stride == kDefaultStride) { + stride.row_stride_bytes = dimension.width * kRgbaChannels; + stride.pixel_stride_bytes = kRgbaChannels; + } + FrameBuffer::Plane input_plane(/*buffer=*/input, + /*stride=*/stride); + std::vector planes{input_plane}; + return std::make_shared(planes, dimension, + FrameBuffer::Format::kRGBA); +} + +std::shared_ptr CreateFromRgbRawBuffer( + uint8_t* input, FrameBuffer::Dimension dimension, + FrameBuffer::Stride stride) { + if (stride == kDefaultStride) { + stride.row_stride_bytes = dimension.width * kRgbChannels; + stride.pixel_stride_bytes = kRgbChannels; + } + FrameBuffer::Plane input_plane(/*buffer=*/input, + /*stride=*/stride); + std::vector planes{input_plane}; + return std::make_shared(planes, dimension, + FrameBuffer::Format::kRGB); +} + +std::shared_ptr CreateFromGrayRawBuffer( + uint8_t* input, FrameBuffer::Dimension dimension, + FrameBuffer::Stride stride) { + if (stride == kDefaultStride) { + stride.row_stride_bytes = dimension.width * kGrayChannel; + stride.pixel_stride_bytes = kGrayChannel; + } + FrameBuffer::Plane input_plane(/*buffer=*/input, + /*stride=*/stride); + std::vector planes{input_plane}; + return std::make_shared(planes, dimension, + FrameBuffer::Format::kGRAY); +} + +absl::StatusOr> CreateFromYuvRawBuffer( + uint8_t* y_plane, uint8_t* u_plane, uint8_t* v_plane, + FrameBuffer::Format format, FrameBuffer::Dimension dimension, + int row_stride_y, int row_stride_uv, int pixel_stride_uv) { + const int pixel_stride_y = 1; + std::vector planes; + if (format == FrameBuffer::Format::kNV21 || + format == FrameBuffer::Format::kYV12) { + planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}}, + {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}, + {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}}; + } else if (format == FrameBuffer::Format::kNV12 || + format == FrameBuffer::Format::kYV21) { + planes = {{y_plane, /*stride=*/{row_stride_y, pixel_stride_y}}, + {u_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}, + {v_plane, /*stride=*/{row_stride_uv, pixel_stride_uv}}}; + } else { + return absl::InvalidArgumentError( + absl::StrFormat("Input format is not YUV-like: %i.", format)); + } + return std::make_shared(planes, dimension, format); +} + +absl::StatusOr> CreateFromRawBuffer( + uint8_t* buffer, FrameBuffer::Dimension dimension, + const FrameBuffer::Format target_format) { + switch (target_format) { + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: { + FrameBuffer::Plane plane(/*buffer=*/buffer, + /*stride=*/{dimension.width, kGrayChannel}); + std::vector planes{plane}; + return std::make_shared(planes, dimension, target_format); + } + case FrameBuffer::Format::kYV12: { + ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension, + GetUvPlaneDimension(dimension, target_format)); + return CreateFromYuvRawBuffer( + /*y_plane=*/buffer, + /*u_plane=*/buffer + dimension.Size() + uv_dimension.Size(), + /*v_plane=*/buffer + dimension.Size(), target_format, dimension, + /*row_stride_y=*/dimension.width, uv_dimension.width, + /*pixel_stride_uv=*/1); + } + case FrameBuffer::Format::kYV21: { + ASSIGN_OR_RETURN(const FrameBuffer::Dimension uv_dimension, + GetUvPlaneDimension(dimension, target_format)); + return CreateFromYuvRawBuffer( + /*y_plane=*/buffer, /*u_plane=*/buffer + dimension.Size(), + /*v_plane=*/buffer + dimension.Size() + uv_dimension.Size(), + target_format, dimension, /*row_stride_y=*/dimension.width, + uv_dimension.width, + /*pixel_stride_uv=*/1); + } + case FrameBuffer::Format::kRGBA: + return CreateFromRgbaRawBuffer(buffer, dimension); + case FrameBuffer::Format::kRGB: + return CreateFromRgbRawBuffer(buffer, dimension); + case FrameBuffer::Format::kGRAY: + return CreateFromGrayRawBuffer(buffer, dimension); + default: + return absl::InternalError( + absl::StrFormat("Unsupported buffer format: %i.", target_format)); + } +} + +absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, + FrameBuffer* output_buffer) { + MP_RETURN_IF_ERROR( + ValidateCropBufferInputs(buffer, *output_buffer, x0, y0, x1, y1)); + MP_RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer)); + + switch (buffer.format()) { + case FrameBuffer::Format::kGRAY: + return CropGrayscale(buffer, x0, y0, x1, y1, output_buffer); + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kRGB: + return CropRgb(buffer, x0, y0, x1, y1, output_buffer); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return CropYuv(buffer, x0, y0, x1, y1, output_buffer); + default: + return absl::InternalError( + absl::StrFormat("Format %i is not supported.", buffer.format())); + } +} + +absl::Status Resize(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + MP_RETURN_IF_ERROR(ValidateResizeBufferInputs(buffer, *output_buffer)); + + switch (buffer.format()) { + case FrameBuffer::Format::kGRAY: + return ResizeGrayscale(buffer, output_buffer); + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kRGB: + return ResizeRgb(buffer, output_buffer); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return ResizeYuv(buffer, output_buffer); + default: + return absl::InternalError( + absl::StrFormat("Format %i is not supported.", buffer.format())); + } +} + +absl::Status Rotate(const FrameBuffer& buffer, int angle_deg, + FrameBuffer* output_buffer) { + MP_RETURN_IF_ERROR( + ValidateRotateBufferInputs(buffer, *output_buffer, angle_deg)); + MP_RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer)); + + switch (buffer.format()) { + case FrameBuffer::Format::kGRAY: + return RotateGrayscale(buffer, angle_deg, output_buffer); + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kRGB: + return RotateRgb(buffer, angle_deg, output_buffer); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return RotateYuv(buffer, angle_deg, output_buffer); + default: + return absl::InternalError( + absl::StrFormat("Format %i is not supported.", buffer.format())); + } +} + +absl::Status FlipHorizontally(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + MP_RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer)); + MP_RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer)); + + switch (buffer.format()) { + case FrameBuffer::Format::kGRAY: + return FlipHorizontallyGrayscale(buffer, output_buffer); + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kRGB: + return FlipHorizontallyRgb(buffer, output_buffer); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return FlipHorizontallyYuv(buffer, output_buffer); + default: + return absl::InternalError( + absl::StrFormat("Format %i is not supported.", buffer.format())); + } +} + +absl::Status FlipVertically(const FrameBuffer& buffer, + FrameBuffer* output_buffer) { + MP_RETURN_IF_ERROR(ValidateFlipBufferInputs(buffer, *output_buffer)); + MP_RETURN_IF_ERROR(ValidateBufferFormats(buffer, *output_buffer)); + + switch (buffer.format()) { + case FrameBuffer::Format::kGRAY: + return FlipVerticallyGrayscale(buffer, output_buffer); + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kRGB: + return FlipVerticallyRgb(buffer, output_buffer); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return FlipVerticallyYuv(buffer, output_buffer); + default: + return absl::InternalError( + absl::StrFormat("Format %i is not supported.", buffer.format())); + } +} + +absl::Status Convert(const FrameBuffer& buffer, FrameBuffer* output_buffer) { + MP_RETURN_IF_ERROR( + ValidateConvertFormats(buffer.format(), output_buffer->format())); + + switch (buffer.format()) { + case FrameBuffer::Format::kRGBA: + case FrameBuffer::Format::kRGB: + return ConvertRgb(buffer, output_buffer); + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return ConvertYuv(buffer, output_buffer); + default: + return absl::InternalError( + absl::StrFormat("Format %i is not supported.", buffer.format())); + } +} + +int GetFrameBufferByteSize(FrameBuffer::Dimension dimension, + FrameBuffer::Format format) { + switch (format) { + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return /*y plane*/ dimension.Size() + + /*uv plane*/ (dimension.width + 1) / 2 * (dimension.height + 1) / + 2 * 2; + case FrameBuffer::Format::kRGB: + return dimension.Size() * kRgbPixelBytes; + case FrameBuffer::Format::kRGBA: + return dimension.Size() * kRgbaPixelBytes; + case FrameBuffer::Format::kGRAY: + return dimension.Size(); + default: + return 0; + } +} + +absl::StatusOr GetPixelStrides(FrameBuffer::Format format) { + switch (format) { + case FrameBuffer::Format::kGRAY: + return kGrayPixelBytes; + case FrameBuffer::Format::kRGB: + return kRgbPixelBytes; + case FrameBuffer::Format::kRGBA: + return kRgbaPixelBytes; + default: + return absl::InvalidArgumentError(absl::StrFormat( + "GetPixelStrides does not support format: %i.", format)); + } +} + +absl::StatusOr GetUvRawBuffer(const FrameBuffer& buffer) { + if (buffer.format() != FrameBuffer::Format::kNV12 && + buffer.format() != FrameBuffer::Format::kNV21) { + return absl::InvalidArgumentError( + "Only support getting biplanar UV buffer from NV12/NV21 frame buffer."); + } + ASSIGN_OR_RETURN(FrameBuffer::YuvData yuv_data, + FrameBuffer::GetYuvDataFromFrameBuffer(buffer)); + const uint8_t* uv_buffer = buffer.format() == FrameBuffer::Format::kNV12 + ? yuv_data.u_buffer + : yuv_data.v_buffer; + return uv_buffer; +} + +absl::StatusOr GetUvPlaneDimension( + FrameBuffer::Dimension dimension, FrameBuffer::Format format) { + if (dimension.width <= 0 || dimension.height <= 0) { + return absl::InvalidArgumentError( + absl::StrFormat("Invalid input dimension: {%d, %d}.", dimension.width, + dimension.height)); + } + switch (format) { + case FrameBuffer::Format::kNV12: + case FrameBuffer::Format::kNV21: + case FrameBuffer::Format::kYV12: + case FrameBuffer::Format::kYV21: + return FrameBuffer::Dimension{(dimension.width + 1) / 2, + (dimension.height + 1) / 2}; + default: + return absl::InvalidArgumentError( + absl::StrFormat("Input format is not YUV-like: %i.", format)); + } +} + +FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1) { + return {x1 - x0 + 1, y1 - y0 + 1}; +} + +} // namespace frame_buffer +} // namespace mediapipe diff --git a/mediapipe/util/frame_buffer/frame_buffer_util.h b/mediapipe/util/frame_buffer/frame_buffer_util.h new file mode 100644 index 000000000..eb17e0094 --- /dev/null +++ b/mediapipe/util/frame_buffer/frame_buffer_util.h @@ -0,0 +1,131 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_FRAME_BUFFER_FRAME_BUFFER_UTIL_H_ +#define MEDIAPIPE_UTIL_FRAME_BUFFER_FRAME_BUFFER_UTIL_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mediapipe/framework/formats/frame_buffer.h" + +namespace mediapipe { +namespace frame_buffer { + +// Creation helpers. +//------------------------------------------------------------------------------ + +// Default stride value for creating frame buffer from raw buffer. When using +// this default value, the default row stride and pixel stride values will be +// applied. e.g. for an RGB image: +// row_stride = width * 3 +// pixel_stride = 3. +inline constexpr FrameBuffer::Stride kDefaultStride = {0, 0}; + +// Creates a FrameBuffer from raw RGBA buffer and passing arguments. +std::shared_ptr CreateFromRgbaRawBuffer( + uint8_t* input, FrameBuffer::Dimension dimension, + FrameBuffer::Stride stride = kDefaultStride); + +// Creates a FrameBuffer from raw RGB buffer and passing arguments. +std::shared_ptr CreateFromRgbRawBuffer( + uint8_t* input, FrameBuffer::Dimension dimension, + FrameBuffer::Stride stride = kDefaultStride); + +// Creates a FrameBuffer from raw grayscale buffer and passing arguments. +std::shared_ptr CreateFromGrayRawBuffer( + uint8_t* input, FrameBuffer::Dimension dimension, + FrameBuffer::Stride stride = kDefaultStride); + +// Creates a FrameBuffer from raw YUV buffer and passing arguments. +absl::StatusOr> CreateFromYuvRawBuffer( + uint8_t* y_plane, uint8_t* u_plane, uint8_t* v_plane, + FrameBuffer::Format format, FrameBuffer::Dimension dimension, + int row_stride_y, int row_stride_uv, int pixel_stride_uv); + +// Creates an instance of FrameBuffer from raw buffer and passing arguments. +absl::StatusOr> CreateFromRawBuffer( + uint8_t* buffer, FrameBuffer::Dimension dimension, + FrameBuffer::Format target_format); + +// Transformations. +//------------------------------------------------------------------------------ + +// Crops `buffer` to the specified points. +// +// (x0, y0) represents the top-left point of the buffer. +// (x1, y1) represents the bottom-right point of the buffer. +// +// The implementation performs origin moving and resizing operations. +absl::Status Crop(const FrameBuffer& buffer, int x0, int y0, int x1, int y1, + FrameBuffer* output_buffer); + +// Resizes `buffer` to the size of the given `output_buffer` using bilinear +// interpolation. +absl::Status Resize(const FrameBuffer& buffer, FrameBuffer* output_buffer); + +// Rotates `buffer` counter-clockwise by the given `angle_deg` (in degrees). +// +// The given angle must be a multiple of 90 degrees. +absl::Status Rotate(const FrameBuffer& buffer, int angle_deg, + FrameBuffer* output_buffer); + +// Flips `buffer` horizontally. +absl::Status FlipHorizontally(const FrameBuffer& buffer, + FrameBuffer* output_buffer); + +// Flips `buffer` vertically. +absl::Status FlipVertically(const FrameBuffer& buffer, + FrameBuffer* output_buffer); + +// Converts `buffer`'s format to the format of the given `output_buffer`. +// +// Note that grayscale format does not convert to other formats. +// Note the NV21 to RGB/RGBA conversion may downsample by factor of 2 based +// on the buffer and output_buffer dimensions. +absl::Status Convert(const FrameBuffer& buffer, FrameBuffer* output_buffer); + +// Miscellaneous Methods +// ----------------------------------------------------------------- + +// Returns the frame buffer size in bytes based on the input format and +// dimensions. GRAY, YV12/YV21 are in the planar formats, NV12/NV21 are in the +// semi-planar formats with the interleaved UV planes. RGB/RGBA are in the +// interleaved format. +int GetFrameBufferByteSize(FrameBuffer::Dimension dimension, + FrameBuffer::Format format); + +// Returns pixel stride info for kGRAY, kRGB, kRGBA formats. +absl::StatusOr GetPixelStrides(FrameBuffer::Format format); + +// Returns the biplanar UV raw buffer for NV12/NV21 frame buffer. +absl::StatusOr GetUvRawBuffer(const FrameBuffer& buffer); + +// Returns U or V plane dimension with the given buffer `dimension` and +// `format`. Only supports NV12/NV21/YV12/YV21 formats. Returns +// InvalidArgumentError if 'dimension' is invalid or 'format' is other than the +// supported formats. This method assums the UV plane share the same dimension, +// especially for the YV12 / YV21 formats. +absl::StatusOr GetUvPlaneDimension( + FrameBuffer::Dimension dimension, FrameBuffer::Format format); + +// Returns crop dimension based on crop start and end points. +FrameBuffer::Dimension GetCropDimension(int x0, int x1, int y0, int y1); + +} // namespace frame_buffer +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_FRAME_BUFFER_FRAME_BUFFER_UTIL_H_ diff --git a/mediapipe/util/frame_buffer/frame_buffer_util_test.cc b/mediapipe/util/frame_buffer/frame_buffer_util_test.cc new file mode 100644 index 000000000..d92eb3f53 --- /dev/null +++ b/mediapipe/util/frame_buffer/frame_buffer_util_test.cc @@ -0,0 +1,1176 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/frame_buffer/frame_buffer_util.h" + +#include +#include +#include + +#include "mediapipe/framework/formats/frame_buffer.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_macros.h" +#include "mediapipe/framework/port/status_matchers.h" + +namespace mediapipe { +namespace frame_buffer { +namespace { + +// Grayscale unit tests. +//------------------------------------------------------------------------------ + +TEST(FrameBufferUtil, GrayCrop) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}, + kOutputDimension = {.width = 1, .height = 1}; + uint8_t data[6] = {1, 2, 3, 4, 5, 6}; + uint8_t output_data[2]; + auto input = CreateFromGrayRawBuffer(data, kBufferDimension); + auto output = CreateFromGrayRawBuffer(output_data, kOutputDimension); + + MP_ASSERT_OK(Crop(*input, 0, 1, 0, 1, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 4); +} + +TEST(FrameBufferUtil, GrayResize) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 2, .height = 2}, + kOutputDimension = {.width = 3, .height = 2}; + uint8_t data[4] = {1, 2, 3, 4}; + uint8_t output_data[6]; + auto input = CreateFromGrayRawBuffer(data, kBufferDimension); + auto output = CreateFromGrayRawBuffer(output_data, kOutputDimension); + + MP_ASSERT_OK(Resize(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 1); + EXPECT_EQ(output->plane(0).buffer()[1], 2); + EXPECT_EQ(output->plane(0).buffer()[2], 2); + EXPECT_EQ(output->plane(0).buffer()[3], 3); + EXPECT_EQ(output->plane(0).buffer()[4], 4); + EXPECT_EQ(output->plane(0).buffer()[5], 4); +} + +TEST(FrameBufferUtil, GrayRotate) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}, + kOutputDimension = {.width = 2, .height = 3}; + uint8_t data[6] = {1, 2, 3, 4, 5, 6}; + uint8_t output_data[6]; + auto input = CreateFromGrayRawBuffer(data, kBufferDimension); + auto output = CreateFromGrayRawBuffer(output_data, kOutputDimension); + + MP_ASSERT_OK(Rotate(*input, 90, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 3); + EXPECT_EQ(output->plane(0).buffer()[1], 6); + EXPECT_EQ(output->plane(0).buffer()[2], 2); + EXPECT_EQ(output->plane(0).buffer()[3], 5); + EXPECT_EQ(output->plane(0).buffer()[4], 1); + EXPECT_EQ(output->plane(0).buffer()[5], 4); +} + +TEST(FrameBufferUtil, GrayFlipHorizontally) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}; + uint8_t data[6] = {1, 2, 3, 4, 5, 6}; + uint8_t output_data[6]; + auto input = CreateFromGrayRawBuffer(data, kBufferDimension); + auto output = CreateFromGrayRawBuffer(output_data, kBufferDimension); + + MP_ASSERT_OK(FlipHorizontally(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 3); + EXPECT_EQ(output->plane(0).buffer()[1], 2); + EXPECT_EQ(output->plane(0).buffer()[2], 1); + EXPECT_EQ(output->plane(0).buffer()[3], 6); + EXPECT_EQ(output->plane(0).buffer()[4], 5); + EXPECT_EQ(output->plane(0).buffer()[5], 4); +} + +TEST(FrameBufferUtil, GrayFlipVertically) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}; + uint8_t data[6] = {1, 2, 3, 4, 5, 6}; + uint8_t output_data[6]; + auto input = CreateFromGrayRawBuffer(data, kBufferDimension); + auto output = CreateFromGrayRawBuffer(output_data, kBufferDimension); + + MP_ASSERT_OK(FlipVertically(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 4); + EXPECT_EQ(output->plane(0).buffer()[1], 5); + EXPECT_EQ(output->plane(0).buffer()[2], 6); + EXPECT_EQ(output->plane(0).buffer()[3], 1); + EXPECT_EQ(output->plane(0).buffer()[4], 2); + EXPECT_EQ(output->plane(0).buffer()[5], 3); +} + +// Grayscale EndToEnd tests. +//------------------------------------------------------------------------------ + +struct GrayInputTestParam { + FrameBuffer::Dimension input_dimension; + FrameBuffer::Format input_format; + FrameBuffer::Dimension output_dimension; + FrameBuffer::Format output_format; + int rotation_angle; + int x0; + int y0; + int x1; + int y1; +}; + +enum Operation { + kRotate = 1, + kCrop = 2, + kResize = 3, + kHorizontalFlip = 4, + kVerticalFlip = 5, + kConvert = 6 +}; + +class GrayInputTest : public ::testing::TestWithParam< + std::tuple> {}; + +TEST_P(GrayInputTest, ValidateInputs) { + GrayInputTestParam inputs; + bool is_valid; + Operation operation; + std::tie(operation, inputs, is_valid) = GetParam(); + MP_ASSERT_OK_AND_ASSIGN( + auto input, + CreateFromRawBuffer(/*buffer=*/nullptr, inputs.input_dimension, + inputs.input_format)); + MP_ASSERT_OK_AND_ASSIGN(auto output, + CreateFromRawBuffer(nullptr, inputs.output_dimension, + inputs.output_format)); + + absl::Status status; + switch (operation) { + case kRotate: + status = Rotate(*input, inputs.rotation_angle, output.get()); + break; + case kResize: + status = Resize(*input, output.get()); + break; + case kCrop: { + status = Crop(*input, inputs.x0, inputs.y0, inputs.x1, inputs.y1, + output.get()); + break; + } + case kHorizontalFlip: + status = FlipHorizontally(*input, output.get()); + break; + case kVerticalFlip: + status = FlipVertically(*input, output.get()); + break; + case kConvert: + status = Convert(*input, output.get()); + break; + } + + if (is_valid) { + MP_EXPECT_OK(status); + } else { + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument)); + } +} + +std::tuple CreateGrayRotateInputTestParam( + int in_width, int in_height, FrameBuffer::Format in_format, int out_width, + int out_height, FrameBuffer::Format out_format, int angle, bool is_valid) { + GrayInputTestParam param = { + .input_dimension = FrameBuffer::Dimension{in_width, in_height}, + .input_format = in_format, + .output_dimension = FrameBuffer::Dimension{out_width, out_height}, + .output_format = out_format, + .rotation_angle = angle}; + return std::make_tuple(kRotate, param, is_valid); +} + +INSTANTIATE_TEST_SUITE_P( + ValidateRotateInputs, GrayInputTest, + testing::Values( + CreateGrayRotateInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 2, 3, + FrameBuffer::Format::kGRAY, 30, false), + CreateGrayRotateInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kRGB, 180, false), + CreateGrayRotateInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kGRAY, 90, false), + CreateGrayRotateInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kGRAY, 0, false), + CreateGrayRotateInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 2, 3, + FrameBuffer::Format::kGRAY, -90, false), + CreateGrayRotateInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 2, 3, + FrameBuffer::Format::kGRAY, 90, true), + CreateGrayRotateInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kGRAY, 180, true), + CreateGrayRotateInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 2, 3, + FrameBuffer::Format::kGRAY, 270, true), + CreateGrayRotateInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 2, 3, + FrameBuffer::Format::kGRAY, 450, + false))); + +std::tuple CreateGrayCropInputTestParam( + int in_width, int in_height, FrameBuffer::Format in_format, int out_width, + int out_height, FrameBuffer::Format out_format, int x0, int y0, int x1, + int y1, bool is_valid) { + GrayInputTestParam param = { + .input_dimension = FrameBuffer::Dimension{in_width, in_height}, + .input_format = in_format, + .output_dimension = FrameBuffer::Dimension{out_width, out_height}, + .output_format = out_format, + .x0 = x0, + .y0 = y0, + .x1 = x1, + .y1 = y1}; + return std::make_tuple(kCrop, param, is_valid); +} + +INSTANTIATE_TEST_SUITE_P( + ValidateCropInputs, GrayInputTest, + ::testing::Values( + CreateGrayCropInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kRGB, 0, 0, 3, 2, + false), + CreateGrayCropInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kGRAY, 1, 1, 1, 4, + false), + CreateGrayCropInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 2, 1, + FrameBuffer::Format::kGRAY, -1, 0, 1, 1, + false), + CreateGrayCropInputTestParam(5, 5, FrameBuffer::Format::kGRAY, 3, 3, + FrameBuffer::Format::kGRAY, 0, 0, 2, 2, + true), + CreateGrayCropInputTestParam(5, 5, FrameBuffer::Format::kGRAY, 2, 2, + FrameBuffer::Format::kGRAY, 1, 2, 2, 3, + true), + CreateGrayCropInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 1, 1, + FrameBuffer::Format::kGRAY, 0, 0, 0, 0, + true))); + +std::tuple CreateGrayResizeInputTestParam( + int in_width, int in_height, FrameBuffer::Format in_format, int out_width, + int out_height, FrameBuffer::Format out_format, bool is_valid) { + GrayInputTestParam param = { + .input_dimension = FrameBuffer::Dimension{in_width, in_height}, + .input_format = in_format, + .output_dimension = FrameBuffer::Dimension{out_width, out_height}, + .output_format = out_format}; + return std::make_tuple(kResize, param, is_valid); +} + +INSTANTIATE_TEST_SUITE_P( + ValidateResizeInputs, GrayInputTest, + ::testing::Values( + CreateGrayResizeInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 1, 1, + FrameBuffer::Format::kRGB, false), + CreateGrayResizeInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 5, 5, + FrameBuffer::Format::kRGB, false), + CreateGrayResizeInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 2, 1, + FrameBuffer::Format::kGRAY, true), + CreateGrayResizeInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 7, 9, + FrameBuffer::Format::kGRAY, true))); + +std::tuple CreateGrayFlipInputTestParam( + int in_width, int in_height, FrameBuffer::Format in_format, int out_width, + int out_height, FrameBuffer::Format out_format, bool horizontal_flip, + bool is_valid) { + GrayInputTestParam param = { + .input_dimension = FrameBuffer::Dimension{in_width, in_height}, + .input_format = in_format, + .output_dimension = FrameBuffer::Dimension{out_width, out_height}, + .output_format = out_format}; + return std::make_tuple(horizontal_flip ? kHorizontalFlip : kVerticalFlip, + param, is_valid); +} + +INSTANTIATE_TEST_SUITE_P( + ValidateFlipInputs, GrayInputTest, + ::testing::Values( + CreateGrayFlipInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kRGB, true, false), + CreateGrayFlipInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 3, + FrameBuffer::Format::kGRAY, true, false), + CreateGrayFlipInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kGRAY, true, true), + CreateGrayFlipInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kRGB, false, false), + CreateGrayFlipInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 3, + FrameBuffer::Format::kGRAY, false, false), + CreateGrayFlipInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kGRAY, false, true))); + +std::tuple CreateGrayConvertInputTestParam( + int in_width, int in_height, FrameBuffer::Format in_format, int out_width, + int out_height, FrameBuffer::Format out_format, bool is_valid) { + GrayInputTestParam param = { + .input_dimension = FrameBuffer::Dimension{in_width, in_height}, + .input_format = in_format, + .output_dimension = FrameBuffer::Dimension{out_width, out_height}, + .output_format = out_format}; + return std::make_tuple(kConvert, param, is_valid); +} + +INSTANTIATE_TEST_SUITE_P( + ValidateConvertInputs, GrayInputTest, + ::testing::Values( + CreateGrayConvertInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kRGB, false), + CreateGrayConvertInputTestParam(3, 2, FrameBuffer::Format::kGRAY, 3, 2, + FrameBuffer::Format::kGRAY, false))); + +// Rgb unit tests. +//------------------------------------------------------------------------------ + +struct FrameBufferPlanarFormat { + FrameBufferPlanarFormat() + : format(FrameBuffer::Format::kGRAY), plane_count(0) {} + FrameBufferPlanarFormat(FrameBuffer::Format format, int plane_count) + : format(format), plane_count(plane_count) {} + + FrameBuffer::Format format; + int plane_count; +}; + +class RgbaConvertTest + : public testing::TestWithParam< + std::tuple> { + public: + void SetUp() override { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 2, + .height = 1}; + constexpr int kBufferSize = 20; + std::tie(input_format_, output_planar_format_) = GetParam(); + + // Setup input frame buffer + input_data_ = std::make_unique(kBufferSize); + FrameBuffer::Stride input_stride; + if (input_format_ == FrameBuffer::Format::kRGBA) { + uint8_t data[] = {200, 100, 0, 1, 0, 200, 100, 50}; + std::copy(data, data + 8, input_data_.get()); + input_stride = {/*row_stride_bytes=*/8, /*pixel_stride_bytes=*/4}; + } else { + uint8_t data[] = {200, 100, 0, 0, 200, 100}; + std::copy(data, data + 6, input_data_.get()); + input_stride = {/*row_stride_bytes=*/6, /*pixel_stride_bytes=*/3}; + } + FrameBuffer::Plane input_plane(/*buffer=*/input_data_.get(), input_stride); + std::vector input_planes = {input_plane}; + input_frame_buffer_ = std::make_shared( + input_planes, kBufferDimension, input_format_); + + // Setup output frame buffer + if (output_planar_format_.format == FrameBuffer::Format::kRGBA) { + output_data_1_ = std::make_unique(kBufferSize); + FrameBuffer::Plane output_plane_1( + /*buffer=*/output_data_1_.get(), + /*stride=*/{/*row_stride_bytes=*/8, /*pixel_stride_bytes=*/4}); + std::vector output_planes = {output_plane_1}; + output_frame_buffer_ = std::make_shared( + output_planes, kBufferDimension, output_planar_format_.format); + } else if (output_planar_format_.format == FrameBuffer::Format::kRGB) { + output_data_1_ = std::make_unique(kBufferSize); + FrameBuffer::Plane output_plane_1( + /*buffer=*/output_data_1_.get(), + /*stride=*/{/*row_stride_bytes=*/6, /*pixel_stride_bytes=*/3}); + std::vector output_planes = {output_plane_1}; + output_frame_buffer_ = std::make_shared( + output_planes, kBufferDimension, output_planar_format_.format); + } else if (output_planar_format_.plane_count == 1) { + output_data_1_ = std::make_unique(kBufferSize); + FrameBuffer::Plane output_plane_1( + /*buffer=*/output_data_1_.get(), + /*stride=*/{/*row_stride_bytes=*/2, + /*pixel_stride_bytes=*/1}); + std::vector output_planes = {output_plane_1}; + output_frame_buffer_ = std::make_shared( + output_planes, kBufferDimension, output_planar_format_.format); + } else if (output_planar_format_.plane_count == 2) { + output_data_1_ = std::make_unique(kBufferSize); + FrameBuffer::Plane output_plane_1( + /*buffer=*/output_data_1_.get(), + /*stride=*/{/*row_stride_bytes=*/2, /*pixel_stride_bytes=*/1}); + output_data_2_ = std::make_unique(kBufferSize); + FrameBuffer::Plane output_plane_2( + /*buffer=*/output_data_2_.get(), + /*stride=*/{/*row_stride_bytes=*/1, /*pixel_stride_bytes=*/2}); + std::vector planes = {output_plane_1, output_plane_2}; + output_frame_buffer_ = std::make_shared( + planes, kBufferDimension, output_planar_format_.format); + } else { + output_data_1_ = std::make_unique(kBufferSize); + output_data_2_ = std::make_unique(kBufferSize); + output_data_3_ = std::make_unique(kBufferSize); + FrameBuffer::Plane output_plane_1( + /*buffer=*/output_data_1_.get(), + /*stride=*/{/*row_stride_bytes=*/2, /*pixel_stride_bytes=*/1}); + FrameBuffer::Plane output_plane_2( + /*buffer=*/output_data_2_.get(), + /*stride=*/{/*row_stride_bytes=*/1, /*pixel_stride_bytes=*/1}); + FrameBuffer::Plane output_plane_3( + /*buffer=*/output_data_3_.get(), + /*stride=*/{/*row_stride_bytes=*/1, /*pixel_stride_bytes=*/1}); + std::vector planes = {output_plane_1, output_plane_2, + output_plane_3}; + output_frame_buffer_ = std::make_shared( + planes, kBufferDimension, output_planar_format_.format); + } + } + + protected: + FrameBuffer::Format input_format_; + FrameBufferPlanarFormat output_planar_format_; + + std::unique_ptr output_data_1_; + std::unique_ptr output_data_2_; + std::unique_ptr output_data_3_; + std::unique_ptr input_data_; + + std::shared_ptr input_frame_buffer_; + std::shared_ptr output_frame_buffer_; +}; + +TEST_P(RgbaConvertTest, RgbaToOtherFormatConversion) { + absl::Status status = + Convert(*input_frame_buffer_, output_frame_buffer_.get()); + if (output_planar_format_.format == FrameBuffer::Format::kGRAY) { + MP_ASSERT_OK(status); + EXPECT_EQ(output_data_1_[0], 118); + EXPECT_EQ(output_data_1_[1], 129); + } else if (output_frame_buffer_->format() == FrameBuffer::Format::kNV12 || + output_frame_buffer_->format() == FrameBuffer::Format::kNV21 || + output_frame_buffer_->format() == FrameBuffer::Format::kYV12 || + output_frame_buffer_->format() == FrameBuffer::Format::kYV21) { + MP_ASSERT_OK(status); + MP_ASSERT_OK_AND_ASSIGN( + FrameBuffer::YuvData yuv_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*output_frame_buffer_)); + EXPECT_EQ(yuv_data.y_buffer[0], 118); + EXPECT_EQ(yuv_data.y_buffer[1], 129); + EXPECT_EQ(yuv_data.u_buffer[0], 61); + EXPECT_EQ(yuv_data.v_buffer[0], 186); + } else if (input_format_ == FrameBuffer::Format::kRGBA && + output_frame_buffer_->format() == FrameBuffer::Format::kRGB) { + EXPECT_EQ(output_data_1_[0], 200); + EXPECT_EQ(output_data_1_[1], 100); + EXPECT_EQ(output_data_1_[2], 0); + EXPECT_EQ(output_data_1_[3], 0); + MP_ASSERT_OK(status); + } else if (input_format_ == FrameBuffer::Format::kRGB && + output_frame_buffer_->format() == FrameBuffer::Format::kRGBA) { + MP_ASSERT_OK(status); + EXPECT_EQ(output_data_1_[0], 200); + EXPECT_EQ(output_data_1_[1], 100); + EXPECT_EQ(output_data_1_[2], 0); + EXPECT_EQ(output_data_1_[3], 255); + } else { + ASSERT_FALSE(status.ok()); + } +} + +INSTANTIATE_TEST_SUITE_P( + RgbaToOtherFormatConversion, RgbaConvertTest, + testing::Combine( + testing::Values(FrameBuffer::Format::kRGBA, FrameBuffer::Format::kRGB), + testing::Values(FrameBufferPlanarFormat(FrameBuffer::Format::kGRAY, + /*plane_count=*/1), + FrameBufferPlanarFormat(FrameBuffer::Format::kRGBA, + /*plane_count=*/1), + FrameBufferPlanarFormat(FrameBuffer::Format::kRGB, + /*plane_count=*/1), + FrameBufferPlanarFormat(FrameBuffer::Format::kNV21, + /*plane_count=*/1), + FrameBufferPlanarFormat(FrameBuffer::Format::kNV21, + /*plane_count=*/2), + FrameBufferPlanarFormat(FrameBuffer::Format::kNV21, + /*plane_count=*/3), + FrameBufferPlanarFormat(FrameBuffer::Format::kNV12, + /*plane_count=*/1), + FrameBufferPlanarFormat(FrameBuffer::Format::kNV12, + /*plane_count=*/2), + FrameBufferPlanarFormat(FrameBuffer::Format::kNV12, + /*plane_count=*/3), + FrameBufferPlanarFormat(FrameBuffer::Format::kYV21, + /*plane_count=*/1), + FrameBufferPlanarFormat(FrameBuffer::Format::kYV21, + /*plane_count=*/3), + FrameBufferPlanarFormat(FrameBuffer::Format::kYV12, + /*plane_count=*/1), + FrameBufferPlanarFormat(FrameBuffer::Format::kYV12, + /*plane_count=*/3)))); + +TEST(FrameBufferUtil, RgbaToRgbConversion) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 2, .height = 1}; + uint8_t data[] = {200, 100, 0, 1, 0, 200, 100, 50}; + auto input = CreateFromRgbaRawBuffer(data, kBufferDimension); + uint8_t output_data[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}; + auto output = CreateFromRgbRawBuffer(output_data, kBufferDimension); + + MP_ASSERT_OK(Convert(*input, output.get())); + EXPECT_EQ(output_data[0], 200); + EXPECT_EQ(output_data[1], 100); + EXPECT_EQ(output_data[2], 0); + EXPECT_EQ(output_data[3], 0); + EXPECT_EQ(output_data[4], 200); + EXPECT_EQ(output_data[5], 100); +} + +TEST(FrameBufferUtil, RgbaCrop) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}, + kOutputDimension = {.width = 1, .height = 1}; + uint8_t kRgbaTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + uint8_t output_data[4]; + auto input = CreateFromRgbaRawBuffer(kRgbaTestData, kBufferDimension); + auto output = CreateFromRgbaRawBuffer(output_data, kOutputDimension); + + MP_ASSERT_OK(Crop(*input, 0, 1, 0, 1, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 13); + EXPECT_EQ(output->plane(0).buffer()[1], 14); + EXPECT_EQ(output->plane(0).buffer()[2], 15); + EXPECT_EQ(output->plane(0).buffer()[3], 16); +} + +TEST(FrameBufferUtil, RgbCrop) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}, + kOutputDimension = {.width = 1, .height = 1}; + uint8_t kRgbTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + uint8_t output_data[3]; + auto input = CreateFromRgbRawBuffer(kRgbTestData, kBufferDimension); + auto output = CreateFromRgbRawBuffer(output_data, kOutputDimension); + + MP_ASSERT_OK(Crop(*input, 0, 1, 0, 1, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 10); + EXPECT_EQ(output->plane(0).buffer()[1], 11); + EXPECT_EQ(output->plane(0).buffer()[2], 12); +} + +TEST(FrameBufferUtil, RgbaFlipHorizontally) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 1}; + uint8_t kRgbaTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + uint8_t output_data[sizeof(kRgbaTestData) / 2]; + auto input = CreateFromRgbaRawBuffer(kRgbaTestData, kBufferDimension); + auto output = CreateFromRgbaRawBuffer(output_data, kBufferDimension); + + MP_ASSERT_OK(FlipHorizontally(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 9); + EXPECT_EQ(output->plane(0).buffer()[1], 10); + EXPECT_EQ(output->plane(0).buffer()[2], 11); + EXPECT_EQ(output->plane(0).buffer()[3], 12); + EXPECT_EQ(output->plane(0).buffer()[4], 5); + EXPECT_EQ(output->plane(0).buffer()[5], 6); + EXPECT_EQ(output->plane(0).buffer()[6], 7); + EXPECT_EQ(output->plane(0).buffer()[7], 8); + EXPECT_EQ(output->plane(0).buffer()[8], 1); + EXPECT_EQ(output->plane(0).buffer()[9], 2); + EXPECT_EQ(output->plane(0).buffer()[10], 3); + EXPECT_EQ(output->plane(0).buffer()[11], 4); +} + +TEST(FrameBufferUtil, RgbFlipHorizontally) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 1}; + uint8_t kRgbTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + uint8_t output_data[sizeof(kRgbTestData) / 2]; + auto input = CreateFromRgbRawBuffer(kRgbTestData, kBufferDimension); + auto output = CreateFromRgbRawBuffer(output_data, kBufferDimension); + + MP_ASSERT_OK(FlipHorizontally(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 7); + EXPECT_EQ(output->plane(0).buffer()[1], 8); + EXPECT_EQ(output->plane(0).buffer()[2], 9); + EXPECT_EQ(output->plane(0).buffer()[3], 4); + EXPECT_EQ(output->plane(0).buffer()[4], 5); + EXPECT_EQ(output->plane(0).buffer()[5], 6); + EXPECT_EQ(output->plane(0).buffer()[6], 1); + EXPECT_EQ(output->plane(0).buffer()[7], 2); + EXPECT_EQ(output->plane(0).buffer()[8], 3); +} + +TEST(FrameBufferUtil, RgbaFlipVertically) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}; + uint8_t kRgbaTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + uint8_t output_data[sizeof(kRgbaTestData)]; + auto input = CreateFromRgbaRawBuffer(kRgbaTestData, kBufferDimension); + auto output = CreateFromRgbaRawBuffer(output_data, kBufferDimension); + + MP_ASSERT_OK(FlipVertically(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 13); + EXPECT_EQ(output->plane(0).buffer()[1], 14); + EXPECT_EQ(output->plane(0).buffer()[2], 15); + EXPECT_EQ(output->plane(0).buffer()[3], 16); + EXPECT_EQ(output->plane(0).buffer()[12], 1); + EXPECT_EQ(output->plane(0).buffer()[13], 2); + EXPECT_EQ(output->plane(0).buffer()[14], 3); + EXPECT_EQ(output->plane(0).buffer()[15], 4); +} + +TEST(FrameBufferUtil, RgbFlipVertically) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}; + uint8_t kRgbTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + uint8_t output_data[sizeof(kRgbTestData)]; + auto input = CreateFromRgbRawBuffer(kRgbTestData, kBufferDimension); + auto output = CreateFromRgbRawBuffer(output_data, kBufferDimension); + + MP_ASSERT_OK(FlipVertically(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 10); + EXPECT_EQ(output->plane(0).buffer()[1], 11); + EXPECT_EQ(output->plane(0).buffer()[2], 12); + EXPECT_EQ(output->plane(0).buffer()[9], 1); + EXPECT_EQ(output->plane(0).buffer()[10], 2); + EXPECT_EQ(output->plane(0).buffer()[11], 3); +} + +TEST(FrameBufferUtil, RgbaResize) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}, + kResizeUpDimension = {.width = 4, + .height = 2}, + kResizeDownDimension = {.width = 2, + .height = 2}; + uint8_t kRgbaTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + uint8_t output_data_up[32]; + auto input = CreateFromRgbaRawBuffer(kRgbaTestData, kBufferDimension); + auto output = CreateFromRgbaRawBuffer(output_data_up, kResizeUpDimension); + + // Test increasing the size. + MP_ASSERT_OK(Resize(*input, output.get())); + uint8_t resize_result_size_increase[] = { + 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 9, 10, 11, 12, + 13, 14, 15, 16, 16, 17, 18, 19, 19, 20, 21, 22, 21, 22, 23, 24}; + for (int i = 0; i < sizeof(output_data_up); i++) { + EXPECT_EQ(output->plane(0).buffer()[i], resize_result_size_increase[i]); + } + + // Test shrinking the image by half. + uint8_t output_data_down[16]; + output = CreateFromRgbaRawBuffer(output_data_down, kResizeDownDimension); + + MP_ASSERT_OK(Resize(*input, output.get())); + uint8_t resize_result_size_decrease[] = {1, 2, 3, 4, 7, 8, 9, 10, + 13, 14, 15, 16, 19, 20, 21, 22}; + for (int i = 0; i < sizeof(output_data_down); i++) { + EXPECT_EQ(output->plane(0).buffer()[i], resize_result_size_decrease[i]); + } +} + +TEST(FrameBufferUtil, RgbResize) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}, + kResizeUpDimension = {.width = 4, + .height = 3}, + kResizeDownDimension = {.width = 2, + .height = 2}; + uint8_t kRgbTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + auto input = CreateFromRgbRawBuffer(kRgbTestData, kBufferDimension); + + // Test increasing the size. + uint8_t output_data_up[36]; + auto output = CreateFromRgbRawBuffer(output_data_up, kResizeUpDimension); + MP_ASSERT_OK(Resize(*input, output.get())); + uint8_t resize_result_size_increase[] = { + 1, 2, 3, 3, 4, 5, 5, 6, 7, 7, 8, 9, 7, 8, 9, 9, 10, 11, + 11, 12, 13, 13, 14, 15, 10, 11, 12, 12, 13, 14, 14, 15, 16, 16, 17, 18}; + for (int i = 0; i < sizeof(output_data_up); i++) { + EXPECT_EQ(output_data_up[i], resize_result_size_increase[i]); + } + + // Test decreasing the size. + uint8_t output_data_down[12]; + output = CreateFromRgbRawBuffer(output_data_down, kResizeDownDimension); + MP_ASSERT_OK(Resize(*input, output.get())); + + uint8_t resize_result_size_decrease[] = {1, 2, 3, 5, 6, 7, + 10, 11, 12, 14, 15, 16}; + for (int i = 0; i < sizeof(resize_result_size_decrease); i++) { + EXPECT_EQ(output->plane(0).buffer()[i], resize_result_size_decrease[i]); + } +} + +TEST(FrameBufferUtil, RgbaRotate) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}, + kRotatedDimension = {.width = 2, + .height = 3}; + uint8_t kRgbaTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}; + uint8_t output_data[sizeof(kRgbaTestData)]; + auto input = CreateFromRgbaRawBuffer(kRgbaTestData, kBufferDimension); + const std::array kAnglesToTest = {90, 180, 270}; + std::map> kOutputBuffers; + kOutputBuffers[90] = CreateFromRgbaRawBuffer(output_data, kRotatedDimension); + kOutputBuffers[180] = CreateFromRgbaRawBuffer(output_data, kBufferDimension); + kOutputBuffers[270] = CreateFromRgbaRawBuffer(output_data, kRotatedDimension); + const std::map> kRotationResults{ + {90, {9, 10, 11, 12, 21, 22, 23, 24, 5, 6, 7, 8, + 17, 18, 19, 20, 1, 2, 3, 4, 13, 14, 15, 16}}, + {180, {21, 22, 23, 24, 17, 18, 19, 20, 13, 14, 15, 16, + 9, 10, 11, 12, 5, 6, 7, 8, 1, 2, 3, 4}}, + {270, {13, 14, 15, 16, 1, 2, 3, 4, 17, 18, 19, 20, + 5, 6, 7, 8, 21, 22, 23, 24, 9, 10, 11, 12}}}; + + for (auto angle : kAnglesToTest) { + auto output = kOutputBuffers.at(angle).get(); + MP_ASSERT_OK(Rotate(*input, angle, output)); + auto results = kRotationResults.at(angle); + for (int i = 0; i < results.size(); i++) { + EXPECT_EQ(output->plane(0).buffer()[i], results[i]); + } + } +} + +TEST(FrameBufferUtil, RgbRotate) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 3, .height = 2}, + kRotatedDimension = {.width = 2, + .height = 3}; + uint8_t kRgbTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + uint8_t output_data[sizeof(kRgbTestData)]; + auto input = CreateFromRgbRawBuffer(kRgbTestData, kBufferDimension); + const std::array kAnglesToTest = {90, 180, 270}; + std::map> kOutputBuffers; + kOutputBuffers[90] = CreateFromRgbRawBuffer(output_data, kRotatedDimension); + kOutputBuffers[180] = CreateFromRgbRawBuffer(output_data, kBufferDimension); + kOutputBuffers[270] = CreateFromRgbRawBuffer(output_data, kRotatedDimension); + const std::map> kRotationResults{ + {90, {7, 8, 9, 16, 17, 18, 4, 5, 6, 13, 14, 15, 1, 2, 3, 10, 11, 12}}, + {180, {16, 17, 18, 13, 14, 15, 10, 11, 12, 7, 8, 9, 4, 5, 6, 1, 2, 3}}, + {270, {10, 11, 12, 1, 2, 3, 13, 14, 15, 4, 5, 6, 16, 17, 18, 7, 8, 9}}}; + + for (auto angle : kAnglesToTest) { + auto output = kOutputBuffers.at(angle).get(); + MP_ASSERT_OK(Rotate(*input, angle, output)); + auto results = kRotationResults.at(angle); + for (int i = 0; i < results.size(); i++) { + EXPECT_EQ(output->plane(0).buffer()[i], results[i]); + } + } +} + +// Nv21 unit tests. +//------------------------------------------------------------------------------ + +// Helper function to create YUV buffer. +absl::StatusOr> CreateYuvBuffer( + uint8_t* buffer, FrameBuffer::Dimension dimension, int plane_count, + FrameBuffer::Format format) { + DCHECK(plane_count > 0 && plane_count < 4); + ASSIGN_OR_RETURN(auto uv_dimension, GetUvPlaneDimension(dimension, format)); + + if (plane_count == 1) { + const std::vector planes = { + {buffer, /*stride=*/{/*row_stride_bytes=*/dimension.width, + /*pixel_stride_bytes=*/1}}}; + return std::make_shared(planes, dimension, format); + } else if (plane_count == 2) { + CHECK(format == FrameBuffer::Format::kNV12 || + format == FrameBuffer::Format::kNV21); + const std::vector planes = { + {buffer, + /*stride=*/{/*row_stride_bytes=*/dimension.width, + /*pixel_stride_bytes=*/1}}, + {buffer + dimension.Size(), + /*stride=*/{/*row_stride_bytes=*/uv_dimension.width * 2, + /*pixel_stride_bytes=*/2}}}; + return std::make_shared(planes, dimension, format); + } else if (plane_count == 3) { + std::vector planes = { + {buffer, + /*stride=*/{/*row_stride_bytes=*/dimension.width, + /*pixel_stride_bytes=*/1}}, + {buffer + dimension.Size(), + /*stride=*/{/*row_stride_bytes=*/uv_dimension.width, + /*pixel_stride_bytes=*/1}}, + {buffer + dimension.Size() + uv_dimension.Size(), + /*stride=*/{/*row_stride_bytes=*/uv_dimension.width, + /*pixel_stride_bytes=*/1}}}; + return std::make_shared(planes, dimension, format); + } + + return absl::InvalidArgumentError("The plane_count must between 1 and 3."); +} + +TEST(FrameBufferUtil, NV21CreatePlanarYuvBuffer) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 6, .height = 2}, + kOutputDimension = {.width = 4, .height = 2}; + uint8_t kYTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + uint8_t kUTestData[] = {13, 15, 17, 0, 0, 0}; + uint8_t kVTestData[] = {14, 16, 18, 0, 0, 0}; + uint8_t kNV21VUTestData[] = {14, 13, 16, 15, 18, 17}; + const std::vector three_input_planes = { + {kYTestData, /*stride=*/{6, 1}}, + {kUTestData, /*stride=*/{3, 1}}, + {kVTestData, /*stride=*/{3, 1}}}; + FrameBuffer three_planar_input(three_input_planes, kBufferDimension, + FrameBuffer::Format::kYV21); + + const std::vector two_input_planes = { + {kYTestData, /*stride=*/{6, 1}}, {kNV21VUTestData, /*stride=*/{6, 2}}}; + FrameBuffer two_planar_input(two_input_planes, kBufferDimension, + FrameBuffer::Format::kNV21); + + uint8_t output_y[8], output_u[2], output_v[2]; + const std::vector output_planes = { + {output_y, /*stride=*/{4, 1}}, + {output_u, /*stride=*/{2, 1}}, + {output_v, /*stride=*/{2, 1}}}; + FrameBuffer output(output_planes, kOutputDimension, + FrameBuffer::Format::kYV12); + + MP_ASSERT_OK(Crop(three_planar_input, 2, 0, 5, 1, &output)); + EXPECT_EQ(output.plane(0).buffer()[0], 3); + EXPECT_EQ(output.plane(0).buffer()[1], 4); + EXPECT_EQ(output.plane(0).buffer()[2], 5); + EXPECT_EQ(output.plane(1).buffer()[0], 16); + EXPECT_EQ(output.plane(2).buffer()[0], 15); + + memset(output_y, 0, sizeof(output_y)); + memset(output_u, 0, sizeof(output_u)); + memset(output_v, 0, sizeof(output_v)); + MP_ASSERT_OK(Crop(two_planar_input, 2, 0, 5, 1, &output)); + EXPECT_EQ(output.plane(0).buffer()[0], 3); + EXPECT_EQ(output.plane(0).buffer()[1], 4); + EXPECT_EQ(output.plane(0).buffer()[2], 5); + EXPECT_EQ(output.plane(1).buffer()[0], 16); + EXPECT_EQ(output.plane(2).buffer()[0], 15); +} + +TEST(FrameBufferUtil, NV21Crop) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 6, .height = 2}, + kOutputDimension = {.width = 4, .height = 2}; + uint8_t kNV21TestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + MP_ASSERT_OK_AND_ASSIGN(auto input, + CreateFromRawBuffer(kNV21TestData, kBufferDimension, + FrameBuffer::Format::kNV21)); + uint8_t output_data[12]; + MP_ASSERT_OK_AND_ASSIGN(auto output, + CreateFromRawBuffer(output_data, kOutputDimension, + FrameBuffer::Format::kNV21)); + + MP_ASSERT_OK(Crop(*input, 2, 0, 5, 1, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 3); + EXPECT_EQ(output->plane(0).buffer()[1], 4); + EXPECT_EQ(output->plane(0).buffer()[2], 5); + EXPECT_EQ(output->plane(0).buffer()[8], 15); + EXPECT_EQ(output->plane(0).buffer()[9], 16); +} + +TEST(FrameBufferUtil, YV21Crop) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 6, .height = 2}, + kOutputDimension = {.width = 4, .height = 2}; + uint8_t kYV21TestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 15, 17, 14, 16, 18}; + MP_ASSERT_OK_AND_ASSIGN( + auto input, + CreateYuvBuffer(kYV21TestData, kBufferDimension, /*plane_count=*/3, + FrameBuffer::Format::kYV21)); + uint8_t output_data[12]{}; + MP_ASSERT_OK_AND_ASSIGN( + auto output, + CreateYuvBuffer(output_data, kOutputDimension, /*plane_count=*/3, + FrameBuffer::Format::kYV21)); + + MP_ASSERT_OK( + Crop(*input, /*x0=*/2, /*y0=*/0, /*x1=*/5, /*y1=*/1, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 3); + EXPECT_EQ(output->plane(0).buffer()[1], 4); + EXPECT_EQ(output->plane(0).buffer()[2], 5); + EXPECT_EQ(output->plane(1).buffer()[0], 15); + EXPECT_EQ(output->plane(2).buffer()[0], 16); +} + +TEST(FrameBufferUtil, NV21HorizontalFlip) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 6, .height = 2}; + uint8_t kNV21TestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + MP_ASSERT_OK_AND_ASSIGN(auto input, + CreateFromRawBuffer(kNV21TestData, kBufferDimension, + FrameBuffer::Format::kNV21)); + uint8_t output_data[18]; + MP_ASSERT_OK_AND_ASSIGN(auto output, + CreateFromRawBuffer(output_data, kBufferDimension, + FrameBuffer::Format::kNV21)); + + MP_ASSERT_OK(FlipHorizontally(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 6); + EXPECT_EQ(output->plane(0).buffer()[1], 5); + EXPECT_EQ(output->plane(0).buffer()[2], 4); + EXPECT_EQ(output->plane(0).buffer()[12], 17); + EXPECT_EQ(output->plane(0).buffer()[13], 18); +} + +TEST(FrameBufferUtil, NV21VerticalFlip) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 6, .height = 2}; + uint8_t kNV21TestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + MP_ASSERT_OK_AND_ASSIGN(auto input, + CreateFromRawBuffer(kNV21TestData, kBufferDimension, + FrameBuffer::Format::kNV21)); + uint8_t output_data[18]; + MP_ASSERT_OK_AND_ASSIGN(auto output, + CreateFromRawBuffer(output_data, kBufferDimension, + FrameBuffer::Format::kNV21)); + + MP_ASSERT_OK(FlipVertically(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 7); + EXPECT_EQ(output->plane(0).buffer()[1], 8); + EXPECT_EQ(output->plane(0).buffer()[2], 9); + EXPECT_EQ(output->plane(0).buffer()[12], 13); + EXPECT_EQ(output->plane(0).buffer()[13], 14); +} + +TEST(FrameBufferUtil, NV21Rotate) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 6, .height = 2}, + kRotatedDimension = {.width = 2, + .height = 6}; + uint8_t kNV21TestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + MP_ASSERT_OK_AND_ASSIGN(auto input, + CreateFromRawBuffer(kNV21TestData, kBufferDimension, + FrameBuffer::Format::kNV21)); + uint8_t output_data[18]; + MP_ASSERT_OK_AND_ASSIGN(auto output, + CreateFromRawBuffer(output_data, kRotatedDimension, + FrameBuffer::Format::kNV21)); + + MP_ASSERT_OK(Rotate(*input, 90, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 6); + EXPECT_EQ(output->plane(0).buffer()[1], 12); + EXPECT_EQ(output->plane(0).buffer()[2], 5); + EXPECT_EQ(output->plane(0).buffer()[12], 17); + EXPECT_EQ(output->plane(0).buffer()[13], 18); +} + +TEST(FrameBufferUtil, NV21Resize) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 6, .height = 2}, + kOutputDimension = {.width = 1, .height = 1}; + uint8_t kNV21TestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + MP_ASSERT_OK_AND_ASSIGN(auto input, + CreateFromRawBuffer(kNV21TestData, kBufferDimension, + FrameBuffer::Format::kNV21)); + uint8_t output_data[6]; + MP_ASSERT_OK_AND_ASSIGN(auto output, + CreateFromRawBuffer(output_data, kOutputDimension, + FrameBuffer::Format::kNV21)); + + MP_ASSERT_OK(Resize(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 1); + EXPECT_EQ(output->plane(0).buffer()[1], 13); +} + +TEST(FrameBufferUtil, NV21ConvertGray) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 6, .height = 2}; + uint8_t kNV21TestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + MP_ASSERT_OK_AND_ASSIGN(auto input, + CreateFromRawBuffer(kNV21TestData, kBufferDimension, + FrameBuffer::Format::kNV21)); + const int kOutputSize = + GetFrameBufferByteSize(kBufferDimension, FrameBuffer::Format::kGRAY); + std::vector output_data(kOutputSize); + auto output = CreateFromGrayRawBuffer(output_data.data(), kBufferDimension); + + MP_ASSERT_OK(Convert(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 1); + EXPECT_EQ(output->plane(0).buffer()[1], 2); + EXPECT_EQ(output->plane(0).buffer()[11], 12); +} + +TEST(FrameBufferUtil, PaddedYuvConvertGray) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 6, .height = 2}; + uint8_t kNV21PaddedTestData[] = {1, 2, 3, 4, 5, 6, 100, 100, + 7, 8, 9, 10, 11, 12, 100, 100, + 13, 14, 15, 16, 17, 18, 100, 100}; + constexpr int row_stride_y = 8; + const std::vector planes = { + {kNV21PaddedTestData, /*stride=*/{row_stride_y, 1}}, + {kNV21PaddedTestData + (row_stride_y * kBufferDimension.width), + /*stride=*/{row_stride_y, 2}}}; + auto input = std::make_shared(planes, kBufferDimension, + FrameBuffer::Format::kNV21); + const int kOutputSize = + GetFrameBufferByteSize(kBufferDimension, FrameBuffer::Format::kGRAY); + std::vector output_data(kOutputSize); + auto output = CreateFromGrayRawBuffer(output_data.data(), kBufferDimension); + + MP_ASSERT_OK(Convert(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 1); + EXPECT_EQ(output->plane(0).buffer()[1], 2); + EXPECT_EQ(output->plane(0).buffer()[6], 7); + EXPECT_EQ(output->plane(0).buffer()[7], 8); + EXPECT_EQ(output->plane(0).buffer()[11], 12); +} + +TEST(FrameBufferUtil, NV21ConvertRgb) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 32, + .height = 8}; + // Note that RGB conversion expects at images width at least width >= 32 + // because the implementation is vectorized. + const int kInputSize = + GetFrameBufferByteSize(kBufferDimension, FrameBuffer::Format::kNV21); + std::vector input_data(kInputSize); + input_data.data()[0] = 1; + input_data.data()[1] = 2; + input_data.data()[32] = 7; + input_data.data()[33] = 8; + input_data.data()[256] = 13; + input_data.data()[257] = 14; + MP_ASSERT_OK_AND_ASSIGN( + auto input, CreateFromRawBuffer(input_data.data(), kBufferDimension, + FrameBuffer::Format::kNV21)); + const int kOutputSize = + GetFrameBufferByteSize(kBufferDimension, FrameBuffer::Format::kRGB); + std::vector output_data(kOutputSize); + auto output = CreateFromRgbRawBuffer(output_data.data(), kBufferDimension); + + MP_ASSERT_OK(Convert(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 0); + EXPECT_EQ(output->plane(0).buffer()[1], 122); +} + +TEST(FrameBufferUtil, NV21ConvertHalfRgb) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 64, + .height = 16}, + kOutputDimension = {.width = 32, + .height = 8}; + // Note that RGB conversion expects at images width at least width >= 32 + // because the implementation is vectorized. + uint8_t data[1576]; + for (int i = 0; i < sizeof(data); i++) { + data[i] = (i + 1); + } + MP_ASSERT_OK_AND_ASSIGN( + auto input, + CreateFromRawBuffer(data, kBufferDimension, FrameBuffer::Format::kNV21)); + uint8_t output_data[768]; + auto output = CreateFromRgbRawBuffer(output_data, kOutputDimension); + + MP_ASSERT_OK(Convert(*input, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 0); + EXPECT_EQ(output->plane(0).buffer()[1], 135); +} + +TEST(FrameBufferUtil, NV12ConvertGray) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 6, .height = 2}; + uint8_t kYTestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; + uint8_t kNV12UVTestData[] = {13, 14, 15, 16, 17, 18}; + const std::vector planes_nv12 = { + {kYTestData, /*stride=*/{kBufferDimension.width, 1}}, + {kNV12UVTestData, /*stride=*/{kBufferDimension.width, 2}}}; + auto buffer_nv12 = std::make_shared( + planes_nv12, kBufferDimension, FrameBuffer::Format::kNV12); + const int kOutputSize = + GetFrameBufferByteSize(kBufferDimension, FrameBuffer::Format::kGRAY); + std::vector output_data(kOutputSize); + auto output = CreateFromGrayRawBuffer(output_data.data(), kBufferDimension); + + MP_ASSERT_OK(Convert(*buffer_nv12, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], kYTestData[0]); + EXPECT_EQ(output->plane(0).buffer()[1], kYTestData[1]); + EXPECT_EQ(output->plane(0).buffer()[11], kYTestData[11]); +} + +TEST(FrameBufferUtil, NV12ConvertRgb) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 32, + .height = 8}; + MP_ASSERT_OK_AND_ASSIGN( + FrameBuffer::Dimension uv_dimension, + GetUvPlaneDimension(kBufferDimension, FrameBuffer::Format::kNV12)); + // Halide RGB converter expects at images width at least width >= 32 because + // the implementation is vectorized. + auto y_data = std::make_unique(kBufferDimension.Size()); + auto uv_data = std::make_unique(uv_dimension.Size() * 2); + y_data[0] = 1; + y_data[1] = 2; + y_data[32] = 7; + y_data[33] = 8; + uv_data[0] = 13; + uv_data[1] = 14; + const std::vector planes_nv12 = { + {y_data.get(), /*stride=*/{kBufferDimension.width, 1}}, + {uv_data.get(), /*stride=*/{kBufferDimension.width, 2}}}; + auto buffer_nv12 = std::make_shared( + planes_nv12, kBufferDimension, FrameBuffer::Format::kNV12); + const int kOutputSize = + GetFrameBufferByteSize(kBufferDimension, FrameBuffer::Format::kRGB); + std::vector output_data(kOutputSize); + auto output = CreateFromRgbRawBuffer(output_data.data(), kBufferDimension); + + MP_ASSERT_OK(Convert(*buffer_nv12, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 0); + EXPECT_EQ(output->plane(0).buffer()[1], 122); +} + +TEST(FrameBufferUtil, NV12ConvertHalfRgb) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 64, + .height = 16}; + MP_ASSERT_OK_AND_ASSIGN( + FrameBuffer::Dimension uv_dimension, + GetUvPlaneDimension(kBufferDimension, FrameBuffer::Format::kNV12)); + // Halide RGB converter expects at images width at least width >= 32 because + // the implementation is vectorized. + auto y_data = std::make_unique(kBufferDimension.Size()); + auto uv_data = std::make_unique(uv_dimension.Size() * 2); + for (int i = 0; i < kBufferDimension.Size(); i++) { + y_data[i] = (i + 1) % 256; + } + for (int i = 0; i < uv_dimension.Size() * 2; i++) { + uv_data[i] = (i + 1) % 256; + } + const std::vector planes_nv12 = { + {y_data.get(), /*stride=*/{kBufferDimension.width, 1}}, + {uv_data.get(), /*stride=*/{kBufferDimension.width, 2}}}; + auto buffer_nv12 = std::make_shared( + planes_nv12, kBufferDimension, FrameBuffer::Format::kNV12); + constexpr FrameBuffer::Dimension kOutputDimension = { + .width = kBufferDimension.width / 2, + .height = kBufferDimension.height / 2}; + const int kOutputSize = + GetFrameBufferByteSize(kOutputDimension, FrameBuffer::Format::kRGB); + std::vector output_data(kOutputSize); + auto output = CreateFromRgbRawBuffer(output_data.data(), kOutputDimension); + + MP_ASSERT_OK(Convert(*buffer_nv12, output.get())); + EXPECT_EQ(output->plane(0).buffer()[0], 0); + EXPECT_EQ(output->plane(0).buffer()[1], 135); +} + +TEST(FrameBufferUtil, NV21ConvertYV12) { + constexpr FrameBuffer::Dimension kBufferDimension = {.width = 6, .height = 2}; + uint8_t kNV21TestData[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18}; + MP_ASSERT_OK_AND_ASSIGN( + auto nv21, + CreateYuvBuffer(kNV21TestData, kBufferDimension, /*plane_count=*/2, + FrameBuffer::Format::kNV21)); + MP_ASSERT_OK_AND_ASSIGN(FrameBuffer::YuvData nv21_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*nv21)); + const int kOutputSize = + GetFrameBufferByteSize(kBufferDimension, FrameBuffer::Format::kYV12); + std::vector output_data(kOutputSize); + MP_ASSERT_OK_AND_ASSIGN( + auto yv12, + CreateYuvBuffer(output_data.data(), kBufferDimension, /*plane_count=*/3, + FrameBuffer::Format::kYV12)); + MP_ASSERT_OK_AND_ASSIGN(FrameBuffer::YuvData yv12_data, + FrameBuffer::GetYuvDataFromFrameBuffer(*yv12)); + + MP_ASSERT_OK(Convert(*nv21, yv12.get())); + EXPECT_EQ(nv21_data.y_buffer[0], yv12_data.y_buffer[0]); + EXPECT_EQ(nv21_data.u_buffer[0], yv12_data.u_buffer[0]); + EXPECT_EQ(nv21_data.v_buffer[0], yv12_data.v_buffer[0]); +} + +} // namespace +} // namespace frame_buffer +} // namespace mediapipe diff --git a/mediapipe/util/frame_buffer/gray_buffer.cc b/mediapipe/util/frame_buffer/gray_buffer.cc new file mode 100644 index 000000000..51f7b09e2 --- /dev/null +++ b/mediapipe/util/frame_buffer/gray_buffer.cc @@ -0,0 +1,95 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/frame_buffer/gray_buffer.h" + +#include + +#include "mediapipe/util/frame_buffer/buffer_common.h" +#include "mediapipe/util/frame_buffer/halide/gray_flip_halide.h" +#include "mediapipe/util/frame_buffer/halide/gray_resize_halide.h" +#include "mediapipe/util/frame_buffer/halide/gray_rotate_halide.h" +#include "mediapipe/util/frame_buffer/yuv_buffer.h" + +namespace mediapipe { +namespace frame_buffer { + +GrayBuffer::GrayBuffer(uint8_t* buffer, int width, int height) + : owned_buffer_(nullptr) { + Initialize(buffer, width, height); +} + +GrayBuffer::GrayBuffer(int width, int height) + : owned_buffer_(new uint8_t[ByteSize(width, height)]) { + Initialize(owned_buffer_.get(), width, height); +} + +GrayBuffer::GrayBuffer(const GrayBuffer& other) : buffer_(other.buffer_) {} + +GrayBuffer::GrayBuffer(GrayBuffer&& other) { *this = std::move(other); } + +GrayBuffer& GrayBuffer::operator=(const GrayBuffer& other) { + if (this != &other) { + buffer_ = other.buffer_; + } + return *this; +} + +GrayBuffer& GrayBuffer::operator=(GrayBuffer&& other) { + if (this != &other) { + owned_buffer_ = std::move(other.owned_buffer_); + buffer_ = other.buffer_; + } + return *this; +} + +GrayBuffer::~GrayBuffer() {} + +void GrayBuffer::Initialize(uint8_t* data, int width, int height) { + buffer_ = Halide::Runtime::Buffer(data, width, height); +} + +bool GrayBuffer::Crop(int x0, int y0, int x1, int y1) { + // Twiddle the buffer start and extents to crop images. + return common::crop_buffer(x0, y0, x1, y1, buffer()); +} + +bool GrayBuffer::Resize(GrayBuffer* output) { + const int result = gray_resize_halide( + buffer(), static_cast(width()) / output->width(), + static_cast(height()) / output->height(), output->buffer()); + return result == 0; +} + +bool GrayBuffer::Rotate(int angle, GrayBuffer* output) { + const int result = gray_rotate_halide(buffer(), angle, output->buffer()); + return result == 0; +} + +bool GrayBuffer::FlipHorizontally(GrayBuffer* output) { + const int result = gray_flip_halide(buffer(), + false, // horizontal + output->buffer()); + return result == 0; +} + +bool GrayBuffer::FlipVertically(GrayBuffer* output) { + const int result = gray_flip_halide(buffer(), + true, // vertical + output->buffer()); + return result == 0; +} + +} // namespace frame_buffer +} // namespace mediapipe diff --git a/mediapipe/util/frame_buffer/gray_buffer.h b/mediapipe/util/frame_buffer/gray_buffer.h new file mode 100644 index 000000000..fa181acec --- /dev/null +++ b/mediapipe/util/frame_buffer/gray_buffer.h @@ -0,0 +1,137 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_FRAME_BUFFER_GRAY_BUFFER_H_ +#define MEDIAPIPE_UTIL_FRAME_BUFFER_GRAY_BUFFER_H_ + +#include + +#include "HalideBuffer.h" +#include "HalideRuntime.h" + +namespace mediapipe { +namespace frame_buffer { + +// GrayBuffer represents a view over a grayscale (i.e. luminance, +// or Y-only) buffer. +// GrayBuffer may be copied and moved efficiently; their backing buffers are +// shared and never deep copied. +// GrayBuffer requires a minimum image width depending on the natural vector +// size of the platform, e.g., 16px. This is not validated by GrayBuffer. +class GrayBuffer { + public: + // Returns the size (in bytes) of a grayscale image of the given + // dimensions. The given dimensions contain padding. + static int ByteSize(int buffer_width, int buffer_height) { + const int size = buffer_width * buffer_height; + return size; + } + + // Builds a grayscale buffer with size as width * height. The buffer should + // be in row-major order with no padding. + // + // Does not take ownership of any backing buffers, which must be large + // enough to fit their contents. + GrayBuffer(uint8_t* buffer, int width, int height); + + // Builds a grayscale buffer with size as width * height. + // + // The underlying backing buffer is allocated and owned by this + // GrayBuffer. + GrayBuffer(int width, int height); + + // GrayBuffer is copyable. The source retains ownership of its backing + // buffers. + // + // Since the source retains ownership of its backing buffer, the source needs + // to outlive this instance's lifetime when the backing buffer is owned by + // the source. Otherwise, the passing in backing buffer should outlive this + // instance. + GrayBuffer(const GrayBuffer& other); + // GrayBuffer is moveable. The source loses ownership of any backing buffers. + // Specifically, if the source owns its backing buffer, after the move, + // Release() will return nullptr. + GrayBuffer(GrayBuffer&& other); + + // GrayBuffer is assignable. The source retains ownership of its backing + // buffers. + // + // Since the source retains ownership of its backing buffer, the source needs + // to outlive this instance's lifetime when the backing buffer is owned by the + // source. Otherwise, the passing in backing buffer should outlive this + // instance. + GrayBuffer& operator=(const GrayBuffer& other); + GrayBuffer& operator=(GrayBuffer&& other); + + ~GrayBuffer(); + + // Performs an in-place crop. Modifies this buffer so that the new extent + // matches that of the given crop rectangle -- (x0, y0) becomes (0, 0) and + // the new width and height are x1 - x0 + 1 and y1 - y0 + 1, respectively. + bool Crop(int x0, int y0, int x1, int y1); + + // Resizes this image to match the dimensions of the given output GrayBuffer + // and places the result into output's backing buffer. + // + // Note, if the output backing buffer is shared with multiple instances, by + // calling this method, all the instances' backing buffers will change. + bool Resize(GrayBuffer* output); + + // Rotates this image into the given buffer by the given angle (90, 180, 270). + // + // Rotation is specified in degrees counter-clockwise such that when rotating + // by 90 degrees, the top-right corner of the source becomes the top-left of + // the output. The output buffer must have its height and width swapped when + // rotating by 90 or 270. + // + // Any angle values other than (90, 180, 270) are invalid. + // + // Note, if the output backing buffer is shared with multiple instances, by + // calling this method, all the instances' backing buffers will change. + bool Rotate(int angle, GrayBuffer* output); + + // Flips this image horizontally/vertically into the given buffer. Both buffer + // dimensions must match. + // + // Note, if the output backing buffer is shared with multiple instances, by + // calling this method, all the instances' backing buffers will change. + bool FlipHorizontally(GrayBuffer* output); + bool FlipVertically(GrayBuffer* output); + + // Releases ownership of the owned backing buffer. + uint8_t* Release() { return owned_buffer_.release(); } + + // Returns the halide_buffer_t* for the image. + halide_buffer_t* buffer() { return buffer_.raw_buffer(); } + + // Returns the image width. + const int width() const { return buffer_.dim(0).extent(); } + // Returns the image height. + const int height() const { return buffer_.dim(1).extent(); } + + private: + void Initialize(uint8_t* data, int width, int height); + + // Non-NULL iff this GrayBuffer owns its buffer. + std::unique_ptr owned_buffer_; + + // Backing buffer: layout is always width x height. The backing buffer binds + // to either "owned_buffer_" or an external buffer. + Halide::Runtime::Buffer buffer_; +}; + +} // namespace frame_buffer +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_FRAME_BUFFER_GRAY_BUFFER_H_ diff --git a/mediapipe/util/frame_buffer/gray_buffer_test.cc b/mediapipe/util/frame_buffer/gray_buffer_test.cc new file mode 100644 index 000000000..f6f9e9e34 --- /dev/null +++ b/mediapipe/util/frame_buffer/gray_buffer_test.cc @@ -0,0 +1,223 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/frame_buffer/gray_buffer.h" + +#include + +#include "absl/log/log.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" + +// The default implementation of halide_error calls abort(), which we don't +// want. Instead, log the error and let the filter invocation fail. +extern "C" void halide_error(void*, const char* message) { + LOG(ERROR) << "Halide Error: " << message; +} + +namespace mediapipe { +namespace frame_buffer { +namespace { + +// Fill a GrayBuffer with zeroes. +void Fill(GrayBuffer* gray_buffer) { + halide_buffer_t* buffer = gray_buffer->buffer(); + for (int y = 0; y < buffer->dim[1].extent; ++y) { + for (int x = 0; x < buffer->dim[0].extent; ++x) { + buffer->host[buffer->dim[1].stride * y + buffer->dim[0].stride * x] = 0; + } + } +} + +TEST(GrayBufferTest, Properties) { + GrayBuffer buffer(5, 4); + EXPECT_EQ(5, buffer.width()); + EXPECT_EQ(4, buffer.height()); +} + +TEST(GrayBufferTest, Release) { + GrayBuffer buffer(4, 4); + delete[] buffer.Release(); +} + +TEST(GrayBufferTest, Assign) { + GrayBuffer buffer(3, 2); + GrayBuffer sink(nullptr, 0, 0); + Fill(&buffer); + sink = buffer; + EXPECT_EQ(3, sink.width()); + EXPECT_EQ(2, sink.height()); + + sink = GrayBuffer(5, 4); + EXPECT_EQ(5, sink.width()); + EXPECT_EQ(4, sink.height()); +} + +TEST(GrayBufferTest, MoveAssign) { + GrayBuffer buffer(3, 2); + GrayBuffer sink(nullptr, 0, 0); + Fill(&buffer); + sink = std::move(buffer); + EXPECT_EQ(nullptr, buffer.Release()); + EXPECT_EQ(3, sink.width()); + EXPECT_EQ(2, sink.height()); +} + +TEST(GrayBufferTest, MoveConstructor) { + GrayBuffer buffer(5, 4); + GrayBuffer sink(std::move(buffer)); + Fill(&buffer); + EXPECT_EQ(nullptr, buffer.Release()); + EXPECT_EQ(5, sink.width()); + EXPECT_EQ(4, sink.height()); +} + +TEST(GrayBufferTest, Crop) { + GrayBuffer source(8, 8); + EXPECT_TRUE(source.Crop(2, 2, 6, 6)); +} + +TEST(GrayBufferTest, Resize_Even) { + uint8_t* data = new uint8_t[16]; + for (int y = 0; y < 4; ++y) { + for (int x = 0; x < 4; ++x) { + data[x + y * 4] = x + y * 4; + } + } + GrayBuffer source(data, 4, 4); + GrayBuffer result(2, 2); + EXPECT_TRUE(source.Resize(&result)); + EXPECT_EQ(0, result.buffer()->host[0]); + EXPECT_EQ(2, result.buffer()->host[1]); + EXPECT_EQ(8, result.buffer()->host[2]); + EXPECT_EQ(10, result.buffer()->host[3]); + delete[] data; +} + +TEST(GrayBufferTest, Resize_Odd) { + uint8_t* data = new uint8_t[16]; + for (int y = 0; y < 4; ++y) { + for (int x = 0; x < 4; ++x) { + data[x + y * 4] = x + y * 4; + } + } + GrayBuffer source(data, 4, 4); + GrayBuffer result(1, 3); + EXPECT_TRUE(source.Resize(&result)); + EXPECT_EQ(0, result.buffer()->host[0]); + EXPECT_EQ(5, result.buffer()->host[1]); + EXPECT_EQ(11, result.buffer()->host[2]); + delete[] data; +} + +TEST(GrayBufferTest, Rotate) { + GrayBuffer buffer(5, 4); + GrayBuffer result(4, 5); + Fill(&buffer); + EXPECT_TRUE(buffer.Rotate(90, &result)); +} + +TEST(GrayBufferTest, Rotate_90) { + uint8_t* data = new uint8_t[4]; + data[0] = 1; + data[1] = 2; + data[2] = 3; + data[3] = 4; + GrayBuffer buffer(data, 2, 2); + GrayBuffer result(2, 2); + EXPECT_TRUE(buffer.Rotate(90, &result)); + + EXPECT_EQ(2, result.buffer()->host[0]); + EXPECT_EQ(4, result.buffer()->host[1]); + EXPECT_EQ(1, result.buffer()->host[2]); + EXPECT_EQ(3, result.buffer()->host[3]); + + delete[] data; +} + +TEST(GrayBufferTest, Rotate_180) { + uint8_t* data = new uint8_t[4]; + data[0] = 1; + data[1] = 2; + data[2] = 3; + data[3] = 4; + GrayBuffer buffer(data, 2, 2); + GrayBuffer result(2, 2); + EXPECT_TRUE(buffer.Rotate(180, &result)); + EXPECT_EQ(4, result.buffer()->host[0]); + EXPECT_EQ(3, result.buffer()->host[1]); + EXPECT_EQ(2, result.buffer()->host[2]); + EXPECT_EQ(1, result.buffer()->host[3]); + delete[] data; +} + +TEST(GrayBufferTest, Rotate_270) { + uint8_t* data = new uint8_t[4]; + data[0] = 1; + data[1] = 2; + data[2] = 3; + data[3] = 4; + GrayBuffer buffer(data, 2, 2); + GrayBuffer result(2, 2); + EXPECT_TRUE(buffer.Rotate(270, &result)); + EXPECT_EQ(3, result.buffer()->host[0]); + EXPECT_EQ(1, result.buffer()->host[1]); + EXPECT_EQ(4, result.buffer()->host[2]); + EXPECT_EQ(2, result.buffer()->host[3]); + delete[] data; +} + +TEST(GrayBufferTest, Flip) { + GrayBuffer buffer(5, 4); + GrayBuffer result(5, 4); + Fill(&buffer); + EXPECT_TRUE(buffer.FlipHorizontally(&result)); + EXPECT_TRUE(buffer.FlipVertically(&result)); +} + +TEST(GrayBufferTest, Flip_Horizontally) { + uint8_t* data = new uint8_t[4]; + data[0] = 1; + data[1] = 2; + data[2] = 3; + data[3] = 4; + GrayBuffer buffer(data, 2, 2); + GrayBuffer result(2, 2); + EXPECT_TRUE(buffer.FlipHorizontally(&result)); + EXPECT_EQ(2, result.buffer()->host[0]); + EXPECT_EQ(1, result.buffer()->host[1]); + EXPECT_EQ(4, result.buffer()->host[2]); + EXPECT_EQ(3, result.buffer()->host[3]); + delete[] data; +} + +TEST(GrayBufferTest, Flip_Vertically) { + uint8_t* data = new uint8_t[4]; + data[0] = 1; + data[1] = 2; + data[2] = 3; + data[3] = 4; + GrayBuffer buffer(data, 2, 2); + GrayBuffer result(2, 2); + EXPECT_TRUE(buffer.FlipVertically(&result)); + EXPECT_EQ(3, result.buffer()->host[0]); + EXPECT_EQ(4, result.buffer()->host[1]); + EXPECT_EQ(1, result.buffer()->host[2]); + EXPECT_EQ(2, result.buffer()->host[3]); + delete[] data; +} + +} // namespace +} // namespace frame_buffer +} // namespace mediapipe diff --git a/mediapipe/util/frame_buffer/halide/BUILD b/mediapipe/util/frame_buffer/halide/BUILD new file mode 100644 index 000000000..619a16d26 --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/BUILD @@ -0,0 +1,118 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@halide//:halide.bzl", "halide_library") + +package(default_visibility = ["//mediapipe/util/frame_buffer:__subpackages__"]) + +# Common Halide library: +cc_library( + name = "common", + srcs = ["common.cc"], + hdrs = ["common.h"], + deps = ["@halide//:language"], +) + +# Enable Halide's built-in profiler with: +# bazel ... --define halide_target_features=profile +# and HTML output of its intermediate representation with: +# bazel ... --define halide_extra_outputs=html + +# RGB operations: +halide_library( + name = "rgb_flip_halide", + srcs = ["rgb_flip_generator.cc"], + generator_name = "rgb_flip_generator", +) + +halide_library( + name = "rgb_resize_halide", + srcs = ["rgb_resize_generator.cc"], + generator_deps = [":common"], + generator_name = "rgb_resize_generator", +) + +halide_library( + name = "rgb_rotate_halide", + srcs = ["rgb_rotate_generator.cc"], + generator_deps = [":common"], + generator_name = "rgb_rotate_generator", +) + +halide_library( + name = "rgb_yuv_halide", + srcs = ["rgb_yuv_generator.cc"], + generator_name = "rgb_yuv_generator", +) + +halide_library( + name = "rgb_rgb_halide", + srcs = ["rgb_rgb_generator.cc"], + generator_name = "rgb_rgb_generator", +) + +# YUV operations: +halide_library( + name = "yuv_flip_halide", + srcs = ["yuv_flip_generator.cc"], + generator_name = "yuv_flip_generator", +) + +halide_library( + name = "yuv_rgb_halide", + srcs = ["yuv_rgb_generator.cc"], + generator_name = "yuv_rgb_generator", +) + +halide_library( + name = "yuv_resize_halide", + srcs = ["yuv_resize_generator.cc"], + generator_deps = [":common"], + generator_name = "yuv_resize_generator", +) + +halide_library( + name = "yuv_rotate_halide", + srcs = ["yuv_rotate_generator.cc"], + generator_deps = [":common"], + generator_name = "yuv_rotate_generator", +) + +# Grayscale operations: + +halide_library( + name = "rgb_gray_halide", + srcs = ["rgb_gray_generator.cc"], + generator_name = "rgb_gray_generator", +) + +halide_library( + name = "gray_rotate_halide", + srcs = ["gray_rotate_generator.cc"], + generator_deps = [":common"], + generator_name = "gray_rotate_generator", +) + +halide_library( + name = "gray_flip_halide", + srcs = ["gray_flip_generator.cc"], + generator_name = "gray_flip_generator", +) + +halide_library( + name = "gray_resize_halide", + srcs = ["gray_resize_generator.cc"], + generator_deps = [":common"], + generator_name = "gray_resize_generator", +) diff --git a/mediapipe/util/frame_buffer/halide/common.cc b/mediapipe/util/frame_buffer/halide/common.cc new file mode 100644 index 000000000..8142c82f5 --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/common.cc @@ -0,0 +1,89 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/frame_buffer/halide/common.h" + +namespace mediapipe { +namespace frame_buffer { +namespace halide { +namespace common { + +namespace { +using ::Halide::_; +} + +void resize_nn(Halide::Func input, Halide::Func result, Halide::Expr fx, + Halide::Expr fy) { + Halide::Var x{"x"}, y{"y"}; + result(x, y, _) = input(Halide::cast((x + 0.5f) * fx), + Halide::cast((y + 0.5f) * fy), _); +} + +// Borrowed from photos/editing/halide/src/resize_image_bilinear_generator.cc: +void resize_bilinear(Halide::Func input, Halide::Func result, Halide::Expr fx, + Halide::Expr fy) { + Halide::Var x{"x"}, y{"y"}; + Halide::Func x_interpolated("x_interpolated"); + + Halide::Expr xi = Halide::cast(x * fx); + Halide::Expr xr = x * fx - xi; + Halide::Expr x0 = input(xi + 0, y, _); + Halide::Expr x1 = input(xi + 1, y, _); + x_interpolated(x, y, _) = lerp(x0, x1, xr); + + Halide::Expr yi = Halide::cast(y * fy); + Halide::Expr yr = y * fy - yi; + Halide::Expr y0 = x_interpolated(x, yi + 0, _); + Halide::Expr y1 = x_interpolated(x, yi + 1, _); + result(x, y, _) = lerp(y0, y1, yr); +} + +void resize_bilinear_int(Halide::Func input, Halide::Func result, + Halide::Expr fx, Halide::Expr fy) { + Halide::Var x{"x"}, y{"y"}; + Halide::Func x_interpolated("x_interpolated"); + + fx = Halide::cast(fx * 65536); + Halide::Expr xi = Halide::cast(x * fx / 65536); + Halide::Expr xr = Halide::cast(x * fx % 65536); + Halide::Expr x0 = input(xi + 0, y, _); + Halide::Expr x1 = input(xi + 1, y, _); + x_interpolated(x, y, _) = lerp(x0, x1, xr); + + fy = Halide::cast(fy * 65536); + Halide::Expr yi = Halide::cast(y * fy / 65536); + Halide::Expr yr = Halide::cast(y * fy % 65536); + Halide::Expr y0 = x_interpolated(x, yi + 0, _); + Halide::Expr y1 = x_interpolated(x, yi + 1, _); + result(x, y, _) = lerp(y0, y1, yr); +} + +void rotate(Halide::Func input, Halide::Func result, Halide::Expr width, + Halide::Expr height, Halide::Expr angle) { + Halide::Var x{"x"}, y{"y"}; + Halide::Func result_90_degrees, result_180_degrees, result_270_degrees; + result_90_degrees(x, y, _) = input(width - 1 - y, x, _); + result_180_degrees(x, y, _) = input(width - 1 - x, height - 1 - y, _); + result_270_degrees(x, y, _) = input(y, height - 1 - x, _); + + result(x, y, _) = + select(angle == 90, result_90_degrees(x, y, _), angle == 180, + result_180_degrees(x, y, _), angle == 270, + result_270_degrees(x, y, _), input(x, y, _)); +} + +} // namespace common +} // namespace halide +} // namespace frame_buffer +} // namespace mediapipe diff --git a/mediapipe/util/frame_buffer/halide/common.h b/mediapipe/util/frame_buffer/halide/common.h new file mode 100644 index 000000000..4b7023a9f --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/common.h @@ -0,0 +1,61 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_FRAME_BUFFER_HALIDE_COMMON_H_ +#define MEDIAPIPE_UTIL_FRAME_BUFFER_HALIDE_COMMON_H_ + +#include "Halide.h" + +namespace mediapipe { +namespace frame_buffer { +namespace halide { +namespace common { + +template +Halide::Expr is_planar(const T& buffer) { + return buffer.dim(0).stride() == 1; +} + +template +Halide::Expr is_interleaved(const T& buffer) { + return buffer.dim(0).stride() == buffer.dim(2).extent() && + buffer.dim(2).stride() == 1; +} + +// Resize scale parameters (fx, fy) are the ratio of source size to output +// size; thus if you want to produce an image half as wide and twice as tall +// as the input, (fx, fy) should be (2, 0.5). + +// Nearest-neighbor resize: fast, but low-quality (prone to aliasing). +void resize_nn(Halide::Func input, Halide::Func result, Halide::Expr fx, + Halide::Expr fy); + +// Resize with bilinear interpolation: slower but higher-quality. +void resize_bilinear(Halide::Func input, Halide::Func result, Halide::Expr fx, + Halide::Expr fy); +// Identical to the above, except that it uses fixed point integer math. +void resize_bilinear_int(Halide::Func input, Halide::Func result, + Halide::Expr fx, Halide::Expr fy); + +// Note: width and height are the source image dimensions; angle must be one +// of [0, 90, 180, 270] or the result is undefined. +void rotate(Halide::Func input, Halide::Func result, Halide::Expr width, + Halide::Expr height, Halide::Expr angle); + +} // namespace common +} // namespace halide +} // namespace frame_buffer +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_FRAME_BUFFER_HALIDE_COMMON_H_ diff --git a/mediapipe/util/frame_buffer/halide/gray_flip_generator.cc b/mediapipe/util/frame_buffer/halide/gray_flip_generator.cc new file mode 100644 index 000000000..528489c84 --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/gray_flip_generator.cc @@ -0,0 +1,61 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" + +namespace { + +using ::Halide::_; + +class GrayFlip : public Halide::Generator { + public: + Var x{"x"}, y{"y"}; + + // Input because that allows us to apply constraints on stride, etc. + Input> src_y{"src_y"}; + + // Flip vertically if true; flips horizontally (mirroring) otherwise. + Input flip_vertical{"flip_vertical", false}; + + Output dst_y{"dst_y", UInt(8), 2}; + + void generate(); + void schedule(); + + private: + void flip(Func input, Func result, Expr width, Expr height, Expr vertical); +}; + +void GrayFlip::generate() { + Halide::Func flip_x, flip_y; + flip_x(x, y, _) = src_y(src_y.dim(0).extent() - x - 1, y, _); + flip_y(x, y, _) = src_y(x, src_y.dim(1).extent() - y - 1, _); + + dst_y(x, y, _) = select(flip_vertical, flip_y(x, y, _), flip_x(x, y, _)); +} + +void GrayFlip::schedule() { + Halide::Func dst_y_func = dst_y; + + // Y plane dimensions start at zero and destination bounds must match. + Halide::OutputImageParam dst_y_output = dst_y_func.output_buffer(); + src_y.dim(0).set_min(0); + src_y.dim(1).set_min(0); + dst_y_output.dim(0).set_bounds(0, src_y.dim(0).extent()); + dst_y_output.dim(1).set_bounds(0, src_y.dim(1).extent()); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(GrayFlip, gray_flip_generator) diff --git a/mediapipe/util/frame_buffer/halide/gray_resize_generator.cc b/mediapipe/util/frame_buffer/halide/gray_resize_generator.cc new file mode 100644 index 000000000..adda9c8d5 --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/gray_resize_generator.cc @@ -0,0 +1,60 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" +#include "mediapipe/util/frame_buffer/halide/common.h" + +namespace { + +using ::Halide::BoundaryConditions::repeat_edge; +using ::mediapipe::frame_buffer::halide::common::resize_bilinear_int; + +class GrayResize : public Halide::Generator { + public: + Var x{"x"}, y{"y"}; + + Input> src_y{"src_y"}; + Input scale_x{"scale_x", 1.0f, 0.0f, 1024.0f}; + Input scale_y{"scale_y", 1.0f, 0.0f, 1024.0f}; + + Output dst_y{"dst_y", UInt(8), 2}; + + void generate(); + void schedule(); +}; + +void GrayResize::generate() { + resize_bilinear_int(repeat_edge(src_y), dst_y, scale_x, scale_y); +} + +void GrayResize::schedule() { + // Grayscale image dimensions start at zero. + Halide::Func dst_y_func = dst_y; + Halide::OutputImageParam dst_y_output = dst_y_func.output_buffer(); + src_y.dim(0).set_min(0); + src_y.dim(1).set_min(0); + dst_y_output.dim(0).set_min(0); + dst_y_output.dim(1).set_min(0); + + // We must ensure that the image is wide enough to support vector + // operations. + const int vector_size = natural_vector_size(); + Halide::Expr min_y_width = + Halide::min(src_y.dim(0).extent(), dst_y_output.dim(0).extent()); + dst_y_func.specialize(min_y_width >= vector_size).vectorize(x, vector_size); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(GrayResize, gray_resize_generator) diff --git a/mediapipe/util/frame_buffer/halide/gray_rotate_generator.cc b/mediapipe/util/frame_buffer/halide/gray_rotate_generator.cc new file mode 100644 index 000000000..825741a5f --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/gray_rotate_generator.cc @@ -0,0 +1,63 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" +#include "mediapipe/util/frame_buffer/halide/common.h" + +namespace { + +using ::mediapipe::frame_buffer::halide::common::rotate; + +class GrayRotate : public Halide::Generator { + public: + Var x{"x"}, y{"y"}; + + // Input because that allows us to apply constraints on stride, etc. + Input> src_y{"src_y"}; + + // Rotation angle in degrees counter-clockwise. Must be in {0, 90, 180, 270}. + Input rotation_angle{"rotation_angle", 0}; + + Output dst_y{"dst_y", UInt(8), 2}; + + void generate(); + void schedule(); +}; + +void GrayRotate::generate() { + const Halide::Expr width = src_y.dim(0).extent(); + const Halide::Expr height = src_y.dim(1).extent(); + + rotate(src_y, dst_y, width, height, rotation_angle); +} + +void GrayRotate::schedule() { + Halide::Func dst_y_func = dst_y; + dst_y_func.specialize(rotation_angle == 0).reorder(x, y); + dst_y_func.specialize(rotation_angle == 90).reorder(y, x); + dst_y_func.specialize(rotation_angle == 180).reorder(x, y); + dst_y_func.specialize(rotation_angle == 270).reorder(y, x); + + // Y plane dimensions start at zero. We could additionally constrain the + // extent to be even, but that doesn't seem to have any benefit. + Halide::OutputImageParam dst_y_output = dst_y_func.output_buffer(); + src_y.dim(0).set_min(0); + src_y.dim(1).set_min(0); + dst_y_output.dim(0).set_min(0); + dst_y_output.dim(1).set_min(0); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(GrayRotate, gray_rotate_generator) diff --git a/mediapipe/util/frame_buffer/halide/rgb_flip_generator.cc b/mediapipe/util/frame_buffer/halide/rgb_flip_generator.cc new file mode 100644 index 000000000..2b2723a68 --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/rgb_flip_generator.cc @@ -0,0 +1,84 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" + +namespace { + +using ::Halide::_; + +class RgbFlip : public Halide::Generator { + public: + Var x{"x"}, y{"y"}; + + // Input because that allows us to apply constraints on stride, etc. + Input> src_rgb{"src_rgb"}; + // Flip vertically if true; flips horizontally (mirroring) otherwise. + Input flip_vertical{"flip_vertical", false}; + + Output dst_rgb{"dst_rgb", UInt(8), 3}; + + void generate(); + void schedule(); + + private: + void flip(Func input, Func result, Expr width, Expr height, Expr vertical); +}; + +void RgbFlip::flip(Halide::Func input, Halide::Func result, Halide::Expr width, + Halide::Expr height, Halide::Expr vertical) { + Halide::Func flip_x, flip_y; + flip_x(x, y, _) = input(width - x - 1, y, _); + flip_y(x, y, _) = input(x, height - y - 1, _); + + result(x, y, _) = select(vertical, flip_y(x, y, _), flip_x(x, y, _)); +} + +void RgbFlip::generate() { + const Halide::Expr width = src_rgb.dim(0).extent(); + const Halide::Expr height = src_rgb.dim(1).extent(); + + // Flip each of the RGB planes independently. + flip(src_rgb, dst_rgb, width, height, flip_vertical); +} + +void RgbFlip::schedule() { + Halide::Func dst_rgb_func = dst_rgb; + Halide::Var c = dst_rgb_func.args()[2]; + Halide::OutputImageParam rgb_output = dst_rgb_func.output_buffer(); + + // Iterate over channel in the innermost loop, then x, then y. + dst_rgb_func.reorder(c, x, y); + + // RGB planes starts at index zero in every dimension and destination bounds + // must match. + src_rgb.dim(0).set_min(0); + src_rgb.dim(1).set_min(0); + src_rgb.dim(2).set_min(0); + rgb_output.dim(0).set_bounds(0, src_rgb.dim(0).extent()); + rgb_output.dim(1).set_bounds(0, src_rgb.dim(1).extent()); + rgb_output.dim(2).set_bounds(0, src_rgb.dim(2).extent()); + + // Require that the input/output buffer be interleaved and tightly- + // packed; that is, either RGBRGBRGB[...] or RGBARGBARGBA[...], + // without gaps between pixels. + src_rgb.dim(0).set_stride(src_rgb.dim(2).extent()); + src_rgb.dim(2).set_stride(1); + rgb_output.dim(0).set_stride(rgb_output.dim(2).extent()); + rgb_output.dim(2).set_stride(1); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(RgbFlip, rgb_flip_generator) diff --git a/mediapipe/util/frame_buffer/halide/rgb_gray_generator.cc b/mediapipe/util/frame_buffer/halide/rgb_gray_generator.cc new file mode 100644 index 000000000..1df96540a --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/rgb_gray_generator.cc @@ -0,0 +1,65 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" + +namespace { + +class RgbGray : public Halide::Generator { + public: + Var x{"x"}, y{"y"}, c{"c"}; + + Input> src_rgb{"rgb"}; + Output> convert{"convert"}; + + void generate(); + void schedule(); +}; + +// Integer math versions of the full-range JFIF RGB-Y coefficients. +// Y = 0.2990*R + 0.5870*G + 0.1140*B +// See https://www.w3.org/Graphics/JPEG/jfif3.pdf. These coefficients are +// similar to, but not identical, to those used in Android. +Halide::Expr rgby(Halide::Expr r, Halide::Expr g, Halide::Expr b) { + r = Halide::cast(r); + g = Halide::cast(g); + b = Halide::cast(b); + return (19595 * r + 38470 * g + 7474 * b + 32768) >> 16; +} + +void RgbGray::generate() { + Halide::Func gray("gray"); + gray(x, y) = rgby(src_rgb(x, y, 0), src_rgb(x, y, 1), src_rgb(x, y, 2)); + convert(x, y) = Halide::saturating_cast(gray(x, y)); +} + +void RgbGray::schedule() { + // RGB images starts at index zero in every dimension. + src_rgb.dim(0).set_min(0); + src_rgb.dim(1).set_min(0); + src_rgb.dim(2).set_min(0); + + // Require that the input buffer be interleaved and tightly-packed; + // with no gaps between pixels. + src_rgb.dim(0).set_stride(src_rgb.dim(2).extent()); + src_rgb.dim(2).set_stride(1); + + // Grayscale images starts at index zero in every dimension. + convert.dim(0).set_min(0); + convert.dim(1).set_min(0); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(RgbGray, rgb_gray_generator) diff --git a/mediapipe/util/frame_buffer/halide/rgb_resize_generator.cc b/mediapipe/util/frame_buffer/halide/rgb_resize_generator.cc new file mode 100644 index 000000000..469120ec3 --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/rgb_resize_generator.cc @@ -0,0 +1,85 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" +#include "mediapipe/util/frame_buffer/halide/common.h" + +namespace { + +using ::Halide::BoundaryConditions::repeat_edge; +using ::mediapipe::frame_buffer::halide::common::resize_bilinear_int; + +class RgbResize : public Halide::Generator { + public: + Var x{"x"}, y{"y"}; + + Input> src_rgb{"src_rgb"}; + Input scale_x{"scale_x", 1.0f, 0.0f, 1024.0f}; + Input scale_y{"scale_y", 1.0f, 0.0f, 1024.0f}; + + Output dst_rgb{"dst_rgb", UInt(8), 3}; + + void generate(); + void schedule(); +}; + +void RgbResize::generate() { + // Resize each of the RGB planes independently. + resize_bilinear_int(repeat_edge(src_rgb), dst_rgb, scale_x, scale_y); +} + +void RgbResize::schedule() { + Halide::Func dst_rgb_func = dst_rgb; + Halide::Var c = dst_rgb_func.args()[2]; + Halide::OutputImageParam rgb_output = dst_rgb_func.output_buffer(); + Halide::Expr input_rgb_channels = src_rgb.dim(2).extent(); + Halide::Expr output_rgb_channels = rgb_output.dim(2).extent(); + Halide::Expr min_width = + Halide::min(src_rgb.dim(0).extent(), rgb_output.dim(0).extent()); + + // Specialize the generated code for RGB and RGBA (input and output channels + // must match); further, specialize the vectorized implementation so it only + // runs on images wide enough to support it. + const int vector_size = natural_vector_size(); + const Expr channel_specializations[] = { + input_rgb_channels == 3 && output_rgb_channels == 3, + input_rgb_channels == 4 && output_rgb_channels == 4, + }; + dst_rgb_func.reorder(c, x, y); + for (const Expr& channel_specialization : channel_specializations) { + dst_rgb_func.specialize(channel_specialization && min_width >= vector_size) + .unroll(c) + .vectorize(x, vector_size); + } + + // Require that the input/output buffer be interleaved and tightly- + // packed; that is, either RGBRGBRGB[...] or RGBARGBARGBA[...], + // without gaps between pixels. + src_rgb.dim(0).set_stride(input_rgb_channels); + src_rgb.dim(2).set_stride(1); + rgb_output.dim(0).set_stride(output_rgb_channels); + rgb_output.dim(2).set_stride(1); + + // RGB planes starts at index zero in every dimension. + src_rgb.dim(0).set_min(0); + src_rgb.dim(1).set_min(0); + src_rgb.dim(2).set_min(0); + rgb_output.dim(0).set_min(0); + rgb_output.dim(1).set_min(0); + rgb_output.dim(2).set_min(0); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(RgbResize, rgb_resize_generator) diff --git a/mediapipe/util/frame_buffer/halide/rgb_rgb_generator.cc b/mediapipe/util/frame_buffer/halide/rgb_rgb_generator.cc new file mode 100644 index 000000000..99a016896 --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/rgb_rgb_generator.cc @@ -0,0 +1,64 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "Halide.h" + +namespace { + +// Convert rgb_buffer between 3 and 4 channels. When converting from 3 channels +// to 4 channels, the alpha value is always 255. +class RgbRgb : public Halide::Generator { + public: + Var x{"x"}, y{"y"}, c{"c"}; + + Input> src_rgb{"src_rgb"}; + Output> dst_rgb{"dst_rgb"}; + + void generate(); + void schedule(); +}; + +void RgbRgb::generate() { + // We use Halide::clamp to avoid evaluating src_rgb(x, y, c) when c == 3 and + // the src_rgb only has c <= 2 (rgb -> rgba conversion case). + dst_rgb(x, y, c) = + Halide::select(c == 3, 255, src_rgb(x, y, Halide::clamp(c, 0, 2))); +} + +void RgbRgb::schedule() { + Halide::Expr input_rgb_channels = src_rgb.dim(2).extent(); + Halide::Expr output_rgb_channels = dst_rgb.dim(2).extent(); + + // The source buffer starts at zero in every dimension and requires an + // interleaved format. + src_rgb.dim(0).set_min(0); + src_rgb.dim(1).set_min(0); + src_rgb.dim(2).set_min(0); + src_rgb.dim(0).set_stride(input_rgb_channels); + src_rgb.dim(2).set_stride(1); + + // The destination buffer starts at zero in every dimension and requires an + // interleaved format. + dst_rgb.dim(0).set_min(0); + dst_rgb.dim(1).set_min(0); + dst_rgb.dim(2).set_min(0); + dst_rgb.dim(0).set_stride(output_rgb_channels); + dst_rgb.dim(2).set_stride(1); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(RgbRgb, rgb_rgb_generator) diff --git a/mediapipe/util/frame_buffer/halide/rgb_rotate_generator.cc b/mediapipe/util/frame_buffer/halide/rgb_rotate_generator.cc new file mode 100644 index 000000000..aa2bb24ec --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/rgb_rotate_generator.cc @@ -0,0 +1,76 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" +#include "mediapipe/util/frame_buffer/halide/common.h" + +namespace { + +using ::mediapipe::frame_buffer::halide::common::rotate; + +class RgbRotate : public Halide::Generator { + public: + Var x{"x"}, y{"y"}; + + // Input because that allows us to apply constraints on stride, etc. + Input> src_rgb{"src_rgb"}; + // Rotation angle in degrees counter-clockwise. Must be in {0, 90, 180, 270}. + Input rotation_angle{"rotation_angle", 0}; + + Output dst_rgb{"dst_rgb", UInt(8), 3}; + + void generate(); + void schedule(); +}; + +void RgbRotate::generate() { + const Halide::Expr width = src_rgb.dim(0).extent(); + const Halide::Expr height = src_rgb.dim(1).extent(); + + // Rotate each of the RGB planes independently. + rotate(src_rgb, dst_rgb, width, height, rotation_angle); +} + +void RgbRotate::schedule() { + // TODO: Remove specialization for (angle == 0) since that is + // a no-op and callers should simply skip rotation. Doing so would cause + // a bounds assertion crash if called with angle=0, however. + Halide::Func dst_rgb_func = dst_rgb; + Halide::Var c = dst_rgb_func.args()[2]; + Halide::OutputImageParam rgb_output = dst_rgb_func.output_buffer(); + dst_rgb_func.specialize(rotation_angle == 0).reorder(c, x, y); + dst_rgb_func.specialize(rotation_angle == 90).reorder(c, y, x); + dst_rgb_func.specialize(rotation_angle == 180).reorder(c, x, y); + dst_rgb_func.specialize(rotation_angle == 270).reorder(c, y, x); + + // RGB planes starts at index zero in every dimension. + src_rgb.dim(0).set_min(0); + src_rgb.dim(1).set_min(0); + src_rgb.dim(2).set_min(0); + rgb_output.dim(0).set_min(0); + rgb_output.dim(1).set_min(0); + rgb_output.dim(2).set_min(0); + + // Require that the input/output buffer be interleaved and tightly- + // packed; that is, either RGBRGBRGB[...] or RGBARGBARGBA[...], + // without gaps between pixels. + src_rgb.dim(0).set_stride(src_rgb.dim(2).extent()); + src_rgb.dim(2).set_stride(1); + rgb_output.dim(0).set_stride(rgb_output.dim(2).extent()); + rgb_output.dim(2).set_stride(1); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(RgbRotate, rgb_rotate_generator) diff --git a/mediapipe/util/frame_buffer/halide/rgb_yuv_generator.cc b/mediapipe/util/frame_buffer/halide/rgb_yuv_generator.cc new file mode 100644 index 000000000..283dcec8a --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/rgb_yuv_generator.cc @@ -0,0 +1,101 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" + +namespace { + +class RgbYuv : public Halide::Generator { + public: + Var x{"x"}, y{"y"}, c{"c"}; + + // Input because that allows us to apply constraints on stride, etc. + Input> src_rgb{"rgb"}; + + Output dst_y{"dst_y", UInt(8), 2}; + Output dst_uv{"dst_uv", UInt(8), 3}; + + void generate(); + void schedule(); +}; + +// Integer math versions of the full-range JFIF RGB-YUV coefficients. +// Y = 0.2990*R + 0.5870*G + 0.1140*B +// U = -0.1687*R - 0.3313*G + 0.5000*B + 128 +// V = 0.5000*R - 0.4187*G - 0.0813*B + 128 +// See https://www.w3.org/Graphics/JPEG/jfif3.pdf. These coefficients are +// similar to, but not identical, to those used in Android. +Halide::Tuple rgbyuv(Halide::Expr r, Halide::Expr g, Halide::Expr b) { + r = Halide::cast(r); + g = Halide::cast(g); + b = Halide::cast(b); + return { + (19595 * r + 38470 * g + 7474 * b + 32768) >> 16, + ((-11056 * r - 21712 * g + 32768 * b + 32768) >> 16) + 128, + ((32768 * r - 27440 * g - 5328 * b + 32768) >> 16) + 128, + }; +} + +void RgbYuv::generate() { + Halide::Func yuv_tuple("yuv_tuple"); + yuv_tuple(x, y) = + rgbyuv(src_rgb(x, y, 0), src_rgb(x, y, 1), src_rgb(x, y, 2)); + + // Y values are copied one-for-one; UV values are sampled 1/4. + // TODO: Take the average UV values across the 2x2 block. + dst_y(x, y) = Halide::saturating_cast(yuv_tuple(x, y)[0]); + dst_uv(x, y, c) = Halide::saturating_cast(Halide::select( + c == 0, yuv_tuple(x * 2, y * 2)[2], yuv_tuple(x * 2, y * 2)[1])); + // NOTE: uv channel indices above assume NV21; this can be abstracted out + // by twiddling strides in calling code. +} + +void RgbYuv::schedule() { + // RGB images starts at index zero in every dimension. + src_rgb.dim(0).set_min(0); + src_rgb.dim(1).set_min(0); + src_rgb.dim(2).set_min(0); + + // Require that the input buffer be interleaved and tightly-packed; + // that is, either RGBRGBRGB[...] or RGBARGBARGBA[...], without gaps + // between pixels. + src_rgb.dim(0).set_stride(src_rgb.dim(2).extent()); + src_rgb.dim(2).set_stride(1); + + // Y plane dimensions start at zero. We could additionally constrain the + // extent to be even, but that doesn't seem to have any benefit. + Halide::Func dst_y_func = dst_y; + Halide::OutputImageParam dst_y_output = dst_y_func.output_buffer(); + dst_y_output.dim(0).set_min(0); + dst_y_output.dim(1).set_min(0); + + // UV plane has two channels and is half the size of the Y plane in X/Y. + Halide::Func dst_uv_func = dst_uv; + Halide::OutputImageParam dst_uv_output = dst_uv_func.output_buffer(); + dst_uv_output.dim(0).set_bounds(0, (dst_y_output.dim(0).extent() + 1) / 2); + dst_uv_output.dim(1).set_bounds(0, (dst_y_output.dim(1).extent() + 1) / 2); + dst_uv_output.dim(2).set_bounds(0, 2); + + // UV channel processing should be loop unrolled. + dst_uv_func.reorder(c, x, y); + dst_uv_func.unroll(c); + + // Remove default memory layout constraints and accept/produce generic UV + // (including semi-planar and planar). + dst_uv_output.dim(0).set_stride(Expr()); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(RgbYuv, rgb_yuv_generator) diff --git a/mediapipe/util/frame_buffer/halide/yuv_flip_generator.cc b/mediapipe/util/frame_buffer/halide/yuv_flip_generator.cc new file mode 100644 index 000000000..83080f3d7 --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/yuv_flip_generator.cc @@ -0,0 +1,90 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" + +namespace { + +using ::Halide::_; + +class YuvFlip : public Halide::Generator { + public: + Var x{"x"}, y{"y"}; + + // Input because that allows us to apply constraints on stride, etc. + Input> src_y{"src_y"}; + Input> src_uv{"src_uv"}; + // Flip vertically if true; flips horizontally (mirroring) otherwise. + Input flip_vertical{"flip_vertical", false}; + + Output dst_y{"dst_y", UInt(8), 2}; + Output dst_uv{"dst_uv", UInt(8), 3}; + + void generate(); + void schedule(); + + private: + void flip(Func input, Func result, Expr width, Expr height, Expr vertical); +}; + +void YuvFlip::flip(Halide::Func input, Halide::Func result, Halide::Expr width, + Halide::Expr height, Halide::Expr vertical) { + Halide::Func flip_x, flip_y; + flip_x(x, y, _) = input(width - x - 1, y, _); + flip_y(x, y, _) = input(x, height - y - 1, _); + + result(x, y, _) = select(vertical, flip_y(x, y, _), flip_x(x, y, _)); +} + +void YuvFlip::generate() { + const Halide::Expr width = src_y.dim(0).extent(); + const Halide::Expr height = src_y.dim(1).extent(); + + // Flip each of the YUV planes independently. + flip(src_y, dst_y, width, height, flip_vertical); + flip(src_uv, dst_uv, (width + 1) / 2, (height + 1) / 2, flip_vertical); +} + +void YuvFlip::schedule() { + Halide::Func dst_y_func = dst_y; + Halide::Func dst_uv_func = dst_uv; + Halide::Var c = dst_uv_func.args()[2]; + dst_uv_func.unroll(c); + dst_uv_func.reorder(c, x, y); + + // Y plane dimensions start at zero and destination bounds must match. + Halide::OutputImageParam dst_y_output = dst_y_func.output_buffer(); + src_y.dim(0).set_min(0); + src_y.dim(1).set_min(0); + dst_y_output.dim(0).set_bounds(0, src_y.dim(0).extent()); + dst_y_output.dim(1).set_bounds(0, src_y.dim(1).extent()); + + // UV plane has two channels and is half the size of the Y plane in X/Y. + Halide::OutputImageParam dst_uv_output = dst_uv_func.output_buffer(); + src_uv.dim(0).set_bounds(0, (src_y.dim(0).extent() + 1) / 2); + src_uv.dim(1).set_bounds(0, (src_y.dim(1).extent() + 1) / 2); + src_uv.dim(2).set_bounds(0, 2); + dst_uv_output.dim(0).set_bounds(0, (dst_y_output.dim(0).extent() + 1) / 2); + dst_uv_output.dim(1).set_bounds(0, (dst_y_output.dim(1).extent() + 1) / 2); + dst_uv_output.dim(2).set_bounds(0, 2); + + // Remove default memory layout constraints and accept/produce generic UV + // (including semi-planar and planar). + src_uv.dim(0).set_stride(Expr()); + dst_uv_output.dim(0).set_stride(Expr()); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(YuvFlip, yuv_flip_generator) diff --git a/mediapipe/util/frame_buffer/halide/yuv_resize_generator.cc b/mediapipe/util/frame_buffer/halide/yuv_resize_generator.cc new file mode 100644 index 000000000..805877fca --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/yuv_resize_generator.cc @@ -0,0 +1,91 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" +#include "mediapipe/util/frame_buffer/halide/common.h" + +namespace { + +using ::Halide::BoundaryConditions::repeat_edge; +using ::mediapipe::frame_buffer::halide::common::is_interleaved; +using ::mediapipe::frame_buffer::halide::common::is_planar; +using ::mediapipe::frame_buffer::halide::common::resize_bilinear_int; + +class YuvResize : public Halide::Generator { + public: + Var x{"x"}, y{"y"}; + + Input> src_y{"src_y"}; + Input> src_uv{"src_uv"}; + Input scale_x{"scale_x", 1.0f, 0.0f, 1024.0f}; + Input scale_y{"scale_y", 1.0f, 0.0f, 1024.0f}; + + Output dst_y{"dst_y", UInt(8), 2}; + Output dst_uv{"dst_uv", UInt(8), 3}; + + void generate(); + void schedule(); +}; + +void YuvResize::generate() { + // Resize each of the YUV planes independently. + resize_bilinear_int(repeat_edge(src_y), dst_y, scale_x, scale_y); + resize_bilinear_int(repeat_edge(src_uv), dst_uv, scale_x, scale_y); +} + +void YuvResize::schedule() { + // Y plane dimensions start at zero. We could additionally constrain the + // extent to be even, but that doesn't seem to have any benefit. + Halide::Func dst_y_func = dst_y; + Halide::OutputImageParam dst_y_output = dst_y_func.output_buffer(); + src_y.dim(0).set_min(0); + src_y.dim(1).set_min(0); + dst_y_output.dim(0).set_min(0); + dst_y_output.dim(1).set_min(0); + + // UV plane has two channels and is half the size of the Y plane in X/Y. + Halide::Func dst_uv_func = dst_uv; + Halide::OutputImageParam dst_uv_output = dst_uv_func.output_buffer(); + src_uv.dim(0).set_bounds(0, (src_y.dim(0).extent() + 1) / 2); + src_uv.dim(1).set_bounds(0, (src_y.dim(1).extent() + 1) / 2); + src_uv.dim(2).set_bounds(0, 2); + dst_uv_output.dim(0).set_bounds(0, (dst_y_output.dim(0).extent() + 1) / 2); + dst_uv_output.dim(1).set_bounds(0, (dst_y_output.dim(1).extent() + 1) / 2); + dst_uv_output.dim(2).set_bounds(0, 2); + + // With bilinear filtering enabled, Y plane resize is profitably vectorizable + // though we must ensure that the image is wide enough to support vector + // operations. + const int vector_size = natural_vector_size(); + Halide::Expr min_y_width = + Halide::min(src_y.dim(0).extent(), dst_y_output.dim(0).extent()); + dst_y_func.specialize(min_y_width >= vector_size).vectorize(x, vector_size); + + // Remove default memory layout constraints and generate specialized + // fast-path implementations when both UV source and output are either + // planar or interleaved. Everything else falls onto a slow path. + src_uv.dim(0).set_stride(Expr()); + dst_uv_output.dim(0).set_stride(Expr()); + + Halide::Var c = dst_uv_func.args()[2]; + dst_uv_func + .specialize(is_interleaved(src_uv) && is_interleaved(dst_uv_output)) + .reorder(c, x, y) + .unroll(c); + dst_uv_func.specialize(is_planar(src_uv) && is_planar(dst_uv_output)); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(YuvResize, yuv_resize_generator) diff --git a/mediapipe/util/frame_buffer/halide/yuv_rgb_generator.cc b/mediapipe/util/frame_buffer/halide/yuv_rgb_generator.cc new file mode 100644 index 000000000..916ea290a --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/yuv_rgb_generator.cc @@ -0,0 +1,110 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" + +namespace { + +class YuvRgb : public Halide::Generator { + public: + Var x{"x"}, y{"y"}, c{"c"}; + + // Input because that allows us to apply constraints on stride, etc. + Input> src_y{"src_y"}; + Input> src_uv{"src_uv"}; + Input halve{"halve", false}; + + Output rgb{"rgb", UInt(8), 3}; + + void generate(); + void schedule(); +}; + +Halide::Expr demux(Halide::Expr c, Halide::Tuple values) { + return select(c == 0, values[0], c == 1, values[1], c == 2, values[2], 255); +} + +// Integer math versions of the full-range JFIF YUV-RGB coefficients. +// R = Y' + 1.40200*(V-128) +// G = Y' - 0.34414*(U-128) - 0.71414*(V-128) +// B = Y' + 1.77200*(U-128) +// See https://www.w3.org/Graphics/JPEG/jfif3.pdf. These coefficients are +// similar to, but not identical, to those used in Android. +Halide::Tuple yuvrgb(Halide::Expr y, Halide::Expr u, Halide::Expr v) { + y = Halide::cast(y); + u = Halide::cast(u) - 128; + v = Halide::cast(v) - 128; + return { + y + ((91881 * v + 32768) >> 16), + y - ((22544 * u + 46802 * v + 32768) >> 16), + y + ((116130 * u + 32768) >> 16), + }; +} + +void YuvRgb::generate() { + // Each 2x2 block of Y pixels shares the same UV values, so UV-coordinates + // advance half as slowly as Y-coordinates. When taking advantage of the + // "free" 2x downsampling, use every UV value but skip every other Y. + Halide::Expr yx = select(halve, 2 * x, x), yy = select(halve, 2 * y, y); + Halide::Expr uvx = select(halve, x, x / 2), uvy = select(halve, y, y / 2); + + rgb(x, y, c) = Halide::saturating_cast(demux( + c, yuvrgb(src_y(yx, yy), src_uv(uvx, uvy, 1), src_uv(uvx, uvy, 0)))); + // NOTE: uv channel indices above assume NV21; this can be abstracted out + // by twiddling strides in calling code. +} + +void YuvRgb::schedule() { + // Y plane dimensions start at zero. We could additionally constrain the + // extent to be even, but that doesn't seem to have any benefit. + src_y.dim(0).set_min(0); + src_y.dim(1).set_min(0); + + // UV plane has two channels and is half the size of the Y plane in X/Y. + src_uv.dim(0).set_bounds(0, (src_y.dim(0).extent() + 1) / 2); + src_uv.dim(1).set_bounds(0, (src_y.dim(1).extent() + 1) / 2); + src_uv.dim(2).set_bounds(0, 2); + + // Remove default memory layout constraints on the UV source so that we + // accept generic UV (including semi-planar and planar). + // + // TODO: Investigate whether it's worth specializing the cross- + // product of [semi-]planar and RGB/RGBA; this would result in 9 codepaths. + src_uv.dim(0).set_stride(Expr()); + + Halide::Func rgb_func = rgb; + Halide::OutputImageParam rgb_output = rgb_func.output_buffer(); + Halide::Expr rgb_channels = rgb_output.dim(2).extent(); + + // Specialize the generated code for RGB and RGBA. + const int vector_size = natural_vector_size(); + rgb_func.reorder(c, x, y); + rgb_func.specialize(rgb_channels == 3).unroll(c).vectorize(x, vector_size); + rgb_func.specialize(rgb_channels == 4).unroll(c).vectorize(x, vector_size); + + // Require that the output buffer be interleaved and tightly-packed; + // that is, either RGBRGBRGB[...] or RGBARGBARGBA[...], without gaps + // between pixels. + rgb_output.dim(0).set_stride(rgb_channels); + rgb_output.dim(2).set_stride(1); + + // RGB output starts at index zero in every dimension. + rgb_output.dim(0).set_min(0); + rgb_output.dim(1).set_min(0); + rgb_output.dim(2).set_min(0); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(YuvRgb, yuv_rgb_generator) diff --git a/mediapipe/util/frame_buffer/halide/yuv_rotate_generator.cc b/mediapipe/util/frame_buffer/halide/yuv_rotate_generator.cc new file mode 100644 index 000000000..46f654a23 --- /dev/null +++ b/mediapipe/util/frame_buffer/halide/yuv_rotate_generator.cc @@ -0,0 +1,91 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Halide.h" +#include "mediapipe/util/frame_buffer/halide/common.h" + +namespace { + +using ::mediapipe::frame_buffer::halide::common::rotate; + +class YuvRotate : public Halide::Generator { + public: + Var x{"x"}, y{"y"}; + + // Input because that allows us to apply constraints on stride, etc. + Input> src_y{"src_y"}; + Input> src_uv{"src_uv"}; + // Rotation angle in degrees counter-clockwise. Must be in {0, 90, 180, 270}. + Input rotation_angle{"rotation_angle", 0}; + + Output dst_y{"dst_y", UInt(8), 2}; + Output dst_uv{"dst_uv", UInt(8), 3}; + + void generate(); + void schedule(); +}; + +void YuvRotate::generate() { + const Halide::Expr width = src_y.dim(0).extent(); + const Halide::Expr height = src_y.dim(1).extent(); + + // Rotate each of the YUV planes independently. + rotate(src_y, dst_y, width, height, rotation_angle); + rotate(src_uv, dst_uv, (width + 1) / 2, (height + 1) / 2, rotation_angle); +} + +void YuvRotate::schedule() { + // TODO: Remove specialization for (angle == 0) since that is + // a no-op and callers should simply skip rotation. Doing so would cause + // a bounds assertion crash if called with angle=0, however. + Halide::Func dst_y_func = dst_y; + dst_y_func.specialize(rotation_angle == 0).reorder(x, y); + dst_y_func.specialize(rotation_angle == 90).reorder(y, x); + dst_y_func.specialize(rotation_angle == 180).reorder(x, y); + dst_y_func.specialize(rotation_angle == 270).reorder(y, x); + + Halide::Func dst_uv_func = dst_uv; + Halide::Var c = dst_uv_func.args()[2]; + dst_uv_func.unroll(c); + dst_uv_func.specialize(rotation_angle == 0).reorder(c, x, y); + dst_uv_func.specialize(rotation_angle == 90).reorder(c, y, x); + dst_uv_func.specialize(rotation_angle == 180).reorder(c, x, y); + dst_uv_func.specialize(rotation_angle == 270).reorder(c, y, x); + + // Y plane dimensions start at zero. We could additionally constrain the + // extent to be even, but that doesn't seem to have any benefit. + Halide::OutputImageParam dst_y_output = dst_y_func.output_buffer(); + src_y.dim(0).set_min(0); + src_y.dim(1).set_min(0); + dst_y_output.dim(0).set_min(0); + dst_y_output.dim(1).set_min(0); + + // UV plane has two channels and is half the size of the Y plane in X/Y. + Halide::OutputImageParam dst_uv_output = dst_uv_func.output_buffer(); + src_uv.dim(0).set_bounds(0, (src_y.dim(0).extent() + 1) / 2); + src_uv.dim(1).set_bounds(0, (src_y.dim(1).extent() + 1) / 2); + src_uv.dim(2).set_bounds(0, 2); + dst_uv_output.dim(0).set_bounds(0, (dst_y_output.dim(0).extent() + 1) / 2); + dst_uv_output.dim(1).set_bounds(0, (dst_y_output.dim(1).extent() + 1) / 2); + dst_uv_output.dim(2).set_bounds(0, 2); + + // Remove default memory layout constraints and accept/produce generic UV + // (including semi-planar and planar). + src_uv.dim(0).set_stride(Expr()); + dst_uv_output.dim(0).set_stride(Expr()); +} + +} // namespace + +HALIDE_REGISTER_GENERATOR(YuvRotate, yuv_rotate_generator) diff --git a/mediapipe/util/frame_buffer/rgb_buffer.cc b/mediapipe/util/frame_buffer/rgb_buffer.cc new file mode 100644 index 000000000..9ae849eab --- /dev/null +++ b/mediapipe/util/frame_buffer/rgb_buffer.cc @@ -0,0 +1,132 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/frame_buffer/rgb_buffer.h" + +#include + +#include "mediapipe/util/frame_buffer/buffer_common.h" +#include "mediapipe/util/frame_buffer/gray_buffer.h" +#include "mediapipe/util/frame_buffer/halide/rgb_flip_halide.h" +#include "mediapipe/util/frame_buffer/halide/rgb_gray_halide.h" +#include "mediapipe/util/frame_buffer/halide/rgb_resize_halide.h" +#include "mediapipe/util/frame_buffer/halide/rgb_rgb_halide.h" +#include "mediapipe/util/frame_buffer/halide/rgb_rotate_halide.h" +#include "mediapipe/util/frame_buffer/halide/rgb_yuv_halide.h" +#include "mediapipe/util/frame_buffer/yuv_buffer.h" + +namespace mediapipe { +namespace frame_buffer { + +RgbBuffer::RgbBuffer(uint8_t* data, int width, int height, bool alpha) + : owned_buffer_(nullptr) { + Initialize(data, width, height, alpha); +} + +RgbBuffer::RgbBuffer(uint8_t* data, int width, int height, int row_stride, + bool alpha) { + const int channels = alpha ? 4 : 3; + const halide_dimension_t dimensions[3] = {{/*m=*/0, width, channels}, + {/*m=*/0, height, row_stride}, + {/*m=*/0, channels, 1}}; + buffer_ = Halide::Runtime::Buffer(data, /*d=*/3, dimensions); +} + +RgbBuffer::RgbBuffer(int width, int height, bool alpha) + : owned_buffer_(new uint8_t[ByteSize(width, height, alpha)]) { + Initialize(owned_buffer_.get(), width, height, alpha); +} + +RgbBuffer::RgbBuffer(const RgbBuffer& other) : buffer_(other.buffer_) { + // Never copy owned_buffer; ownership remains with the source of the copy. +} + +RgbBuffer::RgbBuffer(RgbBuffer&& other) { *this = std::move(other); } + +RgbBuffer& RgbBuffer::operator=(const RgbBuffer& other) { + if (this != &other) { + buffer_ = other.buffer_; + } + return *this; +} +RgbBuffer& RgbBuffer::operator=(RgbBuffer&& other) { + if (this != &other) { + owned_buffer_ = std::move(other.owned_buffer_); + buffer_ = other.buffer_; + } + return *this; +} + +RgbBuffer::~RgbBuffer() {} + +bool RgbBuffer::Crop(int x0, int y0, int x1, int y1) { + // Twiddle the buffer start and extents to crop images. + return common::crop_buffer(x0, y0, x1, y1, buffer()); +} + +bool RgbBuffer::Resize(RgbBuffer* output) { + if (output->channels() > channels()) { + // Fail fast; the Halide implementation would otherwise output garbage + // alpha values (i.e. duplicate the blue channel into alpha). + return false; + } + const int result = rgb_resize_halide( + buffer(), static_cast(width()) / output->width(), + static_cast(height()) / output->height(), output->buffer()); + return result == 0; +} + +bool RgbBuffer::Rotate(int angle, RgbBuffer* output) { + const int result = rgb_rotate_halide(buffer(), angle, output->buffer()); + return result == 0; +} + +bool RgbBuffer::FlipHorizontally(RgbBuffer* output) { + const int result = rgb_flip_halide(buffer(), + false, // horizontal + output->buffer()); + return result == 0; +} + +bool RgbBuffer::FlipVertically(RgbBuffer* output) { + const int result = rgb_flip_halide(buffer(), + true, // vertical + output->buffer()); + return result == 0; +} + +bool RgbBuffer::Convert(YuvBuffer* output) { + const int result = + rgb_yuv_halide(buffer(), output->y_buffer(), output->uv_buffer()); + return result == 0; +} + +bool RgbBuffer::Convert(GrayBuffer* output) { + const int result = rgb_gray_halide(buffer(), output->buffer()); + return result == 0; +} + +bool RgbBuffer::Convert(RgbBuffer* output) { + const int result = rgb_rgb_halide(buffer(), output->buffer()); + return result == 0; +} + +void RgbBuffer::Initialize(uint8_t* data, int width, int height, bool alpha) { + const int channels = alpha ? 4 : 3; + buffer_ = Halide::Runtime::Buffer::make_interleaved( + data, width, height, channels); +} + +} // namespace frame_buffer +} // namespace mediapipe diff --git a/mediapipe/util/frame_buffer/rgb_buffer.h b/mediapipe/util/frame_buffer/rgb_buffer.h new file mode 100644 index 000000000..06423c3f6 --- /dev/null +++ b/mediapipe/util/frame_buffer/rgb_buffer.h @@ -0,0 +1,139 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_FRAME_BUFFER_RGB_BUFFER_H_ +#define MEDIAPIPE_UTIL_FRAME_BUFFER_RGB_BUFFER_H_ + +#include + +#include "HalideBuffer.h" +#include "HalideRuntime.h" +#include "mediapipe/util/frame_buffer/gray_buffer.h" +#include "mediapipe/util/frame_buffer/yuv_buffer.h" + +namespace mediapipe { +namespace frame_buffer { + +// RgbBuffer represents a view over an interleaved RGB/RGBA image. +// +// RgbBuffers may be copied and moved efficiently; their backing buffers are +// shared and never deep copied. +// +// RgbBuffer requires a minimum image width depending on the natural vector +// size of the platform, e.g., 16px. This is not validated by RgbBuffer. +class RgbBuffer { + public: + // Returns the size (in bytes) of an RGB/RGBA image of the given dimensions + // without padding. + static int ByteSize(int width, int height, bool alpha) { + return width * height * (alpha ? 4 : 3); + } + + // Builds a RgbBuffer using the given backing buffer and dimensions. + // + // Does not take ownership of the backing buffer (provided in 'data'). + RgbBuffer(uint8_t* data, int width, int height, bool alpha); + + // Builds a RgbBuffer using the given backing buffer and dimensions. + // 'row_stride' must be greater than or equal to 'width'. Padding bytes are at + // the end of each row, following the image bytes. + // + // Does not take ownership of the backing buffer (provided in 'data'). + RgbBuffer(uint8_t* data, int width, int height, int row_stride, bool alpha); + + // Builds a RgbBuffer using the given dimensions. + // + // The underlying backing buffer is allocated and owned by this RgbBuffer. + RgbBuffer(int width, int height, bool alpha); + + // RgbBuffer is copyable. The source retains ownership of its backing buffer. + RgbBuffer(const RgbBuffer& other); + // RgbBuffer is moveable. The source loses ownership of any backing buffers. + RgbBuffer(RgbBuffer&& other); + // RgbBuffer is assignable. + RgbBuffer& operator=(const RgbBuffer& other); + RgbBuffer& operator=(RgbBuffer&& other); + + ~RgbBuffer(); + + // Performs an in-place crop. Modifies this buffer so that the new extent + // matches that of the given crop rectangle -- (x0, y0) becomes (0, 0) and + // the new width and height are x1 - x0 + 1 and y1 - y0 + 1, respectively. + bool Crop(int x0, int y0, int x1, int y1); + + // Resize this image to match the dimensions of the given output RgbBuffer + // and places the result into its backing buffer. + // + // Performs a resize with bilinear interpolation (over four source pixels). + // Resizing with an RGB source buffer and RGBA destination is currently + // unsupported. + bool Resize(RgbBuffer* output); + + // Rotate this image into the given buffer by the given angle (90, 180, 270). + // + // Rotation is specified in degrees counter-clockwise such that when rotating + // by 90 degrees, the top-right corner of the source becomes the top-left of + // the output. The output buffer must have its height and width swapped when + // rotating by 90 or 270. + // + // Any angle values other than (90, 180, 270) are invalid. + bool Rotate(int angle, RgbBuffer* output); + + // Flip this image horizontally/vertically into the given buffer. Both buffer + // dimensions and formats must match (this method does not convert RGB-to-RGBA + // nor RGBA-to-RGB). + bool FlipHorizontally(RgbBuffer* output); + bool FlipVertically(RgbBuffer* output); + + // Performs a RGB-to-YUV color format conversion and places the result + // in the given output YuvBuffer. Both buffer dimensions must match. + bool Convert(YuvBuffer* output); + + // Performs a RGB to grayscale format conversion. + bool Convert(GrayBuffer* output); + + // Performs a rgb to rgba / rgba to rgb format conversion. + bool Convert(RgbBuffer* output); + + // Release ownership of the owned backing buffer. + uint8_t* Release() { return owned_buffer_.release(); } + + // Returns the halide_buffer_t* for the image. + const halide_buffer_t* buffer() const { return buffer_.raw_buffer(); } + // Returns the halide_buffer_t* for the image. + halide_buffer_t* buffer() { return buffer_.raw_buffer(); } + + // Returns the image width. + const int width() const { return buffer_.dim(0).extent(); } + // Returns the image height. + const int height() const { return buffer_.dim(1).extent(); } + // Returns the number of color channels (3, or 4 if RGBA). + const int channels() const { return buffer_.dim(2).extent(); } + // Returns the image row stride. + const int row_stride() const { return buffer_.dim(1).stride(); } + + private: + void Initialize(uint8_t* data, int width, int height, bool alpha); + + // Non-NULL iff this RgbBuffer owns its backing buffer. + std::unique_ptr owned_buffer_; + + // Backing buffer: layout is always width x height x channel (interleaved). + Halide::Runtime::Buffer buffer_; +}; + +} // namespace frame_buffer +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_FRAME_BUFFER_RGB_BUFFER_H_ diff --git a/mediapipe/util/frame_buffer/rgb_buffer_test.cc b/mediapipe/util/frame_buffer/rgb_buffer_test.cc new file mode 100644 index 000000000..e5cb39c69 --- /dev/null +++ b/mediapipe/util/frame_buffer/rgb_buffer_test.cc @@ -0,0 +1,606 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/frame_buffer/rgb_buffer.h" + +#include + +#include "absl/log/log.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/util/frame_buffer/gray_buffer.h" +#include "mediapipe/util/frame_buffer/yuv_buffer.h" + +// The default implementation of halide_error calls abort(), which we don't +// want. Instead, log the error and let the filter invocation fail. +extern "C" void halide_error(void*, const char* message) { + LOG(ERROR) << "Halide Error: " << message; +} + +namespace mediapipe { +namespace frame_buffer { +namespace { + +// Fill a halide_buffer_t channel with the given value. +void Fill(halide_buffer_t* buffer, int channel, int value) { + for (int y = 0; y < buffer->dim[1].extent; ++y) { + for (int x = 0; x < buffer->dim[0].extent; ++x) { + buffer->host[buffer->dim[1].stride * y + buffer->dim[0].stride * x + + buffer->dim[2].stride * channel] = value; + } + } +} + +// Fill an RgbBuffer with (0, 0, 0). Fills the alpha channel if present. +void Fill(RgbBuffer* buffer) { + for (int c = 0; c < buffer->channels(); ++c) { + Fill(buffer->buffer(), c, 0); + } +} + +// Returns a padded RGB buffer. The metadata are defined as width: 4, height: 2, +// row_stride: 18, channels: 3. +RgbBuffer GetPaddedRgbBuffer() { + static uint8_t rgb_buffer_with_padding[] = { + 10, 20, 30, 20, 30, 40, 30, 40, 50, 40, 50, 60, 0, 0, 0, 0, 0, 0, + 20, 40, 60, 40, 60, 80, 60, 80, 100, 80, 100, 120, 0, 0, 0, 0, 0, 0}; + return RgbBuffer(rgb_buffer_with_padding, + /*width=*/4, /*height=*/2, + /*row_stride=*/18, /*alpha=*/false); +} + +// Returns a padded RGB buffer. The metadata are defined as width: 4, height: 2, +// row_stride: 24, channels: 4. +RgbBuffer GetPaddedRgbaBuffer() { + static uint8_t rgb_buffer_with_padding[] = { + 10, 20, 30, 255, 20, 30, 40, 255, 30, 40, 50, 255, 40, 50, 60, 255, + 0, 0, 0, 0, 0, 0, 0, 0, 20, 40, 60, 255, 40, 60, 80, 255, + 60, 80, 100, 255, 80, 100, 120, 255, 0, 0, 0, 0, 0, 0, 0, 0}; + return RgbBuffer(rgb_buffer_with_padding, + /*width=*/4, /*height=*/2, + /*row_stride=*/24, /*alpha=*/true); +} + +// TODO: Consider move these helper methods into a util class. +// Returns true if the data in the two arrays are the same. Otherwise, return +// false. +bool CompareArray(const uint8_t* lhs_ptr, const uint8_t* rhs_ptr, int width, + int height) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + if (lhs_ptr[i * width + j] != rhs_ptr[i * width + j]) { + return false; + } + } + } + return true; +} + +// Returns true if the halide buffers of two input GrayBuffer are identical. +// Otherwise, returns false; +bool CompareBuffer(const GrayBuffer& lhs, const GrayBuffer& rhs) { + if (lhs.width() != rhs.width() || lhs.height() != rhs.height()) { + return false; + } + const uint8_t* reference_ptr = const_cast(lhs).buffer()->host; + const uint8_t* converted_ptr = const_cast(rhs).buffer()->host; + return CompareArray(reference_ptr, converted_ptr, lhs.width(), lhs.height()); +} + +// Returns true if the halide buffers of two input RgbBuffer are identical. +// Otherwise, returns false; +bool CompareBuffer(const RgbBuffer& lhs, const RgbBuffer& rhs) { + if (lhs.width() != rhs.width() || lhs.height() != rhs.height() || + lhs.row_stride() != rhs.row_stride() || + lhs.channels() != rhs.channels()) { + return false; + } + const uint8_t* reference_ptr = const_cast(lhs).buffer()->host; + const uint8_t* converted_ptr = const_cast(rhs).buffer()->host; + return CompareArray(reference_ptr, converted_ptr, lhs.row_stride(), + lhs.height()); +} + +// Returns true if the halide buffers of two input YuvBuffer are identical. +// Otherwise, returns false; +bool CompareBuffer(const YuvBuffer& lhs, const YuvBuffer& rhs) { + if (lhs.width() != rhs.width() || lhs.height() != rhs.height()) { + return false; + } + const uint8_t* reference_ptr = const_cast(lhs).y_buffer()->host; + const uint8_t* converted_ptr = const_cast(rhs).y_buffer()->host; + if (!CompareArray(reference_ptr, converted_ptr, lhs.width(), lhs.height())) { + return false; + } + reference_ptr = const_cast(lhs).uv_buffer()->host; + converted_ptr = const_cast(rhs).uv_buffer()->host; + return CompareArray(reference_ptr, converted_ptr, lhs.width(), + lhs.height() / 2); +} + +TEST(RgbBufferTest, Properties) { + RgbBuffer rgb(2, 8, false), rgba(2, 8, true); + EXPECT_EQ(2, rgb.width()); + EXPECT_EQ(8, rgb.height()); + EXPECT_EQ(3, rgb.channels()); + + EXPECT_EQ(2, rgba.width()); + EXPECT_EQ(8, rgba.height()); + EXPECT_EQ(4, rgba.channels()); +} + +TEST(RgbBufferTest, PropertiesOfPaddedRgb) { + RgbBuffer rgb_buffer = GetPaddedRgbBuffer(); + EXPECT_EQ(rgb_buffer.width(), 4); + EXPECT_EQ(rgb_buffer.height(), 2); + EXPECT_EQ(rgb_buffer.row_stride(), 18); + EXPECT_EQ(rgb_buffer.channels(), 3); +} + +TEST(RgbBufferTest, PropertiesOfPaddedRgba) { + RgbBuffer rgb_buffer = GetPaddedRgbaBuffer(); + EXPECT_EQ(rgb_buffer.width(), 4); + EXPECT_EQ(rgb_buffer.height(), 2); + EXPECT_EQ(rgb_buffer.row_stride(), 24); + EXPECT_EQ(rgb_buffer.channels(), 4); +} + +TEST(RgbBufferTest, Release) { + RgbBuffer source(8, 8, true); + delete[] source.Release(); +} + +TEST(RgbBufferTest, Assign) { + RgbBuffer source(8, 8, false); + RgbBuffer sink(nullptr, 0, 0, false); + sink = source; + EXPECT_EQ(8, sink.width()); + EXPECT_EQ(8, sink.height()); + EXPECT_EQ(3, sink.channels()); + + sink = RgbBuffer(16, 16, true); + EXPECT_EQ(16, sink.width()); + EXPECT_EQ(16, sink.height()); + EXPECT_EQ(4, sink.channels()); +} + +TEST(RgbBufferTest, MoveAssign) { + RgbBuffer source(8, 8, false); + RgbBuffer sink(nullptr, 0, 0, true); + sink = std::move(source); + EXPECT_EQ(nullptr, source.Release()); + EXPECT_EQ(8, sink.width()); + EXPECT_EQ(8, sink.height()); +} + +TEST(RgbBufferTest, MoveConstructor) { + RgbBuffer source(8, 8, false); + RgbBuffer sink(std::move(source)); + EXPECT_EQ(nullptr, source.Release()); + EXPECT_EQ(8, sink.width()); + EXPECT_EQ(8, sink.height()); +} + +TEST(RgbBufferTest, RgbCrop) { + RgbBuffer source(8, 8, false); + EXPECT_TRUE(source.Crop(2, 2, 6, 6)); +} + +TEST(RgbBufferTest, RgbaCrop) { + RgbBuffer source(8, 8, true); + EXPECT_TRUE(source.Crop(2, 2, 6, 6)); +} + +// Some operations expect images with a platform-dependent minimum width +// because their implementations are vectorized. + +TEST(RgbBufferTest, RgbResize) { + RgbBuffer source(128, 8, false); + RgbBuffer result(32, 4, false); + Fill(&source); + EXPECT_TRUE(source.Resize(&result)); + + // Test odd result sizes too. + source = RgbBuffer(64, 16, false); + result = RgbBuffer(32, 7, false); + Fill(&source); + EXPECT_TRUE(source.Resize(&result)); +} + +TEST(RgbBufferTest, RgbaResize) { + RgbBuffer source(128, 8, true); + RgbBuffer result(32, 4, true); + Fill(&source); + EXPECT_TRUE(source.Resize(&result)); + + // Test odd result sizes too. + source = RgbBuffer(64, 16, true); + result = RgbBuffer(32, 7, true); + Fill(&source); + EXPECT_TRUE(source.Resize(&result)); +} + +// Note: RGB-to-RGBA conversion currently doesn't work. +TEST(RgbBufferTest, RgbResizeDifferentFormat) { + RgbBuffer source(128, 8, false); + RgbBuffer result(16, 4, true); + Fill(&source); + EXPECT_FALSE(source.Resize(&result)); +} + +TEST(RgbBufferTest, RgbaResizeDifferentFormat) { + RgbBuffer source(128, 8, true); + RgbBuffer result(16, 4, false); + Fill(&source); + EXPECT_TRUE(source.Resize(&result)); +} + +TEST(RgbBufferTest, PaddedRgbResize) { + const int target_width = 2; + const int target_height = 1; + RgbBuffer source = GetPaddedRgbBuffer(); + RgbBuffer result(target_width, target_height, /*alpha=*/false); + + ASSERT_TRUE(source.Resize(&result)); + EXPECT_EQ(result.width(), target_width); + EXPECT_EQ(result.height(), target_height); + EXPECT_EQ(result.channels(), 3); + EXPECT_EQ(result.row_stride(), target_width * /*pixel_stride=*/3); + + uint8_t rgb_data[] = {10, 20, 30, 30, 40, 50}; + RgbBuffer rgb_buffer = + RgbBuffer(rgb_data, target_width, target_height, /*alpha=*/false); + EXPECT_TRUE(CompareBuffer(rgb_buffer, result)); +} + +TEST(RgbBufferTest, PaddedRgbaResize) { + const int target_width = 2; + const int target_height = 1; + RgbBuffer source = GetPaddedRgbaBuffer(); + RgbBuffer result(target_width, target_height, /*alpha=*/true); + + ASSERT_TRUE(source.Resize(&result)); + EXPECT_EQ(result.width(), target_width); + EXPECT_EQ(result.height(), target_height); + EXPECT_EQ(result.channels(), 4); + EXPECT_EQ(result.row_stride(), target_width * /*pixel_stride=*/4); + + uint8_t rgb_data[] = {10, 20, 30, 255, 30, 40, 50, 255}; + RgbBuffer rgb_buffer = + RgbBuffer(rgb_data, target_width, target_height, /*alpha=*/true); + EXPECT_TRUE(CompareBuffer(rgb_buffer, result)); +} + +TEST(RgbBufferTest, RgbRotateCheckSize) { + RgbBuffer source(4, 8, false); + RgbBuffer result(8, 4, false); + Fill(&source); + EXPECT_TRUE(source.Rotate(90, &result)); +} + +TEST(RgbBufferTest, RgbRotateCheckData) { + uint8_t* data = new uint8_t[12]; + data[0] = data[1] = data[2] = 1; // Pixel 1 + data[3] = data[4] = data[5] = 2; // Pixel 2 + data[6] = data[7] = data[8] = 3; // Pixel 3 + data[9] = data[10] = data[11] = 4; // Pixel 4 + RgbBuffer source(data, 2, 2, false); + RgbBuffer result(2, 2, false); + source.Rotate(90, &result); + EXPECT_EQ(2, result.buffer()->host[0]); + EXPECT_EQ(4, result.buffer()->host[3]); + EXPECT_EQ(1, result.buffer()->host[6]); + EXPECT_EQ(3, result.buffer()->host[9]); + delete[] data; +} + +TEST(RgbBufferTest, RgbRotateDifferentFormat) { + RgbBuffer source(4, 8, true); + RgbBuffer result(8, 4, false); + Fill(&source); + EXPECT_TRUE(source.Rotate(90, &result)); +} + +// Note: RGB-to-RGBA conversion currently doesn't work. +TEST(RgbBufferTest, RgbRotateDifferentFormatFail) { + RgbBuffer source(4, 8, false); + RgbBuffer result(8, 4, true); + Fill(&source); + EXPECT_FALSE(source.Rotate(90, &result)); +} + +TEST(RgbBufferTest, RgbaRotate) { + RgbBuffer source(4, 8, true); + RgbBuffer result(8, 4, true); + Fill(&source); + EXPECT_TRUE(source.Rotate(90, &result)); +} + +TEST(RgbBufferTest, RgbaRotateDifferentFormat) { + RgbBuffer source(4, 8, true); + RgbBuffer result(8, 4, false); + Fill(&source); + EXPECT_TRUE(source.Rotate(90, &result)); +} + +// Note: RGB-to-RGBA conversion currently doesn't work. +TEST(RgbBufferTest, RgbaRotateDifferentFormatFail) { + RgbBuffer source(4, 8, false); + RgbBuffer result(8, 4, true); + Fill(&source); + EXPECT_FALSE(source.Rotate(90, &result)); +} + +TEST(RgbBufferTest, PaddedRgbRotateCheckData) { + const int target_width = 2; + const int target_height = 4; + RgbBuffer source = GetPaddedRgbBuffer(); + RgbBuffer result(target_width, target_height, /*alpha=*/false); + + ASSERT_TRUE(source.Rotate(/*angle=*/90, &result)); + EXPECT_EQ(result.width(), target_width); + EXPECT_EQ(result.height(), target_height); + EXPECT_EQ(result.channels(), 3); + EXPECT_EQ(result.row_stride(), target_width * /*pixel_stride=*/3); + + uint8_t rgb_data[] = {40, 50, 60, 80, 100, 120, 30, 40, 50, 60, 80, 100, + 20, 30, 40, 40, 60, 80, 10, 20, 30, 20, 40, 60}; + RgbBuffer rgb_buffer = + RgbBuffer(rgb_data, target_width, target_height, /*alpha=*/false); + EXPECT_TRUE(CompareBuffer(rgb_buffer, result)); +} + +TEST(RgbBufferTest, PaddedRgbaRotateCheckData) { + const int target_width = 2; + const int target_height = 4; + RgbBuffer result(target_width, target_height, /*alpha=*/true); + RgbBuffer source = GetPaddedRgbaBuffer(); + + ASSERT_TRUE(source.Rotate(/*angle=*/90, &result)); + EXPECT_EQ(result.width(), target_width); + EXPECT_EQ(result.height(), target_height); + EXPECT_EQ(result.channels(), 4); + EXPECT_EQ(result.row_stride(), target_width * /*pixel_stride=*/4); + + uint8_t rgb_data[] = {40, 50, 60, 255, 80, 100, 120, 255, 30, 40, 50, + 255, 60, 80, 100, 255, 20, 30, 40, 255, 40, 60, + 80, 255, 10, 20, 30, 255, 20, 40, 60, 255}; + RgbBuffer rgb_buffer = + RgbBuffer(rgb_data, target_width, target_height, /*alpha=*/true); + EXPECT_TRUE(CompareBuffer(rgb_buffer, result)); +} + +TEST(RgbBufferTest, RgbaFlip) { + RgbBuffer source(16, 16, true); + RgbBuffer result(16, 16, true); + Fill(&source); + EXPECT_TRUE(source.FlipHorizontally(&result)); + EXPECT_TRUE(source.FlipVertically(&result)); +} + +// Note: Neither RGBA-to-RGB nor RGB-to-RGBA conversion currently works. +TEST(RgbBufferTest, RgbaFlipDifferentFormatFail) { + RgbBuffer source(16, 16, false); + RgbBuffer result(16, 16, true); + Fill(&source); + Fill(&result); + EXPECT_FALSE(source.FlipHorizontally(&result)); + EXPECT_FALSE(result.FlipHorizontally(&source)); + EXPECT_FALSE(source.FlipVertically(&result)); + EXPECT_FALSE(result.FlipVertically(&source)); +} + +TEST(RgbBufferTest, PaddedRgbFlipHorizontally) { + const int target_width = 4; + const int target_height = 2; + RgbBuffer result(target_width, target_height, /*alpha=*/false); + RgbBuffer source = GetPaddedRgbBuffer(); + + ASSERT_TRUE(source.FlipHorizontally(&result)); + EXPECT_EQ(result.width(), target_width); + EXPECT_EQ(result.height(), target_height); + EXPECT_EQ(result.channels(), 3); + EXPECT_EQ(result.row_stride(), target_width * /*pixel_stride=*/3); + + uint8_t rgb_data[] = {40, 50, 60, 30, 40, 50, 20, 30, 40, 10, 20, 30, + 80, 100, 120, 60, 80, 100, 40, 60, 80, 20, 40, 60}; + RgbBuffer rgb_buffer = + RgbBuffer(rgb_data, target_width, target_height, /*alpha=*/false); + EXPECT_TRUE(CompareBuffer(rgb_buffer, result)); +} + +TEST(RgbBufferTest, PaddedRgbaFlipHorizontally) { + const int target_width = 4; + const int target_height = 2; + RgbBuffer result(target_width, target_height, /*alpha=*/true); + RgbBuffer source = GetPaddedRgbaBuffer(); + + ASSERT_TRUE(source.FlipHorizontally(&result)); + EXPECT_EQ(result.width(), target_width); + EXPECT_EQ(result.height(), target_height); + EXPECT_EQ(result.channels(), 4); + EXPECT_EQ(result.row_stride(), target_width * /*pixel_stride=*/4); + + uint8_t rgb_data[] = {40, 50, 60, 255, 30, 40, 50, 255, 20, 30, 40, + 255, 10, 20, 30, 255, 80, 100, 120, 255, 60, 80, + 100, 255, 40, 60, 80, 255, 20, 40, 60, 255}; + RgbBuffer rgb_buffer = + RgbBuffer(rgb_data, target_width, target_height, /*alpha=*/true); + EXPECT_TRUE(CompareBuffer(rgb_buffer, result)); +} + +TEST(RgbBufferTest, RgbConvertNv21) { + RgbBuffer source(32, 8, false); + YuvBuffer result(32, 8, YuvBuffer::NV21); + Fill(&source); + EXPECT_TRUE(source.Convert(&result)); +} + +TEST(RgbBufferTest, RgbaConvertNv21) { + RgbBuffer source(32, 8, true); + YuvBuffer result(32, 8, YuvBuffer::NV21); + Fill(&source); + EXPECT_TRUE(source.Convert(&result)); +} + +TEST(RgbBufferTest, PaddedRgbConvertNv21) { + const int target_width = 4; + const int target_height = 2; + YuvBuffer result(target_width, target_height, YuvBuffer::NV21); + RgbBuffer source = GetPaddedRgbBuffer(); + + ASSERT_TRUE(source.Convert(&result)); + EXPECT_EQ(result.width(), target_width); + EXPECT_EQ(result.height(), target_height); + + uint8_t yuv_data[] = {18, 28, 38, 48, 36, 56, 76, 96, 122, 135, 122, 135}; + YuvBuffer yuv_buffer = + YuvBuffer(yuv_data, target_width, target_height, YuvBuffer::NV21); + EXPECT_TRUE(CompareBuffer(yuv_buffer, result)); +} + +TEST(RgbBufferTest, PaddedRgbaConvertNv21) { + const int target_width = 4; + const int target_height = 2; + YuvBuffer result(target_width, target_height, YuvBuffer::NV21); + RgbBuffer source = GetPaddedRgbaBuffer(); + + ASSERT_TRUE(source.Convert(&result)); + EXPECT_EQ(result.width(), target_width); + EXPECT_EQ(result.height(), target_height); + + uint8_t yuv_data[] = {18, 28, 38, 48, 36, 56, 76, 96, 122, 135, 122, 135}; + YuvBuffer yuv_buffer = + YuvBuffer(yuv_data, target_width, target_height, YuvBuffer::NV21); + EXPECT_TRUE(CompareBuffer(yuv_buffer, result)); +} + +TEST(RgbBufferTest, RgbConvertGray) { + uint8_t* data = new uint8_t[6]; + data[0] = 200; + data[1] = 100; + data[2] = 0; + data[3] = 0; + data[4] = 200; + data[5] = 100; + RgbBuffer source(data, 2, 1, false); + GrayBuffer result(2, 1); + EXPECT_TRUE(source.Convert(&result)); + EXPECT_EQ(118, result.buffer()->host[0]); + EXPECT_EQ(129, result.buffer()->host[1]); + delete[] data; +} + +TEST(RgbBufferTest, RgbaConvertGray) { + uint8_t* data = new uint8_t[8]; + data[0] = 200; + data[1] = 100; + data[2] = 0; + data[3] = 1; + data[4] = 0; + data[5] = 200; + data[6] = 100; + data[7] = 50; + RgbBuffer source(data, 2, 1, true); + GrayBuffer result(2, 1); + EXPECT_TRUE(source.Convert(&result)); + EXPECT_EQ(118, result.buffer()->host[0]); + EXPECT_EQ(129, result.buffer()->host[1]); + delete[] data; +} + +TEST(RgbBufferTest, PaddedRgbConvertGray) { + const int target_width = 4; + const int target_height = 2; + GrayBuffer result(target_width, target_height); + RgbBuffer source = GetPaddedRgbBuffer(); + + ASSERT_TRUE(source.Convert(&result)); + EXPECT_EQ(result.width(), target_width); + EXPECT_EQ(result.height(), target_height); + + uint8_t gray_data[] = {18, 28, 38, 48, 36, 56, 76, 96}; + GrayBuffer gray_buffer = GrayBuffer(gray_data, target_width, target_height); + EXPECT_TRUE(CompareBuffer(gray_buffer, result)); +} + +TEST(RgbBufferTest, PaddedRgbaConvertGray) { + const int target_width = 4; + const int target_height = 2; + GrayBuffer result(target_width, target_height); + RgbBuffer source = GetPaddedRgbaBuffer(); + + ASSERT_TRUE(source.Convert(&result)); + EXPECT_EQ(result.width(), target_width); + EXPECT_EQ(result.height(), target_height); + + uint8_t gray_data[] = {18, 28, 38, 48, 36, 56, 76, 96}; + GrayBuffer gray_buffer = GrayBuffer(gray_data, target_width, target_height); + EXPECT_TRUE(CompareBuffer(gray_buffer, result)); +} + +TEST(RgbBufferTest, RgbConvertRgba) { + constexpr int kWidth = 2, kHeight = 1; + uint8_t rgb_data[] = {200, 100, 50, 100, 50, 20}; + RgbBuffer source(rgb_data, kWidth, kHeight, false); + RgbBuffer result(kWidth, kHeight, true); + + ASSERT_TRUE(source.Convert(&result)); + + uint8_t rgba_data[] = {200, 100, 50, 255, 100, 50, 20, 255}; + RgbBuffer rgba_buffer = RgbBuffer(rgba_data, kWidth, kHeight, true); + EXPECT_TRUE(CompareBuffer(rgba_buffer, result)); +} + +TEST(RgbBufferTest, PaddedRgbConvertRgba) { + constexpr int kWidth = 4, kHeight = 2; + RgbBuffer source = GetPaddedRgbBuffer(); + RgbBuffer result(kWidth, kHeight, true); + ASSERT_TRUE(source.Convert(&result)); + + uint8_t rgba_data[]{10, 20, 30, 255, 20, 30, 40, 255, 30, 40, 50, + 255, 40, 50, 60, 255, 20, 40, 60, 255, 40, 60, + 80, 255, 60, 80, 100, 255, 80, 100, 120, 255}; + RgbBuffer rgba_buffer = RgbBuffer(rgba_data, kWidth, kHeight, true); + EXPECT_TRUE(CompareBuffer(rgba_buffer, result)); +} + +TEST(RgbBufferTest, RgbaConvertRgb) { + constexpr int kWidth = 2, kHeight = 1; + uint8_t rgba_data[] = {200, 100, 50, 30, 100, 50, 20, 70}; + RgbBuffer source(rgba_data, kWidth, kHeight, true); + RgbBuffer result(kWidth, kHeight, false); + + ASSERT_TRUE(source.Convert(&result)); + + uint8_t rgb_data[] = {200, 100, 50, 100, 50, 20}; + RgbBuffer rgb_buffer = RgbBuffer(rgb_data, kWidth, kHeight, false); + EXPECT_TRUE(CompareBuffer(rgb_buffer, result)); +} + +TEST(RgbBufferTest, PaddedRgbaConvertRgb) { + constexpr int kWidth = 4, kHeight = 2; + RgbBuffer source = GetPaddedRgbaBuffer(); + RgbBuffer result(kWidth, kHeight, false); + + ASSERT_TRUE(source.Convert(&result)); + + uint8_t rgb_data[] = {10, 20, 30, 20, 30, 40, 30, 40, 50, 40, 50, 60, + 20, 40, 60, 40, 60, 80, 60, 80, 100, 80, 100, 120}; + RgbBuffer rgb_buffer = RgbBuffer(rgb_data, kWidth, kHeight, false); + EXPECT_TRUE(CompareBuffer(rgb_buffer, result)); +} +} // namespace +} // namespace frame_buffer +} // namespace mediapipe diff --git a/mediapipe/util/frame_buffer/yuv_buffer.cc b/mediapipe/util/frame_buffer/yuv_buffer.cc new file mode 100644 index 000000000..f96282134 --- /dev/null +++ b/mediapipe/util/frame_buffer/yuv_buffer.cc @@ -0,0 +1,152 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/frame_buffer/yuv_buffer.h" + +#include + +#include "mediapipe/util/frame_buffer/buffer_common.h" +#include "mediapipe/util/frame_buffer/halide/yuv_flip_halide.h" +#include "mediapipe/util/frame_buffer/halide/yuv_resize_halide.h" +#include "mediapipe/util/frame_buffer/halide/yuv_rgb_halide.h" +#include "mediapipe/util/frame_buffer/halide/yuv_rotate_halide.h" +#include "mediapipe/util/frame_buffer/rgb_buffer.h" + +namespace mediapipe { +namespace frame_buffer { + +YuvBuffer::YuvBuffer(uint8_t* y_plane, uint8_t* u_plane, uint8_t* v_plane, + int width, int height, int row_stride_y, int row_stride_uv, + int pixel_stride_uv) { + // Initialize the buffer shapes: {min, extent, stride} per dimension. + // TODO: Ensure that width is less than or equal to row stride. + const halide_dimension_t y_dimensions[2] = { + {0, width, 1}, + {0, height, row_stride_y}, + }; + y_buffer_ = Halide::Runtime::Buffer(y_plane, 2, y_dimensions); + + // Note that the Halide implementation expects the planes to be in VU + // order, so we point at the V plane first. + const halide_dimension_t uv_dimensions[3] = { + {0, (width + 1) / 2, pixel_stride_uv}, + {0, (height + 1) / 2, row_stride_uv}, + {0, 2, static_cast(u_plane - v_plane)}, + }; + uv_buffer_ = Halide::Runtime::Buffer(v_plane, 3, uv_dimensions); +} + +YuvBuffer::YuvBuffer(uint8_t* data, int width, int height, Format format) + : owned_buffer_(nullptr) { + Initialize(data, width, height, format); +} + +YuvBuffer::YuvBuffer(int width, int height, Format format) + : owned_buffer_(new uint8_t[ByteSize(width, height)]) { + Initialize(owned_buffer_.get(), width, height, format); +} + +YuvBuffer::YuvBuffer(const YuvBuffer& other) + : y_buffer_(other.y_buffer_), uv_buffer_(other.uv_buffer_) { + // Never copy owned_buffer; ownership remains with the source of the copy. +} + +YuvBuffer::YuvBuffer(YuvBuffer&& other) { *this = std::move(other); } + +YuvBuffer& YuvBuffer::operator=(const YuvBuffer& other) { + if (this != &other) { + y_buffer_ = other.y_buffer_; + uv_buffer_ = other.uv_buffer_; + } + return *this; +} +YuvBuffer& YuvBuffer::operator=(YuvBuffer&& other) { + if (this != &other) { + owned_buffer_ = std::move(other.owned_buffer_); + y_buffer_ = other.y_buffer_; + uv_buffer_ = other.uv_buffer_; + } + return *this; +} + +YuvBuffer::~YuvBuffer() {} + +void YuvBuffer::Initialize(uint8_t* data, int width, int height, + Format format) { + y_buffer_ = Halide::Runtime::Buffer(data, width, height); + + uint8_t* uv_data = data + (width * height); + switch (format) { + case NV21: + // Interleaved UV (actually VU order). + uv_buffer_ = Halide::Runtime::Buffer::make_interleaved( + uv_data, (width + 1) / 2, (height + 1) / 2, 2); + break; + case YV12: + // Planar UV (actually VU order). + uv_buffer_ = Halide::Runtime::Buffer(uv_data, (width + 1) / 2, + (height + 1) / 2, 2); + // NOTE: Halide operations have not been tested extensively in this + // configuration. + break; + } +} + +bool YuvBuffer::Crop(int x0, int y0, int x1, int y1) { + if (x0 & 1 || y0 & 1) { + // YUV images must be left-and top-aligned to even X/Y coordinates. + return false; + } + + // Twiddle the buffer start and extents for each plane to crop images. + return (common::crop_buffer(x0, y0, x1, y1, y_buffer()) && + common::crop_buffer(x0 / 2, y0 / 2, x1 / 2, y1 / 2, uv_buffer())); +} + +bool YuvBuffer::Resize(YuvBuffer* output) { + const int result = yuv_resize_halide( + y_buffer(), uv_buffer(), static_cast(width()) / output->width(), + static_cast(height()) / output->height(), output->y_buffer(), + output->uv_buffer()); + return result == 0; +} + +bool YuvBuffer::Rotate(int angle, YuvBuffer* output) { + const int result = yuv_rotate_halide(y_buffer(), uv_buffer(), angle, + output->y_buffer(), output->uv_buffer()); + return result == 0; +} + +bool YuvBuffer::FlipHorizontally(YuvBuffer* output) { + const int result = yuv_flip_halide(y_buffer(), uv_buffer(), + false, // horizontal + output->y_buffer(), output->uv_buffer()); + return result == 0; +} + +bool YuvBuffer::FlipVertically(YuvBuffer* output) { + const int result = yuv_flip_halide(y_buffer(), uv_buffer(), + true, // vertical + output->y_buffer(), output->uv_buffer()); + return result == 0; +} + +bool YuvBuffer::Convert(bool halve, RgbBuffer* output) { + const int result = + yuv_rgb_halide(y_buffer(), uv_buffer(), halve, output->buffer()); + return result == 0; +} + +} // namespace frame_buffer +} // namespace mediapipe diff --git a/mediapipe/util/frame_buffer/yuv_buffer.h b/mediapipe/util/frame_buffer/yuv_buffer.h new file mode 100644 index 000000000..8e4715dcf --- /dev/null +++ b/mediapipe/util/frame_buffer/yuv_buffer.h @@ -0,0 +1,160 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MEDIAPIPE_UTIL_FRAME_BUFFER_YUV_BUFFER_H_ +#define MEDIAPIPE_UTIL_FRAME_BUFFER_YUV_BUFFER_H_ + +#include + +#include "HalideBuffer.h" +#include "HalideRuntime.h" + +namespace mediapipe { +namespace frame_buffer { +class RgbBuffer; + +// YuvBuffer represents a view over a YUV 4:2:0 image. +// +// YuvBuffers may be copied and moved efficiently; their backing buffers are +// shared and never deep copied. +// +// YuvBuffer requires a minimum image width depending on the natural vector +// size of the platform, e.g., 16px. This is not validated by YuvBuffer. +class YuvBuffer { + public: + // YUV formats. Rather than supporting every possible format, we prioritize + // formats with broad hardware/platform support. + // + // Enum values are FourCC codes; see http://fourcc.org/yuv.php for more. + enum Format { + NV21 = 0x3132564E, // YUV420SP (VU interleaved) + YV12 = 0x32315659, // YUV420P (VU planar) + }; + + // Returns the size (in bytes) of a YUV image of the given dimensions. + static int ByteSize(int width, int height) { + // 1 byte per pixel in the Y plane, 2 bytes per 2x2 block in the UV plane. + // Dimensions with odd sizes are rounded up. + const int y_size = width * height; + const int uv_size = ((width + 1) / 2) * ((height + 1) / 2) * 2; + return y_size + uv_size; + } + + // Builds a generic YUV420 YuvBuffer with the given backing buffers, + // dimensions and strides. Supports both interleaved or planar UV with + // custom strides. + // + // Does not take ownership of any backing buffers, which must be large + // enough to fit their contents. + YuvBuffer(uint8_t* y_plane, uint8_t* u_plane, uint8_t* v_plane, int width, + int height, int row_stride_y, int row_stride_uv, + int pixel_stride_uv); + + // Builds a YuvBuffer using the given backing buffer, dimensions, and format. + // Expects an NV21- or YV12-format image only. + // + // Does not take ownership of the backing buffer (provided in 'data'), which + // must be sized to hold at least the amount indicated by ByteSize(). + YuvBuffer(uint8_t* data, int width, int height, Format format); + + // Builds a YuvBuffer using the given dimensions and format. Expects + // an NV21- or YV12-format image only. + // + // The underlying backing buffer is allocated and owned by this YuvBuffer. + YuvBuffer(int width, int height, Format format); + + // YuvBuffer is copyable. The source retains ownership of its backing buffers. + YuvBuffer(const YuvBuffer& other); + // YuvBuffer is moveable. The source loses ownership of any backing buffers. + YuvBuffer(YuvBuffer&& other); + // YuvBuffer is assignable. + YuvBuffer& operator=(const YuvBuffer& other); + YuvBuffer& operator=(YuvBuffer&& other); + + ~YuvBuffer(); + + // Performs an in-place crop. Modifies this buffer so that the new extent + // matches that of the given crop rectangle -- (x0, y0) becomes (0, 0) and + // the new width and height are x1 - x0 + 1 and y1 - y0 + 1, respectively. + // + // Note that the top-left corner (x0, y0) coordinates must be even to + // maintain alignment between the Y and UV grids. + bool Crop(int x0, int y0, int x1, int y1); + + // Resize this image to match the dimensions of the given output YuvBuffer + // and places the result into its backing buffer. + // + // Performs a resize with bilinear interpolation (over four source pixels). + bool Resize(YuvBuffer* output); + + // Rotate this image into the given buffer by the given angle (90, 180, 270). + // + // Rotation is specified in degrees counter-clockwise such that when rotating + // by 90 degrees, the top-right corner of the source becomes the top-left of + // the output. The output buffer must have its height and width swapped when + // rotating by 90 or 270. + // + // Any angle values other than (90, 180, 270) are invalid. + bool Rotate(int angle, YuvBuffer* output); + + // Flip this image horizontally/vertically into the given buffer. Both buffer + // dimensions must match. + bool FlipHorizontally(YuvBuffer* output); + bool FlipVertically(YuvBuffer* output); + + // Performs a YUV-to-RGB color format conversion and places the result + // in the given output RgbBuffer. Both buffer dimensions must match. + // + // When halve is true, the converted output is downsampled by a factor of + // two by discarding three of four luminance values in every 2x2 block. + bool Convert(bool halve, RgbBuffer* output); + + // Release ownership of the owned backing buffer. + uint8_t* Release() { return owned_buffer_.release(); } + + // Returns the halide_buffer_t* for the Y plane. + const halide_buffer_t* y_buffer() const { return y_buffer_.raw_buffer(); } + // Returns the halide_buffer_t* for the UV plane(s). + const halide_buffer_t* uv_buffer() const { return uv_buffer_.raw_buffer(); } + // Returns the halide_buffer_t* for the Y plane. + halide_buffer_t* y_buffer() { return y_buffer_.raw_buffer(); } + // Returns the halide_buffer_t* for the UV plane(s). + halide_buffer_t* uv_buffer() { return uv_buffer_.raw_buffer(); } + + // Returns the image width. + const int width() const { return y_buffer_.dim(0).extent(); } + // Returns the image height. + const int height() const { return y_buffer_.dim(1).extent(); } + + private: + void Initialize(uint8_t* data, int width, int height, Format format); + + // Non-NULL iff this YuvBuffer owns its buffer. + std::unique_ptr owned_buffer_; + + // Y (luminance) backing buffer: layout is always width x height. + Halide::Runtime::Buffer y_buffer_; + + // UV (chrominance) backing buffer; width/2 x height/2 x 2 (channel). + // May be interleaved or planar. + // + // Note that the planes are in the reverse of the usual order: channel 0 is V + // and channel 1 is U. + Halide::Runtime::Buffer uv_buffer_; +}; + +} // namespace frame_buffer +} // namespace mediapipe + +#endif // MEDIAPIPE_UTIL_FRAME_BUFFER_YUV_BUFFER_H_ diff --git a/mediapipe/util/frame_buffer/yuv_buffer_test.cc b/mediapipe/util/frame_buffer/yuv_buffer_test.cc new file mode 100644 index 000000000..a18b19a92 --- /dev/null +++ b/mediapipe/util/frame_buffer/yuv_buffer_test.cc @@ -0,0 +1,251 @@ +// Copyright 2023 The MediaPipe Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mediapipe/util/frame_buffer/yuv_buffer.h" + +#include + +#include "absl/log/log.h" +#include "mediapipe/framework/port/gmock.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/util/frame_buffer/rgb_buffer.h" + +// The default implementation of halide_error calls abort(), which we don't +// want. Instead, log the error and let the filter invocation fail. +extern "C" void halide_error(void*, const char* message) { + LOG(ERROR) << "Halide Error: " << message; +} + +namespace mediapipe { +namespace frame_buffer { +namespace { + +// Fill a halide_buffer_t channel with the given value. +void Fill(halide_buffer_t* buffer, int channel, int value) { + for (int y = 0; y < buffer->dim[1].extent; ++y) { + for (int x = 0; x < buffer->dim[0].extent; ++x) { + buffer->host[buffer->dim[1].stride * y + buffer->dim[0].stride * x + + buffer->dim[2].stride * channel] = value; + } + } +} + +// Fill a YuvBuffer with the given YUV color. +void Fill(YuvBuffer* buffer, uint8_t y, uint8_t u, uint8_t v) { + Fill(buffer->y_buffer(), 0, y); + Fill(buffer->uv_buffer(), 1, u); + Fill(buffer->uv_buffer(), 0, v); +} + +TEST(YuvBufferTest, Properties) { + YuvBuffer yuv(2, 8, YuvBuffer::NV21); + EXPECT_EQ(2, yuv.width()); + EXPECT_EQ(8, yuv.height()); +} + +TEST(YuvBufferTest, Release) { + YuvBuffer source(8, 8, YuvBuffer::NV21); + delete[] source.Release(); +} + +TEST(YuvBufferTest, Assign) { + YuvBuffer source(8, 8, YuvBuffer::NV21); + YuvBuffer sink(nullptr, 0, 0, YuvBuffer::NV21); + sink = source; + EXPECT_EQ(8, sink.width()); + EXPECT_EQ(8, sink.height()); + + sink = YuvBuffer(16, 16, YuvBuffer::NV21); + EXPECT_EQ(16, sink.width()); + EXPECT_EQ(16, sink.height()); +} + +TEST(YuvBufferTest, MoveAssign) { + YuvBuffer source(8, 8, YuvBuffer::NV21); + YuvBuffer sink(nullptr, 0, 0, YuvBuffer::NV21); + sink = std::move(source); + EXPECT_EQ(nullptr, source.Release()); + EXPECT_EQ(8, sink.width()); + EXPECT_EQ(8, sink.height()); +} + +TEST(YuvBufferTest, MoveConstructor) { + YuvBuffer source(8, 8, YuvBuffer::NV21); + YuvBuffer sink(std::move(source)); + EXPECT_EQ(nullptr, source.Release()); + EXPECT_EQ(8, sink.width()); + EXPECT_EQ(8, sink.height()); +} + +TEST(YuvBufferTest, GenericSemiplanarLayout) { + uint8_t y_plane[16], uv_plane[8]; + YuvBuffer buffer(y_plane, uv_plane, uv_plane + 1, 4, 4, 4, 4, 2); + Fill(&buffer, 16, 32, 64); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(y_plane[i], 16) << i; + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(uv_plane[2 * i], 32); + EXPECT_EQ(uv_plane[2 * i + 1], 64); + } +} + +TEST(YuvBufferTest, GenericPlanarLayout) { + uint8_t y_plane[16], u_plane[4], v_plane[4]; + YuvBuffer buffer(y_plane, u_plane, v_plane, 4, 4, 4, 2, 1); + Fill(&buffer, 16, 32, 64); + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(y_plane[i], 16) << i; + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(u_plane[i], 32); + EXPECT_EQ(v_plane[i], 64); + } +} + +TEST(YuvBufferTest, Nv21Crop) { + YuvBuffer source(8, 8, YuvBuffer::NV21); + EXPECT_TRUE(source.Crop(2, 2, 6, 6)); +} + +TEST(YuvBufferTest, Nv21Resize) { + YuvBuffer source(8, 8, YuvBuffer::NV21); + YuvBuffer result(4, 4, YuvBuffer::NV21); + Fill(&source, 16, 32, 64); + EXPECT_TRUE(source.Resize(&result)); + + // Test odd result sizes too. + source = YuvBuffer(500, 362, YuvBuffer::NV21); + result = YuvBuffer(320, 231, YuvBuffer::NV21); + Fill(&source, 16, 32, 64); + EXPECT_TRUE(source.Resize(&result)); +} + +TEST(YuvBufferTest, Nv21ResizeDifferentFormat) { + YuvBuffer source(8, 8, YuvBuffer::NV21); + YuvBuffer result(4, 4, YuvBuffer::YV12); + Fill(&source, 16, 32, 64); + EXPECT_TRUE(source.Resize(&result)); +} + +TEST(YuvBufferTest, Nv21Rotate) { + YuvBuffer source(4, 8, YuvBuffer::NV21); + YuvBuffer result(8, 4, YuvBuffer::NV21); + Fill(&source, 16, 32, 64); + EXPECT_TRUE(source.Rotate(90, &result)); +} + +TEST(YuvBufferTest, Nv21RotateDifferentFormat) { + YuvBuffer source(8, 8, YuvBuffer::NV21); + YuvBuffer result(8, 8, YuvBuffer::YV12); + Fill(&source, 16, 32, 64); + EXPECT_TRUE(source.Rotate(90, &result)); +} + +TEST(YuvBufferTest, Nv21RotateFailBounds) { + // Expect failure if the destination doesn't have the correct bounds. + YuvBuffer source(4, 8, YuvBuffer::NV21); + YuvBuffer result(4, 8, YuvBuffer::NV21); + Fill(&source, 16, 32, 64); + EXPECT_FALSE(source.Rotate(90, &result)); +} + +TEST(YuvBufferTest, Nv21Flip) { + YuvBuffer source(16, 16, YuvBuffer::NV21); + YuvBuffer result(16, 16, YuvBuffer::NV21); + Fill(&source, 16, 32, 64); + EXPECT_TRUE(source.FlipHorizontally(&result)); + EXPECT_TRUE(source.FlipVertically(&result)); +} + +TEST(YuvBufferTest, Nv21FlipDifferentFormat) { + YuvBuffer source(16, 16, YuvBuffer::NV21); + YuvBuffer result(16, 16, YuvBuffer::YV12); + Fill(&source, 16, 32, 64); + EXPECT_TRUE(source.FlipHorizontally(&result)); + EXPECT_TRUE(source.FlipVertically(&result)); +} + +TEST(YuvBufferTest, Nv21ConvertRgb) { + // Note that RGB conversion expects at least images of width >= 32 because + // the implementation is vectorized. + YuvBuffer source(32, 8, YuvBuffer::NV21); + Fill(&source, 52, 170, 90); + + RgbBuffer result_rgb(32, 8, false); + EXPECT_TRUE(source.Convert(false, &result_rgb)); + + RgbBuffer result_rgba(32, 8, true); + EXPECT_TRUE(source.Convert(false, &result_rgba)); + + uint8_t* pixels = result_rgba.buffer()->host; + ASSERT_TRUE(pixels); + EXPECT_EQ(pixels[0], 0); + EXPECT_EQ(pixels[1], 65); + EXPECT_EQ(pixels[2], 126); + EXPECT_EQ(pixels[3], 255); +} + +TEST(YuvBufferTest, Nv21ConvertRgbCropped) { + // Note that RGB conversion expects at least images of width >= 32 because + // the implementation is vectorized. + YuvBuffer source(1024, 768, YuvBuffer::NV21); + Fill(&source, 52, 170, 90); + + // YUV images must be left-and top-aligned to even X/Y coordinates, + // regardless of whether the target image has even or odd width/height. + EXPECT_FALSE(source.Crop(1, 1, 512, 384)); + EXPECT_FALSE(source.Crop(1, 1, 511, 383)); + + YuvBuffer source1(source); + EXPECT_TRUE(source1.Crop(64, 64, 512, 384)); + RgbBuffer result_rgb(source1.width(), source1.height(), false); + EXPECT_TRUE(source1.Convert(false, &result_rgb)); + + YuvBuffer source2(source); + EXPECT_TRUE(source2.Crop(64, 64, 511, 383)); + RgbBuffer result_rgba(source2.width(), source2.height(), true); + EXPECT_TRUE(source2.Convert(false, &result_rgba)); + + uint8_t* pixels = result_rgba.buffer()->host; + ASSERT_TRUE(pixels); + EXPECT_EQ(pixels[0], 0); + EXPECT_EQ(pixels[1], 65); + EXPECT_EQ(pixels[2], 126); + EXPECT_EQ(pixels[3], 255); +} + +TEST(YuvBufferTest, Nv21ConvertRgbHalve) { + YuvBuffer source(64, 8, YuvBuffer::NV21); + Fill(&source, 52, 170, 90); + + RgbBuffer result_rgb(32, 4, false); + EXPECT_TRUE(source.Convert(true, &result_rgb)); + + RgbBuffer result_rgba(32, 4, true); + EXPECT_TRUE(source.Convert(true, &result_rgba)); + + uint8_t* pixels = result_rgba.buffer()->host; + ASSERT_TRUE(pixels); + EXPECT_EQ(pixels[0], 0); + EXPECT_EQ(pixels[1], 65); + EXPECT_EQ(pixels[2], 126); + EXPECT_EQ(pixels[3], 255); +} + +} // namespace +} // namespace frame_buffer +} // namespace mediapipe diff --git a/mediapipe/util/label_map.proto b/mediapipe/util/label_map.proto index 79301d2b6..5c33269f5 100644 --- a/mediapipe/util/label_map.proto +++ b/mediapipe/util/label_map.proto @@ -16,6 +16,9 @@ syntax = "proto2"; package mediapipe; +option java_package = "com.google.mediapipe.util.proto"; +option java_outer_classname = "LabelMapProto"; + // Mapping a numerical class index output to a Knowledge Graph entity // ID or any other string label representing this class. Optionally it is // possible to specify an additional display name (in a given language) which is diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index da58d59cf..0dec0392a 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -30,10 +30,16 @@ cc_library( ], ) -cc_library( +# TODO: Re-evaluate which of these libraries we can avoid making +# cc_library_with_tflite and can be changed back to cc_library. +cc_library_with_tflite( name = "cpu_op_resolver", srcs = ["cpu_op_resolver.cc"], hdrs = ["cpu_op_resolver.h"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite:framework_stable", + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], visibility = ["//visibility:public"], deps = [ "//mediapipe/framework/port:logging", @@ -44,8 +50,6 @@ cc_library( "//mediapipe/util/tflite/operations:transform_tensor_bilinear", "//mediapipe/util/tflite/operations:transpose_conv_bias", "@org_tensorflow//tensorflow/lite:builtin_op_data", - "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", - "@org_tensorflow//tensorflow/lite/core/shims:framework_stable", ], # For using the symbol `MediaPipe_RegisterTfLiteOpResolver` in Python # with `tensorflow.lite.python.interpreter.InterpreterWithCustomOps`. @@ -63,13 +67,17 @@ cc_library( ], ) -cc_library( +# TODO: Re-evaluate which of these libraries we can avoid making +# cc_library_with_tflite and can be changed back to cc_library. +cc_library_with_tflite( name = "op_resolver", srcs = ["op_resolver.cc"], hdrs = ["op_resolver.h"], + tflite_deps = [ + "@org_tensorflow//tensorflow/lite/kernels:builtin_ops", + ], deps = [ "@org_tensorflow//tensorflow/lite:builtin_op_data", - "@org_tensorflow//tensorflow/lite/core/shims:builtin_ops", ], ) diff --git a/mediapipe/util/tflite/cpu_op_resolver.h b/mediapipe/util/tflite/cpu_op_resolver.h index 173531f11..887683013 100644 --- a/mediapipe/util/tflite/cpu_op_resolver.h +++ b/mediapipe/util/tflite/cpu_op_resolver.h @@ -15,7 +15,7 @@ #ifndef MEDIAPIPE_UTIL_TFLITE_CPU_OP_RESOLVER_H_ #define MEDIAPIPE_UTIL_TFLITE_CPU_OP_RESOLVER_H_ -#include "tensorflow/lite/core/shims/cc/kernels/register.h" +#include "tensorflow/lite/kernels/register.h" namespace mediapipe { @@ -27,8 +27,8 @@ extern "C" void MediaPipe_RegisterTfLiteOpResolver(tflite::MutableOpResolver*); // This resolver is used for the custom ops introduced by // `MediaPipe_RegisterTfLiteOpResolver` (see above). -class CpuOpResolver : public tflite_shims::ops::builtin:: - BuiltinOpResolverWithoutDefaultDelegates { +class CpuOpResolver + : public tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates { public: CpuOpResolver() { MediaPipe_RegisterTfLiteOpResolver(this); } }; diff --git a/mediapipe/util/tflite/op_resolver.h b/mediapipe/util/tflite/op_resolver.h index 8b04d5f1a..4ca179ef1 100644 --- a/mediapipe/util/tflite/op_resolver.h +++ b/mediapipe/util/tflite/op_resolver.h @@ -15,13 +15,13 @@ #ifndef MEDIAPIPE_UTIL_TFLITE_OP_RESOLVER_H_ #define MEDIAPIPE_UTIL_TFLITE_OP_RESOLVER_H_ -#include "tensorflow/lite/core/shims/cc/kernels/register.h" +#include "tensorflow/lite/kernels/register.h" namespace mediapipe { // This OpResolver is used for supporting "Convolution2DTransposeBias" on GPU. -class OpResolver : public tflite_shims::ops::builtin:: - BuiltinOpResolverWithoutDefaultDelegates { +class OpResolver + : public tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates { public: OpResolver(); }; diff --git a/mediapipe/web/graph_runner/graph_runner.ts b/mediapipe/web/graph_runner/graph_runner.ts index 4417f6a03..e2b1684a0 100644 --- a/mediapipe/web/graph_runner/graph_runner.ts +++ b/mediapipe/web/graph_runner/graph_runner.ts @@ -36,6 +36,17 @@ export type EmptyPacketListener = (timestamp: number) => void; export type VectorListener = (data: T, done: boolean, timestamp: number) => void; +/** + * A listener that receives the CalculatorGraphConfig in binary encoding. + */ +export type CalculatorGraphConfigListener = (graphConfig: Uint8Array) => void; + +/** + * The name of the internal listener that we use to obtain the calculator graph + * config. Intended for internal usage. Exported for testing only. + */ +export const CALCULATOR_GRAPH_CONFIG_LISTENER_NAME = '__graph_config__'; + /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler * doesn't break our JS/C++ bridge. @@ -124,6 +135,10 @@ export declare interface WasmModule { _configureAudio: (channels: number, samples: number, sampleRate: number, streamNamePtr: number, headerNamePtr: number) => void; + // Get the graph configuration and invoke the listener configured under + // streamNamePtr + _getGraphConfig: (streamNamePtr: number, makeDeepCopy?: boolean) => void; + // TODO: Refactor to just use a few numbers (perhaps refactor away // from gl_graph_runner_internal.cc entirely to use something a little more // streamlined; new version is _processFrame above). @@ -437,6 +452,29 @@ export class GraphRunner { this.wasmModule._free(heapSpace); } + /** + * Invokes the callback with the current calculator configuration (in binary + * format). + * + * Consumers must deserialize the binary representation themselves as this + * avoids addding a direct dependency on the Protobuf JSPB target in the graph + * library. + */ + getCalculatorGraphConfig( + callback: CalculatorGraphConfigListener, makeDeepCopy?: boolean): void { + const listener = CALCULATOR_GRAPH_CONFIG_LISTENER_NAME; + + // Create a short-lived listener to receive the binary encoded proto + this.setListener(listener, (data: Uint8Array) => { + callback(data); + }); + this.wrapStringPtr(listener, (outputStreamNamePtr: number) => { + this.wasmModule._getGraphConfig(outputStreamNamePtr, makeDeepCopy); + }); + + delete this.wasmModule.simpleListeners![listener]; + } + /** * Ensures existence of the simple listeners table and registers the callback. * Intended for internal usage. diff --git a/mediapipe/web/graph_runner/graph_runner_image_lib.ts b/mediapipe/web/graph_runner/graph_runner_image_lib.ts index 72d5ad965..d9bb0568b 100644 --- a/mediapipe/web/graph_runner/graph_runner_image_lib.ts +++ b/mediapipe/web/graph_runner/graph_runner_image_lib.ts @@ -10,10 +10,11 @@ type LibConstructor = new (...args: any[]) => GraphRunner; /** An image returned from a MediaPipe graph. */ export interface WasmImage { - data: Uint8Array|Float32Array; + data: Uint8ClampedArray|Float32Array; width: number; height: number; } + /** * Declarations for Emscripten's WebAssembly Module behavior, so TS compiler * doesn't break our JS/C++ bridge. diff --git a/third_party/BUILD b/third_party/BUILD index e2044cfd9..7522bab1b 100644 --- a/third_party/BUILD +++ b/third_party/BUILD @@ -112,6 +112,7 @@ cmake_external( "WITH_JPEG": "ON", "WITH_PNG": "ON", "WITH_TIFF": "ON", + "WITH_OPENCL": "OFF", "WITH_WEBP": "OFF", # Optimization flags "CV_ENABLE_INTRINSICS": "ON", diff --git a/third_party/external_files.bzl b/third_party/external_files.bzl index c140003bf..3a08c61c5 100644 --- a/third_party/external_files.bzl +++ b/third_party/external_files.bzl @@ -67,13 +67,7 @@ def external_files(): http_file( name = "com_google_mediapipe_BUILD", sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3", - urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"], - ) - - http_file( - name = "com_google_mediapipe_BUILD_orig", - sha256 = "64d5343a6a5f9be06db0a5074a2260f9ae63a989fe01702832cd215680dc19c1", - urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD.orig?generation=1678323576393653"], + urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=166187566369397616783235763936531678737479599640"], ) http_file( @@ -136,6 +130,18 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs.jpg?generation=1661875684064150"], ) + http_file( + name = "com_google_mediapipe_cats_and_dogs_mask_dog1_png", + sha256 = "2ab37d56ba1e46e70b3ddbfe35dac51b18b597b76904c68d7d34c7c74c677d4c", + urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_mask_dog1.png?generation=1678840350058498"], + ) + + http_file( + name = "com_google_mediapipe_cats_and_dogs_mask_dog2_png", + sha256 = "2010850e2dd7f520fe53b9086d70913b6fb53b178cae15a373e5ee7ffb46824a", + urls = ["https://storage.googleapis.com/mediapipe-assets/cats_and_dogs_mask_dog2.png?generation=1678840352961684"], + ) + http_file( name = "com_google_mediapipe_cats_and_dogs_no_resizing_jpg", sha256 = "9d55933ed66bcdc63cd6509ee2518d7eed75d12db609238387ee4cc50b173e58", @@ -208,10 +214,34 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/corrupted_mobilenet_v1_0.25_224_1_default_1.tflite?generation=1661875706780536"], ) + http_file( + name = "com_google_mediapipe_deeplabv3_json", + sha256 = "f299835bd9ea1cceb25fdf40a761a22716cbd20025cd67c365a860527f178b7f", + urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.json?generation=1678818040715103"], + ) + http_file( name = "com_google_mediapipe_deeplabv3_tflite", - sha256 = "9711334db2b01d5894feb8ed0f5cb3e97d125b8d229f8d8692f625801818f5ef", - urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1661875711618421"], + sha256 = "5faed2c653905d3e22a8f6f29ee198da84e9b0e7936a207bf431f17f6b4d87ff", + urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3.tflite?generation=1678775085237701"], + ) + + http_file( + name = "com_google_mediapipe_deeplabv3_with_activation_json", + sha256 = "a7633476d02f970db3cc30f5f027bcb608149e02207b2ccae36a4b69d730c82c", + urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_with_activation.json?generation=1678818047050984"], + ) + + http_file( + name = "com_google_mediapipe_deeplabv3_without_labels_json", + sha256 = "7d045a583a4046f17a52d2078b0175607a45ed0cc187558325f9c66534c08401", + urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_without_labels.json?generation=1678818050191996"], + ) + + http_file( + name = "com_google_mediapipe_deeplabv3_without_metadata_tflite", + sha256 = "68a539782c2c6a72f8aac3724600124a85ed977162b44e84cbae5db717c933c6", + urls = ["https://storage.googleapis.com/mediapipe-assets/deeplabv3_without_metadata.tflite?generation=1678818053623010"], ) http_file( @@ -318,8 +348,8 @@ def external_files(): http_file( name = "com_google_mediapipe_face_landmarker_with_blendshapes_task", - sha256 = "a75c1ba70e4b8568000af2ad0b355ed559ab5d5793db50fa9ad241f8dc4fad5f", - urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmarker_with_blendshapes.task?generation=1678323586260800"], + sha256 = "b44e4cae6f5822456d60f33e7c852640d78c7e342aee7eacc22589451a0b9dc2", + urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmarker_with_blendshapes.task?generation=1678504998301299"], ) http_file( @@ -390,8 +420,8 @@ def external_files(): http_file( name = "com_google_mediapipe_hair_segmentation_tflite", - sha256 = "0bec40bc9ba97c4143f3d4225a935014abffea37c1f3766ae32aba3f2748e711", - urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1678218355806671"], + sha256 = "7cbddcfe6f6e10c3e0a509eb2e14225fda5c0de6c35e2e8c6ca8e3971988fc17", + urls = ["https://storage.googleapis.com/mediapipe-assets/hair_segmentation.tflite?generation=1678775089064550"], ) http_file( @@ -822,8 +852,8 @@ def external_files(): http_file( name = "com_google_mediapipe_portrait_expected_face_geometry_with_attention_pbtxt", - sha256 = "5cc57b8da3ad0527dce581fe1309f6b36043e5837e3f4f5af5e24005a99dc52a", - urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678323601064393"], + sha256 = "7ed1eed98e61e0a10811bb611c895d87c8023f398a36db01b6d9ba2e1ab09e16", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678737486927530"], ) http_file( @@ -862,10 +892,40 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_rotated.jpg?generation=1677194680138164"], ) + http_file( + name = "com_google_mediapipe_portrait_selfie_segmentation_expected_category_mask_jpg", + sha256 = "d8f20fa746e14067f668dd293f21bbc50ec81196d186386a6ded1278c3ec8f46", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_category_mask.jpg?generation=1678606935088873"], + ) + + http_file( + name = "com_google_mediapipe_portrait_selfie_segmentation_expected_confidence_mask_jpg", + sha256 = "25b723e90608edaf6ed92f382da703dc904a59c87525b6d271e60d9eed7a90e9", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_expected_confidence_mask.jpg?generation=1678606937358235"], + ) + + http_file( + name = "com_google_mediapipe_portrait_selfie_segmentation_landscape_expected_category_mask_jpg", + sha256 = "f5c3fa3d93f8e7289b69b8a89c2519276dfa5014dcc50ed6e86e8cd4d4ae7f27", + urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_selfie_segmentation_landscape_expected_category_mask.jpg?generation=1678606939469429"], + ) + http_file( name = "com_google_mediapipe_pose_detection_tflite", - sha256 = "a63c614bef30d35947f13be361820b1e4e3bec9cfeebf4d11216a18373108e85", - urls = ["https://storage.googleapis.com/mediapipe-assets/pose_detection.tflite?generation=1661875889147923"], + sha256 = "9ba9dd3d42efaaba86b4ff0122b06f29c4122e756b329d89dca1e297fd8f866c", + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_detection.tflite?generation=1678737489600422"], + ) + + http_file( + name = "com_google_mediapipe_pose_expected_detection_pbtxt", + sha256 = "e0d40e98dd5320a780a642c336d0c8720243ac5bcc0e39c4061ad970a503ae24", + urls = ["https://storage.googleapis.com/mediapipe-assets/pose_expected_detection.pbtxt?generation=1678737492211540"], + ) + + http_file( + name = "com_google_mediapipe_pose_jpg", + sha256 = "c8a830ed683c0276d713dd5aeda28f415f10cd6291972084a40d0d8b934ed62b", + urls = ["https://storage.googleapis.com/mediapipe-assets/pose.jpg?generation=1678737494661975"], ) http_file( @@ -964,40 +1024,52 @@ def external_files(): urls = ["https://storage.googleapis.com/mediapipe-assets/segmentation_input_rotation0.jpg?generation=1661875914048401"], ) + http_file( + name = "com_google_mediapipe_segmentation_mask_meta_json", + sha256 = "4294d53b309c1fbe38a5184de4057576c3dec14e07d16491f1dd459ac9116ab3", + urls = ["https://storage.googleapis.com/mediapipe-assets/segmentation_mask_meta.json?generation=1678818065134737"], + ) + + http_file( + name = "com_google_mediapipe_segmenter_labelmap_txt", + sha256 = "d9efa78274f1799ddbcab1f87263e19dae338c1697de47a5b270c9526c45d364", + urls = ["https://storage.googleapis.com/mediapipe-assets/segmenter_labelmap.txt?generation=1678818068181025"], + ) + http_file( name = "com_google_mediapipe_selfie_segm_128_128_3_expected_mask_jpg", - sha256 = "a295f3ab394a5e0caff2db5041337da58341ec331f1413ef91f56e0d650b4a1e", - urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3_expected_mask.jpg?generation=1661875916766416"], + sha256 = "1a2a068287d8bcd4184492485b3dbb95a09b763f4653fd729d14a836147eb383", + urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3_expected_mask.jpg?generation=1678606942616777"], ) http_file( name = "com_google_mediapipe_selfie_segm_128_128_3_tflite", - sha256 = "bb154f248543c0738e32f1c74375245651351a84746dc21f10bdfaabd8fae4ca", - urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3.tflite?generation=1661875919964123"], + sha256 = "8322982866488b063af6531b1d16ac27c7bf404135b7905f20aaf5e6af7aa45b", + urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_128_128_3.tflite?generation=1678775097370282"], ) http_file( name = "com_google_mediapipe_selfie_segm_144_256_3_expected_mask_jpg", - sha256 = "cfc699db9670585c04414d0d1a07b289a027ba99d6903d2219f897d34e2c9952", - urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3_expected_mask.jpg?generation=1661875922646736"], + sha256 = "2de433b6e8adabec2aaf80135232db900903ead4f2811c0c9378a6792b2a68b5", + urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3_expected_mask.jpg?generation=1678606945085676"], ) http_file( name = "com_google_mediapipe_selfie_segm_144_256_3_tflite", - sha256 = "5c770b8834ad50586599eae7710921be09d356898413fc0bf37a9458da0610eb", - urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3.tflite?generation=1661875925519713"], + sha256 = "f16a9551a408edeadd53f70d1d2911fc20f9f9de7a394129a268ca9faa2d6a08", + urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segm_144_256_3.tflite?generation=1678775099616375"], ) http_file( name = "com_google_mediapipe_selfie_segmentation_landscape_tflite", - sha256 = "4aafe6223bb8dac6fac8ca8ed56852870a33051ef3f6238822d282a109962894", - urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1661875928328455"], + sha256 = "28fb4c287d6295a2dba6c1f43b43315a37f927ddcd6693d635d625d176eef162", + urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation_landscape.tflite?generation=1678775102234495"], ) http_file( name = "com_google_mediapipe_selfie_segmentation_tflite", - sha256 = "8d13b7fae74af625c641226813616a2117bd6bca19eb3b75574621fc08557f27", - urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1661875931201364"], + sha256 = "b0e2ec6f95107795b952b27f3d92806b45f0bc069dac76dcd264cd1b90d61c6c", + urls = ["https://storage.googleapis.com/mediapipe-assets/selfie_segmentation.tflite?generation=1678775104900954"], ) http_file( @@ -1224,8 +1296,8 @@ def external_files(): http_file( name = "com_google_mediapipe_object_detection_saved_model_README_md", - sha256 = "fe163cf12fbd017738a2fd360c03d223e964ba6404ac75c635f5918784e9c34d", - urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md?generation=1661875995856372"], + sha256 = "acc23dee09f69210717ac060035c844ba902e8271486f1086f29fb156c236690", + urls = ["https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md?generation=1678737498915254"], ) http_file( diff --git a/third_party/halide.BUILD b/third_party/halide.BUILD new file mode 100644 index 000000000..02e701585 --- /dev/null +++ b/third_party/halide.BUILD @@ -0,0 +1,70 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@halide//:halide.bzl", "halide_language_copts") + +licenses(["notice"]) + +package( + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "language", + hdrs = ["include/Halide.h"], + copts = halide_language_copts(), + includes = ["include"], + deps = [ + ":runtime", + ], +) + +cc_library( + name = "runtime", + hdrs = glob([ + "include/HalideRuntime*.h", + "include/HalideBuffer*.h", + ]), + includes = ["include"], +) + +cc_library( + name = "lib_halide_static", + srcs = select({ + "@halide//:halide_config_windows_x86_64": [ + "lib/Release/Halide.lib", + "bin/Release/Halide.dll", + ], + "//conditions:default": [ + "lib/libHalide.a", + ], + }), + visibility = ["//visibility:private"], +) + +cc_library( + name = "gengen", + srcs = [ + "share/Halide/tools/GenGen.cpp", + ], + includes = [ + "include", + "share/Halide/tools", + ], + visibility = ["//visibility:public"], + deps = [ + ":language", + ":lib_halide_static", + ], +) diff --git a/third_party/halide/BUILD b/third_party/halide/BUILD new file mode 100644 index 000000000..82bab3ffd --- /dev/null +++ b/third_party/halide/BUILD @@ -0,0 +1 @@ +# This empty BUILD file is required to make Bazel treat this directory as a package. diff --git a/third_party/halide/BUILD.bazel b/third_party/halide/BUILD.bazel new file mode 100644 index 000000000..8b69a2503 --- /dev/null +++ b/third_party/halide/BUILD.bazel @@ -0,0 +1,45 @@ +# Copyright 2023 The MediaPipe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@halide//:halide.bzl", "halide_library_runtimes") + +licenses(["notice"]) + +package( + default_visibility = ["//visibility:public"], +) + +halide_library_runtimes() + +# Aliases to platform-specific targets. +[ + alias( + name = target_name, + actual = select( + { + ":halide_config_linux_x86_64": "@linux_halide//:%s" % target_name, + ":halide_config_macos_x86_64": "@macos_x86_64_halide//:%s" % target_name, + ":halide_config_macos_arm64": "@macos_arm_64_halide//:%s" % target_name, + ":halide_config_windows_x86_64": "@windows_halide//:%s" % target_name, + # deliberately no //condition:default clause here + }, + no_match_error = "Compiling Halide code requires that the build host is one of Linux x86-64, Windows x86-64, macOS x86-64, or macOS arm64.", + ), + ) + for target_name in [ + "language", + "runtime", + "gengen", + ] +] diff --git a/third_party/halide/halide.bzl b/third_party/halide/halide.bzl new file mode 100644 index 000000000..8d0d48e32 --- /dev/null +++ b/third_party/halide/halide.bzl @@ -0,0 +1,884 @@ +# Copyright 2023 The MediaPipe Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bazel build rules for Halide.""" + +load("@bazel_skylib//lib:collections.bzl", "collections") +load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "use_cpp_toolchain") + +def halide_language_copts(): + _common_opts = [ + "-fPIC", + "-frtti", + "-Wno-conversion", + "-Wno-sign-compare", + ] + _posix_opts = [ + "$(STACK_FRAME_UNLIMITED)", + "-fno-exceptions", + "-funwind-tables", + "-fvisibility-inlines-hidden", + ] + _msvc_opts = [ + "-D_CRT_SECURE_NO_WARNINGS", + "/MD", + ] + return _common_opts + select({ + "//conditions:default": _posix_opts, + "@mediapipe//mediapipe:windows": _msvc_opts, + }) + +def halide_language_linkopts(): + _linux_opts = [ + "-ldl", + "-lpthread", + "-lz", + "-rdynamic", + ] + _osx_opts = [ + "-lz", + "-Wl,-stack_size", + "-Wl,1000000", + ] + _msvc_opts = [] + return select({ + "//conditions:default": _linux_opts, + "@mediapipe//mediapipe:macos": _osx_opts, + "@mediapipe//mediapipe:windows": _msvc_opts, + }) + +def halide_runtime_linkopts(): + """ Return the linkopts needed when linking against halide_library_runtime. + + Returns: + List to be used for linkopts. + """ + _posix_opts = [ + "-ldl", + "-lpthread", + ] + _android_opts = [ + "-llog", + ] + _msvc_opts = [] + + return select({ + "//conditions:default": _posix_opts, + "@mediapipe//mediapipe:android": _android_opts, + "@mediapipe//mediapipe:windows": _msvc_opts, + }) + +# Map of halide-target-base -> config_settings +_HALIDE_TARGET_CONFIG_SETTINGS_MAP = { + # Android + "arm-32-android": ["@halide//:halide_config_android_arm"], + "arm-64-android": ["@halide//:halide_config_android_arm64"], + "x86-32-android": ["@halide//:halide_config_android_i386"], + "x86-64-android": ["@halide//:halide_config_android_x86_64"], + # iOS + "arm-32-ios": ["@halide//:halide_config_ios_arm"], + "arm-64-ios": ["@halide//:halide_config_ios_arm64"], + # OSX (or iOS simulator) + "x86-32-osx": ["@halide//:halide_config_macos_i386", "@halide//:halide_config_ios_i386"], + "x86-64-osx": ["@halide//:halide_config_macos_x86_64", "@halide//:halide_config_ios_x86_64"], + "arm-64-osx": ["@halide//:halide_config_macos_arm64"], + # Windows + "x86-64-windows": ["@halide//:halide_config_windows_x86_64"], + # Linux + "x86-64-linux": ["@halide//:halide_config_linux_x86_64"], + # deliberately nothing here using //conditions:default +} + +_HALIDE_TARGET_MAP_DEFAULT = { + "x86-64-linux": [ + "x86-64-linux-sse41-avx-avx2-fma", + "x86-64-linux-sse41", + "x86-64-linux", + ], + "x86-64-osx": [ + "x86-64-osx-sse41-avx-avx2-fma", + "x86-64-osx-sse41", + "x86-64-osx", + ], + "x86-64-windows": [ + "x86-64-windows-sse41-avx-avx2-fma", + "x86-64-windows-sse41", + "x86-64-windows", + ], +} + +def halide_library_default_target_map(): + return _HALIDE_TARGET_MAP_DEFAULT + +# Alphabetizes the features part of the target to make sure they always match no +# matter the concatenation order of the target string pieces. +def _canonicalize_target(halide_target): + if halide_target == "host": + return halide_target + if "," in halide_target: + fail("Multitarget may not be specified here") + tokens = halide_target.split("-") + if len(tokens) < 3: + fail("Illegal target: %s" % halide_target) + + # rejoin the tokens with the features sorted + return "-".join(tokens[0:3] + sorted(tokens[3:])) + +# Converts comma and dash separators to underscore and alphabetizes +# the features part of the target to make sure they always match no +# matter the concatenation order of the target string pieces. +def _halide_target_to_bazel_rule_name(multitarget): + subtargets = multitarget.split(",") + subtargets = [_canonicalize_target(st).replace("-", "_") for st in subtargets] + return "_".join(subtargets) + +# The second argument is True if there is a separate file generated +# for each subtarget of a multitarget output, False if not. The third +# argument is True if the output is a directory (vs. a single file). +# The fourth argument is a list of output group(s) that the files should +# be added to. + +_is_multi = True +_is_single = False +_is_file = False + +_output_extensions = { + "assembly": ("s", _is_multi, _is_file, []), + "bitcode": ("bc", _is_multi, _is_file, ["generated_bitcode"]), + "c_header": ("h", _is_single, _is_file, ["generated_headers"]), + "c_source": ("halide_generated.cpp", _is_multi, _is_file, []), + "compiler_log": ("halide_compiler_log", _is_single, _is_file, ["generated_object", "generated_compiler_log"]), + "cpp_stub": ("stub.h", _is_single, _is_file, []), + "featurization": ("featurization", _is_multi, _is_file, []), + "llvm_assembly": ("ll", _is_multi, _is_file, []), + "object": ("o", _is_single, _is_file, ["generated_object"]), + "python_extension": ("py.cpp", _is_single, _is_file, []), + "registration": ("registration.cpp", _is_single, _is_file, ["generated_registration"]), + "schedule": ("schedule.h", _is_single, _is_file, []), + "static_library": ("a", _is_single, _is_file, ["generated_object"]), + "stmt": ("stmt", _is_multi, _is_file, []), + "stmt_html": ("stmt.html", _is_multi, _is_file, []), +} + +def _add_output_file(f, fmt, output_files, output_dict, verbose_extra_outputs, verbose_output_paths): + if fmt in verbose_extra_outputs: + verbose_output_paths.append(f.path) + output_files.append(f) + if fmt in _output_extensions: + for group in _output_extensions[fmt][3]: + output_dict.setdefault(group, []).append(f) + +HalideFunctionNameInfo = provider(fields = ["function_name"]) +HalideGeneratorBinaryInfo = provider(fields = ["generator_binary"]) +HalideGeneratorNameInfo = provider(fields = ["generator_name_"]) +HalideGeneratorParamsInfo = provider(fields = ["generator_params"]) +HalideLibraryNameInfo = provider(fields = ["library_name"]) +HalideTargetFeaturesInfo = provider(fields = ["target_features"]) + +def _gengen_closure_impl(ctx): + return [ + HalideGeneratorBinaryInfo(generator_binary = ctx.attr.generator_binary), + HalideGeneratorNameInfo(generator_name_ = ctx.attr.generator_name_), + ] + +_gengen_closure = rule( + implementation = _gengen_closure_impl, + attrs = { + "generator_binary": attr.label( + executable = True, + allow_files = True, + mandatory = True, + cfg = "exec", + ), + # "generator_name" is apparently reserved by Bazel for attrs in rules + "generator_name_": attr.string(mandatory = True), + }, + provides = [HalideGeneratorBinaryInfo, HalideGeneratorNameInfo], +) + +def _halide_library_instance_impl(ctx): + generator_binary = ctx.attr.generator_closure[HalideGeneratorBinaryInfo].generator_binary if ctx.attr.generator_closure else "" + generator_name = ctx.attr.generator_closure[HalideGeneratorNameInfo].generator_name_ if ctx.attr.generator_closure else "" + return [ + HalideFunctionNameInfo(function_name = ctx.attr.function_name), + HalideGeneratorBinaryInfo(generator_binary = generator_binary), + HalideGeneratorNameInfo(generator_name_ = generator_name), + HalideGeneratorParamsInfo(generator_params = ctx.attr.generator_params), + HalideLibraryNameInfo(library_name = ctx.attr.library_name), + HalideTargetFeaturesInfo(target_features = ctx.attr.target_features), + ] + +_halide_library_instance = rule( + implementation = _halide_library_instance_impl, + attrs = { + "function_name": attr.string(), + "generator_closure": attr.label( + cfg = "exec", + providers = [HalideGeneratorBinaryInfo, HalideGeneratorNameInfo], + ), + "generator_params": attr.string_list(), + "library_name": attr.string(), + "target_features": attr.string_list(), + }, + provides = [ + HalideFunctionNameInfo, + HalideGeneratorBinaryInfo, + HalideGeneratorNameInfo, + HalideGeneratorParamsInfo, + HalideLibraryNameInfo, + HalideTargetFeaturesInfo, + ], +) + +def _gengen_impl(ctx): + if _has_dupes(ctx.attr.requested_outputs): + fail("Duplicate values in outputs: " + str(ctx.attr.requested_outputs)) + + function_name = ctx.attr.function_name[HalideFunctionNameInfo].function_name if ctx.attr.function_name else "" + generator_binary = ctx.attr.generator_binary[HalideGeneratorBinaryInfo].generator_binary if ctx.attr.generator_binary else "" + generator_name_ = ctx.attr.generator_name_[HalideGeneratorNameInfo].generator_name_ if ctx.attr.generator_name_ else "" + generator_params = ctx.attr.generator_params[HalideGeneratorParamsInfo].generator_params if ctx.attr.generator_params else [] + library_name = ctx.attr.library_name[HalideLibraryNameInfo].library_name if ctx.attr.library_name else "" + target_features = ctx.attr.target_features[HalideTargetFeaturesInfo].target_features if ctx.attr.target_features else [] + + for gp in generator_params: + if " " in gp: + fail("%s: Entries in generator_params must not contain spaces." % library_name) + + # Escape backslashes and double quotes. + generator_params = [gp.replace("\\", '\\\\"').replace('"', '\\"') for gp in generator_params] + + execution_requirements = {} + + # --- Calculate the output type(s) we're going to produce (and which ones should be verbose) + quiet_extra_outputs = [] + verbose_extra_outputs = [] + if ctx.attr.consider_halide_extra_outputs: + if "halide_extra_outputs" in ctx.var: + verbose_extra_outputs = ctx.var.get("halide_extra_outputs", "").split(",") + if "halide_extra_outputs_quiet" in ctx.var: + quiet_extra_outputs = ctx.var.get("halide_extra_outputs_quiet", "").split(",") + requested_outputs = sorted(collections.uniq(ctx.attr.requested_outputs + + verbose_extra_outputs + + quiet_extra_outputs)) + + # --- Assemble halide_target, adding extra features if necessary + base_target = ctx.attr.halide_base_target + if "," in base_target: + fail("halide_base_target should never be a multitarget") + if len(base_target.split("-")) != 3: + fail("halide_base_target should have exactly 3 components") + + target_features = target_features + ctx.var.get("halide_target_features", "").split(",") + + if "no_runtime" in target_features: + fail("Specifying 'no_runtime' in halide_target_features is not supported; " + + "please add 'add_halide_runtime_deps = False' to the halide_library() rule instead.") + + for san in ["asan", "msan", "tsan"]: + if san in target_features: + fail("halide_library doesn't support '%s' in halide_target_features; please build with --config=%s instead." % (san, san)) + + # Append the features common to everything. + target_features.append("c_plus_plus_name_mangling") + target_features.append("no_runtime") + + # Make it all neat and tidy. + target_features = sorted(collections.uniq(target_features)) + + # Get the multitarget list (if any) from halide_target_map + halide_targets = ctx.attr.halide_target_map.get(base_target, [base_target]) + + # Add the extra features to all of them + halide_targets = _add_features_to_all(halide_targets, target_features) + + leaf_name = ctx.attr.filename.split("/")[-1] + + output_files = [] + output_dict = {} + verbose_output_paths = [] + inputs = [] + + env = { + "HL_DEBUG_CODEGEN": str(ctx.var.get("halide_debug_codegen", 0)), + # --define halide_llvm_args=-time-passes is a typical usage + "HL_LLVM_ARGS": str(ctx.var.get("halide_llvm_args", "")), + } + + be_very_quiet = ctx.var.get("halide_experimental_quiet", False) # I'm hunting wabbit... + + # --- Calculate the final set of output files + for fmt in requested_outputs: + if fmt not in _output_extensions: + fail("Unknown Halide output '%s'; known outputs are %s" % + (fmt, sorted(_output_extensions.keys()))) + ext, is_multiple, is_dir, _ = _output_extensions[fmt] + + # Special-case Windows file extensions + if "windows" in halide_targets[-1]: + if ext == "o": + ext = "obj" + if ext == "a": + ext = "lib" + if is_multiple and len(halide_targets) > 1: + for h in halide_targets: + suffix = _canonicalize_target(h) + name = "%s-%s.%s" % (ctx.attr.filename, suffix, ext) + f = ctx.actions.declare_directory(name) if is_dir else ctx.actions.declare_file(name) + _add_output_file(f, fmt, output_files, output_dict, verbose_extra_outputs, verbose_output_paths) + else: + name = "%s.%s" % (ctx.attr.filename, ext) + f = ctx.actions.declare_directory(name) if is_dir else ctx.actions.declare_file(name) + _add_output_file(f, fmt, output_files, output_dict, verbose_extra_outputs, verbose_output_paths) + + # --- Progress message(s), including log info about any 'extra' files being output due to --define halide_extra_output + progress_message = "Executing generator %s with target (%s) args (%s)." % ( + generator_name_, + ",".join(halide_targets), + " ".join(generator_params), + ) + + for f in output_files: + if any([f.path.endswith(suf) for suf in [".h", ".a", ".o", ".lib", ".registration.cpp", ".bc", ".halide_compiler_log"]]): + continue + + # If an extra output was specified via --define halide_extra_outputs=foo on the command line, + # add to the progress message (so that it is ephemeral and doesn't clog stdout). + # + # (Trailing space is intentional since Starlark will append a period to the end, + # making copy-n-paste harder than it might otherwise be...) + if not be_very_quiet: + extra_msg = "Emitting extra Halide output: %s " % f.path + progress_message += "\n" + extra_msg + if f.path in verbose_output_paths: + # buildifier: disable=print + print(extra_msg) + + # --- Construct the arguments list for the Generator + arguments = ctx.actions.args() + arguments.add("-o", output_files[0].dirname) + if ctx.attr.generate_runtime: + arguments.add("-r", leaf_name) + if len(halide_targets) > 1: + fail("Only one halide_target allowed when using generate_runtime") + if function_name: + fail("halide_function_name not allowed when using generate_runtime") + else: + arguments.add("-g", generator_name_) + arguments.add("-n", leaf_name) + if function_name: + arguments.add("-f", function_name) + + if requested_outputs: + arguments.add_joined("-e", requested_outputs, join_with = ",") + + # Can't use add_joined(), as it will insert a space after target= + arguments.add("target=%s" % (",".join(halide_targets))) + if generator_params: + for p in generator_params: + for s in ["target"]: + if p.startswith("%s=" % s): + fail("You cannot specify %s in the generator_params parameter in bazel." % s) + arguments.add_all(generator_params) + + show_gen_arg = ctx.var.get("halide_show_generator_command", "") + + # If it's an exact match of a fully qualified path, show just that one. + # If it's * or "all", match everything. + if library_name and show_gen_arg in [library_name, "all", "*"] and not ctx.attr.generate_runtime: + # The 'Args' object can be printed, but can't be usefully converted to a string, or iterated, + # so we'll reproduce the logic here. We'll also take the opportunity to add or augment + # some args to be more useful to whoever runs it (eg, add `-v=1`, add some output files). + sg_args = ["-v", "1"] + sg_args += ["-o", "/tmp"] + sg_args += ["-g", generator_name_] + sg_args += ["-n", leaf_name] + if function_name: + sg_args += ["-f", function_name] + if requested_outputs: + # Ensure that several commonly-useful output are added + ro = sorted(collections.uniq(requested_outputs + ["stmt", "assembly", "llvm_assembly"])) + sg_args += ["-e", ",".join(ro)] + sg_args.append("target=%s" % (",".join(halide_targets))) + + if generator_params: + sg_args += generator_params + + # buildifier: disable=print + print( + "\n\nTo locally run the Generator for", + library_name, + "use the command:\n\n", + "bazel run -c opt", + generator_binary.label, + "--", + " ".join(sg_args), + "\n\n", + ) + + # Finally... run the Generator. + ctx.actions.run( + execution_requirements = execution_requirements, + arguments = [arguments], + env = env, + executable = generator_binary.files_to_run.executable, + mnemonic = "ExecuteHalideGenerator", + inputs = depset(direct = inputs), + outputs = output_files, + progress_message = progress_message, + exec_group = "generator", + ) + + return [ + DefaultInfo(files = depset(direct = output_files)), + OutputGroupInfo(**output_dict), + ] + +_gengen = rule( + implementation = _gengen_impl, + attrs = { + "consider_halide_extra_outputs": attr.bool(), + "filename": attr.string(), + "generate_runtime": attr.bool(default = False), + "generator_binary": attr.label( + cfg = "exec", + providers = [HalideGeneratorBinaryInfo], + ), + # "generator_name" is apparently reserved by Bazel for attrs in rules + "generator_name_": attr.label( + cfg = "exec", + providers = [HalideGeneratorNameInfo], + ), + "halide_base_target": attr.string(), + "function_name": attr.label( + cfg = "target", + providers = [ + HalideFunctionNameInfo, + ], + ), + "generator_params": attr.label( + cfg = "target", + providers = [ + HalideGeneratorParamsInfo, + ], + ), + "library_name": attr.label( + cfg = "target", + providers = [ + HalideLibraryNameInfo, + ], + ), + "target_features": attr.label( + cfg = "target", + providers = [ + HalideTargetFeaturesInfo, + ], + ), + "halide_target_map": attr.string_list_dict(), + "requested_outputs": attr.string_list(), + "_cc_toolchain": attr.label(default = "@bazel_tools//tools/cpp:current_cc_toolchain"), + }, + fragments = ["cpp"], + output_to_genfiles = True, + toolchains = use_cpp_toolchain(), + exec_groups = { + "generator": exec_group(), + }, +) + +def _add_target_features(target, features): + if "," in target: + fail("Cannot use multitarget here") + new_target = target.split("-") + for f in features: + if f and f not in new_target: + new_target.append(f) + return "-".join(new_target) + +def _add_features_to_all(halide_targets, features): + return [_canonicalize_target(_add_target_features(t, features)) for t in halide_targets] + +def _has_dupes(some_list): + clean = collections.uniq(some_list) + return sorted(some_list) != sorted(clean) + +# Target features which do not affect runtime compatibility. +_IRRELEVANT_FEATURES = collections.uniq([ + "arm_dot_prod", + "arm_fp16", + "c_plus_plus_name_mangling", + "check_unsafe_promises", + "embed_bitcode", + "enable_llvm_loop_opt", + "large_buffers", + "no_asserts", + "no_bounds_query", + "profile", + "strict_float", + "sve", + "sve2", + "trace_loads", + "trace_pipeline", + "trace_realizations", + "trace_stores", + "user_context", + "wasm_sat_float_to_int", + "wasm_signext", + "wasm_simd128", +]) + +def _discard_irrelevant_features(halide_target_features = []): + return sorted(collections.uniq([f for f in halide_target_features if f not in _IRRELEVANT_FEATURES])) + +def _halide_library_runtime_target_name(halide_target_features = []): + return "_".join(["halide_library_runtime"] + _discard_irrelevant_features(halide_target_features)) + +def _define_halide_library_runtime( + halide_target_features = [], + compatible_with = []): + target_name = _halide_library_runtime_target_name(halide_target_features) + + if not native.existing_rule("halide_library_runtime.generator"): + halide_generator( + name = "halide_library_runtime.generator", + srcs = [], + deps = [], + visibility = ["//visibility:private"], + ) + condition_deps = {} + for base_target, cfgs in _HALIDE_TARGET_CONFIG_SETTINGS_MAP.items(): + target_features = _discard_irrelevant_features(halide_target_features) + halide_target_name = _halide_target_to_bazel_rule_name(base_target) + gengen_name = "%s_%s" % (halide_target_name, target_name) + + _halide_library_instance( + name = "%s.library_instance" % gengen_name, + compatible_with = compatible_with, + function_name = "", + generator_closure = ":halide_library_runtime.generator_closure", + generator_params = [], + library_name = "", + target_features = target_features, + visibility = ["//visibility:private"], + ) + hl_instance = ":%s.library_instance" % gengen_name + + _gengen( + name = gengen_name, + compatible_with = compatible_with, + filename = "%s/%s" % (halide_target_name, target_name), + generate_runtime = True, + generator_binary = hl_instance, + generator_name_ = hl_instance, + halide_base_target = base_target, + requested_outputs = ["object"], + tags = ["manual"], + target_features = hl_instance, + visibility = ["@halide//:__subpackages__"], + ) + for cfg in cfgs: + condition_deps[cfg] = [":%s" % gengen_name] + + deps = [] + native.cc_library( + name = target_name, + compatible_with = compatible_with, + srcs = select(condition_deps), + linkopts = halide_runtime_linkopts(), + tags = ["manual"], + deps = deps, + visibility = ["//visibility:public"], + ) + + return target_name + +def _standard_library_runtime_features(): + _standard_features = [ + [], + ["cuda"], + ["metal"], + ["opencl"], + ["openglcompute"], + ["openglcompute", "egl"], + ] + return [f for f in _standard_features] + [f + ["debug"] for f in _standard_features] + +def _standard_library_runtime_names(): + return collections.uniq([_halide_library_runtime_target_name(f) for f in _standard_library_runtime_features()]) + +def halide_library_runtimes(compatible_with = []): + # Note that we don't use all of these combinations + # (and some are invalid), but that's ok. + for cpu in ["arm", "arm64", "i386", "x86_64"]: + for os in ["android", "linux", "windows", "ios", "macos"]: + native.config_setting( + name = "halide_config_%s_%s" % (os, cpu), + constraint_values = [ + "@platforms//os:%s" % os, + "@platforms//cpu:%s" % cpu, + ], + visibility = ["//visibility:public"], + ) + + unused = [ + _define_halide_library_runtime(f, compatible_with = compatible_with) + for f in _standard_library_runtime_features() + ] + unused = unused # unused variable + +def halide_generator( + name, + srcs, + compatible_with = [], + copts = [], + deps = [], + generator_name = "", + includes = [], + tags = [], + testonly = False, + visibility = None): + if not name.endswith(".generator"): + fail("halide_generator rules must end in .generator") + + basename = name[:-10] # strip ".generator" suffix + if not generator_name: + generator_name = basename + + # Note: This target is public, but should not be needed by the vast + # majority of users. Unless you are writing a custom Bazel rule that + # involves Halide generation, you most probably won't need to depend on + # this rule. + native.cc_binary( + name = name, + copts = copts + halide_language_copts(), + linkopts = halide_language_linkopts(), + compatible_with = compatible_with, + srcs = srcs, + deps = [ + "@halide//:gengen", + "@halide//:language", + ] + deps, + tags = ["manual"] + tags, + testonly = testonly, + visibility = ["//visibility:public"], + ) + + _gengen_closure( + name = "%s_closure" % name, + generator_binary = name, + generator_name_ = generator_name, + compatible_with = compatible_with, + testonly = testonly, + visibility = ["//visibility:private"], + ) + +# This rule exists to allow us to select() on halide_target_features. +def _select_halide_library_runtime_impl(ctx): + f = ctx.attr.halide_target_features + + standard_runtimes = {t.label.name: t for t in ctx.attr._standard_runtimes} + + f = sorted(_discard_irrelevant_features(collections.uniq(f))) + runtime_name = _halide_library_runtime_target_name(f) + if runtime_name not in standard_runtimes: + fail(("There is no Halide runtime available for the feature set combination %s. " + + "Please use contact information from halide-lang.org to contact the Halide " + + "team to add the right combination.") % str(f)) + + return standard_runtimes[runtime_name][CcInfo] + +_select_halide_library_runtime = rule( + implementation = _select_halide_library_runtime_impl, + attrs = { + "halide_target_features": attr.string_list(), + "_standard_runtimes": attr.label_list( + default = ["@halide//:%s" % n for n in _standard_library_runtime_names()], + providers = [CcInfo], + ), + }, + provides = [CcInfo], +) + +def halide_library_from_generator( + name, + generator, + add_halide_runtime_deps = True, + compatible_with = [], + deps = [], + function_name = None, + generator_params = [], + halide_target_features = [], + halide_target_map = halide_library_default_target_map(), + includes = [], + namespace = None, + tags = [], + testonly = False, + visibility = None): + if not function_name: + function_name = name + + if namespace: + function_name = "%s::%s" % (namespace, function_name) + + generator_closure = "%s_closure" % generator + + _halide_library_instance( + name = "%s.library_instance" % name, + compatible_with = compatible_with, + function_name = function_name, + generator_closure = generator_closure, + generator_params = generator_params, + library_name = "//%s:%s" % (native.package_name(), name), + target_features = halide_target_features, + testonly = testonly, + visibility = ["//visibility:private"], + ) + hl_instance = ":%s.library_instance" % name + + condition_deps = {} + for base_target, cfgs in _HALIDE_TARGET_CONFIG_SETTINGS_MAP.items(): + base_target_name = _halide_target_to_bazel_rule_name(base_target) + gengen_name = "%s_%s" % (base_target_name, name) + _gengen( + name = gengen_name, + compatible_with = compatible_with, + consider_halide_extra_outputs = True, + filename = "%s/%s" % (base_target_name, name), + function_name = hl_instance, + generator_binary = generator_closure, + generator_name_ = generator_closure, + generator_params = hl_instance, + halide_base_target = base_target, + halide_target_map = halide_target_map, + library_name = hl_instance, + requested_outputs = ["static_library"], + tags = ["manual"] + tags, + target_features = hl_instance, + testonly = testonly, + ) + for cfg in cfgs: + condition_deps[cfg] = [":%s" % gengen_name] + + # Use a canonical target to build CC, regardless of config detected + cc_base_target = "x86-64-linux" + + for output, target_name in [ + ("c_header", "%s_h" % name), + ("c_source", "%s_cc" % name), + ]: + _gengen( + name = target_name, + compatible_with = compatible_with, + filename = name, + function_name = hl_instance, + generator_binary = generator_closure, + generator_name_ = generator_closure, + generator_params = hl_instance, + halide_base_target = cc_base_target, + library_name = hl_instance, + requested_outputs = [output], + tags = ["manual"] + tags, + target_features = hl_instance, + testonly = testonly, + ) + + _select_halide_library_runtime( + name = "%s.halide_library_runtime_deps" % name, + halide_target_features = halide_target_features, + compatible_with = compatible_with, + tags = tags, + visibility = ["//visibility:private"], + ) + + native.filegroup( + name = "%s_object" % name, + srcs = select(condition_deps), + output_group = "generated_object", + visibility = ["//visibility:private"], + compatible_with = compatible_with, + tags = tags, + testonly = testonly, + ) + + native.cc_library( + name = name, + srcs = ["%s_object" % name], + hdrs = [ + ":%s_h" % name, + ], + deps = deps + + ["@halide//:runtime"] + # for HalideRuntime.h, etc + ([":%s.halide_library_runtime_deps" % name] if add_halide_runtime_deps else []), # for the runtime implementation + defines = ["HALIDE_FUNCTION_ATTRS=HALIDE_MUST_USE_RESULT"], + compatible_with = compatible_with, + includes = includes, + tags = tags, + testonly = testonly, + visibility = visibility, + linkstatic = 1, + ) + + # Return the fully-qualified built target name. + return "//%s:%s" % (native.package_name(), name) + +def halide_library( + name, + srcs = [], + add_halide_runtime_deps = True, + copts = [], + compatible_with = [], + filter_deps = [], + function_name = None, + generator_params = [], + generator_deps = [], + generator_name = None, + halide_target_features = [], + halide_target_map = halide_library_default_target_map(), + includes = [], + namespace = None, + tags = [], + testonly = False, + visibility = None): + if not srcs and not generator_deps: + fail("halide_library needs at least one of srcs or generator_deps to provide a generator") + + halide_generator( + name = "%s.generator" % name, + srcs = srcs, + compatible_with = compatible_with, + generator_name = generator_name, + deps = generator_deps, + includes = includes, + copts = copts, + tags = tags, + testonly = testonly, + visibility = visibility, + ) + + return halide_library_from_generator( + name = name, + generator = ":%s.generator" % name, + add_halide_runtime_deps = add_halide_runtime_deps, + compatible_with = compatible_with, + deps = filter_deps, + function_name = function_name, + generator_params = generator_params, + halide_target_features = halide_target_features, + halide_target_map = halide_target_map, + includes = includes, + namespace = namespace, + tags = tags, + testonly = testonly, + visibility = visibility, + )