Merge branch 'google:master' into face-landmarker-python

This commit is contained in:
Kinar R 2023-03-14 11:27:52 +05:30 committed by GitHub
commit 4a7489cd3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
56 changed files with 969 additions and 285 deletions

View File

@ -118,9 +118,9 @@ on how to build MediaPipe examples.
* With a TensorFlow Model * With a TensorFlow Model
This uses the This uses the
[TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model) [TensorFlow model](https://github.com/google/mediapipe/tree/v0.8.10/mediapipe/models/object_detection_saved_model)
( see also ( see also
[model info](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md)), [model info](https://github.com/google/mediapipe/tree/master/mediapipe/g3doc/solutions/object_detection_saved_model.md)),
and the pipeline is implemented in this and the pipeline is implemented in this
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt). [graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt).

View File

@ -0,0 +1,62 @@
## TensorFlow/TFLite Object Detection Model
### TensorFlow model
The model is trained on [MSCOCO 2014](http://cocodataset.org) dataset using [TensorFlow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection). It is a MobileNetV2-based SSD model with 0.5 depth multiplier. Detailed training configuration is in the provided `pipeline.config`. The model is a relatively compact model which has `0.171 mAP` to achieve real-time performance on mobile devices. You can compare it with other models from the [TensorFlow detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md).
### TFLite model
The TFLite model is converted from the TensorFlow above. The steps needed to convert the model are similar to [this tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) with minor modifications. Assuming now we have a trained TensorFlow model which includes the checkpoint files and the training configuration file, for example the files provided in this repo:
* `model.ckpt.index`
* `model.ckpt.meta`
* `model.ckpt.data-00000-of-00001`
* `pipeline.config`
Make sure you have installed these [python libraries](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1.md). Then to get the frozen graph, run the `export_tflite_ssd_graph.py` script from the `models/research` directory with this command:
```bash
$ PATH_TO_MODEL=path/to/the/model
$ bazel run object_detection:export_tflite_ssd_graph -- \
--pipeline_config_path ${PATH_TO_MODEL}/pipeline.config \
--trained_checkpoint_prefix ${PATH_TO_MODEL}/model.ckpt \
--output_directory ${PATH_TO_MODEL} \
--add_postprocessing_op=False
```
The exported model contains two files:
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
The difference between this step and the one in [the tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) is that we set `add_postprocessing_op` to False. In MediaPipe, we have provided all the calculators needed for post-processing such that we can exclude the custom TFLite ops for post-processing in the original graph, e.g., non-maximum suppression. This enables the flexibility to integrate with different post-processing algorithms and implementations.
Optional: You can install and use the [graph tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms) to inspect the input/output of the exported model:
```bash
$ bazel run graph_transforms:summarize_graph -- \
--in_graph=${PATH_TO_MODEL}/tflite_graph.pb
```
You should be able to see the input image size of the model is 320x320 and the outputs of the model are:
* `raw_outputs/box_encodings`
* `raw_outputs/class_predictions`
The last step is to convert the model to TFLite. You can look at [this guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md) for more detail. For this example, you just need to run:
```bash
$ tflite_convert -- \
--graph_def_file=${PATH_TO_MODEL}/tflite_graph.pb \
--output_file=${PATH_TO_MODEL}/model.tflite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--inference_type=FLOAT \
--input_shapes=1,320,320,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays=raw_outputs/box_encodings,raw_outputs/class_predictions
```
Now you have the TFLite model `model.tflite` ready to use with MediaPipe Object Detection graphs. Please see the examples for more detail.

View File

@ -269,6 +269,7 @@ Supported configuration options:
```python ```python
import cv2 import cv2
import mediapipe as mp import mediapipe as mp
import numpy as np
mp_drawing = mp.solutions.drawing_utils mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles mp_drawing_styles = mp.solutions.drawing_styles
mp_pose = mp.solutions.pose mp_pose = mp.solutions.pose

View File

@ -748,6 +748,7 @@ cc_test(
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
], ],
tags = ["desktop_only_test"], tags = ["desktop_only_test"],

View File

@ -29,6 +29,9 @@ class AffineTransformation {
// pixels will be calculated. // pixels will be calculated.
enum class BorderMode { kZero, kReplicate }; enum class BorderMode { kZero, kReplicate };
// Pixel sampling interpolation method.
enum class Interpolation { kLinear, kCubic };
struct Size { struct Size {
int width; int width;
int height; int height;

View File

@ -77,8 +77,11 @@ class GlTextureWarpAffineRunner
std::unique_ptr<GpuBuffer>> { std::unique_ptr<GpuBuffer>> {
public: public:
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper, GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
GpuOrigin::Mode gpu_origin) GpuOrigin::Mode gpu_origin,
: gl_helper_(gl_helper), gpu_origin_(gpu_origin) {} AffineTransformation::Interpolation interpolation)
: gl_helper_(gl_helper),
gpu_origin_(gpu_origin),
interpolation_(interpolation) {}
absl::Status Init() { absl::Status Init() {
return gl_helper_->RunInGlContext([this]() -> absl::Status { return gl_helper_->RunInGlContext([this]() -> absl::Status {
const GLint attr_location[kNumAttributes] = { const GLint attr_location[kNumAttributes] = {
@ -103,10 +106,13 @@ class GlTextureWarpAffineRunner
} }
)"; )";
// TODO Move bicubic code to common shared place.
constexpr GLchar kFragShader[] = R"( constexpr GLchar kFragShader[] = R"(
DEFAULT_PRECISION(highp, float) DEFAULT_PRECISION(highp, float)
in vec2 sample_coordinate; in vec2 sample_coordinate;
uniform sampler2D input_texture; uniform sampler2D input_texture;
uniform vec2 input_size;
#ifdef GL_ES #ifdef GL_ES
#define fragColor gl_FragColor #define fragColor gl_FragColor
@ -114,8 +120,60 @@ class GlTextureWarpAffineRunner
out vec4 fragColor; out vec4 fragColor;
#endif // defined(GL_ES); #endif // defined(GL_ES);
#ifdef CUBIC_INTERPOLATION
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
const vec2 halve = vec2(0.5,0.5);
const vec2 one = vec2(1.0,1.0);
const vec2 two = vec2(2.0,2.0);
const vec2 three = vec2(3.0,3.0);
const vec2 six = vec2(6.0,6.0);
// Calculate the fraction and integer.
tex_coord = tex_coord * tex_size - halve;
vec2 frac = fract(tex_coord);
vec2 index = tex_coord - frac + halve;
// Calculate weights for Catmull-Rom filter.
vec2 w0 = frac * (-halve + frac * (one - halve * frac));
vec2 w1 = one + frac * frac * (-(two+halve) + three/two * frac);
vec2 w2 = frac * (halve + frac * (two - three/two * frac));
vec2 w3 = frac * frac * (-halve + halve * frac);
// Calculate weights to take advantage of bilinear texture lookup.
vec2 w12 = w1 + w2;
vec2 offset12 = w2 / (w1 + w2);
vec2 index_tl = index - one;
vec2 index_br = index + two;
vec2 index_eq = index + offset12;
index_tl /= tex_size;
index_br /= tex_size;
index_eq /= tex_size;
// 9 texture lookup and linear blending.
vec4 color = vec4(0.0);
color += texture2D(tex, vec2(index_tl.x, index_tl.y)) * w0.x * w0.y;
color += texture2D(tex, vec2(index_eq.x, index_tl.y)) * w12.x *w0.y;
color += texture2D(tex, vec2(index_br.x, index_tl.y)) * w3.x * w0.y;
color += texture2D(tex, vec2(index_tl.x, index_eq.y)) * w0.x * w12.y;
color += texture2D(tex, vec2(index_eq.x, index_eq.y)) * w12.x *w12.y;
color += texture2D(tex, vec2(index_br.x, index_eq.y)) * w3.x * w12.y;
color += texture2D(tex, vec2(index_tl.x, index_br.y)) * w0.x * w3.y;
color += texture2D(tex, vec2(index_eq.x, index_br.y)) * w12.x *w3.y;
color += texture2D(tex, vec2(index_br.x, index_br.y)) * w3.x * w3.y;
return color;
}
#else
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
return texture2D(tex, tex_coord);
}
#endif // defined(CUBIC_INTERPOLATION)
void main() { void main() {
vec4 color = texture2D(input_texture, sample_coordinate); vec4 color = sample(input_texture, sample_coordinate, input_size);
#ifdef CUSTOM_ZERO_BORDER_MODE #ifdef CUSTOM_ZERO_BORDER_MODE
float out_of_bounds = float out_of_bounds =
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 || float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
@ -137,14 +195,28 @@ class GlTextureWarpAffineRunner
glUseProgram(program); glUseProgram(program);
glUniform1i(glGetUniformLocation(program, "input_texture"), 1); glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
GLint matrix_id = glGetUniformLocation(program, "transform_matrix"); GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
return Program{.id = program, .matrix_id = matrix_id}; GLint size_id = glGetUniformLocation(program, "input_size");
return Program{
.id = program, .matrix_id = matrix_id, .size_id = size_id};
}; };
const std::string vert_src = const std::string vert_src =
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader); absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
const std::string frag_src = absl::StrCat( std::string interpolation_def;
mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader); switch (interpolation_) {
case AffineTransformation::Interpolation::kCubic:
interpolation_def = R"(
#define CUBIC_INTERPOLATION
)";
break;
case AffineTransformation::Interpolation::kLinear:
break;
}
const std::string frag_src =
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
interpolation_def, kFragShader);
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src)); ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
@ -152,9 +224,9 @@ class GlTextureWarpAffineRunner
std::string custom_zero_border_mode_def = R"( std::string custom_zero_border_mode_def = R"(
#define CUSTOM_ZERO_BORDER_MODE #define CUSTOM_ZERO_BORDER_MODE
)"; )";
const std::string frag_custom_zero_src = const std::string frag_custom_zero_src = absl::StrCat(
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble, mediapipe::kMediaPipeFragmentShaderPreamble,
custom_zero_border_mode_def, kFragShader); custom_zero_border_mode_def, interpolation_def, kFragShader);
return create_fn(vert_src, frag_custom_zero_src); return create_fn(vert_src, frag_custom_zero_src);
}; };
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED #if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
@ -256,6 +328,7 @@ class GlTextureWarpAffineRunner
} }
glUseProgram(program->id); glUseProgram(program->id);
// uniforms
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data()); Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) { if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
// @matrix describes affine transformation in terms of TOP LEFT origin, so // @matrix describes affine transformation in terms of TOP LEFT origin, so
@ -275,6 +348,10 @@ class GlTextureWarpAffineRunner
eigen_mat.transposeInPlace(); eigen_mat.transposeInPlace();
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data()); glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
if (interpolation_ == AffineTransformation::Interpolation::kCubic) {
glUniform2f(program->size_id, texture.width(), texture.height());
}
// vao // vao
glBindVertexArray(vao_); glBindVertexArray(vao_);
@ -327,6 +404,7 @@ class GlTextureWarpAffineRunner
struct Program { struct Program {
GLuint id; GLuint id;
GLint matrix_id; GLint matrix_id;
GLint size_id;
}; };
std::shared_ptr<GlCalculatorHelper> gl_helper_; std::shared_ptr<GlCalculatorHelper> gl_helper_;
GpuOrigin::Mode gpu_origin_; GpuOrigin::Mode gpu_origin_;
@ -335,6 +413,8 @@ class GlTextureWarpAffineRunner
Program program_; Program program_;
std::optional<Program> program_custom_zero_; std::optional<Program> program_custom_zero_;
GLuint framebuffer_ = 0; GLuint framebuffer_ = 0;
AffineTransformation::Interpolation interpolation_ =
AffineTransformation::Interpolation::kLinear;
}; };
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED #undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
@ -344,9 +424,10 @@ class GlTextureWarpAffineRunner
absl::StatusOr<std::unique_ptr< absl::StatusOr<std::unique_ptr<
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>> AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
CreateAffineTransformationGlRunner( CreateAffineTransformationGlRunner(
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin) { std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin,
auto runner = AffineTransformation::Interpolation interpolation) {
absl::make_unique<GlTextureWarpAffineRunner>(gl_helper, gpu_origin); auto runner = absl::make_unique<GlTextureWarpAffineRunner>(
gl_helper, gpu_origin, interpolation);
MP_RETURN_IF_ERROR(runner->Init()); MP_RETURN_IF_ERROR(runner->Init());
return runner; return runner;
} }

View File

@ -29,7 +29,8 @@ absl::StatusOr<std::unique_ptr<AffineTransformation::Runner<
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>> mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
CreateAffineTransformationGlRunner( CreateAffineTransformationGlRunner(
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper, std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
mediapipe::GpuOrigin::Mode gpu_origin); mediapipe::GpuOrigin::Mode gpu_origin,
AffineTransformation::Interpolation interpolation);
} // namespace mediapipe } // namespace mediapipe

View File

@ -39,9 +39,22 @@ cv::BorderTypes GetBorderModeForOpenCv(
} }
} }
int GetInterpolationForOpenCv(
AffineTransformation::Interpolation interpolation) {
switch (interpolation) {
case AffineTransformation::Interpolation::kLinear:
return cv::INTER_LINEAR;
case AffineTransformation::Interpolation::kCubic:
return cv::INTER_CUBIC;
}
}
class OpenCvRunner class OpenCvRunner
: public AffineTransformation::Runner<ImageFrame, ImageFrame> { : public AffineTransformation::Runner<ImageFrame, ImageFrame> {
public: public:
OpenCvRunner(AffineTransformation::Interpolation interpolation)
: interpolation_(GetInterpolationForOpenCv(interpolation)) {}
absl::StatusOr<ImageFrame> Run( absl::StatusOr<ImageFrame> Run(
const ImageFrame& input, const std::array<float, 16>& matrix, const ImageFrame& input, const std::array<float, 16>& matrix,
const AffineTransformation::Size& size, const AffineTransformation::Size& size,
@ -142,19 +155,23 @@ class OpenCvRunner
cv::warpAffine(in_mat, out_mat, cv_affine_transform, cv::warpAffine(in_mat, out_mat, cv_affine_transform,
cv::Size(out_mat.cols, out_mat.rows), cv::Size(out_mat.cols, out_mat.rows),
/*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP, /*flags=*/interpolation_ | cv::WARP_INVERSE_MAP,
GetBorderModeForOpenCv(border_mode)); GetBorderModeForOpenCv(border_mode));
return out_image; return out_image;
} }
private:
int interpolation_ = cv::INTER_LINEAR;
}; };
} // namespace } // namespace
absl::StatusOr< absl::StatusOr<
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>> std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
CreateAffineTransformationOpenCvRunner() { CreateAffineTransformationOpenCvRunner(
return absl::make_unique<OpenCvRunner>(); AffineTransformation::Interpolation interpolation) {
return absl::make_unique<OpenCvRunner>(interpolation);
} }
} // namespace mediapipe } // namespace mediapipe

View File

@ -25,7 +25,8 @@ namespace mediapipe {
absl::StatusOr< absl::StatusOr<
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>> std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
CreateAffineTransformationOpenCvRunner(); CreateAffineTransformationOpenCvRunner(
AffineTransformation::Interpolation interpolation);
} // namespace mediapipe } // namespace mediapipe

View File

@ -53,6 +53,17 @@ AffineTransformation::BorderMode GetBorderMode(
} }
} }
AffineTransformation::Interpolation GetInterpolation(
mediapipe::WarpAffineCalculatorOptions::Interpolation interpolation) {
switch (interpolation) {
case mediapipe::WarpAffineCalculatorOptions::INTER_UNSPECIFIED:
case mediapipe::WarpAffineCalculatorOptions::INTER_LINEAR:
return AffineTransformation::Interpolation::kLinear;
case mediapipe::WarpAffineCalculatorOptions::INTER_CUBIC:
return AffineTransformation::Interpolation::kCubic;
}
}
template <typename ImageT> template <typename ImageT>
class WarpAffineRunnerHolder {}; class WarpAffineRunnerHolder {};
@ -61,16 +72,22 @@ template <>
class WarpAffineRunnerHolder<ImageFrame> { class WarpAffineRunnerHolder<ImageFrame> {
public: public:
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>; using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) {
interpolation_ = GetInterpolation(
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
return absl::OkStatus();
}
absl::StatusOr<RunnerType*> GetRunner() { absl::StatusOr<RunnerType*> GetRunner() {
if (!runner_) { if (!runner_) {
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner()); ASSIGN_OR_RETURN(runner_,
CreateAffineTransformationOpenCvRunner(interpolation_));
} }
return runner_.get(); return runner_.get();
} }
private: private:
std::unique_ptr<RunnerType> runner_; std::unique_ptr<RunnerType> runner_;
AffineTransformation::Interpolation interpolation_;
}; };
#endif // !MEDIAPIPE_DISABLE_OPENCV #endif // !MEDIAPIPE_DISABLE_OPENCV
@ -85,12 +102,14 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
gpu_origin_ = gpu_origin_ =
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin(); cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>(); gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
interpolation_ = GetInterpolation(
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
return gl_helper_->Open(cc); return gl_helper_->Open(cc);
} }
absl::StatusOr<RunnerType*> GetRunner() { absl::StatusOr<RunnerType*> GetRunner() {
if (!runner_) { if (!runner_) {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(runner_, CreateAffineTransformationGlRunner(
runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_)); gl_helper_, gpu_origin_, interpolation_));
} }
return runner_.get(); return runner_.get();
} }
@ -99,6 +118,7 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
mediapipe::GpuOrigin::Mode gpu_origin_; mediapipe::GpuOrigin::Mode gpu_origin_;
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_; std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
std::unique_ptr<RunnerType> runner_; std::unique_ptr<RunnerType> runner_;
AffineTransformation::Interpolation interpolation_;
}; };
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU

View File

@ -31,6 +31,13 @@ message WarpAffineCalculatorOptions {
BORDER_REPLICATE = 2; BORDER_REPLICATE = 2;
} }
// Pixel sampling interpolation methods. See @interpolation.
enum Interpolation {
INTER_UNSPECIFIED = 0;
INTER_LINEAR = 1;
INTER_CUBIC = 2;
}
// Pixel extrapolation method. // Pixel extrapolation method.
// When converting image to tensor it may happen that tensor needs to read // When converting image to tensor it may happen that tensor needs to read
// pixels outside image boundaries. Border mode helps to specify how such // pixels outside image boundaries. Border mode helps to specify how such
@ -43,4 +50,10 @@ message WarpAffineCalculatorOptions {
// to be flipped vertically as tensors are expected to start at top. // to be flipped vertically as tensors are expected to start at top.
// (DEFAULT or unset interpreted as CONVENTIONAL.) // (DEFAULT or unset interpreted as CONVENTIONAL.)
optional GpuOrigin.Mode gpu_origin = 2; optional GpuOrigin.Mode gpu_origin = 2;
// Sampling method for neighboring pixels.
// INTER_LINEAR (bilinear) linearly interpolates from the nearest 4 neighbors.
// INTER_CUBIC (bicubic) interpolates a small neighborhood with cubic weights.
// INTER_UNSPECIFIED or unset interpreted as INTER_LINEAR.
optional Interpolation interpolation = 3;
} }

View File

@ -63,7 +63,8 @@ void RunTest(const std::string& graph_text, const std::string& tag,
const cv::Mat& input, cv::Mat expected_result, const cv::Mat& input, cv::Mat expected_result,
float similarity_threshold, std::array<float, 16> matrix, float similarity_threshold, std::array<float, 16> matrix,
int out_width, int out_height, int out_width, int out_height,
absl::optional<AffineTransformation::BorderMode> border_mode) { std::optional<AffineTransformation::BorderMode> border_mode,
std::optional<AffineTransformation::Interpolation> interpolation) {
std::string border_mode_str; std::string border_mode_str;
if (border_mode) { if (border_mode) {
switch (*border_mode) { switch (*border_mode) {
@ -75,8 +76,20 @@ void RunTest(const std::string& graph_text, const std::string& tag,
break; break;
} }
} }
std::string interpolation_str;
if (interpolation) {
switch (*interpolation) {
case AffineTransformation::Interpolation::kLinear:
interpolation_str = "interpolation: INTER_LINEAR";
break;
case AffineTransformation::Interpolation::kCubic:
interpolation_str = "interpolation: INTER_CUBIC";
break;
}
}
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>( auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(graph_text, /*$0=*/border_mode_str)); absl::Substitute(graph_text, /*$0=*/border_mode_str,
/*$1=*/interpolation_str));
std::vector<Packet> output_packets; std::vector<Packet> output_packets;
tool::AddVectorSink("output_image", &graph_config, &output_packets); tool::AddVectorSink("output_image", &graph_config, &output_packets);
@ -132,7 +145,8 @@ struct SimilarityConfig {
void RunTest(cv::Mat input, cv::Mat expected_result, void RunTest(cv::Mat input, cv::Mat expected_result,
const SimilarityConfig& similarity, std::array<float, 16> matrix, const SimilarityConfig& similarity, std::array<float, 16> matrix,
int out_width, int out_height, int out_width, int out_height,
absl::optional<AffineTransformation::BorderMode> border_mode) { std::optional<AffineTransformation::BorderMode> border_mode,
std::optional<AffineTransformation::Interpolation> interpolation) {
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
input_stream: "output_size" input_stream: "output_size"
@ -146,12 +160,13 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
} }
} }
} }
)", )",
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix, "cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
@ -171,6 +186,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
} }
} }
} }
@ -181,7 +197,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
} }
)", )",
"cpu_image", input, expected_result, similarity.threshold_on_cpu, "cpu_image", input, expected_result, similarity.threshold_on_cpu,
matrix, out_width, out_height, border_mode); matrix, out_width, out_height, border_mode, interpolation);
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
@ -201,6 +217,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
gpu_origin: TOP_LEFT gpu_origin: TOP_LEFT
} }
} }
@ -212,7 +229,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
} }
)", )",
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix, "gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
@ -237,6 +254,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
gpu_origin: TOP_LEFT gpu_origin: TOP_LEFT
} }
} }
@ -253,7 +271,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
} }
)", )",
"gpu_image", input, expected_result, similarity.threshold_on_gpu, "gpu_image", input, expected_result, similarity.threshold_on_gpu,
matrix, out_width, out_height, border_mode); matrix, out_width, out_height, border_mode, interpolation);
} }
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi, std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
@ -287,10 +305,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) {
int out_height = 256; int out_height = 256;
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = {}; std::optional<AffineTransformation::BorderMode> border_mode = {};
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) { TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
@ -312,10 +331,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) { TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
@ -337,10 +357,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) { TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
@ -362,10 +383,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) { TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
@ -386,10 +408,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) { TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
@ -411,10 +434,38 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZeroInterpCubic) {
mediapipe::NormalizedRect roi;
roi.set_x_center(0.65f);
roi.set_y_center(0.4f);
roi.set_width(0.5f);
roi.set_height(0.5f);
roi.set_rotation(M_PI * -45.0f / 180.0f);
auto input = GetRgb(
"/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/input.jpg");
auto expected_output = GetRgb(
"/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"medium_sub_rect_with_rotation_border_zero_interp_cubic.png");
int out_width = 256;
int out_height = 256;
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation =
AffineTransformation::Interpolation::kCubic;
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.78},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRect) { TEST(WarpAffineCalculatorTest, LargeSubRect) {
@ -435,10 +486,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRect) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) { TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
@ -459,10 +511,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
@ -483,10 +536,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
@ -508,10 +562,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
@ -532,10 +587,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
int out_height = 128; int out_height = 128;
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = {}; std::optional<AffineTransformation::BorderMode> border_mode = {};
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
@ -557,10 +613,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, NoOp) { TEST(WarpAffineCalculatorTest, NoOp) {
@ -581,10 +638,11 @@ TEST(WarpAffineCalculatorTest, NoOp) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, NoOpBorderZero) { TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
@ -605,10 +663,11 @@ TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
} // namespace } // namespace

View File

@ -997,17 +997,20 @@ cc_library(
":image_to_tensor_converter_gl_buffer", ":image_to_tensor_converter_gl_buffer",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
], ],
"//mediapipe:apple": [ "//mediapipe:apple": [
":image_to_tensor_converter_metal", ":image_to_tensor_converter_metal",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
], ],
"//conditions:default": [ "//conditions:default": [
":image_to_tensor_converter_gl_buffer", ":image_to_tensor_converter_gl_buffer",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
], ],
}), }),
) )
@ -1045,6 +1048,10 @@ cc_test(
":image_to_tensor_calculator", ":image_to_tensor_calculator",
":image_to_tensor_converter", ":image_to_tensor_converter",
":image_to_tensor_utils", ":image_to_tensor_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
@ -1061,11 +1068,10 @@ cc_test(
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/util:image_test_utils", "//mediapipe/util:image_test_utils",
"@com_google_absl//absl/flags:flag", ] + select({
"@com_google_absl//absl/memory", "//mediapipe:apple": [],
"@com_google_absl//absl/strings", "//conditions:default": ["//mediapipe/gpu:gl_context"],
"@com_google_absl//absl/strings:str_format", }),
],
) )
cc_library( cc_library(

View File

@ -45,9 +45,11 @@
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h"
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_service.h"
#else #else
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h"
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_service.h"
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
@ -147,7 +149,7 @@ class ImageToTensorCalculator : public Node {
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#else #else
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); cc->UseService(kGpuService).Optional();
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#endif // MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_DISABLE_GPU

View File

@ -41,6 +41,10 @@
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/util/image_test_utils.h" #include "mediapipe/util/image_test_utils.h"
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
#include "mediapipe/gpu/gl_context.h"
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
namespace mediapipe { namespace mediapipe {
namespace { namespace {
@ -507,5 +511,79 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeAndUseInputImageDims) {
/*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt, /*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt,
/*keep_aspect=*/false, BorderMode::kZero, roi); /*keep_aspect=*/false, BorderMode::kZero, roi);
} }
TEST(ImageToTensorCalculatorTest, CanBeUsedWithoutGpuServiceSet) {
auto graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_image"
node {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.ImageToTensorCalculatorOptions.ext] {
output_tensor_float_range { min: 0.0f max: 1.0f }
}
}
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
MP_ASSERT_OK(graph.StartRun({}));
auto image_frame =
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
Image image = Image(std::move(image_frame));
Packet packet = MakePacket<Image>(std::move(image));
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
MP_ASSERT_OK(graph.WaitUntilIdle());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
TEST(ImageToTensorCalculatorTest,
FailsGracefullyWhenGpuServiceNeededButNotAvailable) {
auto graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_image"
node {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.ImageToTensorCalculatorOptions.ext] {
output_tensor_float_range { min: 0.0f max: 1.0f }
}
}
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK_AND_ASSIGN(auto context,
GlContext::Create(nullptr, /*create_thread=*/true));
Packet packet;
context->Run([&packet]() {
auto image_frame =
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
Image image = Image(std::move(image_frame));
// Ensure image is available on GPU to force ImageToTensorCalculator to
// run on GPU.
ASSERT_TRUE(image.ConvertToGpu());
packet = MakePacket<Image>(std::move(image));
});
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
EXPECT_THAT(graph.WaitUntilIdle(),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("GPU service not available")));
}
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

View File

@ -138,7 +138,23 @@ void TestWithAspectRatio(const double aspect_ratio,
std::string result_image; std::string result_image;
MP_ASSERT_OK( MP_ASSERT_OK(
mediapipe::file::GetContents(result_string_path, &result_image)); mediapipe::file::GetContents(result_string_path, &result_image));
EXPECT_EQ(result_image, output_string); if (result_image != output_string) {
// There may be slight differences due to the way the JPEG was encoded or
// the OpenCV version used to generate the reference files. Compare
// pixel-by-pixel using the Peak Signal-to-Noise Ratio instead.
cv::Mat result_mat =
cv::imdecode(cv::Mat(1, result_image.size(), CV_8UC1,
const_cast<char*>(result_image.data())),
cv::IMREAD_UNCHANGED);
cv::Mat output_mat =
cv::imdecode(cv::Mat(1, output_string.size(), CV_8UC1,
const_cast<char*>(output_string.data())),
cv::IMREAD_UNCHANGED);
ASSERT_EQ(result_mat.rows, output_mat.rows);
ASSERT_EQ(result_mat.cols, output_mat.cols);
ASSERT_EQ(result_mat.type(), output_mat.type());
EXPECT_GT(cv::PSNR(result_mat, output_mat), 45.0);
}
} else { } else {
std::string output_string_path = mediapipe::file::JoinPath( std::string output_string_path = mediapipe::file::JoinPath(
absl::GetFlag(FLAGS_output_folder), absl::GetFlag(FLAGS_output_folder),

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.1 KiB

After

Width:  |  Height:  |  Size: 6.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.2 KiB

After

Width:  |  Height:  |  Size: 8.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.6 KiB

After

Width:  |  Height:  |  Size: 7.6 KiB

View File

@ -19,9 +19,6 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
cc_library( cc_library(
name = "face_geometry_from_landmarks_graph", name = "face_geometry_from_landmarks_graph",
srcs = ["face_geometry_from_landmarks_graph.cc"], srcs = ["face_geometry_from_landmarks_graph.cc"],
data = [
"//mediapipe/tasks/cc/vision/face_geometry/data:geometry_pipeline_metadata_landmarks",
],
deps = [ deps = [
"//mediapipe/calculators/core:begin_loop_calculator", "//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:end_loop_calculator", "//mediapipe/calculators/core:end_loop_calculator",
@ -39,6 +36,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_cc_proto",
"//mediapipe/util:graph_builder_utils", "//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],

View File

@ -45,6 +45,7 @@ mediapipe_proto_library(
srcs = ["geometry_pipeline_calculator.proto"], srcs = ["geometry_pipeline_calculator.proto"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/tasks/cc/core/proto:external_file_proto",
], ],
) )
@ -59,6 +60,8 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"//mediapipe/tasks/cc/core:external_file_handler",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline", "//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline",
"//mediapipe/tasks/cc/vision/face_geometry/libs:validation_utils", "//mediapipe/tasks/cc/vision/face_geometry/libs:validation_utils",
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",

View File

@ -24,6 +24,8 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "mediapipe/tasks/cc/core/external_file_handler.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.h" #include "mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.h"
#include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h" #include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h"
@ -69,8 +71,8 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
// A vector of face geometry data. // A vector of face geometry data.
// //
// Options: // Options:
// metadata_path (`string`, optional): // metadata_file (`ExternalFile`, optional):
// Defines a path for the geometry pipeline metadata file. // Defines an ExternalFile for the geometry pipeline metadata file.
// //
// The geometry pipeline metadata file format must be the binary // The geometry pipeline metadata file format must be the binary
// `GeometryPipelineMetadata` proto. // `GeometryPipelineMetadata` proto.
@ -95,7 +97,7 @@ class GeometryPipelineCalculator : public CalculatorBase {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
GeometryPipelineMetadata metadata, GeometryPipelineMetadata metadata,
ReadMetadataFromFile(options.metadata_path()), ReadMetadataFromFile(options.metadata_file()),
_ << "Failed to read the geometry pipeline metadata from file!"); _ << "Failed to read the geometry pipeline metadata from file!");
MP_RETURN_IF_ERROR(ValidateGeometryPipelineMetadata(metadata)) MP_RETURN_IF_ERROR(ValidateGeometryPipelineMetadata(metadata))
@ -155,32 +157,19 @@ class GeometryPipelineCalculator : public CalculatorBase {
private: private:
static absl::StatusOr<GeometryPipelineMetadata> ReadMetadataFromFile( static absl::StatusOr<GeometryPipelineMetadata> ReadMetadataFromFile(
const std::string& metadata_path) { const core::proto::ExternalFile& metadata_file) {
ASSIGN_OR_RETURN(std::string metadata_blob, ASSIGN_OR_RETURN(
ReadContentBlobFromFile(metadata_path), const auto file_handler,
_ << "Failed to read a metadata blob from file!"); core::ExternalFileHandler::CreateFromExternalFile(&metadata_file));
GeometryPipelineMetadata metadata; GeometryPipelineMetadata metadata;
RET_CHECK(metadata.ParseFromString(metadata_blob)) RET_CHECK(
metadata.ParseFromString(std::string(file_handler->GetFileContent())))
<< "Failed to parse a metadata proto from a binary blob!"; << "Failed to parse a metadata proto from a binary blob!";
return metadata; return metadata;
} }
static absl::StatusOr<std::string> ReadContentBlobFromFile(
const std::string& unresolved_path) {
ASSIGN_OR_RETURN(std::string resolved_path,
mediapipe::PathToResourceAsFile(unresolved_path),
_ << "Failed to resolve path! Path = " << unresolved_path);
std::string content_blob;
MP_RETURN_IF_ERROR(
mediapipe::GetResourceContents(resolved_path, &content_blob))
<< "Failed to read content blob! Resolved path = " << resolved_path;
return content_blob;
}
std::unique_ptr<GeometryPipeline> geometry_pipeline_; std::unique_ptr<GeometryPipeline> geometry_pipeline_;
}; };

View File

@ -17,11 +17,12 @@ syntax = "proto2";
package mediapipe.tasks.vision.face_geometry; package mediapipe.tasks.vision.face_geometry;
import "mediapipe/framework/calculator_options.proto"; import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/external_file.proto";
message FaceGeometryPipelineCalculatorOptions { message FaceGeometryPipelineCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional FaceGeometryPipelineCalculatorOptions ext = 512499200; optional FaceGeometryPipelineCalculatorOptions ext = 512499200;
} }
optional string metadata_path = 1; optional core.proto.ExternalFile metadata_file = 1;
} }

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.pb.h"
#include "mediapipe/util/graph_builder_utils.h" #include "mediapipe/util/graph_builder_utils.h"
namespace mediapipe::tasks::vision::face_geometry { namespace mediapipe::tasks::vision::face_geometry {
@ -49,10 +50,6 @@ constexpr char kIterableTag[] = "ITERABLE";
constexpr char kBatchEndTag[] = "BATCH_END"; constexpr char kBatchEndTag[] = "BATCH_END";
constexpr char kItemTag[] = "ITEM"; constexpr char kItemTag[] = "ITEM";
constexpr char kGeometryPipelineMetadataPath[] =
"mediapipe/tasks/cc/vision/face_geometry/data/"
"geometry_pipeline_metadata_landmarks.binarypb";
struct FaceGeometryOuts { struct FaceGeometryOuts {
Stream<std::vector<FaceGeometry>> multi_face_geometry; Stream<std::vector<FaceGeometry>> multi_face_geometry;
}; };
@ -127,6 +124,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
} }
ASSIGN_OR_RETURN(auto outs, ASSIGN_OR_RETURN(auto outs,
BuildFaceGeometryFromLandmarksGraph( BuildFaceGeometryFromLandmarksGraph(
*sc->MutableOptions<proto::FaceGeometryGraphOptions>(),
graph.In(kFaceLandmarksTag) graph.In(kFaceLandmarksTag)
.Cast<std::vector<NormalizedLandmarkList>>(), .Cast<std::vector<NormalizedLandmarkList>>(),
graph.In(kImageSizeTag).Cast<std::pair<int, int>>(), graph.In(kImageSizeTag).Cast<std::pair<int, int>>(),
@ -138,6 +136,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
private: private:
absl::StatusOr<FaceGeometryOuts> BuildFaceGeometryFromLandmarksGraph( absl::StatusOr<FaceGeometryOuts> BuildFaceGeometryFromLandmarksGraph(
proto::FaceGeometryGraphOptions& graph_options,
Stream<std::vector<NormalizedLandmarkList>> multi_face_landmarks, Stream<std::vector<NormalizedLandmarkList>> multi_face_landmarks,
Stream<std::pair<int, int>> image_size, Stream<std::pair<int, int>> image_size,
std::optional<SidePacket<Environment>> environment, Graph& graph) { std::optional<SidePacket<Environment>> environment, Graph& graph) {
@ -185,7 +184,8 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
"mediapipe.tasks.vision.face_geometry.FaceGeometryPipelineCalculator"); "mediapipe.tasks.vision.face_geometry.FaceGeometryPipelineCalculator");
auto& geometry_pipeline_options = auto& geometry_pipeline_options =
geometry_pipeline.GetOptions<FaceGeometryPipelineCalculatorOptions>(); geometry_pipeline.GetOptions<FaceGeometryPipelineCalculatorOptions>();
geometry_pipeline_options.set_metadata_path(kGeometryPipelineMetadataPath); geometry_pipeline_options.Swap(
graph_options.mutable_geometry_pipeline_options());
image_size >> geometry_pipeline.In(kImageSizeTag); image_size >> geometry_pipeline.In(kImageSizeTag);
multi_face_landmarks_no_iris >> multi_face_landmarks_no_iris >>
geometry_pipeline.In(kMultiFaceLandmarksTag); geometry_pipeline.In(kMultiFaceLandmarksTag);

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
@ -31,6 +32,7 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/tool/sink.h" #include "mediapipe/framework/tool/sink.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
@ -49,6 +51,9 @@ constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kFaceLandmarksFileName[] = constexpr char kFaceLandmarksFileName[] =
"face_blendshapes_in_landmarks.prototxt"; "face_blendshapes_in_landmarks.prototxt";
constexpr char kFaceGeometryFileName[] = "face_geometry_expected_out.pbtxt"; constexpr char kFaceGeometryFileName[] = "face_geometry_expected_out.pbtxt";
constexpr char kGeometryPipelineMetadataPath[] =
"mediapipe/tasks/cc/vision/face_geometry/data/"
"geometry_pipeline_metadata_landmarks.binarypb";
std::vector<NormalizedLandmarkList> GetLandmarks(absl::string_view filename) { std::vector<NormalizedLandmarkList> GetLandmarks(absl::string_view filename) {
NormalizedLandmarkList landmarks; NormalizedLandmarkList landmarks;
@ -89,7 +94,8 @@ void MakeInputPacketsAndRunGraph(CalculatorGraph& graph) {
TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) { TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
CalculatorGraphConfig graph_config = ParseTextProtoOrDie< CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
CalculatorGraphConfig>(R"pb( CalculatorGraphConfig>(absl::Substitute(
R"pb(
input_stream: "FACE_LANDMARKS:face_landmarks" input_stream: "FACE_LANDMARKS:face_landmarks"
input_stream: "IMAGE_SIZE:image_size" input_stream: "IMAGE_SIZE:image_size"
output_stream: "FACE_GEOMETRY:face_geometry" output_stream: "FACE_GEOMETRY:face_geometry"
@ -98,8 +104,15 @@ TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
input_stream: "FACE_LANDMARKS:face_landmarks" input_stream: "FACE_LANDMARKS:face_landmarks"
input_stream: "IMAGE_SIZE:image_size" input_stream: "IMAGE_SIZE:image_size"
output_stream: "FACE_GEOMETRY:face_geometry" output_stream: "FACE_GEOMETRY:face_geometry"
options: {
[mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions
.ext] {
geometry_pipeline_options { metadata_file { file_name: "$0" } }
} }
)pb"); }
}
)pb",
kGeometryPipelineMetadataPath));
std::vector<Packet> output_packets; std::vector<Packet> output_packets;
tool::AddVectorSink("face_geometry", &graph_config, &output_packets); tool::AddVectorSink("face_geometry", &graph_config, &output_packets);
@ -116,7 +129,8 @@ TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) { TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) {
CalculatorGraphConfig graph_config = ParseTextProtoOrDie< CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
CalculatorGraphConfig>(R"pb( CalculatorGraphConfig>(absl::Substitute(
R"pb(
input_stream: "FACE_LANDMARKS:face_landmarks" input_stream: "FACE_LANDMARKS:face_landmarks"
input_stream: "IMAGE_SIZE:image_size" input_stream: "IMAGE_SIZE:image_size"
input_side_packet: "ENVIRONMENT:environment" input_side_packet: "ENVIRONMENT:environment"
@ -127,8 +141,15 @@ TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) {
input_stream: "IMAGE_SIZE:image_size" input_stream: "IMAGE_SIZE:image_size"
input_side_packet: "ENVIRONMENT:environment" input_side_packet: "ENVIRONMENT:environment"
output_stream: "FACE_GEOMETRY:face_geometry" output_stream: "FACE_GEOMETRY:face_geometry"
options: {
[mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions
.ext] {
geometry_pipeline_options { metadata_file { file_name: "$0" } }
} }
)pb"); }
}
)pb",
kGeometryPipelineMetadataPath));
std::vector<Packet> output_packets; std::vector<Packet> output_packets;
tool::AddVectorSink("face_geometry", &graph_config, &output_packets); tool::AddVectorSink("face_geometry", &graph_config, &output_packets);

View File

@ -44,3 +44,12 @@ mediapipe_proto_library(
name = "mesh_3d_proto", name = "mesh_3d_proto",
srcs = ["mesh_3d.proto"], srcs = ["mesh_3d.proto"],
) )
mediapipe_proto_library(
name = "face_geometry_graph_options_proto",
srcs = ["face_geometry_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_proto",
],
)

View File

@ -0,0 +1,28 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe.tasks.vision.face_geometry.proto;
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.proto";
message FaceGeometryGraphOptions {
extend mediapipe.CalculatorOptions {
optional FaceGeometryGraphOptions ext = 515723506;
}
optional FaceGeometryPipelineCalculatorOptions geometry_pipeline_options = 1;
}

View File

@ -210,8 +210,10 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph", "//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry:face_geometry_from_landmarks_graph", "//mediapipe/tasks/cc/vision/face_geometry:face_geometry_from_landmarks_graph",
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",

View File

@ -40,8 +40,10 @@ limitations under the License.
#include "mediapipe/tasks/cc/core/utils.h" #include "mediapipe/tasks/cc/core/utils.h"
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h" #include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h" #include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
@ -93,6 +95,8 @@ constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite";
constexpr char kFaceLandmarksDetectorTFLiteName[] = constexpr char kFaceLandmarksDetectorTFLiteName[] =
"face_landmarks_detector.tflite"; "face_landmarks_detector.tflite";
constexpr char kFaceBlendshapeTFLiteName[] = "face_blendshapes.tflite"; constexpr char kFaceBlendshapeTFLiteName[] = "face_blendshapes.tflite";
constexpr char kFaceGeometryPipelineMetadataName[] =
"geometry_pipeline_metadata_landmarks.binarypb";
struct FaceLandmarkerOutputs { struct FaceLandmarkerOutputs {
Source<std::vector<NormalizedLandmarkList>> landmark_lists; Source<std::vector<NormalizedLandmarkList>> landmark_lists;
@ -305,6 +309,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
absl::StatusOr<CalculatorGraphConfig> GetConfig( absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override { SubgraphContext* sc) override {
Graph graph; Graph graph;
bool output_geometry = HasOutput(sc->OriginalNode(), kFaceGeometryTag);
if (sc->Options<FaceLandmarkerGraphOptions>() if (sc->Options<FaceLandmarkerGraphOptions>()
.base_options() .base_options()
.has_model_asset()) { .has_model_asset()) {
@ -318,6 +323,18 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
sc->MutableOptions<FaceLandmarkerGraphOptions>(), sc->MutableOptions<FaceLandmarkerGraphOptions>(),
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService) !sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
.IsAvailable())); .IsAvailable()));
if (output_geometry) {
// Set the face geometry metdata file for
// FaceGeometryFromLandmarksGraph.
ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file,
model_asset_bundle_resources->GetModelFile(
kFaceGeometryPipelineMetadataName));
SetExternalFile(face_geometry_pipeline_metadata_file,
sc->MutableOptions<FaceLandmarkerGraphOptions>()
->mutable_face_geometry_graph_options()
->mutable_geometry_pipeline_options()
->mutable_metadata_file());
}
} }
std::optional<SidePacket<Environment>> environment; std::optional<SidePacket<Environment>> environment;
if (HasSideInput(sc->OriginalNode(), kEnvironmentTag)) { if (HasSideInput(sc->OriginalNode(), kEnvironmentTag)) {
@ -338,7 +355,6 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
.face_landmarks_detector_graph_options() .face_landmarks_detector_graph_options()
.has_face_blendshapes_graph_options())); .has_face_blendshapes_graph_options()));
} }
bool output_geometry = HasOutput(sc->OriginalNode(), kFaceGeometryTag);
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
auto outs, auto outs,
BuildFaceLandmarkerGraph( BuildFaceLandmarkerGraph(
@ -481,6 +497,9 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
auto& face_geometry_from_landmarks = graph.AddNode( auto& face_geometry_from_landmarks = graph.AddNode(
"mediapipe.tasks.vision.face_geometry." "mediapipe.tasks.vision.face_geometry."
"FaceGeometryFromLandmarksGraph"); "FaceGeometryFromLandmarksGraph");
face_geometry_from_landmarks
.GetOptions<face_geometry::proto::FaceGeometryGraphOptions>()
.Swap(tasks_options.mutable_face_geometry_graph_options());
if (environment.has_value()) { if (environment.has_value()) {
*environment >> face_geometry_from_landmarks.SideIn(kEnvironmentTag); *environment >> face_geometry_from_landmarks.SideIn(kEnvironmentTag);
} }

View File

@ -60,5 +60,6 @@ mediapipe_proto_library(
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/tasks/cc/core/proto:base_options_proto", "//mediapipe/tasks/cc/core/proto:base_options_proto",
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_proto", "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_proto",
], ],
) )

View File

@ -21,6 +21,7 @@ import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto"; import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/base_options.proto"; import "mediapipe/tasks/cc/core/proto/base_options.proto";
import "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto"; import "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto";
import "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.proto";
import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto"; import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto";
option java_package = "com.google.mediapipe.tasks.vision.facelandmarker.proto"; option java_package = "com.google.mediapipe.tasks.vision.facelandmarker.proto";
@ -45,4 +46,8 @@ message FaceLandmarkerGraphOptions {
// Minimum confidence for face landmarks tracking to be considered // Minimum confidence for face landmarks tracking to be considered
// successfully. // successfully.
optional float min_tracking_confidence = 4 [default = 0.5]; optional float min_tracking_confidence = 4 [default = 0.5];
// Options for FaceGeometryGraph to get facial transformation matrix.
optional face_geometry.proto.FaceGeometryGraphOptions
face_geometry_graph_options = 5;
} }

View File

@ -47,6 +47,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto", "//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],
alwayslink = 1,
) )
cc_library( cc_library(

View File

@ -294,7 +294,7 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) {
threadsPerThreadgroup:threads_per_group]; threadsPerThreadgroup:threads_per_group];
[compute_encoder endEncoding]; [compute_encoder endEncoding];
[command_buffer commit]; [command_buffer commit];
[command_buffer waitUntilCompleted];
kOutputImage(cc).Send(Image(output)); kOutputImage(cc).Send(Image(output));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -36,6 +36,7 @@ cc_library(
":hand_association_calculator_cc_proto", ":hand_association_calculator_cc_proto",
"//mediapipe/calculators/util:association_calculator", "//mediapipe/calculators/util:association_calculator",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:rectangle", "//mediapipe/framework/port:rectangle",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/collection_item_id.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/rectangle.h" #include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
@ -29,30 +30,55 @@ namespace mediapipe::api2 {
using ::mediapipe::NormalizedRect; using ::mediapipe::NormalizedRect;
// HandAssociationCalculator accepts multiple inputs of vectors of // Input:
// NormalizedRect. The output is a vector of NormalizedRect that contains // BASE_RECTS - Vector of NormalizedRect.
// rects from the input vectors that don't overlap with each other. When two // RECTS - Vector of NormalizedRect.
// rects overlap, the rect that comes in from an earlier input stream is //
// kept in the output. If a rect has no ID (i.e. from detection stream), // Output:
// then a unique rect ID is assigned for it. // No tag - Vector of NormalizedRect.
//
// The rects in multiple input streams are effectively flattened to a single // Example use:
// list. For example: // node {
// Stream1 : rect 1, rect 2 // calculator: "HandAssociationCalculator"
// Stream2: rect 3, rect 4 // input_stream: "BASE_RECTS:base_rects"
// Stream3: rect 5, rect 6 // input_stream: "RECTS:0:rects0"
// (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6 // input_stream: "RECTS:1:rects1"
// In the flattened list, if a rect with a higher index overlaps with a rect a // input_stream: "RECTS:2:rects2"
// lower index, beyond a specified IOU threshold, the rect with the lower // output_stream: "output_rects"
// index will be in the output, and the rect with higher index will be // options {
// discarded. // [mediapipe.HandAssociationCalculatorOptions.ext] {
// min_similarity_threshold: 0.1
// }
// }
//
// IMPORTANT Notes:
// - Rects from input streams tagged with "BASE_RECTS" are always preserved.
// - This calculator checks for overlap among rects from input streams tagged
// with "RECTS". Rects are prioritized based on their index in the vector and
// input streams to the calculator. When two rects overlap, the rect that
// comes from an input stream with lower tag-index is kept in the output.
// - Example of inputs for the node above:
// "base_rects": rect 0, rect 1
// "rects0": rect 2, rect 3
// "rects1": rect 4, rect 5
// "rects2": rect 6, rect 7
// (Conceptually) flattened list: 0, 1, 2, 3, 4, 5, 6, 7.
// Rects 0, 1 will be preserved. Rects 2, 3, 4, 5, 6, 7 will be checked for
// overlap. If a rect with a higher index overlaps with a rect with lower
// index, beyond a specified IOU threshold, the rect with the lower index
// will be in the output, and the rect with higher index will be discarded.
// TODO: Upgrade this to latest API for calculators // TODO: Upgrade this to latest API for calculators
class HandAssociationCalculator : public CalculatorBase { class HandAssociationCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
// Initialize input and output streams. // Initialize input and output streams.
for (auto& input_stream : cc->Inputs()) { for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS");
input_stream.Set<std::vector<NormalizedRect>>(); id != cc->Inputs().EndId("BASE_RECTS"); ++id) {
cc->Inputs().Get(id).Set<std::vector<NormalizedRect>>();
}
for (CollectionItemId id = cc->Inputs().BeginId("RECTS");
id != cc->Inputs().EndId("RECTS"); ++id) {
cc->Inputs().Get(id).Set<std::vector<NormalizedRect>>();
} }
cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>(); cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>();
@ -89,7 +115,24 @@ class HandAssociationCalculator : public CalculatorBase {
CalculatorContext* cc) { CalculatorContext* cc) {
std::vector<NormalizedRect> result; std::vector<NormalizedRect> result;
for (const auto& input_stream : cc->Inputs()) { for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS");
id != cc->Inputs().EndId("BASE_RECTS"); ++id) {
const auto& input_stream = cc->Inputs().Get(id);
if (input_stream.IsEmpty()) {
continue;
}
for (auto rect : input_stream.Get<std::vector<NormalizedRect>>()) {
if (!rect.has_rect_id()) {
rect.set_rect_id(GetNextRectId());
}
result.push_back(rect);
}
}
for (CollectionItemId id = cc->Inputs().BeginId("RECTS");
id != cc->Inputs().EndId("RECTS"); ++id) {
const auto& input_stream = cc->Inputs().Get(id);
if (input_stream.IsEmpty()) { if (input_stream.IsEmpty()) {
continue; continue;
} }

View File

@ -27,6 +27,8 @@ namespace mediapipe {
namespace { namespace {
using ::mediapipe::NormalizedRect; using ::mediapipe::NormalizedRect;
using ::testing::ElementsAre;
using ::testing::EqualsProto;
class HandAssociationCalculatorTest : public testing::Test { class HandAssociationCalculatorTest : public testing::Test {
protected: protected:
@ -87,9 +89,9 @@ class HandAssociationCalculatorTest : public testing::Test {
TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) { TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator" calculator: "HandAssociationCalculator"
input_stream: "input_vec_0" input_stream: "BASE_RECTS:input_vec_0"
input_stream: "input_vec_1" input_stream: "RECTS:0:input_vec_1"
input_stream: "input_vec_2" input_stream: "RECTS:1:input_vec_2"
output_stream: "output_vec" output_stream: "output_vec"
options { options {
[mediapipe.HandAssociationCalculatorOptions.ext] { [mediapipe.HandAssociationCalculatorOptions.ext] {
@ -103,20 +105,23 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
input_vec_0->push_back(nr_0_); input_vec_0->push_back(nr_0_);
input_vec_0->push_back(nr_1_); input_vec_0->push_back(nr_1_);
input_vec_0->push_back(nr_2_); input_vec_0->push_back(nr_2_);
runner.MutableInputs()->Index(0).packets.push_back( runner.MutableInputs()
Adopt(input_vec_0.release()).At(Timestamp(1))); ->Tag("BASE_RECTS")
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_3, nr_4. // Input Stream 1: nr_3, nr_4.
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>(); auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_3_); input_vec_1->push_back(nr_3_);
input_vec_1->push_back(nr_4_); input_vec_1->push_back(nr_4_);
runner.MutableInputs()->Index(1).packets.push_back( auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1))); Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_5. // Input Stream 2: nr_5.
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>(); auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_2->push_back(nr_5_); input_vec_2->push_back(nr_5_);
runner.MutableInputs()->Index(2).packets.push_back( index_id = runner.MutableInputs()->GetId("RECTS", 1);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1))); Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
@ -134,25 +139,18 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
EXPECT_EQ(3, assoc_rects.size()); EXPECT_EQ(3, assoc_rects.size());
// Check that IDs are filled in and contents match. // Check that IDs are filled in and contents match.
EXPECT_EQ(assoc_rects[0].rect_id(), 1); nr_0_.set_rect_id(1);
assoc_rects[0].clear_rect_id(); nr_1_.set_rect_id(2);
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); nr_2_.set_rect_id(3);
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_),
EXPECT_EQ(assoc_rects[1].rect_id(), 2); EqualsProto(nr_2_)));
assoc_rects[1].clear_rect_id();
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
assoc_rects[2].clear_rect_id();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
} }
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) { TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator" calculator: "HandAssociationCalculator"
input_stream: "input_vec_0" input_stream: "BASE_RECTS:input_vec_0"
input_stream: "input_vec_1" input_stream: "RECTS:0:input_vec_1"
input_stream: "input_vec_2"
output_stream: "output_vec" output_stream: "output_vec"
options { options {
[mediapipe.HandAssociationCalculatorOptions.ext] { [mediapipe.HandAssociationCalculatorOptions.ext] {
@ -169,14 +167,15 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
input_vec_0->push_back(nr_0_); input_vec_0->push_back(nr_0_);
nr_1_.set_rect_id(-1); nr_1_.set_rect_id(-1);
input_vec_0->push_back(nr_1_); input_vec_0->push_back(nr_1_);
runner.MutableInputs()->Index(0).packets.push_back( runner.MutableInputs()
Adopt(input_vec_0.release()).At(Timestamp(1))); ->Tag("BASE_RECTS")
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_2, nr_3. Newly detected palms. // Input Stream 1: nr_2, nr_3. Newly detected palms.
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>(); auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_2_); input_vec_1->push_back(nr_2_);
input_vec_1->push_back(nr_3_); input_vec_1->push_back(nr_3_);
runner.MutableInputs()->Index(1).packets.push_back( runner.MutableInputs()->Tag("RECTS").packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1))); Adopt(input_vec_1.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
@ -192,23 +191,17 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
EXPECT_EQ(3, assoc_rects.size()); EXPECT_EQ(3, assoc_rects.size());
// Check that IDs are filled in and contents match. // Check that IDs are filled in and contents match.
EXPECT_EQ(assoc_rects[0].rect_id(), -2); nr_2_.set_rect_id(1);
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_)); EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_),
EqualsProto(nr_2_)));
EXPECT_EQ(assoc_rects[1].rect_id(), -1);
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
EXPECT_EQ(assoc_rects[2].rect_id(), 1);
assoc_rects[2].clear_rect_id();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
} }
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) { TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator" calculator: "HandAssociationCalculator"
input_stream: "input_vec_0" input_stream: "BASE_RECTS:input_vec_0"
input_stream: "input_vec_1" input_stream: "RECTS:0:input_vec_1"
input_stream: "input_vec_2" input_stream: "RECTS:1:input_vec_2"
output_stream: "output_vec" output_stream: "output_vec"
options { options {
[mediapipe.HandAssociationCalculatorOptions.ext] { [mediapipe.HandAssociationCalculatorOptions.ext] {
@ -220,14 +213,16 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
// Input Stream 0: nr_5. // Input Stream 0: nr_5.
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>(); auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_0->push_back(nr_5_); input_vec_0->push_back(nr_5_);
runner.MutableInputs()->Index(0).packets.push_back( runner.MutableInputs()
Adopt(input_vec_0.release()).At(Timestamp(1))); ->Tag("BASE_RECTS")
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_4, nr_3 // Input Stream 1: nr_4, nr_3
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>(); auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_4_); input_vec_1->push_back(nr_4_);
input_vec_1->push_back(nr_3_); input_vec_1->push_back(nr_3_);
runner.MutableInputs()->Index(1).packets.push_back( auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1))); Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_2, nr_1, nr_0. // Input Stream 2: nr_2, nr_1, nr_0.
@ -235,7 +230,8 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
input_vec_2->push_back(nr_2_); input_vec_2->push_back(nr_2_);
input_vec_2->push_back(nr_1_); input_vec_2->push_back(nr_1_);
input_vec_2->push_back(nr_0_); input_vec_2->push_back(nr_0_);
runner.MutableInputs()->Index(2).packets.push_back( index_id = runner.MutableInputs()->GetId("RECTS", 1);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1))); Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
@ -253,23 +249,78 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
EXPECT_EQ(3, assoc_rects.size()); EXPECT_EQ(3, assoc_rects.size());
// Outputs are in same order as inputs, and IDs are filled in. // Outputs are in same order as inputs, and IDs are filled in.
EXPECT_EQ(assoc_rects[0].rect_id(), 1); nr_5_.set_rect_id(1);
assoc_rects[0].clear_rect_id(); nr_4_.set_rect_id(2);
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_)); nr_0_.set_rect_id(3);
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_4_),
EqualsProto(nr_0_)));
}
EXPECT_EQ(assoc_rects[1].rect_id(), 2); TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReservesBaseRects) {
assoc_rects[1].clear_rect_id(); CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_)); calculator: "HandAssociationCalculator"
input_stream: "BASE_RECTS:input_vec_0"
input_stream: "RECTS:0:input_vec_1"
input_stream: "RECTS:1:input_vec_2"
output_stream: "output_vec"
options {
[mediapipe.HandAssociationCalculatorOptions.ext] {
min_similarity_threshold: 0.1
}
}
)pb"));
EXPECT_EQ(assoc_rects[2].rect_id(), 3); // Input Stream 0: nr_5, nr_3, nr_1.
assoc_rects[2].clear_rect_id(); auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_)); input_vec_0->push_back(nr_5_);
input_vec_0->push_back(nr_3_);
input_vec_0->push_back(nr_1_);
runner.MutableInputs()
->Tag("BASE_RECTS")
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
// Input Stream 1: nr_4.
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_1->push_back(nr_4_);
auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_1.release()).At(Timestamp(1)));
// Input Stream 2: nr_2, nr_0.
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
input_vec_2->push_back(nr_2_);
input_vec_2->push_back(nr_0_);
index_id = runner.MutableInputs()->GetId("RECTS", 1);
runner.MutableInputs()->Get(index_id).packets.push_back(
Adopt(input_vec_2.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
EXPECT_EQ(1, output.size());
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
// Rectangles are added in the following sequence:
// nr_5 is added because it is in BASE_RECTS input stream.
// nr_3 is added because it is in BASE_RECTS input stream.
// nr_1 is added because it is in BASE_RECTS input stream.
// nr_4 is added because it does not overlap with nr_5.
// nr_2 is NOT added because it overlaps with nr_4.
// nr_0 is NOT added because it overlaps with nr_3.
EXPECT_EQ(4, assoc_rects.size());
// Outputs are in same order as inputs, and IDs are filled in.
nr_5_.set_rect_id(1);
nr_3_.set_rect_id(2);
nr_1_.set_rect_id(3);
nr_4_.set_rect_id(4);
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_3_),
EqualsProto(nr_1_), EqualsProto(nr_4_)));
} }
TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) { TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb( CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "HandAssociationCalculator" calculator: "HandAssociationCalculator"
input_stream: "input_vec" input_stream: "BASE_RECTS:input_vec"
output_stream: "output_vec" output_stream: "output_vec"
options { options {
[mediapipe.HandAssociationCalculatorOptions.ext] { [mediapipe.HandAssociationCalculatorOptions.ext] {
@ -282,8 +333,9 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
auto input_vec = std::make_unique<std::vector<NormalizedRect>>(); auto input_vec = std::make_unique<std::vector<NormalizedRect>>();
input_vec->push_back(nr_3_); input_vec->push_back(nr_3_);
input_vec->push_back(nr_5_); input_vec->push_back(nr_5_);
runner.MutableInputs()->Index(0).packets.push_back( runner.MutableInputs()
Adopt(input_vec.release()).At(Timestamp(1))); ->Tag("BASE_RECTS")
.packets.push_back(Adopt(input_vec.release()).At(Timestamp(1)));
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed."; MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
const std::vector<Packet>& output = runner.Outputs().Index(0).packets; const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
@ -292,12 +344,12 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
// Rectangles are added in the following sequence: // Rectangles are added in the following sequence:
// nr_3 is added 1st. // nr_3 is added 1st.
// nr_5 is NOT added because it overlaps with nr_3. // nr_5 is added 2nd. The calculator assumes it does not overlap with nr_3.
EXPECT_EQ(1, assoc_rects.size()); EXPECT_EQ(2, assoc_rects.size());
EXPECT_EQ(assoc_rects[0].rect_id(), 1); nr_3_.set_rect_id(1);
assoc_rects[0].clear_rect_id(); nr_5_.set_rect_id(2);
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_)); EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_3_), EqualsProto(nr_5_)));
} }
} // namespace } // namespace

View File

@ -318,9 +318,9 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
.set_min_similarity_threshold( .set_min_similarity_threshold(
tasks_options.min_tracking_confidence()); tasks_options.min_tracking_confidence());
prev_hand_rects_from_landmarks >> prev_hand_rects_from_landmarks >>
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][0]; hand_association[Input<std::vector<NormalizedRect>>("BASE_RECTS")];
hand_rects_from_hand_detector >> hand_rects_from_hand_detector >>
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][1]; hand_association[Input<std::vector<NormalizedRect>>("RECTS")];
auto hand_rects = hand_association.Out(""); auto hand_rects = hand_association.Out("");
hand_rects >> clip_hand_rects.In(""); hand_rects >> clip_hand_rects.In("");
} else { } else {

View File

@ -34,18 +34,19 @@ _AUDIO_TASKS_JAVA_PROTO_LITE_TARGETS = [
] ]
_VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [ _VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite", "//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite", "//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
] ]
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [ _TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [

View File

@ -36,11 +36,9 @@ export abstract class AudioTaskRunner<T> extends TaskRunner {
/** Sends a single audio clip to the graph and awaits results. */ /** Sends a single audio clip to the graph and awaits results. */
protected processAudioClip(audioData: Float32Array, sampleRate?: number): T { protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
// Increment the timestamp by 1 millisecond to guarantee that we send
// monotonically increasing timestamps to the graph.
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
return this.process( return this.process(
audioData, sampleRate ?? this.defaultSampleRate, syntheticTimestamp); audioData, sampleRate ?? this.defaultSampleRate,
this.getSynctheticTimestamp());
} }
} }

View File

@ -15,6 +15,11 @@ mediapipe_ts_declaration(
deps = [":category"], deps = [":category"],
) )
mediapipe_ts_declaration(
name = "keypoint",
srcs = ["keypoint.d.ts"],
)
mediapipe_ts_declaration( mediapipe_ts_declaration(
name = "landmark", name = "landmark",
srcs = ["landmark.d.ts"], srcs = ["landmark.d.ts"],

View File

@ -0,0 +1,33 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* A keypoint, defined by the coordinates (x, y), normalized by the image
* dimensions.
*/
export declare interface NormalizedKeypoint {
/** X in normalized image coordinates. */
x: number;
/** Y in normalized image coordinates. */
y: number;
/** Optional label of the keypoint. */
label?: string;
/** Optional score of the keypoint. */
score?: number;
}

View File

@ -175,9 +175,13 @@ export abstract class TaskRunner {
Math.max(this.latestOutputTimestamp, timestamp); Math.max(this.latestOutputTimestamp, timestamp);
} }
/** Returns the latest output timestamp. */ /**
protected getLatestOutputTimestamp() { * Gets a syncthethic timestamp in ms that can be used to send data to the
return this.latestOutputTimestamp; * next packet. The timestamp is one millisecond past the last timestamp
* received from the graph.
*/
protected getSynctheticTimestamp(): number {
return this.latestOutputTimestamp + 1;
} }
/** Throws the error from the error listener if an error was raised. */ /** Throws the error from the error listener if an error was raised. */

View File

@ -131,11 +131,9 @@ export class TextClassifier extends TaskRunner {
* @return The classification result of the text * @return The classification result of the text
*/ */
classify(text: string): TextClassifierResult { classify(text: string): TextClassifierResult {
// Increment the timestamp by 1 millisecond to guarantee that we send
// monotonically increasing timestamps to the graph.
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
this.classificationResult = {classifications: []}; this.classificationResult = {classifications: []};
this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp); this.graphRunner.addStringToStream(
text, INPUT_STREAM, this.getSynctheticTimestamp());
this.finishProcessing(); this.finishProcessing();
return this.classificationResult; return this.classificationResult;
} }

View File

@ -135,10 +135,8 @@ export class TextEmbedder extends TaskRunner {
* @return The embedding resuls of the text * @return The embedding resuls of the text
*/ */
embed(text: string): TextEmbedderResult { embed(text: string): TextEmbedderResult {
// Increment the timestamp by 1 millisecond to guarantee that we send this.graphRunner.addStringToStream(
// monotonically increasing timestamps to the graph. text, INPUT_STREAM, this.getSynctheticTimestamp());
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp);
this.finishProcessing(); this.finishProcessing();
return this.embeddingResult; return this.embeddingResult;
} }

View File

@ -2,23 +2,42 @@
This package contains the vision tasks for MediaPipe. This package contains the vision tasks for MediaPipe.
## Object Detection ## Gesture Recognition
The MediaPipe Object Detector task lets you detect the presence and location of The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real
multiple classes of objects within images or videos. time, and provides the recognized hand gesture results along with the landmarks
of the detected hands. You can use this task to recognize specific hand gestures
from a user, and invoke application features that correspond to those gestures.
``` ```
const vision = await FilesetResolver.forVisionTasks( const vision = await FilesetResolver.forVisionTasks(
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
); );
const objectDetector = await ObjectDetector.createFromModelPath(vision, const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision,
"https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite" "https://storage.googleapis.com/mediapipe-tasks/gesture_recognizer/gesture_recognizer.task"
); );
const image = document.getElementById("image") as HTMLImageElement; const image = document.getElementById("image") as HTMLImageElement;
const detections = objectDetector.detect(image); const recognitions = gestureRecognizer.recognize(image);
``` ```
For more information, refer to the [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector/web_js) documentation. ## Hand Landmark Detection
The MediaPipe Hand Landmarker task lets you detect the landmarks of the hands in
an image. You can use this Task to localize key points of the hands and render
visual effects over the hands.
```
const vision = await FilesetResolver.forVisionTasks(
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
);
const handLandmarker = await HandLandmarker.createFromModelPath(vision,
"https://storage.googleapis.com/mediapipe-tasks/hand_landmarker/hand_landmarker.task"
);
const image = document.getElementById("image") as HTMLImageElement;
const landmarks = handLandmarker.detect(image);
```
For more information, refer to the [Handlandmark Detection](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker/web_js) documentation.
## Image Classification ## Image Classification
@ -56,40 +75,21 @@ imageSegmenter.segment(image, (masks, width, height) => {
}); });
``` ```
## Gesture Recognition ## Object Detection
The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real The MediaPipe Object Detector task lets you detect the presence and location of
time, and provides the recognized hand gesture results along with the landmarks multiple classes of objects within images or videos.
of the detected hands. You can use this task to recognize specific hand gestures
from a user, and invoke application features that correspond to those gestures.
``` ```
const vision = await FilesetResolver.forVisionTasks( const vision = await FilesetResolver.forVisionTasks(
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm" "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
); );
const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision, const objectDetector = await ObjectDetector.createFromModelPath(vision,
"https://storage.googleapis.com/mediapipe-tasks/gesture_recognizer/gesture_recognizer.task" "https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite"
); );
const image = document.getElementById("image") as HTMLImageElement; const image = document.getElementById("image") as HTMLImageElement;
const recognitions = gestureRecognizer.recognize(image); const detections = objectDetector.detect(image);
``` ```
## Handlandmark Detection For more information, refer to the [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector/web_js) documentation.
The MediaPipe Hand Landmarker task lets you detect the landmarks of the hands in
an image. You can use this Task to localize key points of the hands and render
visual effects over the hands.
```
const vision = await FilesetResolver.forVisionTasks(
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
);
const handLandmarker = await HandLandmarker.createFromModelPath(vision,
"https://storage.googleapis.com/mediapipe-tasks/hand_landmarker/hand_landmarker.task"
);
const image = document.getElementById("image") as HTMLImageElement;
const landmarks = handLandmarker.detect(image);
```
For more information, refer to the [Handlandmark Detection](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker/web_js) documentation.

View File

@ -21,6 +21,14 @@ mediapipe_ts_declaration(
], ],
) )
mediapipe_ts_declaration(
name = "types",
srcs = ["types.d.ts"],
deps = [
"//mediapipe/tasks/web/components/containers:keypoint",
],
)
mediapipe_ts_library( mediapipe_ts_library(
name = "vision_task_runner", name = "vision_task_runner",
srcs = ["vision_task_runner.ts"], srcs = ["vision_task_runner.ts"],
@ -51,6 +59,11 @@ mediapipe_ts_library(
], ],
) )
mediapipe_ts_library(
name = "render_utils",
srcs = ["render_utils.ts"],
)
jasmine_node_test( jasmine_node_test(
name = "vision_task_runner_test", name = "vision_task_runner_test",
deps = [":vision_task_runner_test_lib"], deps = [":vision_task_runner_test_lib"],

View File

@ -0,0 +1,78 @@
/** @fileoverview Utility functions used in the vision demos. */
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Pre-baked color table for a maximum of 12 classes.
const CM_ALPHA = 128;
const COLOR_MAP = [
[0, 0, 0, CM_ALPHA], // class 0 is BG = transparent
[255, 0, 0, CM_ALPHA], // class 1 is red
[0, 255, 0, CM_ALPHA], // class 2 is light green
[0, 0, 255, CM_ALPHA], // class 3 is blue
[255, 255, 0, CM_ALPHA], // class 4 is yellow
[255, 0, 255, CM_ALPHA], // class 5 is light purple / magenta
[0, 255, 255, CM_ALPHA], // class 6 is light blue / aqua
[128, 128, 128, CM_ALPHA], // class 7 is gray
[255, 128, 0, CM_ALPHA], // class 8 is orange
[128, 0, 255, CM_ALPHA], // class 9 is dark purple
[0, 128, 0, CM_ALPHA], // class 10 is dark green
[255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead?
];
/** Helper function to draw a confidence mask */
export function drawConfidenceMask(
ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array,
width: number, height: number): void {
const uint8ClampedArray = new Uint8ClampedArray(width * height * 4);
for (let i = 0; i < image.length; i++) {
uint8ClampedArray[4 * i] = 128;
uint8ClampedArray[4 * i + 1] = 0;
uint8ClampedArray[4 * i + 2] = 0;
uint8ClampedArray[4 * i + 3] = image[i] * 255;
}
ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
}
/**
* Helper function to draw a category mask. For GPU, we only have F32Arrays
* for now.
*/
export function drawCategoryMask(
ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array,
width: number, height: number): void {
const uint8ClampedArray = new Uint8ClampedArray(width * height * 4);
const isFloatArray = image instanceof Float32Array;
for (let i = 0; i < image.length; i++) {
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
const color = COLOR_MAP[colorIndex];
// When we're given a confidence mask by accident, we just log and return.
// TODO: We should fix this.
if (!color) {
console.warn('No color for ', colorIndex);
return;
}
uint8ClampedArray[4 * i] = color[0];
uint8ClampedArray[4 * i + 1] = color[1];
uint8ClampedArray[4 * i + 2] = color[2];
uint8ClampedArray[4 * i + 3] = color[3];
}
ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
}

View File

@ -0,0 +1,42 @@
/**
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint';
/**
* The segmentation tasks return the segmentation result as a Uint8Array
* (when the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
* output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
* for future usage.
*/
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
/**
* A callback that receives the computed masks from the segmentation tasks. The
* callback either receives a single element array with a category mask (as a
* `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
* The returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type SegmentationMaskCallback =
(masks: SegmentationMask[], width: number, height: number) => void;
/** A Region-Of-Interest (ROI) to represent a region within an image. */
export declare interface RegionOfInterest {
/** The ROI in keypoint format. */
keypoint: NormalizedKeypoint;
}

View File

@ -74,11 +74,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
'Task is not initialized with image mode. ' + 'Task is not initialized with image mode. ' +
'\'runningMode\' must be set to \'IMAGE\'.'); '\'runningMode\' must be set to \'IMAGE\'.');
} }
this.process(image, imageProcessingOptions, this.getSynctheticTimestamp());
// Increment the timestamp by 1 millisecond to guarantee that we send
// monotonically increasing timestamps to the graph.
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
this.process(image, imageProcessingOptions, syntheticTimestamp);
} }
/** Sends a single video frame to the graph and awaits results. */ /** Sends a single video frame to the graph and awaits results. */

View File

@ -19,8 +19,8 @@ mediapipe_ts_library(
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto", "//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
"//mediapipe/tasks/web/core", "//mediapipe/tasks/web/core",
"//mediapipe/tasks/web/vision/core:image_processing_options", "//mediapipe/tasks/web/vision/core:image_processing_options",
"//mediapipe/tasks/web/vision/core:types",
"//mediapipe/tasks/web/vision/core:vision_task_runner", "//mediapipe/tasks/web/vision/core:vision_task_runner",
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
"//mediapipe/web/graph_runner:graph_runner_ts", "//mediapipe/web/graph_runner:graph_runner_ts",
], ],
) )

View File

@ -21,6 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb'; import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset'; import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options'; import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner'; import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner'; import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
// Placeholder for internal dependency on trusted resource url // Placeholder for internal dependency on trusted resource url
@ -28,27 +29,9 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner
import {ImageSegmenterOptions} from './image_segmenter_options'; import {ImageSegmenterOptions} from './image_segmenter_options';
export * from './image_segmenter_options'; export * from './image_segmenter_options';
export {SegmentationMask, SegmentationMaskCallback};
export {ImageSource}; // Used in the public API export {ImageSource}; // Used in the public API
/**
* The ImageSegmenter returns the segmentation result as a Uint8Array (when
* the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
* output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
* for future usage.
*/
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
/**
* A callback that receives the computed masks from the image segmenter. The
* callback either receives a single element array with a category mask (as a
* `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
* The returned data is only valid for the duration of the callback. If
* asynchronous processing is needed, all data needs to be copied before the
* callback returns.
*/
export type SegmentationMaskCallback =
(masks: SegmentationMask[], width: number, height: number) => void;
const IMAGE_STREAM = 'image_in'; const IMAGE_STREAM = 'image_in';
const NORM_RECT_STREAM = 'norm_rect'; const NORM_RECT_STREAM = 'norm_rect';
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks'; const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';

1
third_party/BUILD vendored
View File

@ -112,6 +112,7 @@ cmake_external(
"WITH_JPEG": "ON", "WITH_JPEG": "ON",
"WITH_PNG": "ON", "WITH_PNG": "ON",
"WITH_TIFF": "ON", "WITH_TIFF": "ON",
"WITH_OPENCL": "OFF",
"WITH_WEBP": "OFF", "WITH_WEBP": "OFF",
# Optimization flags # Optimization flags
"CV_ENABLE_INTRINSICS": "ON", "CV_ENABLE_INTRINSICS": "ON",

View File

@ -67,7 +67,7 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_BUILD", name = "com_google_mediapipe_BUILD",
sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3", sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3",
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"], urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=16618756636939761678323576393653"],
) )
http_file( http_file(
@ -318,8 +318,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_face_landmarker_with_blendshapes_task", name = "com_google_mediapipe_face_landmarker_with_blendshapes_task",
sha256 = "a75c1ba70e4b8568000af2ad0b355ed559ab5d5793db50fa9ad241f8dc4fad5f", sha256 = "b44e4cae6f5822456d60f33e7c852640d78c7e342aee7eacc22589451a0b9dc2",
urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmarker_with_blendshapes.task?generation=1678323586260800"], urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmarker_with_blendshapes.task?generation=1678504998301299"],
) )
http_file( http_file(
@ -822,8 +822,8 @@ def external_files():
http_file( http_file(
name = "com_google_mediapipe_portrait_expected_face_geometry_with_attention_pbtxt", name = "com_google_mediapipe_portrait_expected_face_geometry_with_attention_pbtxt",
sha256 = "5cc57b8da3ad0527dce581fe1309f6b36043e5837e3f4f5af5e24005a99dc52a", sha256 = "7ed1eed98e61e0a10811bb611c895d87c8023f398a36db01b6d9ba2e1ab09e16",
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678323601064393"], urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678505004840652"],
) )
http_file( http_file(