Merge branch 'google:master' into face-landmarker-python
|  | @ -118,9 +118,9 @@ on how to build MediaPipe examples. | |||
| *   With a TensorFlow Model | ||||
| 
 | ||||
|     This uses the | ||||
|     [TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model) | ||||
|     [TensorFlow model](https://github.com/google/mediapipe/tree/v0.8.10/mediapipe/models/object_detection_saved_model) | ||||
|     ( see also | ||||
|     [model info](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md)), | ||||
|     [model info](https://github.com/google/mediapipe/tree/master/mediapipe/g3doc/solutions/object_detection_saved_model.md)), | ||||
|     and the pipeline is implemented in this | ||||
|     [graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt). | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										62
									
								
								docs/solutions/object_detection_saved_model.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						|  | @ -0,0 +1,62 @@ | |||
| ## TensorFlow/TFLite Object Detection Model | ||||
| 
 | ||||
| ### TensorFlow model | ||||
| 
 | ||||
| The model is trained on [MSCOCO 2014](http://cocodataset.org) dataset using [TensorFlow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection). It is a MobileNetV2-based SSD model with 0.5 depth multiplier. Detailed training configuration is in the provided `pipeline.config`. The model is a relatively compact model which has `0.171 mAP` to achieve real-time performance on mobile devices. You can compare it with other models from the [TensorFlow detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md). | ||||
| 
 | ||||
| 
 | ||||
| ### TFLite model | ||||
| 
 | ||||
| The TFLite model is converted from the TensorFlow above. The steps needed to convert the model are similar to [this tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) with minor modifications. Assuming now we have a trained TensorFlow model which includes the checkpoint files and the training configuration file, for example the files provided in this repo: | ||||
| 
 | ||||
|    * `model.ckpt.index` | ||||
|    * `model.ckpt.meta` | ||||
|    * `model.ckpt.data-00000-of-00001` | ||||
|    * `pipeline.config` | ||||
| 
 | ||||
| Make sure you have installed these [python libraries](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1.md). Then to get the frozen graph, run the `export_tflite_ssd_graph.py` script from the `models/research` directory with this command: | ||||
| 
 | ||||
| ```bash | ||||
| $ PATH_TO_MODEL=path/to/the/model | ||||
| $ bazel run object_detection:export_tflite_ssd_graph -- \ | ||||
|     --pipeline_config_path ${PATH_TO_MODEL}/pipeline.config \ | ||||
|     --trained_checkpoint_prefix ${PATH_TO_MODEL}/model.ckpt \ | ||||
|     --output_directory ${PATH_TO_MODEL} \ | ||||
|     --add_postprocessing_op=False | ||||
| ``` | ||||
| 
 | ||||
| The exported model contains two files: | ||||
| 
 | ||||
|    * `tflite_graph.pb` | ||||
|    * `tflite_graph.pbtxt` | ||||
| 
 | ||||
| The difference between this step and the one in [the tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) is that we set `add_postprocessing_op` to False. In MediaPipe, we have provided all the calculators needed for post-processing such that we can exclude the custom TFLite ops for post-processing in the original graph, e.g., non-maximum suppression. This enables the flexibility to integrate with different post-processing algorithms and implementations. | ||||
| 
 | ||||
| Optional: You can install and use the [graph tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms) to inspect the input/output of the exported model: | ||||
| 
 | ||||
| ```bash | ||||
| $ bazel run graph_transforms:summarize_graph -- \ | ||||
|     --in_graph=${PATH_TO_MODEL}/tflite_graph.pb | ||||
| ``` | ||||
| 
 | ||||
| You should be able to see the input image size of the model is 320x320 and the outputs of the model are: | ||||
| 
 | ||||
|    * `raw_outputs/box_encodings` | ||||
|    * `raw_outputs/class_predictions` | ||||
| 
 | ||||
| The last step is to convert the model to TFLite. You can look at [this guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md) for more detail. For this example, you just need to run: | ||||
| 
 | ||||
| ```bash | ||||
| $ tflite_convert --  \ | ||||
|   --graph_def_file=${PATH_TO_MODEL}/tflite_graph.pb \ | ||||
|   --output_file=${PATH_TO_MODEL}/model.tflite \ | ||||
|   --input_format=TENSORFLOW_GRAPHDEF \ | ||||
|   --output_format=TFLITE \ | ||||
|   --inference_type=FLOAT \ | ||||
|   --input_shapes=1,320,320,3 \ | ||||
|   --input_arrays=normalized_input_image_tensor \ | ||||
|   --output_arrays=raw_outputs/box_encodings,raw_outputs/class_predictions | ||||
| 
 | ||||
| ``` | ||||
| 
 | ||||
| Now you have the TFLite model `model.tflite` ready to use with MediaPipe Object Detection graphs. Please see the examples for more detail. | ||||
|  | @ -269,6 +269,7 @@ Supported configuration options: | |||
| ```python | ||||
| import cv2 | ||||
| import mediapipe as mp | ||||
| import numpy as np | ||||
| mp_drawing = mp.solutions.drawing_utils | ||||
| mp_drawing_styles = mp.solutions.drawing_styles | ||||
| mp_pose = mp.solutions.pose | ||||
|  |  | |||
|  | @ -748,6 +748,7 @@ cc_test( | |||
|         "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png", | ||||
|         "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png", | ||||
|         "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png", | ||||
|         "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png", | ||||
|         "//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png", | ||||
|     ], | ||||
|     tags = ["desktop_only_test"], | ||||
|  |  | |||
|  | @ -29,6 +29,9 @@ class AffineTransformation { | |||
|   // pixels will be calculated.
 | ||||
|   enum class BorderMode { kZero, kReplicate }; | ||||
| 
 | ||||
|   // Pixel sampling interpolation method.
 | ||||
|   enum class Interpolation { kLinear, kCubic }; | ||||
| 
 | ||||
|   struct Size { | ||||
|     int width; | ||||
|     int height; | ||||
|  |  | |||
|  | @ -77,8 +77,11 @@ class GlTextureWarpAffineRunner | |||
|                                           std::unique_ptr<GpuBuffer>> { | ||||
|  public: | ||||
|   GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper, | ||||
|                             GpuOrigin::Mode gpu_origin) | ||||
|       : gl_helper_(gl_helper), gpu_origin_(gpu_origin) {} | ||||
|                             GpuOrigin::Mode gpu_origin, | ||||
|                             AffineTransformation::Interpolation interpolation) | ||||
|       : gl_helper_(gl_helper), | ||||
|         gpu_origin_(gpu_origin), | ||||
|         interpolation_(interpolation) {} | ||||
|   absl::Status Init() { | ||||
|     return gl_helper_->RunInGlContext([this]() -> absl::Status { | ||||
|       const GLint attr_location[kNumAttributes] = { | ||||
|  | @ -103,28 +106,83 @@ class GlTextureWarpAffineRunner | |||
|             } | ||||
|           )"; | ||||
| 
 | ||||
|       // TODO Move bicubic code to common shared place.
 | ||||
|       constexpr GLchar kFragShader[] = R"( | ||||
|             DEFAULT_PRECISION(highp, float) | ||||
|             in vec2 sample_coordinate; | ||||
|             uniform sampler2D input_texture; | ||||
|         DEFAULT_PRECISION(highp, float) | ||||
| 
 | ||||
|           #ifdef GL_ES | ||||
|             #define fragColor gl_FragColor | ||||
|           #else | ||||
|             out vec4 fragColor; | ||||
|           #endif  // defined(GL_ES);
 | ||||
|         in vec2 sample_coordinate; | ||||
|         uniform sampler2D input_texture; | ||||
|         uniform vec2 input_size; | ||||
| 
 | ||||
|             void main() { | ||||
|               vec4 color = texture2D(input_texture, sample_coordinate); | ||||
|           #ifdef CUSTOM_ZERO_BORDER_MODE | ||||
|               float out_of_bounds = | ||||
|                   float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 || | ||||
|                         sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0); | ||||
|               color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds); | ||||
|           #endif  // defined(CUSTOM_ZERO_BORDER_MODE)
 | ||||
|               fragColor = color; | ||||
|             } | ||||
|           )"; | ||||
|       #ifdef GL_ES | ||||
|         #define fragColor gl_FragColor | ||||
|       #else | ||||
|         out vec4 fragColor; | ||||
|       #endif  // defined(GL_ES);
 | ||||
| 
 | ||||
|       #ifdef CUBIC_INTERPOLATION | ||||
|         vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) { | ||||
|           const vec2 halve = vec2(0.5,0.5); | ||||
|           const vec2 one = vec2(1.0,1.0); | ||||
|           const vec2 two = vec2(2.0,2.0); | ||||
|           const vec2 three = vec2(3.0,3.0); | ||||
|           const vec2 six = vec2(6.0,6.0); | ||||
| 
 | ||||
|           // Calculate the fraction and integer.
 | ||||
|           tex_coord = tex_coord * tex_size - halve; | ||||
|           vec2 frac = fract(tex_coord); | ||||
|           vec2 index = tex_coord - frac + halve; | ||||
| 
 | ||||
|           // Calculate weights for Catmull-Rom filter.
 | ||||
|           vec2 w0 = frac * (-halve + frac * (one - halve * frac)); | ||||
|           vec2 w1 = one + frac * frac * (-(two+halve) + three/two * frac); | ||||
|           vec2 w2 = frac * (halve + frac * (two - three/two * frac)); | ||||
|           vec2 w3 = frac * frac * (-halve + halve * frac); | ||||
| 
 | ||||
|           // Calculate weights to take advantage of bilinear texture lookup.
 | ||||
|           vec2 w12 = w1 + w2; | ||||
|           vec2 offset12 = w2 / (w1 + w2); | ||||
| 
 | ||||
|           vec2 index_tl = index - one; | ||||
|           vec2 index_br = index + two; | ||||
|           vec2 index_eq = index + offset12; | ||||
| 
 | ||||
|           index_tl /= tex_size; | ||||
|           index_br /= tex_size; | ||||
|           index_eq /= tex_size; | ||||
| 
 | ||||
|           // 9 texture lookup and linear blending.
 | ||||
|           vec4 color = vec4(0.0); | ||||
|           color += texture2D(tex, vec2(index_tl.x, index_tl.y)) * w0.x * w0.y; | ||||
|           color += texture2D(tex, vec2(index_eq.x, index_tl.y)) * w12.x *w0.y; | ||||
|           color += texture2D(tex, vec2(index_br.x, index_tl.y)) * w3.x * w0.y; | ||||
| 
 | ||||
|           color += texture2D(tex, vec2(index_tl.x, index_eq.y)) * w0.x * w12.y; | ||||
|           color += texture2D(tex, vec2(index_eq.x, index_eq.y)) * w12.x *w12.y; | ||||
|           color += texture2D(tex, vec2(index_br.x, index_eq.y)) * w3.x * w12.y; | ||||
| 
 | ||||
|           color += texture2D(tex, vec2(index_tl.x, index_br.y)) * w0.x * w3.y; | ||||
|           color += texture2D(tex, vec2(index_eq.x, index_br.y)) * w12.x *w3.y; | ||||
|           color += texture2D(tex, vec2(index_br.x, index_br.y)) * w3.x * w3.y; | ||||
|           return color; | ||||
|         } | ||||
|       #else | ||||
|         vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) { | ||||
|           return texture2D(tex, tex_coord); | ||||
|         } | ||||
|       #endif  // defined(CUBIC_INTERPOLATION)
 | ||||
| 
 | ||||
|         void main() { | ||||
|           vec4 color = sample(input_texture, sample_coordinate, input_size); | ||||
|       #ifdef CUSTOM_ZERO_BORDER_MODE | ||||
|           float out_of_bounds = | ||||
|               float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 || | ||||
|                     sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0); | ||||
|           color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds); | ||||
|       #endif  // defined(CUSTOM_ZERO_BORDER_MODE)
 | ||||
|           fragColor = color; | ||||
|         } | ||||
|       )"; | ||||
| 
 | ||||
|       // Create program and set parameters.
 | ||||
|       auto create_fn = [&](const std::string& vs, | ||||
|  | @ -137,14 +195,28 @@ class GlTextureWarpAffineRunner | |||
|         glUseProgram(program); | ||||
|         glUniform1i(glGetUniformLocation(program, "input_texture"), 1); | ||||
|         GLint matrix_id = glGetUniformLocation(program, "transform_matrix"); | ||||
|         return Program{.id = program, .matrix_id = matrix_id}; | ||||
|         GLint size_id = glGetUniformLocation(program, "input_size"); | ||||
|         return Program{ | ||||
|             .id = program, .matrix_id = matrix_id, .size_id = size_id}; | ||||
|       }; | ||||
| 
 | ||||
|       const std::string vert_src = | ||||
|           absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader); | ||||
| 
 | ||||
|       const std::string frag_src = absl::StrCat( | ||||
|           mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader); | ||||
|       std::string interpolation_def; | ||||
|       switch (interpolation_) { | ||||
|         case AffineTransformation::Interpolation::kCubic: | ||||
|           interpolation_def = R"( | ||||
|             #define CUBIC_INTERPOLATION | ||||
|           )"; | ||||
|           break; | ||||
|         case AffineTransformation::Interpolation::kLinear: | ||||
|           break; | ||||
|       } | ||||
| 
 | ||||
|       const std::string frag_src = | ||||
|           absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble, | ||||
|                        interpolation_def, kFragShader); | ||||
| 
 | ||||
|       ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src)); | ||||
| 
 | ||||
|  | @ -152,9 +224,9 @@ class GlTextureWarpAffineRunner | |||
|         std::string custom_zero_border_mode_def = R"( | ||||
|           #define CUSTOM_ZERO_BORDER_MODE | ||||
|         )"; | ||||
|         const std::string frag_custom_zero_src = | ||||
|             absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble, | ||||
|                          custom_zero_border_mode_def, kFragShader); | ||||
|         const std::string frag_custom_zero_src = absl::StrCat( | ||||
|             mediapipe::kMediaPipeFragmentShaderPreamble, | ||||
|             custom_zero_border_mode_def, interpolation_def, kFragShader); | ||||
|         return create_fn(vert_src, frag_custom_zero_src); | ||||
|       }; | ||||
| #if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED | ||||
|  | @ -256,6 +328,7 @@ class GlTextureWarpAffineRunner | |||
|     } | ||||
|     glUseProgram(program->id); | ||||
| 
 | ||||
|     // uniforms
 | ||||
|     Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data()); | ||||
|     if (IsMatrixVerticalFlipNeeded(gpu_origin_)) { | ||||
|       // @matrix describes affine transformation in terms of TOP LEFT origin, so
 | ||||
|  | @ -275,6 +348,10 @@ class GlTextureWarpAffineRunner | |||
|     eigen_mat.transposeInPlace(); | ||||
|     glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data()); | ||||
| 
 | ||||
|     if (interpolation_ == AffineTransformation::Interpolation::kCubic) { | ||||
|       glUniform2f(program->size_id, texture.width(), texture.height()); | ||||
|     } | ||||
| 
 | ||||
|     // vao
 | ||||
|     glBindVertexArray(vao_); | ||||
| 
 | ||||
|  | @ -327,6 +404,7 @@ class GlTextureWarpAffineRunner | |||
|   struct Program { | ||||
|     GLuint id; | ||||
|     GLint matrix_id; | ||||
|     GLint size_id; | ||||
|   }; | ||||
|   std::shared_ptr<GlCalculatorHelper> gl_helper_; | ||||
|   GpuOrigin::Mode gpu_origin_; | ||||
|  | @ -335,6 +413,8 @@ class GlTextureWarpAffineRunner | |||
|   Program program_; | ||||
|   std::optional<Program> program_custom_zero_; | ||||
|   GLuint framebuffer_ = 0; | ||||
|   AffineTransformation::Interpolation interpolation_ = | ||||
|       AffineTransformation::Interpolation::kLinear; | ||||
| }; | ||||
| 
 | ||||
| #undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED | ||||
|  | @ -344,9 +424,10 @@ class GlTextureWarpAffineRunner | |||
| absl::StatusOr<std::unique_ptr< | ||||
|     AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>> | ||||
| CreateAffineTransformationGlRunner( | ||||
|     std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin) { | ||||
|   auto runner = | ||||
|       absl::make_unique<GlTextureWarpAffineRunner>(gl_helper, gpu_origin); | ||||
|     std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin, | ||||
|     AffineTransformation::Interpolation interpolation) { | ||||
|   auto runner = absl::make_unique<GlTextureWarpAffineRunner>( | ||||
|       gl_helper, gpu_origin, interpolation); | ||||
|   MP_RETURN_IF_ERROR(runner->Init()); | ||||
|   return runner; | ||||
| } | ||||
|  |  | |||
|  | @ -29,7 +29,8 @@ absl::StatusOr<std::unique_ptr<AffineTransformation::Runner< | |||
|     mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>> | ||||
| CreateAffineTransformationGlRunner( | ||||
|     std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper, | ||||
|     mediapipe::GpuOrigin::Mode gpu_origin); | ||||
|     mediapipe::GpuOrigin::Mode gpu_origin, | ||||
|     AffineTransformation::Interpolation interpolation); | ||||
| 
 | ||||
| }  // namespace mediapipe
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -39,9 +39,22 @@ cv::BorderTypes GetBorderModeForOpenCv( | |||
|   } | ||||
| } | ||||
| 
 | ||||
| int GetInterpolationForOpenCv( | ||||
|     AffineTransformation::Interpolation interpolation) { | ||||
|   switch (interpolation) { | ||||
|     case AffineTransformation::Interpolation::kLinear: | ||||
|       return cv::INTER_LINEAR; | ||||
|     case AffineTransformation::Interpolation::kCubic: | ||||
|       return cv::INTER_CUBIC; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| class OpenCvRunner | ||||
|     : public AffineTransformation::Runner<ImageFrame, ImageFrame> { | ||||
|  public: | ||||
|   OpenCvRunner(AffineTransformation::Interpolation interpolation) | ||||
|       : interpolation_(GetInterpolationForOpenCv(interpolation)) {} | ||||
| 
 | ||||
|   absl::StatusOr<ImageFrame> Run( | ||||
|       const ImageFrame& input, const std::array<float, 16>& matrix, | ||||
|       const AffineTransformation::Size& size, | ||||
|  | @ -142,19 +155,23 @@ class OpenCvRunner | |||
| 
 | ||||
|     cv::warpAffine(in_mat, out_mat, cv_affine_transform, | ||||
|                    cv::Size(out_mat.cols, out_mat.rows), | ||||
|                    /*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP, | ||||
|                    /*flags=*/interpolation_ | cv::WARP_INVERSE_MAP, | ||||
|                    GetBorderModeForOpenCv(border_mode)); | ||||
| 
 | ||||
|     return out_image; | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   int interpolation_ = cv::INTER_LINEAR; | ||||
| }; | ||||
| 
 | ||||
| }  // namespace
 | ||||
| 
 | ||||
| absl::StatusOr< | ||||
|     std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>> | ||||
| CreateAffineTransformationOpenCvRunner() { | ||||
|   return absl::make_unique<OpenCvRunner>(); | ||||
| CreateAffineTransformationOpenCvRunner( | ||||
|     AffineTransformation::Interpolation interpolation) { | ||||
|   return absl::make_unique<OpenCvRunner>(interpolation); | ||||
| } | ||||
| 
 | ||||
| }  // namespace mediapipe
 | ||||
|  |  | |||
|  | @ -25,7 +25,8 @@ namespace mediapipe { | |||
| 
 | ||||
| absl::StatusOr< | ||||
|     std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>> | ||||
| CreateAffineTransformationOpenCvRunner(); | ||||
| CreateAffineTransformationOpenCvRunner( | ||||
|     AffineTransformation::Interpolation interpolation); | ||||
| 
 | ||||
| }  // namespace mediapipe
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -53,6 +53,17 @@ AffineTransformation::BorderMode GetBorderMode( | |||
|   } | ||||
| } | ||||
| 
 | ||||
| AffineTransformation::Interpolation GetInterpolation( | ||||
|     mediapipe::WarpAffineCalculatorOptions::Interpolation interpolation) { | ||||
|   switch (interpolation) { | ||||
|     case mediapipe::WarpAffineCalculatorOptions::INTER_UNSPECIFIED: | ||||
|     case mediapipe::WarpAffineCalculatorOptions::INTER_LINEAR: | ||||
|       return AffineTransformation::Interpolation::kLinear; | ||||
|     case mediapipe::WarpAffineCalculatorOptions::INTER_CUBIC: | ||||
|       return AffineTransformation::Interpolation::kCubic; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| template <typename ImageT> | ||||
| class WarpAffineRunnerHolder {}; | ||||
| 
 | ||||
|  | @ -61,16 +72,22 @@ template <> | |||
| class WarpAffineRunnerHolder<ImageFrame> { | ||||
|  public: | ||||
|   using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>; | ||||
|   absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); } | ||||
|   absl::Status Open(CalculatorContext* cc) { | ||||
|     interpolation_ = GetInterpolation( | ||||
|         cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation()); | ||||
|     return absl::OkStatus(); | ||||
|   } | ||||
|   absl::StatusOr<RunnerType*> GetRunner() { | ||||
|     if (!runner_) { | ||||
|       ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner()); | ||||
|       ASSIGN_OR_RETURN(runner_, | ||||
|                        CreateAffineTransformationOpenCvRunner(interpolation_)); | ||||
|     } | ||||
|     return runner_.get(); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   std::unique_ptr<RunnerType> runner_; | ||||
|   AffineTransformation::Interpolation interpolation_; | ||||
| }; | ||||
| #endif  // !MEDIAPIPE_DISABLE_OPENCV
 | ||||
| 
 | ||||
|  | @ -85,12 +102,14 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> { | |||
|     gpu_origin_ = | ||||
|         cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin(); | ||||
|     gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>(); | ||||
|     interpolation_ = GetInterpolation( | ||||
|         cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation()); | ||||
|     return gl_helper_->Open(cc); | ||||
|   } | ||||
|   absl::StatusOr<RunnerType*> GetRunner() { | ||||
|     if (!runner_) { | ||||
|       ASSIGN_OR_RETURN( | ||||
|           runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_)); | ||||
|       ASSIGN_OR_RETURN(runner_, CreateAffineTransformationGlRunner( | ||||
|                                     gl_helper_, gpu_origin_, interpolation_)); | ||||
|     } | ||||
|     return runner_.get(); | ||||
|   } | ||||
|  | @ -99,6 +118,7 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> { | |||
|   mediapipe::GpuOrigin::Mode gpu_origin_; | ||||
|   std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_; | ||||
|   std::unique_ptr<RunnerType> runner_; | ||||
|   AffineTransformation::Interpolation interpolation_; | ||||
| }; | ||||
| #endif  // !MEDIAPIPE_DISABLE_GPU
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -31,6 +31,13 @@ message WarpAffineCalculatorOptions { | |||
|     BORDER_REPLICATE = 2; | ||||
|   } | ||||
| 
 | ||||
|   // Pixel sampling interpolation methods. See @interpolation. | ||||
|   enum Interpolation { | ||||
|     INTER_UNSPECIFIED = 0; | ||||
|     INTER_LINEAR = 1; | ||||
|     INTER_CUBIC = 2; | ||||
|   } | ||||
| 
 | ||||
|   // Pixel extrapolation method. | ||||
|   // When converting image to tensor it may happen that tensor needs to read | ||||
|   // pixels outside image boundaries. Border mode helps to specify how such | ||||
|  | @ -43,4 +50,10 @@ message WarpAffineCalculatorOptions { | |||
|   // to be flipped vertically as tensors are expected to start at top. | ||||
|   // (DEFAULT or unset interpreted as CONVENTIONAL.) | ||||
|   optional GpuOrigin.Mode gpu_origin = 2; | ||||
| 
 | ||||
|   // Sampling method for neighboring pixels. | ||||
|   // INTER_LINEAR (bilinear) linearly interpolates from the nearest 4 neighbors. | ||||
|   // INTER_CUBIC (bicubic) interpolates a small neighborhood with cubic weights. | ||||
|   // INTER_UNSPECIFIED or unset interpreted as INTER_LINEAR. | ||||
|   optional Interpolation interpolation = 3; | ||||
| } | ||||
|  |  | |||
|  | @ -63,7 +63,8 @@ void RunTest(const std::string& graph_text, const std::string& tag, | |||
|              const cv::Mat& input, cv::Mat expected_result, | ||||
|              float similarity_threshold, std::array<float, 16> matrix, | ||||
|              int out_width, int out_height, | ||||
|              absl::optional<AffineTransformation::BorderMode> border_mode) { | ||||
|              std::optional<AffineTransformation::BorderMode> border_mode, | ||||
|              std::optional<AffineTransformation::Interpolation> interpolation) { | ||||
|   std::string border_mode_str; | ||||
|   if (border_mode) { | ||||
|     switch (*border_mode) { | ||||
|  | @ -75,8 +76,20 @@ void RunTest(const std::string& graph_text, const std::string& tag, | |||
|         break; | ||||
|     } | ||||
|   } | ||||
|   std::string interpolation_str; | ||||
|   if (interpolation) { | ||||
|     switch (*interpolation) { | ||||
|       case AffineTransformation::Interpolation::kLinear: | ||||
|         interpolation_str = "interpolation: INTER_LINEAR"; | ||||
|         break; | ||||
|       case AffineTransformation::Interpolation::kCubic: | ||||
|         interpolation_str = "interpolation: INTER_CUBIC"; | ||||
|         break; | ||||
|     } | ||||
|   } | ||||
|   auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>( | ||||
|       absl::Substitute(graph_text, /*$0=*/border_mode_str)); | ||||
|       absl::Substitute(graph_text, /*$0=*/border_mode_str, | ||||
|                        /*$1=*/interpolation_str)); | ||||
| 
 | ||||
|   std::vector<Packet> output_packets; | ||||
|   tool::AddVectorSink("output_image", &graph_config, &output_packets); | ||||
|  | @ -132,7 +145,8 @@ struct SimilarityConfig { | |||
| void RunTest(cv::Mat input, cv::Mat expected_result, | ||||
|              const SimilarityConfig& similarity, std::array<float, 16> matrix, | ||||
|              int out_width, int out_height, | ||||
|              absl::optional<AffineTransformation::BorderMode> border_mode) { | ||||
|              std::optional<AffineTransformation::BorderMode> border_mode, | ||||
|              std::optional<AffineTransformation::Interpolation> interpolation) { | ||||
|   RunTest(R"( | ||||
|         input_stream: "input_image" | ||||
|         input_stream: "output_size" | ||||
|  | @ -146,12 +160,13 @@ void RunTest(cv::Mat input, cv::Mat expected_result, | |||
|           options { | ||||
|             [mediapipe.WarpAffineCalculatorOptions.ext] { | ||||
|               $0 # border mode | ||||
|               $1 # interpolation | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|         )", | ||||
|           "cpu", input, expected_result, similarity.threshold_on_cpu, matrix, | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| 
 | ||||
|   RunTest(R"( | ||||
|         input_stream: "input_image" | ||||
|  | @ -171,6 +186,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, | |||
|           options { | ||||
|             [mediapipe.WarpAffineCalculatorOptions.ext] { | ||||
|               $0 # border mode | ||||
|               $1 # interpolation | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|  | @ -181,7 +197,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, | |||
|         } | ||||
|         )", | ||||
|           "cpu_image", input, expected_result, similarity.threshold_on_cpu, | ||||
|           matrix, out_width, out_height, border_mode); | ||||
|           matrix, out_width, out_height, border_mode, interpolation); | ||||
| 
 | ||||
|   RunTest(R"( | ||||
|         input_stream: "input_image" | ||||
|  | @ -201,6 +217,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, | |||
|           options { | ||||
|             [mediapipe.WarpAffineCalculatorOptions.ext] { | ||||
|               $0 # border mode | ||||
|               $1 # interpolation | ||||
|               gpu_origin: TOP_LEFT | ||||
|             } | ||||
|           } | ||||
|  | @ -212,7 +229,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, | |||
|         } | ||||
|         )", | ||||
|           "gpu", input, expected_result, similarity.threshold_on_gpu, matrix, | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| 
 | ||||
|   RunTest(R"( | ||||
|         input_stream: "input_image" | ||||
|  | @ -237,6 +254,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, | |||
|           options { | ||||
|             [mediapipe.WarpAffineCalculatorOptions.ext] { | ||||
|               $0 # border mode | ||||
|               $1 # interpolation | ||||
|               gpu_origin: TOP_LEFT | ||||
|             } | ||||
|           } | ||||
|  | @ -253,7 +271,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result, | |||
|         } | ||||
|         )", | ||||
|           "gpu_image", input, expected_result, similarity.threshold_on_gpu, | ||||
|           matrix, out_width, out_height, border_mode); | ||||
|           matrix, out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi, | ||||
|  | @ -287,10 +305,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) { | |||
|   int out_height = 256; | ||||
|   bool keep_aspect_ratio = true; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = {}; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) { | ||||
|  | @ -312,10 +331,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) { | |||
|   bool keep_aspect_ratio = true; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kZero; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) { | ||||
|  | @ -337,10 +357,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) { | |||
|   bool keep_aspect_ratio = true; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kReplicate; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) { | ||||
|  | @ -362,10 +383,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) { | |||
|   bool keep_aspect_ratio = true; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kZero; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) { | ||||
|  | @ -386,10 +408,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) { | |||
|   bool keep_aspect_ratio = false; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kReplicate; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) { | ||||
|  | @ -411,10 +434,38 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) { | |||
|   bool keep_aspect_ratio = false; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kZero; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZeroInterpCubic) { | ||||
|   mediapipe::NormalizedRect roi; | ||||
|   roi.set_x_center(0.65f); | ||||
|   roi.set_y_center(0.4f); | ||||
|   roi.set_width(0.5f); | ||||
|   roi.set_height(0.5f); | ||||
|   roi.set_rotation(M_PI * -45.0f / 180.0f); | ||||
|   auto input = GetRgb( | ||||
|       "/mediapipe/calculators/" | ||||
|       "tensor/testdata/image_to_tensor/input.jpg"); | ||||
|   auto expected_output = GetRgb( | ||||
|       "/mediapipe/calculators/" | ||||
|       "tensor/testdata/image_to_tensor/" | ||||
|       "medium_sub_rect_with_rotation_border_zero_interp_cubic.png"); | ||||
|   int out_width = 256; | ||||
|   int out_height = 256; | ||||
|   bool keep_aspect_ratio = false; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kZero; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = | ||||
|       AffineTransformation::Interpolation::kCubic; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.78}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, LargeSubRect) { | ||||
|  | @ -435,10 +486,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRect) { | |||
|   bool keep_aspect_ratio = false; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kReplicate; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) { | ||||
|  | @ -459,10 +511,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) { | |||
|   bool keep_aspect_ratio = false; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kZero; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) { | ||||
|  | @ -483,10 +536,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) { | |||
|   bool keep_aspect_ratio = true; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kReplicate; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) { | ||||
|  | @ -508,10 +562,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) { | |||
|   bool keep_aspect_ratio = true; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kZero; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) { | ||||
|  | @ -532,10 +587,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) { | |||
|   int out_height = 128; | ||||
|   bool keep_aspect_ratio = true; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = {}; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { | ||||
|  | @ -557,10 +613,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { | |||
|   bool keep_aspect_ratio = true; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kZero; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, NoOp) { | ||||
|  | @ -581,10 +638,11 @@ TEST(WarpAffineCalculatorTest, NoOp) { | |||
|   bool keep_aspect_ratio = true; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kReplicate; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| TEST(WarpAffineCalculatorTest, NoOpBorderZero) { | ||||
|  | @ -605,10 +663,11 @@ TEST(WarpAffineCalculatorTest, NoOpBorderZero) { | |||
|   bool keep_aspect_ratio = true; | ||||
|   std::optional<AffineTransformation::BorderMode> border_mode = | ||||
|       AffineTransformation::BorderMode::kZero; | ||||
|   std::optional<AffineTransformation::Interpolation> interpolation = {}; | ||||
|   RunTest(input, expected_output, | ||||
|           {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, | ||||
|           GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), | ||||
|           out_width, out_height, border_mode); | ||||
|           out_width, out_height, border_mode, interpolation); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
|  |  | |||
|  | @ -997,17 +997,20 @@ cc_library( | |||
|             ":image_to_tensor_converter_gl_buffer", | ||||
|             "//mediapipe/gpu:gl_calculator_helper", | ||||
|             "//mediapipe/gpu:gpu_buffer", | ||||
|             "//mediapipe/gpu:gpu_service", | ||||
|         ], | ||||
|         "//mediapipe:apple": [ | ||||
|             ":image_to_tensor_converter_metal", | ||||
|             "//mediapipe/gpu:gl_calculator_helper", | ||||
|             "//mediapipe/gpu:MPPMetalHelper", | ||||
|             "//mediapipe/gpu:gpu_buffer", | ||||
|             "//mediapipe/gpu:gpu_service", | ||||
|         ], | ||||
|         "//conditions:default": [ | ||||
|             ":image_to_tensor_converter_gl_buffer", | ||||
|             "//mediapipe/gpu:gl_calculator_helper", | ||||
|             "//mediapipe/gpu:gpu_buffer", | ||||
|             "//mediapipe/gpu:gpu_service", | ||||
|         ], | ||||
|     }), | ||||
| ) | ||||
|  | @ -1045,6 +1048,10 @@ cc_test( | |||
|         ":image_to_tensor_calculator", | ||||
|         ":image_to_tensor_converter", | ||||
|         ":image_to_tensor_utils", | ||||
|         "@com_google_absl//absl/flags:flag", | ||||
|         "@com_google_absl//absl/memory", | ||||
|         "@com_google_absl//absl/strings", | ||||
|         "@com_google_absl//absl/strings:str_format", | ||||
|         "//mediapipe/framework:calculator_framework", | ||||
|         "//mediapipe/framework:calculator_runner", | ||||
|         "//mediapipe/framework/deps:file_path", | ||||
|  | @ -1061,11 +1068,10 @@ cc_test( | |||
|         "//mediapipe/framework/port:opencv_imgproc", | ||||
|         "//mediapipe/framework/port:parse_text_proto", | ||||
|         "//mediapipe/util:image_test_utils", | ||||
|         "@com_google_absl//absl/flags:flag", | ||||
|         "@com_google_absl//absl/memory", | ||||
|         "@com_google_absl//absl/strings", | ||||
|         "@com_google_absl//absl/strings:str_format", | ||||
|     ], | ||||
|     ] + select({ | ||||
|         "//mediapipe:apple": [], | ||||
|         "//conditions:default": ["//mediapipe/gpu:gl_context"], | ||||
|     }), | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|  |  | |||
|  | @ -45,9 +45,11 @@ | |||
| #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 | ||||
| #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h" | ||||
| #include "mediapipe/gpu/gl_calculator_helper.h" | ||||
| #include "mediapipe/gpu/gpu_service.h" | ||||
| #else | ||||
| #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h" | ||||
| #include "mediapipe/gpu/gl_calculator_helper.h" | ||||
| #include "mediapipe/gpu/gpu_service.h" | ||||
| #endif  // MEDIAPIPE_METAL_ENABLED
 | ||||
| #endif  // !MEDIAPIPE_DISABLE_GPU
 | ||||
| 
 | ||||
|  | @ -147,7 +149,7 @@ class ImageToTensorCalculator : public Node { | |||
| #if MEDIAPIPE_METAL_ENABLED | ||||
|     MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); | ||||
| #else | ||||
|     MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); | ||||
|     cc->UseService(kGpuService).Optional(); | ||||
| #endif  // MEDIAPIPE_METAL_ENABLED
 | ||||
| #endif  // MEDIAPIPE_DISABLE_GPU
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -41,6 +41,10 @@ | |||
| #include "mediapipe/framework/port/status_matchers.h" | ||||
| #include "mediapipe/util/image_test_utils.h" | ||||
| 
 | ||||
| #if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED | ||||
| #include "mediapipe/gpu/gl_context.h" | ||||
| #endif  // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
 | ||||
| 
 | ||||
| namespace mediapipe { | ||||
| namespace { | ||||
| 
 | ||||
|  | @ -507,5 +511,79 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeAndUseInputImageDims) { | |||
|           /*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt, | ||||
|           /*keep_aspect=*/false, BorderMode::kZero, roi); | ||||
| } | ||||
| 
 | ||||
| TEST(ImageToTensorCalculatorTest, CanBeUsedWithoutGpuServiceSet) { | ||||
|   auto graph_config = | ||||
|       mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( | ||||
|         input_stream: "input_image" | ||||
|         node { | ||||
|           calculator: "ImageToTensorCalculator" | ||||
|           input_stream: "IMAGE:input_image" | ||||
|           output_stream: "TENSORS:tensor" | ||||
|           options { | ||||
|             [mediapipe.ImageToTensorCalculatorOptions.ext] { | ||||
|               output_tensor_float_range { min: 0.0f max: 1.0f } | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|       )pb"); | ||||
|   CalculatorGraph graph; | ||||
|   MP_ASSERT_OK(graph.Initialize(graph_config)); | ||||
|   MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization()); | ||||
|   MP_ASSERT_OK(graph.StartRun({})); | ||||
|   auto image_frame = | ||||
|       std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4); | ||||
|   Image image = Image(std::move(image_frame)); | ||||
|   Packet packet = MakePacket<Image>(std::move(image)); | ||||
|   MP_ASSERT_OK( | ||||
|       graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1)))); | ||||
|   MP_ASSERT_OK(graph.WaitUntilIdle()); | ||||
|   MP_ASSERT_OK(graph.CloseAllPacketSources()); | ||||
|   MP_ASSERT_OK(graph.WaitUntilDone()); | ||||
| } | ||||
| 
 | ||||
| #if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED | ||||
| 
 | ||||
| TEST(ImageToTensorCalculatorTest, | ||||
|      FailsGracefullyWhenGpuServiceNeededButNotAvailable) { | ||||
|   auto graph_config = | ||||
|       mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb( | ||||
|         input_stream: "input_image" | ||||
|         node { | ||||
|           calculator: "ImageToTensorCalculator" | ||||
|           input_stream: "IMAGE:input_image" | ||||
|           output_stream: "TENSORS:tensor" | ||||
|           options { | ||||
|             [mediapipe.ImageToTensorCalculatorOptions.ext] { | ||||
|               output_tensor_float_range { min: 0.0f max: 1.0f } | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|       )pb"); | ||||
|   CalculatorGraph graph; | ||||
|   MP_ASSERT_OK(graph.Initialize(graph_config)); | ||||
|   MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization()); | ||||
|   MP_ASSERT_OK(graph.StartRun({})); | ||||
| 
 | ||||
|   MP_ASSERT_OK_AND_ASSIGN(auto context, | ||||
|                           GlContext::Create(nullptr, /*create_thread=*/true)); | ||||
|   Packet packet; | ||||
|   context->Run([&packet]() { | ||||
|     auto image_frame = | ||||
|         std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4); | ||||
|     Image image = Image(std::move(image_frame)); | ||||
|     // Ensure image is available on GPU to force ImageToTensorCalculator to
 | ||||
|     // run on GPU.
 | ||||
|     ASSERT_TRUE(image.ConvertToGpu()); | ||||
|     packet = MakePacket<Image>(std::move(image)); | ||||
|   }); | ||||
|   MP_ASSERT_OK( | ||||
|       graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1)))); | ||||
|   EXPECT_THAT(graph.WaitUntilIdle(), | ||||
|               StatusIs(absl::StatusCode::kInternal, | ||||
|                        HasSubstr("GPU service not available"))); | ||||
| } | ||||
| #endif  // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
 | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace mediapipe
 | ||||
|  |  | |||
| After Width: | Height: | Size: 64 KiB | 
|  | @ -138,7 +138,23 @@ void TestWithAspectRatio(const double aspect_ratio, | |||
|     std::string result_image; | ||||
|     MP_ASSERT_OK( | ||||
|         mediapipe::file::GetContents(result_string_path, &result_image)); | ||||
|     EXPECT_EQ(result_image, output_string); | ||||
|     if (result_image != output_string) { | ||||
|       // There may be slight differences due to the way the JPEG was encoded or
 | ||||
|       // the OpenCV version used to generate the reference files. Compare
 | ||||
|       // pixel-by-pixel using the Peak Signal-to-Noise Ratio instead.
 | ||||
|       cv::Mat result_mat = | ||||
|           cv::imdecode(cv::Mat(1, result_image.size(), CV_8UC1, | ||||
|                                const_cast<char*>(result_image.data())), | ||||
|                        cv::IMREAD_UNCHANGED); | ||||
|       cv::Mat output_mat = | ||||
|           cv::imdecode(cv::Mat(1, output_string.size(), CV_8UC1, | ||||
|                                const_cast<char*>(output_string.data())), | ||||
|                        cv::IMREAD_UNCHANGED); | ||||
|       ASSERT_EQ(result_mat.rows, output_mat.rows); | ||||
|       ASSERT_EQ(result_mat.cols, output_mat.cols); | ||||
|       ASSERT_EQ(result_mat.type(), output_mat.type()); | ||||
|       EXPECT_GT(cv::PSNR(result_mat, output_mat), 45.0); | ||||
|     } | ||||
|   } else { | ||||
|     std::string output_string_path = mediapipe::file::JoinPath( | ||||
|         absl::GetFlag(FLAGS_output_folder), | ||||
|  |  | |||
| Before Width: | Height: | Size: 3.2 KiB After Width: | Height: | Size: 3.2 KiB | 
| Before Width: | Height: | Size: 6.1 KiB After Width: | Height: | Size: 6.1 KiB | 
| Before Width: | Height: | Size: 8.2 KiB After Width: | Height: | Size: 8.2 KiB | 
| Before Width: | Height: | Size: 7.6 KiB After Width: | Height: | Size: 7.6 KiB | 
|  | @ -19,9 +19,6 @@ package(default_visibility = ["//mediapipe/tasks:internal"]) | |||
| cc_library( | ||||
|     name = "face_geometry_from_landmarks_graph", | ||||
|     srcs = ["face_geometry_from_landmarks_graph.cc"], | ||||
|     data = [ | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/data:geometry_pipeline_metadata_landmarks", | ||||
|     ], | ||||
|     deps = [ | ||||
|         "//mediapipe/calculators/core:begin_loop_calculator", | ||||
|         "//mediapipe/calculators/core:end_loop_calculator", | ||||
|  | @ -39,6 +36,7 @@ cc_library( | |||
|         "//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_cc_proto", | ||||
|         "//mediapipe/util:graph_builder_utils", | ||||
|         "@com_google_absl//absl/status:statusor", | ||||
|     ], | ||||
|  |  | |||
|  | @ -45,6 +45,7 @@ mediapipe_proto_library( | |||
|     srcs = ["geometry_pipeline_calculator.proto"], | ||||
|     deps = [ | ||||
|         "//mediapipe/framework:calculator_options_proto", | ||||
|         "//mediapipe/tasks/cc/core/proto:external_file_proto", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
|  | @ -59,6 +60,8 @@ cc_library( | |||
|         "//mediapipe/framework/port:ret_check", | ||||
|         "//mediapipe/framework/port:status", | ||||
|         "//mediapipe/framework/port:statusor", | ||||
|         "//mediapipe/tasks/cc/core:external_file_handler", | ||||
|         "//mediapipe/tasks/cc/core/proto:external_file_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/libs:validation_utils", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto", | ||||
|  |  | |||
|  | @ -24,6 +24,8 @@ | |||
| #include "mediapipe/framework/port/status.h" | ||||
| #include "mediapipe/framework/port/status_macros.h" | ||||
| #include "mediapipe/framework/port/statusor.h" | ||||
| #include "mediapipe/tasks/cc/core/external_file_handler.h" | ||||
| #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h" | ||||
|  | @ -69,8 +71,8 @@ using ::mediapipe::tasks::vision::face_geometry::proto:: | |||
| //     A vector of face geometry data.
 | ||||
| //
 | ||||
| // Options:
 | ||||
| //   metadata_path (`string`, optional):
 | ||||
| //     Defines a path for the geometry pipeline metadata file.
 | ||||
| //   metadata_file (`ExternalFile`, optional):
 | ||||
| //     Defines an ExternalFile for the geometry pipeline metadata file.
 | ||||
| //
 | ||||
| //     The geometry pipeline metadata file format must be the binary
 | ||||
| //     `GeometryPipelineMetadata` proto.
 | ||||
|  | @ -95,7 +97,7 @@ class GeometryPipelineCalculator : public CalculatorBase { | |||
| 
 | ||||
|     ASSIGN_OR_RETURN( | ||||
|         GeometryPipelineMetadata metadata, | ||||
|         ReadMetadataFromFile(options.metadata_path()), | ||||
|         ReadMetadataFromFile(options.metadata_file()), | ||||
|         _ << "Failed to read the geometry pipeline metadata from file!"); | ||||
| 
 | ||||
|     MP_RETURN_IF_ERROR(ValidateGeometryPipelineMetadata(metadata)) | ||||
|  | @ -155,32 +157,19 @@ class GeometryPipelineCalculator : public CalculatorBase { | |||
| 
 | ||||
|  private: | ||||
|   static absl::StatusOr<GeometryPipelineMetadata> ReadMetadataFromFile( | ||||
|       const std::string& metadata_path) { | ||||
|     ASSIGN_OR_RETURN(std::string metadata_blob, | ||||
|                      ReadContentBlobFromFile(metadata_path), | ||||
|                      _ << "Failed to read a metadata blob from file!"); | ||||
|       const core::proto::ExternalFile& metadata_file) { | ||||
|     ASSIGN_OR_RETURN( | ||||
|         const auto file_handler, | ||||
|         core::ExternalFileHandler::CreateFromExternalFile(&metadata_file)); | ||||
| 
 | ||||
|     GeometryPipelineMetadata metadata; | ||||
|     RET_CHECK(metadata.ParseFromString(metadata_blob)) | ||||
|     RET_CHECK( | ||||
|         metadata.ParseFromString(std::string(file_handler->GetFileContent()))) | ||||
|         << "Failed to parse a metadata proto from a binary blob!"; | ||||
| 
 | ||||
|     return metadata; | ||||
|   } | ||||
| 
 | ||||
|   static absl::StatusOr<std::string> ReadContentBlobFromFile( | ||||
|       const std::string& unresolved_path) { | ||||
|     ASSIGN_OR_RETURN(std::string resolved_path, | ||||
|                      mediapipe::PathToResourceAsFile(unresolved_path), | ||||
|                      _ << "Failed to resolve path! Path = " << unresolved_path); | ||||
| 
 | ||||
|     std::string content_blob; | ||||
|     MP_RETURN_IF_ERROR( | ||||
|         mediapipe::GetResourceContents(resolved_path, &content_blob)) | ||||
|         << "Failed to read content blob! Resolved path = " << resolved_path; | ||||
| 
 | ||||
|     return content_blob; | ||||
|   } | ||||
| 
 | ||||
|   std::unique_ptr<GeometryPipeline> geometry_pipeline_; | ||||
| }; | ||||
| 
 | ||||
|  |  | |||
|  | @ -17,11 +17,12 @@ syntax = "proto2"; | |||
| package mediapipe.tasks.vision.face_geometry; | ||||
| 
 | ||||
| import "mediapipe/framework/calculator_options.proto"; | ||||
| import "mediapipe/tasks/cc/core/proto/external_file.proto"; | ||||
| 
 | ||||
| message FaceGeometryPipelineCalculatorOptions { | ||||
|   extend mediapipe.CalculatorOptions { | ||||
|     optional FaceGeometryPipelineCalculatorOptions ext = 512499200; | ||||
|   } | ||||
| 
 | ||||
|   optional string metadata_path = 1; | ||||
|   optional core.proto.ExternalFile metadata_file = 1; | ||||
| } | ||||
|  |  | |||
|  | @ -28,6 +28,7 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.pb.h" | ||||
| #include "mediapipe/util/graph_builder_utils.h" | ||||
| 
 | ||||
| namespace mediapipe::tasks::vision::face_geometry { | ||||
|  | @ -49,10 +50,6 @@ constexpr char kIterableTag[] = "ITERABLE"; | |||
| constexpr char kBatchEndTag[] = "BATCH_END"; | ||||
| constexpr char kItemTag[] = "ITEM"; | ||||
| 
 | ||||
| constexpr char kGeometryPipelineMetadataPath[] = | ||||
|     "mediapipe/tasks/cc/vision/face_geometry/data/" | ||||
|     "geometry_pipeline_metadata_landmarks.binarypb"; | ||||
| 
 | ||||
| struct FaceGeometryOuts { | ||||
|   Stream<std::vector<FaceGeometry>> multi_face_geometry; | ||||
| }; | ||||
|  | @ -127,6 +124,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph { | |||
|     } | ||||
|     ASSIGN_OR_RETURN(auto outs, | ||||
|                      BuildFaceGeometryFromLandmarksGraph( | ||||
|                          *sc->MutableOptions<proto::FaceGeometryGraphOptions>(), | ||||
|                          graph.In(kFaceLandmarksTag) | ||||
|                              .Cast<std::vector<NormalizedLandmarkList>>(), | ||||
|                          graph.In(kImageSizeTag).Cast<std::pair<int, int>>(), | ||||
|  | @ -138,6 +136,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph { | |||
| 
 | ||||
|  private: | ||||
|   absl::StatusOr<FaceGeometryOuts> BuildFaceGeometryFromLandmarksGraph( | ||||
|       proto::FaceGeometryGraphOptions& graph_options, | ||||
|       Stream<std::vector<NormalizedLandmarkList>> multi_face_landmarks, | ||||
|       Stream<std::pair<int, int>> image_size, | ||||
|       std::optional<SidePacket<Environment>> environment, Graph& graph) { | ||||
|  | @ -185,7 +184,8 @@ class FaceGeometryFromLandmarksGraph : public Subgraph { | |||
|         "mediapipe.tasks.vision.face_geometry.FaceGeometryPipelineCalculator"); | ||||
|     auto& geometry_pipeline_options = | ||||
|         geometry_pipeline.GetOptions<FaceGeometryPipelineCalculatorOptions>(); | ||||
|     geometry_pipeline_options.set_metadata_path(kGeometryPipelineMetadataPath); | ||||
|     geometry_pipeline_options.Swap( | ||||
|         graph_options.mutable_geometry_pipeline_options()); | ||||
|     image_size >> geometry_pipeline.In(kImageSizeTag); | ||||
|     multi_face_landmarks_no_iris >> | ||||
|         geometry_pipeline.In(kMultiFaceLandmarksTag); | ||||
|  |  | |||
|  | @ -20,6 +20,7 @@ limitations under the License. | |||
| #include "absl/status/statusor.h" | ||||
| #include "absl/strings/str_format.h" | ||||
| #include "absl/strings/string_view.h" | ||||
| #include "absl/strings/substitute.h" | ||||
| #include "mediapipe/framework/api2/port.h" | ||||
| #include "mediapipe/framework/calculator_framework.h" | ||||
| #include "mediapipe/framework/calculator_runner.h" | ||||
|  | @ -31,6 +32,7 @@ limitations under the License. | |||
| #include "mediapipe/framework/port/gtest.h" | ||||
| #include "mediapipe/framework/port/parse_text_proto.h" | ||||
| #include "mediapipe/framework/tool/sink.h" | ||||
| #include "mediapipe/tasks/cc/core/proto/external_file.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" | ||||
| 
 | ||||
|  | @ -49,6 +51,9 @@ constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/"; | |||
| constexpr char kFaceLandmarksFileName[] = | ||||
|     "face_blendshapes_in_landmarks.prototxt"; | ||||
| constexpr char kFaceGeometryFileName[] = "face_geometry_expected_out.pbtxt"; | ||||
| constexpr char kGeometryPipelineMetadataPath[] = | ||||
|     "mediapipe/tasks/cc/vision/face_geometry/data/" | ||||
|     "geometry_pipeline_metadata_landmarks.binarypb"; | ||||
| 
 | ||||
| std::vector<NormalizedLandmarkList> GetLandmarks(absl::string_view filename) { | ||||
|   NormalizedLandmarkList landmarks; | ||||
|  | @ -89,17 +94,25 @@ void MakeInputPacketsAndRunGraph(CalculatorGraph& graph) { | |||
| 
 | ||||
| TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) { | ||||
|   CalculatorGraphConfig graph_config = ParseTextProtoOrDie< | ||||
|       CalculatorGraphConfig>(R"pb( | ||||
|     input_stream: "FACE_LANDMARKS:face_landmarks" | ||||
|     input_stream: "IMAGE_SIZE:image_size" | ||||
|     output_stream: "FACE_GEOMETRY:face_geometry" | ||||
|     node { | ||||
|       calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph" | ||||
|       input_stream: "FACE_LANDMARKS:face_landmarks" | ||||
|       input_stream: "IMAGE_SIZE:image_size" | ||||
|       output_stream: "FACE_GEOMETRY:face_geometry" | ||||
|     } | ||||
|   )pb"); | ||||
|       CalculatorGraphConfig>(absl::Substitute( | ||||
|       R"pb( | ||||
|         input_stream: "FACE_LANDMARKS:face_landmarks" | ||||
|         input_stream: "IMAGE_SIZE:image_size" | ||||
|         output_stream: "FACE_GEOMETRY:face_geometry" | ||||
|         node { | ||||
|           calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph" | ||||
|           input_stream: "FACE_LANDMARKS:face_landmarks" | ||||
|           input_stream: "IMAGE_SIZE:image_size" | ||||
|           output_stream: "FACE_GEOMETRY:face_geometry" | ||||
|           options: { | ||||
|             [mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions | ||||
|                  .ext] { | ||||
|               geometry_pipeline_options { metadata_file { file_name: "$0" } } | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|       )pb", | ||||
|       kGeometryPipelineMetadataPath)); | ||||
|   std::vector<Packet> output_packets; | ||||
|   tool::AddVectorSink("face_geometry", &graph_config, &output_packets); | ||||
| 
 | ||||
|  | @ -116,19 +129,27 @@ TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) { | |||
| 
 | ||||
| TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) { | ||||
|   CalculatorGraphConfig graph_config = ParseTextProtoOrDie< | ||||
|       CalculatorGraphConfig>(R"pb( | ||||
|     input_stream: "FACE_LANDMARKS:face_landmarks" | ||||
|     input_stream: "IMAGE_SIZE:image_size" | ||||
|     input_side_packet: "ENVIRONMENT:environment" | ||||
|     output_stream: "FACE_GEOMETRY:face_geometry" | ||||
|     node { | ||||
|       calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph" | ||||
|       input_stream: "FACE_LANDMARKS:face_landmarks" | ||||
|       input_stream: "IMAGE_SIZE:image_size" | ||||
|       input_side_packet: "ENVIRONMENT:environment" | ||||
|       output_stream: "FACE_GEOMETRY:face_geometry" | ||||
|     } | ||||
|   )pb"); | ||||
|       CalculatorGraphConfig>(absl::Substitute( | ||||
|       R"pb( | ||||
|         input_stream: "FACE_LANDMARKS:face_landmarks" | ||||
|         input_stream: "IMAGE_SIZE:image_size" | ||||
|         input_side_packet: "ENVIRONMENT:environment" | ||||
|         output_stream: "FACE_GEOMETRY:face_geometry" | ||||
|         node { | ||||
|           calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph" | ||||
|           input_stream: "FACE_LANDMARKS:face_landmarks" | ||||
|           input_stream: "IMAGE_SIZE:image_size" | ||||
|           input_side_packet: "ENVIRONMENT:environment" | ||||
|           output_stream: "FACE_GEOMETRY:face_geometry" | ||||
|           options: { | ||||
|             [mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions | ||||
|                  .ext] { | ||||
|               geometry_pipeline_options { metadata_file { file_name: "$0" } } | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|       )pb", | ||||
|       kGeometryPipelineMetadataPath)); | ||||
|   std::vector<Packet> output_packets; | ||||
|   tool::AddVectorSink("face_geometry", &graph_config, &output_packets); | ||||
| 
 | ||||
|  |  | |||
|  | @ -44,3 +44,12 @@ mediapipe_proto_library( | |||
|     name = "mesh_3d_proto", | ||||
|     srcs = ["mesh_3d.proto"], | ||||
| ) | ||||
| 
 | ||||
| mediapipe_proto_library( | ||||
|     name = "face_geometry_graph_options_proto", | ||||
|     srcs = ["face_geometry_graph_options.proto"], | ||||
|     deps = [ | ||||
|         "//mediapipe/framework:calculator_options_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_proto", | ||||
|     ], | ||||
| ) | ||||
|  |  | |||
|  | @ -0,0 +1,28 @@ | |||
| // Copyright 2023 The MediaPipe Authors. | ||||
| // | ||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| // you may not use this file except in compliance with the License. | ||||
| // You may obtain a copy of the License at | ||||
| // | ||||
| //      http://www.apache.org/licenses/LICENSE-2.0 | ||||
| // | ||||
| // Unless required by applicable law or agreed to in writing, software | ||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| // See the License for the specific language governing permissions and | ||||
| // limitations under the License. | ||||
| 
 | ||||
| syntax = "proto2"; | ||||
| 
 | ||||
| package mediapipe.tasks.vision.face_geometry.proto; | ||||
| 
 | ||||
| import "mediapipe/framework/calculator_options.proto"; | ||||
| import "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.proto"; | ||||
| 
 | ||||
| message FaceGeometryGraphOptions { | ||||
|   extend mediapipe.CalculatorOptions { | ||||
|     optional FaceGeometryGraphOptions ext = 515723506; | ||||
|   } | ||||
| 
 | ||||
|   optional FaceGeometryPipelineCalculatorOptions geometry_pipeline_options = 1; | ||||
| } | ||||
|  | @ -210,8 +210,10 @@ cc_library( | |||
|         "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", | ||||
|         "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry:face_geometry_from_landmarks_graph", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", | ||||
|  |  | |||
|  | @ -40,8 +40,10 @@ limitations under the License. | |||
| #include "mediapipe/tasks/cc/core/utils.h" | ||||
| #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" | ||||
| #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" | ||||
|  | @ -93,6 +95,8 @@ constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite"; | |||
| constexpr char kFaceLandmarksDetectorTFLiteName[] = | ||||
|     "face_landmarks_detector.tflite"; | ||||
| constexpr char kFaceBlendshapeTFLiteName[] = "face_blendshapes.tflite"; | ||||
| constexpr char kFaceGeometryPipelineMetadataName[] = | ||||
|     "geometry_pipeline_metadata_landmarks.binarypb"; | ||||
| 
 | ||||
| struct FaceLandmarkerOutputs { | ||||
|   Source<std::vector<NormalizedLandmarkList>> landmark_lists; | ||||
|  | @ -305,6 +309,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { | |||
|   absl::StatusOr<CalculatorGraphConfig> GetConfig( | ||||
|       SubgraphContext* sc) override { | ||||
|     Graph graph; | ||||
|     bool output_geometry = HasOutput(sc->OriginalNode(), kFaceGeometryTag); | ||||
|     if (sc->Options<FaceLandmarkerGraphOptions>() | ||||
|             .base_options() | ||||
|             .has_model_asset()) { | ||||
|  | @ -318,6 +323,18 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { | |||
|           sc->MutableOptions<FaceLandmarkerGraphOptions>(), | ||||
|           !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) | ||||
|                .IsAvailable())); | ||||
|       if (output_geometry) { | ||||
|         // Set the face geometry metdata file for
 | ||||
|         // FaceGeometryFromLandmarksGraph.
 | ||||
|         ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file, | ||||
|                          model_asset_bundle_resources->GetModelFile( | ||||
|                              kFaceGeometryPipelineMetadataName)); | ||||
|         SetExternalFile(face_geometry_pipeline_metadata_file, | ||||
|                         sc->MutableOptions<FaceLandmarkerGraphOptions>() | ||||
|                             ->mutable_face_geometry_graph_options() | ||||
|                             ->mutable_geometry_pipeline_options() | ||||
|                             ->mutable_metadata_file()); | ||||
|       } | ||||
|     } | ||||
|     std::optional<SidePacket<Environment>> environment; | ||||
|     if (HasSideInput(sc->OriginalNode(), kEnvironmentTag)) { | ||||
|  | @ -338,7 +355,6 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { | |||
|               .face_landmarks_detector_graph_options() | ||||
|               .has_face_blendshapes_graph_options())); | ||||
|     } | ||||
|     bool output_geometry = HasOutput(sc->OriginalNode(), kFaceGeometryTag); | ||||
|     ASSIGN_OR_RETURN( | ||||
|         auto outs, | ||||
|         BuildFaceLandmarkerGraph( | ||||
|  | @ -481,6 +497,9 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph { | |||
|       auto& face_geometry_from_landmarks = graph.AddNode( | ||||
|           "mediapipe.tasks.vision.face_geometry." | ||||
|           "FaceGeometryFromLandmarksGraph"); | ||||
|       face_geometry_from_landmarks | ||||
|           .GetOptions<face_geometry::proto::FaceGeometryGraphOptions>() | ||||
|           .Swap(tasks_options.mutable_face_geometry_graph_options()); | ||||
|       if (environment.has_value()) { | ||||
|         *environment >> face_geometry_from_landmarks.SideIn(kEnvironmentTag); | ||||
|       } | ||||
|  |  | |||
|  | @ -60,5 +60,6 @@ mediapipe_proto_library( | |||
|         "//mediapipe/framework:calculator_proto", | ||||
|         "//mediapipe/tasks/cc/core/proto:base_options_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_proto", | ||||
|         "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_proto", | ||||
|     ], | ||||
| ) | ||||
|  |  | |||
|  | @ -21,6 +21,7 @@ import "mediapipe/framework/calculator.proto"; | |||
| import "mediapipe/framework/calculator_options.proto"; | ||||
| import "mediapipe/tasks/cc/core/proto/base_options.proto"; | ||||
| import "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto"; | ||||
| import "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.proto"; | ||||
| import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto"; | ||||
| 
 | ||||
| option java_package = "com.google.mediapipe.tasks.vision.facelandmarker.proto"; | ||||
|  | @ -45,4 +46,8 @@ message FaceLandmarkerGraphOptions { | |||
|   // Minimum confidence for face landmarks tracking to be considered | ||||
|   // successfully. | ||||
|   optional float min_tracking_confidence = 4 [default = 0.5]; | ||||
| 
 | ||||
|   // Options for FaceGeometryGraph to get facial transformation matrix. | ||||
|   optional face_geometry.proto.FaceGeometryGraphOptions | ||||
|       face_geometry_graph_options = 5; | ||||
| } | ||||
|  |  | |||
|  | @ -47,6 +47,7 @@ cc_library( | |||
|         "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto", | ||||
|         "@com_google_absl//absl/status:statusor", | ||||
|     ], | ||||
|     alwayslink = 1, | ||||
| ) | ||||
| 
 | ||||
| cc_library( | ||||
|  |  | |||
|  | @ -294,7 +294,7 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) { | |||
|                   threadsPerThreadgroup:threads_per_group]; | ||||
|   [compute_encoder endEncoding]; | ||||
|   [command_buffer commit]; | ||||
| 
 | ||||
|   [command_buffer waitUntilCompleted]; | ||||
|   kOutputImage(cc).Send(Image(output)); | ||||
|   return absl::OkStatus(); | ||||
| } | ||||
|  |  | |||
|  | @ -36,6 +36,7 @@ cc_library( | |||
|         ":hand_association_calculator_cc_proto", | ||||
|         "//mediapipe/calculators/util:association_calculator", | ||||
|         "//mediapipe/framework:calculator_framework", | ||||
|         "//mediapipe/framework:collection_item_id", | ||||
|         "//mediapipe/framework/api2:node", | ||||
|         "//mediapipe/framework/formats:rect_cc_proto", | ||||
|         "//mediapipe/framework/port:rectangle", | ||||
|  |  | |||
|  | @ -19,6 +19,7 @@ limitations under the License. | |||
| 
 | ||||
| #include "mediapipe/framework/api2/node.h" | ||||
| #include "mediapipe/framework/calculator_framework.h" | ||||
| #include "mediapipe/framework/collection_item_id.h" | ||||
| #include "mediapipe/framework/formats/rect.pb.h" | ||||
| #include "mediapipe/framework/port/rectangle.h" | ||||
| #include "mediapipe/framework/port/status.h" | ||||
|  | @ -29,30 +30,55 @@ namespace mediapipe::api2 { | |||
| 
 | ||||
| using ::mediapipe::NormalizedRect; | ||||
| 
 | ||||
| // HandAssociationCalculator accepts multiple inputs of vectors of
 | ||||
| // NormalizedRect. The output is a vector of NormalizedRect that contains
 | ||||
| // rects from the input vectors that don't overlap with each other. When two
 | ||||
| // rects overlap, the rect that comes in from an earlier input stream is
 | ||||
| // kept in the output. If a rect has no ID (i.e. from detection stream),
 | ||||
| // then a unique rect ID is assigned for it.
 | ||||
| 
 | ||||
| // The rects in multiple input streams are effectively flattened to a single
 | ||||
| // list.  For example:
 | ||||
| // Stream1 : rect 1, rect 2
 | ||||
| // Stream2:  rect 3, rect 4
 | ||||
| // Stream3: rect 5, rect 6
 | ||||
| // (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6
 | ||||
| // In the flattened list, if a rect with a higher index overlaps with a rect a
 | ||||
| // lower index, beyond a specified IOU threshold, the rect with the lower
 | ||||
| // index will be in the output, and the rect with higher index will be
 | ||||
| // discarded.
 | ||||
| // Input:
 | ||||
| //  BASE_RECTS - Vector of NormalizedRect.
 | ||||
| //  RECTS - Vector of NormalizedRect.
 | ||||
| //
 | ||||
| // Output:
 | ||||
| //  No tag - Vector of NormalizedRect.
 | ||||
| //
 | ||||
| // Example use:
 | ||||
| // node {
 | ||||
| //   calculator: "HandAssociationCalculator"
 | ||||
| //   input_stream: "BASE_RECTS:base_rects"
 | ||||
| //   input_stream: "RECTS:0:rects0"
 | ||||
| //   input_stream: "RECTS:1:rects1"
 | ||||
| //   input_stream: "RECTS:2:rects2"
 | ||||
| //   output_stream: "output_rects"
 | ||||
| //   options {
 | ||||
| //     [mediapipe.HandAssociationCalculatorOptions.ext] {
 | ||||
| //       min_similarity_threshold: 0.1
 | ||||
| //   }
 | ||||
| // }
 | ||||
| //
 | ||||
| // IMPORTANT Notes:
 | ||||
| //  - Rects from input streams tagged with "BASE_RECTS" are always preserved.
 | ||||
| //  - This calculator checks for overlap among rects from input streams tagged
 | ||||
| //    with "RECTS". Rects are prioritized based on their index in the vector and
 | ||||
| //    input streams to the calculator. When two rects overlap, the rect that
 | ||||
| //    comes from an input stream with lower tag-index is kept in the output.
 | ||||
| //  - Example of inputs for the node above:
 | ||||
| //      "base_rects": rect 0, rect 1
 | ||||
| //      "rects0": rect 2, rect 3
 | ||||
| //      "rects1": rect 4, rect 5
 | ||||
| //      "rects2": rect 6, rect 7
 | ||||
| //    (Conceptually) flattened list: 0, 1, 2, 3, 4, 5, 6, 7.
 | ||||
| //    Rects 0, 1 will be preserved. Rects 2, 3, 4, 5, 6, 7 will be checked for
 | ||||
| //    overlap. If a rect with a higher index overlaps with a rect with lower
 | ||||
| //    index, beyond a specified IOU threshold, the rect with the lower index
 | ||||
| //    will be in the output, and the rect with higher index will be discarded.
 | ||||
| // TODO: Upgrade this to latest API for calculators
 | ||||
| class HandAssociationCalculator : public CalculatorBase { | ||||
|  public: | ||||
|   static absl::Status GetContract(CalculatorContract* cc) { | ||||
|     // Initialize input and output streams.
 | ||||
|     for (auto& input_stream : cc->Inputs()) { | ||||
|       input_stream.Set<std::vector<NormalizedRect>>(); | ||||
|     for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS"); | ||||
|          id != cc->Inputs().EndId("BASE_RECTS"); ++id) { | ||||
|       cc->Inputs().Get(id).Set<std::vector<NormalizedRect>>(); | ||||
|     } | ||||
|     for (CollectionItemId id = cc->Inputs().BeginId("RECTS"); | ||||
|          id != cc->Inputs().EndId("RECTS"); ++id) { | ||||
|       cc->Inputs().Get(id).Set<std::vector<NormalizedRect>>(); | ||||
|     } | ||||
|     cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>(); | ||||
| 
 | ||||
|  | @ -89,7 +115,24 @@ class HandAssociationCalculator : public CalculatorBase { | |||
|       CalculatorContext* cc) { | ||||
|     std::vector<NormalizedRect> result; | ||||
| 
 | ||||
|     for (const auto& input_stream : cc->Inputs()) { | ||||
|     for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS"); | ||||
|          id != cc->Inputs().EndId("BASE_RECTS"); ++id) { | ||||
|       const auto& input_stream = cc->Inputs().Get(id); | ||||
|       if (input_stream.IsEmpty()) { | ||||
|         continue; | ||||
|       } | ||||
| 
 | ||||
|       for (auto rect : input_stream.Get<std::vector<NormalizedRect>>()) { | ||||
|         if (!rect.has_rect_id()) { | ||||
|           rect.set_rect_id(GetNextRectId()); | ||||
|         } | ||||
|         result.push_back(rect); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     for (CollectionItemId id = cc->Inputs().BeginId("RECTS"); | ||||
|          id != cc->Inputs().EndId("RECTS"); ++id) { | ||||
|       const auto& input_stream = cc->Inputs().Get(id); | ||||
|       if (input_stream.IsEmpty()) { | ||||
|         continue; | ||||
|       } | ||||
|  |  | |||
|  | @ -27,6 +27,8 @@ namespace mediapipe { | |||
| namespace { | ||||
| 
 | ||||
| using ::mediapipe::NormalizedRect; | ||||
| using ::testing::ElementsAre; | ||||
| using ::testing::EqualsProto; | ||||
| 
 | ||||
| class HandAssociationCalculatorTest : public testing::Test { | ||||
|  protected: | ||||
|  | @ -87,9 +89,9 @@ class HandAssociationCalculatorTest : public testing::Test { | |||
| TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) { | ||||
|   CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( | ||||
|     calculator: "HandAssociationCalculator" | ||||
|     input_stream: "input_vec_0" | ||||
|     input_stream: "input_vec_1" | ||||
|     input_stream: "input_vec_2" | ||||
|     input_stream: "BASE_RECTS:input_vec_0" | ||||
|     input_stream: "RECTS:0:input_vec_1" | ||||
|     input_stream: "RECTS:1:input_vec_2" | ||||
|     output_stream: "output_vec" | ||||
|     options { | ||||
|       [mediapipe.HandAssociationCalculatorOptions.ext] { | ||||
|  | @ -103,20 +105,23 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) { | |||
|   input_vec_0->push_back(nr_0_); | ||||
|   input_vec_0->push_back(nr_1_); | ||||
|   input_vec_0->push_back(nr_2_); | ||||
|   runner.MutableInputs()->Index(0).packets.push_back( | ||||
|       Adopt(input_vec_0.release()).At(Timestamp(1))); | ||||
|   runner.MutableInputs() | ||||
|       ->Tag("BASE_RECTS") | ||||
|       .packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   // Input Stream 1: nr_3, nr_4.
 | ||||
|   auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>(); | ||||
|   input_vec_1->push_back(nr_3_); | ||||
|   input_vec_1->push_back(nr_4_); | ||||
|   runner.MutableInputs()->Index(1).packets.push_back( | ||||
|   auto index_id = runner.MutableInputs()->GetId("RECTS", 0); | ||||
|   runner.MutableInputs()->Get(index_id).packets.push_back( | ||||
|       Adopt(input_vec_1.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   // Input Stream 2: nr_5.
 | ||||
|   auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>(); | ||||
|   input_vec_2->push_back(nr_5_); | ||||
|   runner.MutableInputs()->Index(2).packets.push_back( | ||||
|   index_id = runner.MutableInputs()->GetId("RECTS", 1); | ||||
|   runner.MutableInputs()->Get(index_id).packets.push_back( | ||||
|       Adopt(input_vec_2.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; | ||||
|  | @ -134,25 +139,18 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) { | |||
|   EXPECT_EQ(3, assoc_rects.size()); | ||||
| 
 | ||||
|   // Check that IDs are filled in and contents match.
 | ||||
|   EXPECT_EQ(assoc_rects[0].rect_id(), 1); | ||||
|   assoc_rects[0].clear_rect_id(); | ||||
|   EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); | ||||
| 
 | ||||
|   EXPECT_EQ(assoc_rects[1].rect_id(), 2); | ||||
|   assoc_rects[1].clear_rect_id(); | ||||
|   EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_)); | ||||
| 
 | ||||
|   EXPECT_EQ(assoc_rects[2].rect_id(), 3); | ||||
|   assoc_rects[2].clear_rect_id(); | ||||
|   EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_)); | ||||
|   nr_0_.set_rect_id(1); | ||||
|   nr_1_.set_rect_id(2); | ||||
|   nr_2_.set_rect_id(3); | ||||
|   EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_), | ||||
|                                        EqualsProto(nr_2_))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) { | ||||
|   CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( | ||||
|     calculator: "HandAssociationCalculator" | ||||
|     input_stream: "input_vec_0" | ||||
|     input_stream: "input_vec_1" | ||||
|     input_stream: "input_vec_2" | ||||
|     input_stream: "BASE_RECTS:input_vec_0" | ||||
|     input_stream: "RECTS:0:input_vec_1" | ||||
|     output_stream: "output_vec" | ||||
|     options { | ||||
|       [mediapipe.HandAssociationCalculatorOptions.ext] { | ||||
|  | @ -169,14 +167,15 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) { | |||
|   input_vec_0->push_back(nr_0_); | ||||
|   nr_1_.set_rect_id(-1); | ||||
|   input_vec_0->push_back(nr_1_); | ||||
|   runner.MutableInputs()->Index(0).packets.push_back( | ||||
|       Adopt(input_vec_0.release()).At(Timestamp(1))); | ||||
|   runner.MutableInputs() | ||||
|       ->Tag("BASE_RECTS") | ||||
|       .packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   // Input Stream 1: nr_2, nr_3. Newly detected palms.
 | ||||
|   auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>(); | ||||
|   input_vec_1->push_back(nr_2_); | ||||
|   input_vec_1->push_back(nr_3_); | ||||
|   runner.MutableInputs()->Index(1).packets.push_back( | ||||
|   runner.MutableInputs()->Tag("RECTS").packets.push_back( | ||||
|       Adopt(input_vec_1.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; | ||||
|  | @ -192,23 +191,17 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) { | |||
|   EXPECT_EQ(3, assoc_rects.size()); | ||||
| 
 | ||||
|   // Check that IDs are filled in and contents match.
 | ||||
|   EXPECT_EQ(assoc_rects[0].rect_id(), -2); | ||||
|   EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); | ||||
| 
 | ||||
|   EXPECT_EQ(assoc_rects[1].rect_id(), -1); | ||||
|   EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_)); | ||||
| 
 | ||||
|   EXPECT_EQ(assoc_rects[2].rect_id(), 1); | ||||
|   assoc_rects[2].clear_rect_id(); | ||||
|   EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_)); | ||||
|   nr_2_.set_rect_id(1); | ||||
|   EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_), | ||||
|                                        EqualsProto(nr_2_))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { | ||||
|   CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( | ||||
|     calculator: "HandAssociationCalculator" | ||||
|     input_stream: "input_vec_0" | ||||
|     input_stream: "input_vec_1" | ||||
|     input_stream: "input_vec_2" | ||||
|     input_stream: "BASE_RECTS:input_vec_0" | ||||
|     input_stream: "RECTS:0:input_vec_1" | ||||
|     input_stream: "RECTS:1:input_vec_2" | ||||
|     output_stream: "output_vec" | ||||
|     options { | ||||
|       [mediapipe.HandAssociationCalculatorOptions.ext] { | ||||
|  | @ -220,14 +213,16 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { | |||
|   // Input Stream 0: nr_5.
 | ||||
|   auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>(); | ||||
|   input_vec_0->push_back(nr_5_); | ||||
|   runner.MutableInputs()->Index(0).packets.push_back( | ||||
|       Adopt(input_vec_0.release()).At(Timestamp(1))); | ||||
|   runner.MutableInputs() | ||||
|       ->Tag("BASE_RECTS") | ||||
|       .packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   // Input Stream 1: nr_4, nr_3
 | ||||
|   auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>(); | ||||
|   input_vec_1->push_back(nr_4_); | ||||
|   input_vec_1->push_back(nr_3_); | ||||
|   runner.MutableInputs()->Index(1).packets.push_back( | ||||
|   auto index_id = runner.MutableInputs()->GetId("RECTS", 0); | ||||
|   runner.MutableInputs()->Get(index_id).packets.push_back( | ||||
|       Adopt(input_vec_1.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   // Input Stream 2: nr_2, nr_1, nr_0.
 | ||||
|  | @ -235,7 +230,8 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { | |||
|   input_vec_2->push_back(nr_2_); | ||||
|   input_vec_2->push_back(nr_1_); | ||||
|   input_vec_2->push_back(nr_0_); | ||||
|   runner.MutableInputs()->Index(2).packets.push_back( | ||||
|   index_id = runner.MutableInputs()->GetId("RECTS", 1); | ||||
|   runner.MutableInputs()->Get(index_id).packets.push_back( | ||||
|       Adopt(input_vec_2.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; | ||||
|  | @ -253,23 +249,78 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { | |||
|   EXPECT_EQ(3, assoc_rects.size()); | ||||
| 
 | ||||
|   // Outputs are in same order as inputs, and IDs are filled in.
 | ||||
|   EXPECT_EQ(assoc_rects[0].rect_id(), 1); | ||||
|   assoc_rects[0].clear_rect_id(); | ||||
|   EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_)); | ||||
|   nr_5_.set_rect_id(1); | ||||
|   nr_4_.set_rect_id(2); | ||||
|   nr_0_.set_rect_id(3); | ||||
|   EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_4_), | ||||
|                                        EqualsProto(nr_0_))); | ||||
| } | ||||
| 
 | ||||
|   EXPECT_EQ(assoc_rects[1].rect_id(), 2); | ||||
|   assoc_rects[1].clear_rect_id(); | ||||
|   EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_)); | ||||
| TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReservesBaseRects) { | ||||
|   CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( | ||||
|     calculator: "HandAssociationCalculator" | ||||
|     input_stream: "BASE_RECTS:input_vec_0" | ||||
|     input_stream: "RECTS:0:input_vec_1" | ||||
|     input_stream: "RECTS:1:input_vec_2" | ||||
|     output_stream: "output_vec" | ||||
|     options { | ||||
|       [mediapipe.HandAssociationCalculatorOptions.ext] { | ||||
|         min_similarity_threshold: 0.1 | ||||
|       } | ||||
|     } | ||||
|   )pb")); | ||||
| 
 | ||||
|   EXPECT_EQ(assoc_rects[2].rect_id(), 3); | ||||
|   assoc_rects[2].clear_rect_id(); | ||||
|   EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_)); | ||||
|   // Input Stream 0: nr_5, nr_3, nr_1.
 | ||||
|   auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>(); | ||||
|   input_vec_0->push_back(nr_5_); | ||||
|   input_vec_0->push_back(nr_3_); | ||||
|   input_vec_0->push_back(nr_1_); | ||||
|   runner.MutableInputs() | ||||
|       ->Tag("BASE_RECTS") | ||||
|       .packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   // Input Stream 1: nr_4.
 | ||||
|   auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>(); | ||||
|   input_vec_1->push_back(nr_4_); | ||||
|   auto index_id = runner.MutableInputs()->GetId("RECTS", 0); | ||||
|   runner.MutableInputs()->Get(index_id).packets.push_back( | ||||
|       Adopt(input_vec_1.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   // Input Stream 2: nr_2, nr_0.
 | ||||
|   auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>(); | ||||
|   input_vec_2->push_back(nr_2_); | ||||
|   input_vec_2->push_back(nr_0_); | ||||
|   index_id = runner.MutableInputs()->GetId("RECTS", 1); | ||||
|   runner.MutableInputs()->Get(index_id).packets.push_back( | ||||
|       Adopt(input_vec_2.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; | ||||
|   const std::vector<Packet>& output = runner.Outputs().Index(0).packets; | ||||
|   EXPECT_EQ(1, output.size()); | ||||
|   auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>(); | ||||
| 
 | ||||
|   // Rectangles are added in the following sequence:
 | ||||
|   // nr_5 is added because it is in BASE_RECTS input stream.
 | ||||
|   // nr_3 is added because it is in BASE_RECTS input stream.
 | ||||
|   // nr_1 is added because it is in BASE_RECTS input stream.
 | ||||
|   // nr_4 is added because it does not overlap with nr_5.
 | ||||
|   // nr_2 is NOT added because it overlaps with nr_4.
 | ||||
|   // nr_0 is NOT added because it overlaps with nr_3.
 | ||||
|   EXPECT_EQ(4, assoc_rects.size()); | ||||
| 
 | ||||
|   // Outputs are in same order as inputs, and IDs are filled in.
 | ||||
|   nr_5_.set_rect_id(1); | ||||
|   nr_3_.set_rect_id(2); | ||||
|   nr_1_.set_rect_id(3); | ||||
|   nr_4_.set_rect_id(4); | ||||
|   EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_3_), | ||||
|                                        EqualsProto(nr_1_), EqualsProto(nr_4_))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) { | ||||
|   CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( | ||||
|     calculator: "HandAssociationCalculator" | ||||
|     input_stream: "input_vec" | ||||
|     input_stream: "BASE_RECTS:input_vec" | ||||
|     output_stream: "output_vec" | ||||
|     options { | ||||
|       [mediapipe.HandAssociationCalculatorOptions.ext] { | ||||
|  | @ -282,8 +333,9 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) { | |||
|   auto input_vec = std::make_unique<std::vector<NormalizedRect>>(); | ||||
|   input_vec->push_back(nr_3_); | ||||
|   input_vec->push_back(nr_5_); | ||||
|   runner.MutableInputs()->Index(0).packets.push_back( | ||||
|       Adopt(input_vec.release()).At(Timestamp(1))); | ||||
|   runner.MutableInputs() | ||||
|       ->Tag("BASE_RECTS") | ||||
|       .packets.push_back(Adopt(input_vec.release()).At(Timestamp(1))); | ||||
| 
 | ||||
|   MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; | ||||
|   const std::vector<Packet>& output = runner.Outputs().Index(0).packets; | ||||
|  | @ -292,12 +344,12 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) { | |||
| 
 | ||||
|   // Rectangles are added in the following sequence:
 | ||||
|   // nr_3 is added 1st.
 | ||||
|   // nr_5 is NOT added because it overlaps with nr_3.
 | ||||
|   EXPECT_EQ(1, assoc_rects.size()); | ||||
|   // nr_5 is added 2nd. The calculator assumes it does not overlap with nr_3.
 | ||||
|   EXPECT_EQ(2, assoc_rects.size()); | ||||
| 
 | ||||
|   EXPECT_EQ(assoc_rects[0].rect_id(), 1); | ||||
|   assoc_rects[0].clear_rect_id(); | ||||
|   EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_)); | ||||
|   nr_3_.set_rect_id(1); | ||||
|   nr_5_.set_rect_id(2); | ||||
|   EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_3_), EqualsProto(nr_5_))); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
|  |  | |||
|  | @ -318,9 +318,9 @@ class HandLandmarkerGraph : public core::ModelTaskGraph { | |||
|           .set_min_similarity_threshold( | ||||
|               tasks_options.min_tracking_confidence()); | ||||
|       prev_hand_rects_from_landmarks >> | ||||
|           hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][0]; | ||||
|           hand_association[Input<std::vector<NormalizedRect>>("BASE_RECTS")]; | ||||
|       hand_rects_from_hand_detector >> | ||||
|           hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][1]; | ||||
|           hand_association[Input<std::vector<NormalizedRect>>("RECTS")]; | ||||
|       auto hand_rects = hand_association.Out(""); | ||||
|       hand_rects >> clip_hand_rects.In(""); | ||||
|     } else { | ||||
|  |  | |||
|  | @ -34,18 +34,19 @@ _AUDIO_TASKS_JAVA_PROTO_LITE_TARGETS = [ | |||
| ] | ||||
| 
 | ||||
| _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ | ||||
|     "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", | ||||
|     "//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", | ||||
| ] | ||||
| 
 | ||||
| _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ | ||||
|  |  | |||
|  | @ -36,11 +36,9 @@ export abstract class AudioTaskRunner<T> extends TaskRunner { | |||
| 
 | ||||
|   /** Sends a single audio clip to the graph and awaits results. */ | ||||
|   protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { | ||||
|     // Increment the timestamp by 1 millisecond to guarantee that we send
 | ||||
|     // monotonically increasing timestamps to the graph.
 | ||||
|     const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; | ||||
|     return this.process( | ||||
|         audioData, sampleRate ?? this.defaultSampleRate, syntheticTimestamp); | ||||
|         audioData, sampleRate ?? this.defaultSampleRate, | ||||
|         this.getSynctheticTimestamp()); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -15,6 +15,11 @@ mediapipe_ts_declaration( | |||
|     deps = [":category"], | ||||
| ) | ||||
| 
 | ||||
| mediapipe_ts_declaration( | ||||
|     name = "keypoint", | ||||
|     srcs = ["keypoint.d.ts"], | ||||
| ) | ||||
| 
 | ||||
| mediapipe_ts_declaration( | ||||
|     name = "landmark", | ||||
|     srcs = ["landmark.d.ts"], | ||||
|  |  | |||
							
								
								
									
										33
									
								
								mediapipe/tasks/web/components/containers/keypoint.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						|  | @ -0,0 +1,33 @@ | |||
| /** | ||||
|  * Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| /** | ||||
|  * A keypoint, defined by the coordinates (x, y), normalized by the image | ||||
|  * dimensions. | ||||
|  */ | ||||
| export declare interface NormalizedKeypoint { | ||||
|   /** X in normalized image coordinates. */ | ||||
|   x: number; | ||||
| 
 | ||||
|   /** Y in normalized image coordinates. */ | ||||
|   y: number; | ||||
| 
 | ||||
|   /** Optional label of the keypoint. */ | ||||
|   label?: string; | ||||
| 
 | ||||
|   /** Optional score of the keypoint. */ | ||||
|   score?: number; | ||||
| } | ||||
|  | @ -175,9 +175,13 @@ export abstract class TaskRunner { | |||
|         Math.max(this.latestOutputTimestamp, timestamp); | ||||
|   } | ||||
| 
 | ||||
|   /** Returns the latest output timestamp. */ | ||||
|   protected getLatestOutputTimestamp() { | ||||
|     return this.latestOutputTimestamp; | ||||
|   /** | ||||
|    * Gets a syncthethic timestamp in ms that can be used to send data to the | ||||
|    * next packet. The timestamp is one millisecond past the last timestamp | ||||
|    * received from the graph. | ||||
|    */ | ||||
|   protected getSynctheticTimestamp(): number { | ||||
|     return this.latestOutputTimestamp + 1; | ||||
|   } | ||||
| 
 | ||||
|   /** Throws the error from the error listener if an error was raised. */ | ||||
|  |  | |||
|  | @ -131,11 +131,9 @@ export class TextClassifier extends TaskRunner { | |||
|    * @return The classification result of the text | ||||
|    */ | ||||
|   classify(text: string): TextClassifierResult { | ||||
|     // Increment the timestamp by 1 millisecond to guarantee that we send
 | ||||
|     // monotonically increasing timestamps to the graph.
 | ||||
|     const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; | ||||
|     this.classificationResult = {classifications: []}; | ||||
|     this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); | ||||
|     this.graphRunner.addStringToStream( | ||||
|         text, INPUT_STREAM, this.getSynctheticTimestamp()); | ||||
|     this.finishProcessing(); | ||||
|     return this.classificationResult; | ||||
|   } | ||||
|  |  | |||
|  | @ -135,10 +135,8 @@ export class TextEmbedder extends TaskRunner { | |||
|    * @return The embedding resuls of the text | ||||
|    */ | ||||
|   embed(text: string): TextEmbedderResult { | ||||
|     // Increment the timestamp by 1 millisecond to guarantee that we send
 | ||||
|     // monotonically increasing timestamps to the graph.
 | ||||
|     const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; | ||||
|     this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); | ||||
|     this.graphRunner.addStringToStream( | ||||
|         text, INPUT_STREAM, this.getSynctheticTimestamp()); | ||||
|     this.finishProcessing(); | ||||
|     return this.embeddingResult; | ||||
|   } | ||||
|  |  | |||
|  | @ -2,23 +2,42 @@ | |||
| 
 | ||||
| This package contains the vision tasks for MediaPipe. | ||||
| 
 | ||||
| ## Object Detection | ||||
| ## Gesture Recognition | ||||
| 
 | ||||
| The MediaPipe Object Detector task lets you detect the presence and location of | ||||
| multiple classes of objects within images or videos. | ||||
| The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real | ||||
| time, and provides the recognized hand gesture results along with the landmarks | ||||
| of the detected hands. You can use this task to recognize specific hand gestures | ||||
| from a user, and invoke application features that correspond to those gestures. | ||||
| 
 | ||||
| ``` | ||||
| const vision = await FilesetResolver.forVisionTasks( | ||||
|     "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" | ||||
| ); | ||||
| const objectDetector = await ObjectDetector.createFromModelPath(vision, | ||||
|     "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" | ||||
| const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision, | ||||
|     "https://storage.googleapis.com/mediapipe-tasks/gesture_recognizer/gesture_recognizer.task" | ||||
| ); | ||||
| const image = document.getElementById("image") as HTMLImageElement; | ||||
| const detections = objectDetector.detect(image); | ||||
| const recognitions = gestureRecognizer.recognize(image); | ||||
| ``` | ||||
| 
 | ||||
| For more information, refer to the [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector/web_js) documentation. | ||||
| ## Hand Landmark Detection | ||||
| 
 | ||||
| The MediaPipe Hand Landmarker task lets you detect the landmarks of the hands in | ||||
| an image. You can use this Task to localize key points of the hands and render | ||||
| visual effects over the hands. | ||||
| 
 | ||||
| ``` | ||||
| const vision = await FilesetResolver.forVisionTasks( | ||||
|     "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" | ||||
| ); | ||||
| const handLandmarker = await HandLandmarker.createFromModelPath(vision, | ||||
|     "https://storage.googleapis.com/mediapipe-tasks/hand_landmarker/hand_landmarker.task" | ||||
| ); | ||||
| const image = document.getElementById("image") as HTMLImageElement; | ||||
| const landmarks = handLandmarker.detect(image); | ||||
| ``` | ||||
| 
 | ||||
| For more information, refer to the [Handlandmark Detection](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker/web_js) documentation. | ||||
| 
 | ||||
| ## Image Classification | ||||
| 
 | ||||
|  | @ -56,40 +75,21 @@ imageSegmenter.segment(image, (masks, width, height) => { | |||
| }); | ||||
| ``` | ||||
| 
 | ||||
| ## Gesture Recognition | ||||
| ## Object Detection | ||||
| 
 | ||||
| The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real | ||||
| time, and provides the recognized hand gesture results along with the landmarks | ||||
| of the detected hands. You can use this task to recognize specific hand gestures | ||||
| from a user, and invoke application features that correspond to those gestures. | ||||
| The MediaPipe Object Detector task lets you detect the presence and location of | ||||
| multiple classes of objects within images or videos. | ||||
| 
 | ||||
| ``` | ||||
| const vision = await FilesetResolver.forVisionTasks( | ||||
|     "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" | ||||
| ); | ||||
| const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision, | ||||
|     "https://storage.googleapis.com/mediapipe-tasks/gesture_recognizer/gesture_recognizer.task" | ||||
| const objectDetector = await ObjectDetector.createFromModelPath(vision, | ||||
|     "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" | ||||
| ); | ||||
| const image = document.getElementById("image") as HTMLImageElement; | ||||
| const recognitions = gestureRecognizer.recognize(image); | ||||
| const detections = objectDetector.detect(image); | ||||
| ``` | ||||
| 
 | ||||
| ## Handlandmark Detection | ||||
| 
 | ||||
| The MediaPipe Hand Landmarker task lets you detect the landmarks of the hands in | ||||
| an image. You can use this Task to localize key points of the hands and render | ||||
| visual effects over the hands. | ||||
| 
 | ||||
| ``` | ||||
| const vision = await FilesetResolver.forVisionTasks( | ||||
|     "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" | ||||
| ); | ||||
| const handLandmarker = await HandLandmarker.createFromModelPath(vision, | ||||
|     "https://storage.googleapis.com/mediapipe-tasks/hand_landmarker/hand_landmarker.task" | ||||
| ); | ||||
| const image = document.getElementById("image") as HTMLImageElement; | ||||
| const landmarks = handLandmarker.detect(image); | ||||
| ``` | ||||
| 
 | ||||
| For more information, refer to the [Handlandmark Detection](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker/web_js) documentation. | ||||
| For more information, refer to the [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector/web_js) documentation. | ||||
| 
 | ||||
|  |  | |||
|  | @ -21,6 +21,14 @@ mediapipe_ts_declaration( | |||
|     ], | ||||
| ) | ||||
| 
 | ||||
| mediapipe_ts_declaration( | ||||
|     name = "types", | ||||
|     srcs = ["types.d.ts"], | ||||
|     deps = [ | ||||
|         "//mediapipe/tasks/web/components/containers:keypoint", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| mediapipe_ts_library( | ||||
|     name = "vision_task_runner", | ||||
|     srcs = ["vision_task_runner.ts"], | ||||
|  | @ -51,6 +59,11 @@ mediapipe_ts_library( | |||
|     ], | ||||
| ) | ||||
| 
 | ||||
| mediapipe_ts_library( | ||||
|     name = "render_utils", | ||||
|     srcs = ["render_utils.ts"], | ||||
| ) | ||||
| 
 | ||||
| jasmine_node_test( | ||||
|     name = "vision_task_runner_test", | ||||
|     deps = [":vision_task_runner_test_lib"], | ||||
|  |  | |||
							
								
								
									
										78
									
								
								mediapipe/tasks/web/vision/core/render_utils.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						|  | @ -0,0 +1,78 @@ | |||
| /** @fileoverview Utility functions used in the vision demos. */ | ||||
| 
 | ||||
| /** | ||||
|  * Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| // Pre-baked color table for a maximum of 12 classes.
 | ||||
| const CM_ALPHA = 128; | ||||
| const COLOR_MAP = [ | ||||
|   [0, 0, 0, CM_ALPHA],        // class 0 is BG = transparent
 | ||||
|   [255, 0, 0, CM_ALPHA],      // class 1 is red
 | ||||
|   [0, 255, 0, CM_ALPHA],      // class 2 is light green
 | ||||
|   [0, 0, 255, CM_ALPHA],      // class 3 is blue
 | ||||
|   [255, 255, 0, CM_ALPHA],    // class 4 is yellow
 | ||||
|   [255, 0, 255, CM_ALPHA],    // class 5 is light purple / magenta
 | ||||
|   [0, 255, 255, CM_ALPHA],    // class 6 is light blue / aqua
 | ||||
|   [128, 128, 128, CM_ALPHA],  // class 7 is gray
 | ||||
|   [255, 128, 0, CM_ALPHA],    // class 8 is orange
 | ||||
|   [128, 0, 255, CM_ALPHA],    // class 9 is dark purple
 | ||||
|   [0, 128, 0, CM_ALPHA],      // class 10 is dark green
 | ||||
|   [255, 255, 255, CM_ALPHA]   // class 11 is white; could do black instead?
 | ||||
| ]; | ||||
| 
 | ||||
| 
 | ||||
| /** Helper function to draw a confidence mask */ | ||||
| export function drawConfidenceMask( | ||||
| 
 | ||||
|     ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array, | ||||
|     width: number, height: number): void { | ||||
|   const uint8ClampedArray = new Uint8ClampedArray(width * height * 4); | ||||
|   for (let i = 0; i < image.length; i++) { | ||||
|     uint8ClampedArray[4 * i] = 128; | ||||
|     uint8ClampedArray[4 * i + 1] = 0; | ||||
|     uint8ClampedArray[4 * i + 2] = 0; | ||||
|     uint8ClampedArray[4 * i + 3] = image[i] * 255; | ||||
|   } | ||||
|   ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0); | ||||
| } | ||||
| 
 | ||||
| /** | ||||
|  * Helper function to draw a category mask. For GPU, we only have F32Arrays | ||||
|  * for now. | ||||
|  */ | ||||
| export function drawCategoryMask( | ||||
|     ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array, | ||||
|     width: number, height: number): void { | ||||
|   const uint8ClampedArray = new Uint8ClampedArray(width * height * 4); | ||||
|   const isFloatArray = image instanceof Float32Array; | ||||
|   for (let i = 0; i < image.length; i++) { | ||||
|     const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i]; | ||||
|     const color = COLOR_MAP[colorIndex]; | ||||
| 
 | ||||
|     // When we're given a confidence mask by accident, we just log and return.
 | ||||
|     // TODO: We should fix this.
 | ||||
|     if (!color) { | ||||
|       console.warn('No color for ', colorIndex); | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     uint8ClampedArray[4 * i] = color[0]; | ||||
|     uint8ClampedArray[4 * i + 1] = color[1]; | ||||
|     uint8ClampedArray[4 * i + 2] = color[2]; | ||||
|     uint8ClampedArray[4 * i + 3] = color[3]; | ||||
|   } | ||||
|   ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0); | ||||
| } | ||||
							
								
								
									
										42
									
								
								mediapipe/tasks/web/vision/core/types.d.ts
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						|  | @ -0,0 +1,42 @@ | |||
| /** | ||||
|  * Copyright 2023 The MediaPipe Authors. All Rights Reserved. | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint'; | ||||
| 
 | ||||
| /** | ||||
|  * The segmentation tasks return the segmentation result as a Uint8Array | ||||
|  * (when the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for | ||||
|  * output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved | ||||
|  * for future usage. | ||||
|  */ | ||||
| export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture; | ||||
| 
 | ||||
| /** | ||||
|  * A callback that receives the computed masks from the segmentation tasks. The | ||||
|  * callback either receives a single element array with a category mask (as a | ||||
|  * `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`). | ||||
|  * The returned data is only valid for the duration of the callback. If | ||||
|  * asynchronous processing is needed, all data needs to be copied before the | ||||
|  * callback returns. | ||||
|  */ | ||||
| export type SegmentationMaskCallback = | ||||
|     (masks: SegmentationMask[], width: number, height: number) => void; | ||||
| 
 | ||||
| /** A Region-Of-Interest (ROI) to represent a region within an image. */ | ||||
| export declare interface RegionOfInterest { | ||||
|   /** The ROI in keypoint format. */ | ||||
|   keypoint: NormalizedKeypoint; | ||||
| } | ||||
|  | @ -74,11 +74,7 @@ export abstract class VisionTaskRunner extends TaskRunner { | |||
|           'Task is not initialized with image mode. ' + | ||||
|           '\'runningMode\' must be set to \'IMAGE\'.'); | ||||
|     } | ||||
| 
 | ||||
|     // Increment the timestamp by 1 millisecond to guarantee that we send
 | ||||
|     // monotonically increasing timestamps to the graph.
 | ||||
|     const syntheticTimestamp = this.getLatestOutputTimestamp() + 1; | ||||
|     this.process(image, imageProcessingOptions, syntheticTimestamp); | ||||
|     this.process(image, imageProcessingOptions, this.getSynctheticTimestamp()); | ||||
|   } | ||||
| 
 | ||||
|   /** Sends a single video frame to the graph and awaits results. */ | ||||
|  |  | |||
|  | @ -19,8 +19,8 @@ mediapipe_ts_library( | |||
|         "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto", | ||||
|         "//mediapipe/tasks/web/core", | ||||
|         "//mediapipe/tasks/web/vision/core:image_processing_options", | ||||
|         "//mediapipe/tasks/web/vision/core:types", | ||||
|         "//mediapipe/tasks/web/vision/core:vision_task_runner", | ||||
|         "//mediapipe/web/graph_runner:graph_runner_image_lib_ts", | ||||
|         "//mediapipe/web/graph_runner:graph_runner_ts", | ||||
|     ], | ||||
| ) | ||||
|  |  | |||
|  | @ -21,6 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../ | |||
| import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; | ||||
| import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; | ||||
| import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; | ||||
| import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types'; | ||||
| import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; | ||||
| import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; | ||||
| // Placeholder for internal dependency on trusted resource url
 | ||||
|  | @ -28,27 +29,9 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner | |||
| import {ImageSegmenterOptions} from './image_segmenter_options'; | ||||
| 
 | ||||
| export * from './image_segmenter_options'; | ||||
| export {SegmentationMask, SegmentationMaskCallback}; | ||||
| export {ImageSource};  // Used in the public API
 | ||||
| 
 | ||||
| /** | ||||
|  * The ImageSegmenter returns the segmentation result as a Uint8Array (when | ||||
|  * the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for | ||||
|  * output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved | ||||
|  * for future usage. | ||||
|  */ | ||||
| export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture; | ||||
| 
 | ||||
| /** | ||||
|  * A callback that receives the computed masks from the image segmenter. The | ||||
|  * callback either receives a single element array with a category mask (as a | ||||
|  * `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`). | ||||
|  * The returned data is only valid for the duration of the callback. If | ||||
|  * asynchronous processing is needed, all data needs to be copied before the | ||||
|  * callback returns. | ||||
|  */ | ||||
| export type SegmentationMaskCallback = | ||||
|     (masks: SegmentationMask[], width: number, height: number) => void; | ||||
| 
 | ||||
| const IMAGE_STREAM = 'image_in'; | ||||
| const NORM_RECT_STREAM = 'norm_rect'; | ||||
| const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; | ||||
|  |  | |||
							
								
								
									
										1
									
								
								third_party/BUILD
									
									
									
									
										vendored
									
									
								
							
							
						
						|  | @ -112,6 +112,7 @@ cmake_external( | |||
|         "WITH_JPEG": "ON", | ||||
|         "WITH_PNG": "ON", | ||||
|         "WITH_TIFF": "ON", | ||||
|         "WITH_OPENCL": "OFF", | ||||
|         "WITH_WEBP": "OFF", | ||||
|         # Optimization flags | ||||
|         "CV_ENABLE_INTRINSICS": "ON", | ||||
|  |  | |||
							
								
								
									
										10
									
								
								third_party/external_files.bzl
									
									
									
									
										vendored
									
									
								
							
							
						
						|  | @ -67,7 +67,7 @@ def external_files(): | |||
|     http_file( | ||||
|         name = "com_google_mediapipe_BUILD", | ||||
|         sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3", | ||||
|         urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"], | ||||
|         urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=16618756636939761678323576393653"], | ||||
|     ) | ||||
| 
 | ||||
|     http_file( | ||||
|  | @ -318,8 +318,8 @@ def external_files(): | |||
| 
 | ||||
|     http_file( | ||||
|         name = "com_google_mediapipe_face_landmarker_with_blendshapes_task", | ||||
|         sha256 = "a75c1ba70e4b8568000af2ad0b355ed559ab5d5793db50fa9ad241f8dc4fad5f", | ||||
|         urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmarker_with_blendshapes.task?generation=1678323586260800"], | ||||
|         sha256 = "b44e4cae6f5822456d60f33e7c852640d78c7e342aee7eacc22589451a0b9dc2", | ||||
|         urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmarker_with_blendshapes.task?generation=1678504998301299"], | ||||
|     ) | ||||
| 
 | ||||
|     http_file( | ||||
|  | @ -822,8 +822,8 @@ def external_files(): | |||
| 
 | ||||
|     http_file( | ||||
|         name = "com_google_mediapipe_portrait_expected_face_geometry_with_attention_pbtxt", | ||||
|         sha256 = "5cc57b8da3ad0527dce581fe1309f6b36043e5837e3f4f5af5e24005a99dc52a", | ||||
|         urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678323601064393"], | ||||
|         sha256 = "7ed1eed98e61e0a10811bb611c895d87c8023f398a36db01b6d9ba2e1ab09e16", | ||||
|         urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678505004840652"], | ||||
|     ) | ||||
| 
 | ||||
|     http_file( | ||||
|  |  | |||