Merge branch 'google:master' into face-landmarker-python
|
@ -118,9 +118,9 @@ on how to build MediaPipe examples.
|
||||||
* With a TensorFlow Model
|
* With a TensorFlow Model
|
||||||
|
|
||||||
This uses the
|
This uses the
|
||||||
[TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model)
|
[TensorFlow model](https://github.com/google/mediapipe/tree/v0.8.10/mediapipe/models/object_detection_saved_model)
|
||||||
( see also
|
( see also
|
||||||
[model info](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md)),
|
[model info](https://github.com/google/mediapipe/tree/master/mediapipe/g3doc/solutions/object_detection_saved_model.md)),
|
||||||
and the pipeline is implemented in this
|
and the pipeline is implemented in this
|
||||||
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt).
|
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt).
|
||||||
|
|
||||||
|
|
62
docs/solutions/object_detection_saved_model.md
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
## TensorFlow/TFLite Object Detection Model
|
||||||
|
|
||||||
|
### TensorFlow model
|
||||||
|
|
||||||
|
The model is trained on [MSCOCO 2014](http://cocodataset.org) dataset using [TensorFlow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection). It is a MobileNetV2-based SSD model with 0.5 depth multiplier. Detailed training configuration is in the provided `pipeline.config`. The model is a relatively compact model which has `0.171 mAP` to achieve real-time performance on mobile devices. You can compare it with other models from the [TensorFlow detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md).
|
||||||
|
|
||||||
|
|
||||||
|
### TFLite model
|
||||||
|
|
||||||
|
The TFLite model is converted from the TensorFlow above. The steps needed to convert the model are similar to [this tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) with minor modifications. Assuming now we have a trained TensorFlow model which includes the checkpoint files and the training configuration file, for example the files provided in this repo:
|
||||||
|
|
||||||
|
* `model.ckpt.index`
|
||||||
|
* `model.ckpt.meta`
|
||||||
|
* `model.ckpt.data-00000-of-00001`
|
||||||
|
* `pipeline.config`
|
||||||
|
|
||||||
|
Make sure you have installed these [python libraries](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1.md). Then to get the frozen graph, run the `export_tflite_ssd_graph.py` script from the `models/research` directory with this command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ PATH_TO_MODEL=path/to/the/model
|
||||||
|
$ bazel run object_detection:export_tflite_ssd_graph -- \
|
||||||
|
--pipeline_config_path ${PATH_TO_MODEL}/pipeline.config \
|
||||||
|
--trained_checkpoint_prefix ${PATH_TO_MODEL}/model.ckpt \
|
||||||
|
--output_directory ${PATH_TO_MODEL} \
|
||||||
|
--add_postprocessing_op=False
|
||||||
|
```
|
||||||
|
|
||||||
|
The exported model contains two files:
|
||||||
|
|
||||||
|
* `tflite_graph.pb`
|
||||||
|
* `tflite_graph.pbtxt`
|
||||||
|
|
||||||
|
The difference between this step and the one in [the tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) is that we set `add_postprocessing_op` to False. In MediaPipe, we have provided all the calculators needed for post-processing such that we can exclude the custom TFLite ops for post-processing in the original graph, e.g., non-maximum suppression. This enables the flexibility to integrate with different post-processing algorithms and implementations.
|
||||||
|
|
||||||
|
Optional: You can install and use the [graph tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms) to inspect the input/output of the exported model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ bazel run graph_transforms:summarize_graph -- \
|
||||||
|
--in_graph=${PATH_TO_MODEL}/tflite_graph.pb
|
||||||
|
```
|
||||||
|
|
||||||
|
You should be able to see the input image size of the model is 320x320 and the outputs of the model are:
|
||||||
|
|
||||||
|
* `raw_outputs/box_encodings`
|
||||||
|
* `raw_outputs/class_predictions`
|
||||||
|
|
||||||
|
The last step is to convert the model to TFLite. You can look at [this guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md) for more detail. For this example, you just need to run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ tflite_convert -- \
|
||||||
|
--graph_def_file=${PATH_TO_MODEL}/tflite_graph.pb \
|
||||||
|
--output_file=${PATH_TO_MODEL}/model.tflite \
|
||||||
|
--input_format=TENSORFLOW_GRAPHDEF \
|
||||||
|
--output_format=TFLITE \
|
||||||
|
--inference_type=FLOAT \
|
||||||
|
--input_shapes=1,320,320,3 \
|
||||||
|
--input_arrays=normalized_input_image_tensor \
|
||||||
|
--output_arrays=raw_outputs/box_encodings,raw_outputs/class_predictions
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you have the TFLite model `model.tflite` ready to use with MediaPipe Object Detection graphs. Please see the examples for more detail.
|
|
@ -269,6 +269,7 @@ Supported configuration options:
|
||||||
```python
|
```python
|
||||||
import cv2
|
import cv2
|
||||||
import mediapipe as mp
|
import mediapipe as mp
|
||||||
|
import numpy as np
|
||||||
mp_drawing = mp.solutions.drawing_utils
|
mp_drawing = mp.solutions.drawing_utils
|
||||||
mp_drawing_styles = mp.solutions.drawing_styles
|
mp_drawing_styles = mp.solutions.drawing_styles
|
||||||
mp_pose = mp.solutions.pose
|
mp_pose = mp.solutions.pose
|
||||||
|
|
|
@ -748,6 +748,7 @@ cc_test(
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
|
||||||
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png",
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
|
||||||
],
|
],
|
||||||
tags = ["desktop_only_test"],
|
tags = ["desktop_only_test"],
|
||||||
|
|
|
@ -29,6 +29,9 @@ class AffineTransformation {
|
||||||
// pixels will be calculated.
|
// pixels will be calculated.
|
||||||
enum class BorderMode { kZero, kReplicate };
|
enum class BorderMode { kZero, kReplicate };
|
||||||
|
|
||||||
|
// Pixel sampling interpolation method.
|
||||||
|
enum class Interpolation { kLinear, kCubic };
|
||||||
|
|
||||||
struct Size {
|
struct Size {
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
|
|
|
@ -77,8 +77,11 @@ class GlTextureWarpAffineRunner
|
||||||
std::unique_ptr<GpuBuffer>> {
|
std::unique_ptr<GpuBuffer>> {
|
||||||
public:
|
public:
|
||||||
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
|
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
|
||||||
GpuOrigin::Mode gpu_origin)
|
GpuOrigin::Mode gpu_origin,
|
||||||
: gl_helper_(gl_helper), gpu_origin_(gpu_origin) {}
|
AffineTransformation::Interpolation interpolation)
|
||||||
|
: gl_helper_(gl_helper),
|
||||||
|
gpu_origin_(gpu_origin),
|
||||||
|
interpolation_(interpolation) {}
|
||||||
absl::Status Init() {
|
absl::Status Init() {
|
||||||
return gl_helper_->RunInGlContext([this]() -> absl::Status {
|
return gl_helper_->RunInGlContext([this]() -> absl::Status {
|
||||||
const GLint attr_location[kNumAttributes] = {
|
const GLint attr_location[kNumAttributes] = {
|
||||||
|
@ -103,28 +106,83 @@ class GlTextureWarpAffineRunner
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
// TODO Move bicubic code to common shared place.
|
||||||
constexpr GLchar kFragShader[] = R"(
|
constexpr GLchar kFragShader[] = R"(
|
||||||
DEFAULT_PRECISION(highp, float)
|
DEFAULT_PRECISION(highp, float)
|
||||||
in vec2 sample_coordinate;
|
|
||||||
uniform sampler2D input_texture;
|
|
||||||
|
|
||||||
#ifdef GL_ES
|
in vec2 sample_coordinate;
|
||||||
#define fragColor gl_FragColor
|
uniform sampler2D input_texture;
|
||||||
#else
|
uniform vec2 input_size;
|
||||||
out vec4 fragColor;
|
|
||||||
#endif // defined(GL_ES);
|
|
||||||
|
|
||||||
void main() {
|
#ifdef GL_ES
|
||||||
vec4 color = texture2D(input_texture, sample_coordinate);
|
#define fragColor gl_FragColor
|
||||||
#ifdef CUSTOM_ZERO_BORDER_MODE
|
#else
|
||||||
float out_of_bounds =
|
out vec4 fragColor;
|
||||||
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
|
#endif // defined(GL_ES);
|
||||||
sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0);
|
|
||||||
color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds);
|
#ifdef CUBIC_INTERPOLATION
|
||||||
#endif // defined(CUSTOM_ZERO_BORDER_MODE)
|
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
|
||||||
fragColor = color;
|
const vec2 halve = vec2(0.5,0.5);
|
||||||
}
|
const vec2 one = vec2(1.0,1.0);
|
||||||
)";
|
const vec2 two = vec2(2.0,2.0);
|
||||||
|
const vec2 three = vec2(3.0,3.0);
|
||||||
|
const vec2 six = vec2(6.0,6.0);
|
||||||
|
|
||||||
|
// Calculate the fraction and integer.
|
||||||
|
tex_coord = tex_coord * tex_size - halve;
|
||||||
|
vec2 frac = fract(tex_coord);
|
||||||
|
vec2 index = tex_coord - frac + halve;
|
||||||
|
|
||||||
|
// Calculate weights for Catmull-Rom filter.
|
||||||
|
vec2 w0 = frac * (-halve + frac * (one - halve * frac));
|
||||||
|
vec2 w1 = one + frac * frac * (-(two+halve) + three/two * frac);
|
||||||
|
vec2 w2 = frac * (halve + frac * (two - three/two * frac));
|
||||||
|
vec2 w3 = frac * frac * (-halve + halve * frac);
|
||||||
|
|
||||||
|
// Calculate weights to take advantage of bilinear texture lookup.
|
||||||
|
vec2 w12 = w1 + w2;
|
||||||
|
vec2 offset12 = w2 / (w1 + w2);
|
||||||
|
|
||||||
|
vec2 index_tl = index - one;
|
||||||
|
vec2 index_br = index + two;
|
||||||
|
vec2 index_eq = index + offset12;
|
||||||
|
|
||||||
|
index_tl /= tex_size;
|
||||||
|
index_br /= tex_size;
|
||||||
|
index_eq /= tex_size;
|
||||||
|
|
||||||
|
// 9 texture lookup and linear blending.
|
||||||
|
vec4 color = vec4(0.0);
|
||||||
|
color += texture2D(tex, vec2(index_tl.x, index_tl.y)) * w0.x * w0.y;
|
||||||
|
color += texture2D(tex, vec2(index_eq.x, index_tl.y)) * w12.x *w0.y;
|
||||||
|
color += texture2D(tex, vec2(index_br.x, index_tl.y)) * w3.x * w0.y;
|
||||||
|
|
||||||
|
color += texture2D(tex, vec2(index_tl.x, index_eq.y)) * w0.x * w12.y;
|
||||||
|
color += texture2D(tex, vec2(index_eq.x, index_eq.y)) * w12.x *w12.y;
|
||||||
|
color += texture2D(tex, vec2(index_br.x, index_eq.y)) * w3.x * w12.y;
|
||||||
|
|
||||||
|
color += texture2D(tex, vec2(index_tl.x, index_br.y)) * w0.x * w3.y;
|
||||||
|
color += texture2D(tex, vec2(index_eq.x, index_br.y)) * w12.x *w3.y;
|
||||||
|
color += texture2D(tex, vec2(index_br.x, index_br.y)) * w3.x * w3.y;
|
||||||
|
return color;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
|
||||||
|
return texture2D(tex, tex_coord);
|
||||||
|
}
|
||||||
|
#endif // defined(CUBIC_INTERPOLATION)
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
vec4 color = sample(input_texture, sample_coordinate, input_size);
|
||||||
|
#ifdef CUSTOM_ZERO_BORDER_MODE
|
||||||
|
float out_of_bounds =
|
||||||
|
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
|
||||||
|
sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0);
|
||||||
|
color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds);
|
||||||
|
#endif // defined(CUSTOM_ZERO_BORDER_MODE)
|
||||||
|
fragColor = color;
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
// Create program and set parameters.
|
// Create program and set parameters.
|
||||||
auto create_fn = [&](const std::string& vs,
|
auto create_fn = [&](const std::string& vs,
|
||||||
|
@ -137,14 +195,28 @@ class GlTextureWarpAffineRunner
|
||||||
glUseProgram(program);
|
glUseProgram(program);
|
||||||
glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
|
glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
|
||||||
GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
|
GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
|
||||||
return Program{.id = program, .matrix_id = matrix_id};
|
GLint size_id = glGetUniformLocation(program, "input_size");
|
||||||
|
return Program{
|
||||||
|
.id = program, .matrix_id = matrix_id, .size_id = size_id};
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::string vert_src =
|
const std::string vert_src =
|
||||||
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
|
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
|
||||||
|
|
||||||
const std::string frag_src = absl::StrCat(
|
std::string interpolation_def;
|
||||||
mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader);
|
switch (interpolation_) {
|
||||||
|
case AffineTransformation::Interpolation::kCubic:
|
||||||
|
interpolation_def = R"(
|
||||||
|
#define CUBIC_INTERPOLATION
|
||||||
|
)";
|
||||||
|
break;
|
||||||
|
case AffineTransformation::Interpolation::kLinear:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string frag_src =
|
||||||
|
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
|
||||||
|
interpolation_def, kFragShader);
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
|
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
|
||||||
|
|
||||||
|
@ -152,9 +224,9 @@ class GlTextureWarpAffineRunner
|
||||||
std::string custom_zero_border_mode_def = R"(
|
std::string custom_zero_border_mode_def = R"(
|
||||||
#define CUSTOM_ZERO_BORDER_MODE
|
#define CUSTOM_ZERO_BORDER_MODE
|
||||||
)";
|
)";
|
||||||
const std::string frag_custom_zero_src =
|
const std::string frag_custom_zero_src = absl::StrCat(
|
||||||
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
|
mediapipe::kMediaPipeFragmentShaderPreamble,
|
||||||
custom_zero_border_mode_def, kFragShader);
|
custom_zero_border_mode_def, interpolation_def, kFragShader);
|
||||||
return create_fn(vert_src, frag_custom_zero_src);
|
return create_fn(vert_src, frag_custom_zero_src);
|
||||||
};
|
};
|
||||||
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
||||||
|
@ -256,6 +328,7 @@ class GlTextureWarpAffineRunner
|
||||||
}
|
}
|
||||||
glUseProgram(program->id);
|
glUseProgram(program->id);
|
||||||
|
|
||||||
|
// uniforms
|
||||||
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
|
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
|
||||||
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
|
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
|
||||||
// @matrix describes affine transformation in terms of TOP LEFT origin, so
|
// @matrix describes affine transformation in terms of TOP LEFT origin, so
|
||||||
|
@ -275,6 +348,10 @@ class GlTextureWarpAffineRunner
|
||||||
eigen_mat.transposeInPlace();
|
eigen_mat.transposeInPlace();
|
||||||
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
|
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
|
||||||
|
|
||||||
|
if (interpolation_ == AffineTransformation::Interpolation::kCubic) {
|
||||||
|
glUniform2f(program->size_id, texture.width(), texture.height());
|
||||||
|
}
|
||||||
|
|
||||||
// vao
|
// vao
|
||||||
glBindVertexArray(vao_);
|
glBindVertexArray(vao_);
|
||||||
|
|
||||||
|
@ -327,6 +404,7 @@ class GlTextureWarpAffineRunner
|
||||||
struct Program {
|
struct Program {
|
||||||
GLuint id;
|
GLuint id;
|
||||||
GLint matrix_id;
|
GLint matrix_id;
|
||||||
|
GLint size_id;
|
||||||
};
|
};
|
||||||
std::shared_ptr<GlCalculatorHelper> gl_helper_;
|
std::shared_ptr<GlCalculatorHelper> gl_helper_;
|
||||||
GpuOrigin::Mode gpu_origin_;
|
GpuOrigin::Mode gpu_origin_;
|
||||||
|
@ -335,6 +413,8 @@ class GlTextureWarpAffineRunner
|
||||||
Program program_;
|
Program program_;
|
||||||
std::optional<Program> program_custom_zero_;
|
std::optional<Program> program_custom_zero_;
|
||||||
GLuint framebuffer_ = 0;
|
GLuint framebuffer_ = 0;
|
||||||
|
AffineTransformation::Interpolation interpolation_ =
|
||||||
|
AffineTransformation::Interpolation::kLinear;
|
||||||
};
|
};
|
||||||
|
|
||||||
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
||||||
|
@ -344,9 +424,10 @@ class GlTextureWarpAffineRunner
|
||||||
absl::StatusOr<std::unique_ptr<
|
absl::StatusOr<std::unique_ptr<
|
||||||
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
|
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
|
||||||
CreateAffineTransformationGlRunner(
|
CreateAffineTransformationGlRunner(
|
||||||
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin) {
|
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin,
|
||||||
auto runner =
|
AffineTransformation::Interpolation interpolation) {
|
||||||
absl::make_unique<GlTextureWarpAffineRunner>(gl_helper, gpu_origin);
|
auto runner = absl::make_unique<GlTextureWarpAffineRunner>(
|
||||||
|
gl_helper, gpu_origin, interpolation);
|
||||||
MP_RETURN_IF_ERROR(runner->Init());
|
MP_RETURN_IF_ERROR(runner->Init());
|
||||||
return runner;
|
return runner;
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,8 @@ absl::StatusOr<std::unique_ptr<AffineTransformation::Runner<
|
||||||
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
|
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
|
||||||
CreateAffineTransformationGlRunner(
|
CreateAffineTransformationGlRunner(
|
||||||
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
|
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
|
||||||
mediapipe::GpuOrigin::Mode gpu_origin);
|
mediapipe::GpuOrigin::Mode gpu_origin,
|
||||||
|
AffineTransformation::Interpolation interpolation);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -39,9 +39,22 @@ cv::BorderTypes GetBorderModeForOpenCv(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int GetInterpolationForOpenCv(
|
||||||
|
AffineTransformation::Interpolation interpolation) {
|
||||||
|
switch (interpolation) {
|
||||||
|
case AffineTransformation::Interpolation::kLinear:
|
||||||
|
return cv::INTER_LINEAR;
|
||||||
|
case AffineTransformation::Interpolation::kCubic:
|
||||||
|
return cv::INTER_CUBIC;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class OpenCvRunner
|
class OpenCvRunner
|
||||||
: public AffineTransformation::Runner<ImageFrame, ImageFrame> {
|
: public AffineTransformation::Runner<ImageFrame, ImageFrame> {
|
||||||
public:
|
public:
|
||||||
|
OpenCvRunner(AffineTransformation::Interpolation interpolation)
|
||||||
|
: interpolation_(GetInterpolationForOpenCv(interpolation)) {}
|
||||||
|
|
||||||
absl::StatusOr<ImageFrame> Run(
|
absl::StatusOr<ImageFrame> Run(
|
||||||
const ImageFrame& input, const std::array<float, 16>& matrix,
|
const ImageFrame& input, const std::array<float, 16>& matrix,
|
||||||
const AffineTransformation::Size& size,
|
const AffineTransformation::Size& size,
|
||||||
|
@ -142,19 +155,23 @@ class OpenCvRunner
|
||||||
|
|
||||||
cv::warpAffine(in_mat, out_mat, cv_affine_transform,
|
cv::warpAffine(in_mat, out_mat, cv_affine_transform,
|
||||||
cv::Size(out_mat.cols, out_mat.rows),
|
cv::Size(out_mat.cols, out_mat.rows),
|
||||||
/*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP,
|
/*flags=*/interpolation_ | cv::WARP_INVERSE_MAP,
|
||||||
GetBorderModeForOpenCv(border_mode));
|
GetBorderModeForOpenCv(border_mode));
|
||||||
|
|
||||||
return out_image;
|
return out_image;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int interpolation_ = cv::INTER_LINEAR;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::StatusOr<
|
absl::StatusOr<
|
||||||
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
||||||
CreateAffineTransformationOpenCvRunner() {
|
CreateAffineTransformationOpenCvRunner(
|
||||||
return absl::make_unique<OpenCvRunner>();
|
AffineTransformation::Interpolation interpolation) {
|
||||||
|
return absl::make_unique<OpenCvRunner>(interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -25,7 +25,8 @@ namespace mediapipe {
|
||||||
|
|
||||||
absl::StatusOr<
|
absl::StatusOr<
|
||||||
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
||||||
CreateAffineTransformationOpenCvRunner();
|
CreateAffineTransformationOpenCvRunner(
|
||||||
|
AffineTransformation::Interpolation interpolation);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -53,6 +53,17 @@ AffineTransformation::BorderMode GetBorderMode(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AffineTransformation::Interpolation GetInterpolation(
|
||||||
|
mediapipe::WarpAffineCalculatorOptions::Interpolation interpolation) {
|
||||||
|
switch (interpolation) {
|
||||||
|
case mediapipe::WarpAffineCalculatorOptions::INTER_UNSPECIFIED:
|
||||||
|
case mediapipe::WarpAffineCalculatorOptions::INTER_LINEAR:
|
||||||
|
return AffineTransformation::Interpolation::kLinear;
|
||||||
|
case mediapipe::WarpAffineCalculatorOptions::INTER_CUBIC:
|
||||||
|
return AffineTransformation::Interpolation::kCubic;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename ImageT>
|
template <typename ImageT>
|
||||||
class WarpAffineRunnerHolder {};
|
class WarpAffineRunnerHolder {};
|
||||||
|
|
||||||
|
@ -61,16 +72,22 @@ template <>
|
||||||
class WarpAffineRunnerHolder<ImageFrame> {
|
class WarpAffineRunnerHolder<ImageFrame> {
|
||||||
public:
|
public:
|
||||||
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
|
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
|
||||||
absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); }
|
absl::Status Open(CalculatorContext* cc) {
|
||||||
|
interpolation_ = GetInterpolation(
|
||||||
|
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
absl::StatusOr<RunnerType*> GetRunner() {
|
absl::StatusOr<RunnerType*> GetRunner() {
|
||||||
if (!runner_) {
|
if (!runner_) {
|
||||||
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner());
|
ASSIGN_OR_RETURN(runner_,
|
||||||
|
CreateAffineTransformationOpenCvRunner(interpolation_));
|
||||||
}
|
}
|
||||||
return runner_.get();
|
return runner_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<RunnerType> runner_;
|
std::unique_ptr<RunnerType> runner_;
|
||||||
|
AffineTransformation::Interpolation interpolation_;
|
||||||
};
|
};
|
||||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||||
|
|
||||||
|
@ -85,12 +102,14 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
|
||||||
gpu_origin_ =
|
gpu_origin_ =
|
||||||
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
|
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
|
||||||
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
|
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
|
||||||
|
interpolation_ = GetInterpolation(
|
||||||
|
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
|
||||||
return gl_helper_->Open(cc);
|
return gl_helper_->Open(cc);
|
||||||
}
|
}
|
||||||
absl::StatusOr<RunnerType*> GetRunner() {
|
absl::StatusOr<RunnerType*> GetRunner() {
|
||||||
if (!runner_) {
|
if (!runner_) {
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationGlRunner(
|
||||||
runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_));
|
gl_helper_, gpu_origin_, interpolation_));
|
||||||
}
|
}
|
||||||
return runner_.get();
|
return runner_.get();
|
||||||
}
|
}
|
||||||
|
@ -99,6 +118,7 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
|
||||||
mediapipe::GpuOrigin::Mode gpu_origin_;
|
mediapipe::GpuOrigin::Mode gpu_origin_;
|
||||||
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
|
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
|
||||||
std::unique_ptr<RunnerType> runner_;
|
std::unique_ptr<RunnerType> runner_;
|
||||||
|
AffineTransformation::Interpolation interpolation_;
|
||||||
};
|
};
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,13 @@ message WarpAffineCalculatorOptions {
|
||||||
BORDER_REPLICATE = 2;
|
BORDER_REPLICATE = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pixel sampling interpolation methods. See @interpolation.
|
||||||
|
enum Interpolation {
|
||||||
|
INTER_UNSPECIFIED = 0;
|
||||||
|
INTER_LINEAR = 1;
|
||||||
|
INTER_CUBIC = 2;
|
||||||
|
}
|
||||||
|
|
||||||
// Pixel extrapolation method.
|
// Pixel extrapolation method.
|
||||||
// When converting image to tensor it may happen that tensor needs to read
|
// When converting image to tensor it may happen that tensor needs to read
|
||||||
// pixels outside image boundaries. Border mode helps to specify how such
|
// pixels outside image boundaries. Border mode helps to specify how such
|
||||||
|
@ -43,4 +50,10 @@ message WarpAffineCalculatorOptions {
|
||||||
// to be flipped vertically as tensors are expected to start at top.
|
// to be flipped vertically as tensors are expected to start at top.
|
||||||
// (DEFAULT or unset interpreted as CONVENTIONAL.)
|
// (DEFAULT or unset interpreted as CONVENTIONAL.)
|
||||||
optional GpuOrigin.Mode gpu_origin = 2;
|
optional GpuOrigin.Mode gpu_origin = 2;
|
||||||
|
|
||||||
|
// Sampling method for neighboring pixels.
|
||||||
|
// INTER_LINEAR (bilinear) linearly interpolates from the nearest 4 neighbors.
|
||||||
|
// INTER_CUBIC (bicubic) interpolates a small neighborhood with cubic weights.
|
||||||
|
// INTER_UNSPECIFIED or unset interpreted as INTER_LINEAR.
|
||||||
|
optional Interpolation interpolation = 3;
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,8 @@ void RunTest(const std::string& graph_text, const std::string& tag,
|
||||||
const cv::Mat& input, cv::Mat expected_result,
|
const cv::Mat& input, cv::Mat expected_result,
|
||||||
float similarity_threshold, std::array<float, 16> matrix,
|
float similarity_threshold, std::array<float, 16> matrix,
|
||||||
int out_width, int out_height,
|
int out_width, int out_height,
|
||||||
absl::optional<AffineTransformation::BorderMode> border_mode) {
|
std::optional<AffineTransformation::BorderMode> border_mode,
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation) {
|
||||||
std::string border_mode_str;
|
std::string border_mode_str;
|
||||||
if (border_mode) {
|
if (border_mode) {
|
||||||
switch (*border_mode) {
|
switch (*border_mode) {
|
||||||
|
@ -75,8 +76,20 @@ void RunTest(const std::string& graph_text, const std::string& tag,
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
std::string interpolation_str;
|
||||||
|
if (interpolation) {
|
||||||
|
switch (*interpolation) {
|
||||||
|
case AffineTransformation::Interpolation::kLinear:
|
||||||
|
interpolation_str = "interpolation: INTER_LINEAR";
|
||||||
|
break;
|
||||||
|
case AffineTransformation::Interpolation::kCubic:
|
||||||
|
interpolation_str = "interpolation: INTER_CUBIC";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
absl::Substitute(graph_text, /*$0=*/border_mode_str));
|
absl::Substitute(graph_text, /*$0=*/border_mode_str,
|
||||||
|
/*$1=*/interpolation_str));
|
||||||
|
|
||||||
std::vector<Packet> output_packets;
|
std::vector<Packet> output_packets;
|
||||||
tool::AddVectorSink("output_image", &graph_config, &output_packets);
|
tool::AddVectorSink("output_image", &graph_config, &output_packets);
|
||||||
|
@ -132,7 +145,8 @@ struct SimilarityConfig {
|
||||||
void RunTest(cv::Mat input, cv::Mat expected_result,
|
void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
const SimilarityConfig& similarity, std::array<float, 16> matrix,
|
const SimilarityConfig& similarity, std::array<float, 16> matrix,
|
||||||
int out_width, int out_height,
|
int out_width, int out_height,
|
||||||
absl::optional<AffineTransformation::BorderMode> border_mode) {
|
std::optional<AffineTransformation::BorderMode> border_mode,
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation) {
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
input_stream: "output_size"
|
input_stream: "output_size"
|
||||||
|
@ -146,12 +160,13 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
|
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
|
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
|
@ -171,6 +186,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -181,7 +197,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"cpu_image", input, expected_result, similarity.threshold_on_cpu,
|
"cpu_image", input, expected_result, similarity.threshold_on_cpu,
|
||||||
matrix, out_width, out_height, border_mode);
|
matrix, out_width, out_height, border_mode, interpolation);
|
||||||
|
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
|
@ -201,6 +217,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
gpu_origin: TOP_LEFT
|
gpu_origin: TOP_LEFT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -212,7 +229,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
|
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
|
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
|
@ -237,6 +254,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
gpu_origin: TOP_LEFT
|
gpu_origin: TOP_LEFT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -253,7 +271,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"gpu_image", input, expected_result, similarity.threshold_on_gpu,
|
"gpu_image", input, expected_result, similarity.threshold_on_gpu,
|
||||||
matrix, out_width, out_height, border_mode);
|
matrix, out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
|
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
|
||||||
|
@ -287,10 +305,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) {
|
||||||
int out_height = 256;
|
int out_height = 256;
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
||||||
|
@ -312,10 +331,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
||||||
|
@ -337,10 +357,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
|
||||||
|
@ -362,10 +383,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
|
||||||
|
@ -386,10 +408,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
||||||
|
@ -411,10 +434,38 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZeroInterpCubic) {
|
||||||
|
mediapipe::NormalizedRect roi;
|
||||||
|
roi.set_x_center(0.65f);
|
||||||
|
roi.set_y_center(0.4f);
|
||||||
|
roi.set_width(0.5f);
|
||||||
|
roi.set_height(0.5f);
|
||||||
|
roi.set_rotation(M_PI * -45.0f / 180.0f);
|
||||||
|
auto input = GetRgb(
|
||||||
|
"/mediapipe/calculators/"
|
||||||
|
"tensor/testdata/image_to_tensor/input.jpg");
|
||||||
|
auto expected_output = GetRgb(
|
||||||
|
"/mediapipe/calculators/"
|
||||||
|
"tensor/testdata/image_to_tensor/"
|
||||||
|
"medium_sub_rect_with_rotation_border_zero_interp_cubic.png");
|
||||||
|
int out_width = 256;
|
||||||
|
int out_height = 256;
|
||||||
|
bool keep_aspect_ratio = false;
|
||||||
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation =
|
||||||
|
AffineTransformation::Interpolation::kCubic;
|
||||||
|
RunTest(input, expected_output,
|
||||||
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.78},
|
||||||
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRect) {
|
TEST(WarpAffineCalculatorTest, LargeSubRect) {
|
||||||
|
@ -435,10 +486,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRect) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
|
||||||
|
@ -459,10 +511,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
|
||||||
|
@ -483,10 +536,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
||||||
|
@ -508,10 +562,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
||||||
|
@ -532,10 +587,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
||||||
int out_height = 128;
|
int out_height = 128;
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
|
||||||
|
@ -557,10 +613,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, NoOp) {
|
TEST(WarpAffineCalculatorTest, NoOp) {
|
||||||
|
@ -581,10 +638,11 @@ TEST(WarpAffineCalculatorTest, NoOp) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
|
TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
|
||||||
|
@ -605,10 +663,11 @@ TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -997,17 +997,20 @@ cc_library(
|
||||||
":image_to_tensor_converter_gl_buffer",
|
":image_to_tensor_converter_gl_buffer",
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
"//mediapipe/gpu:gpu_service",
|
||||||
],
|
],
|
||||||
"//mediapipe:apple": [
|
"//mediapipe:apple": [
|
||||||
":image_to_tensor_converter_metal",
|
":image_to_tensor_converter_metal",
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
"//mediapipe/gpu:MPPMetalHelper",
|
"//mediapipe/gpu:MPPMetalHelper",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
"//mediapipe/gpu:gpu_service",
|
||||||
],
|
],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
":image_to_tensor_converter_gl_buffer",
|
":image_to_tensor_converter_gl_buffer",
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
"//mediapipe/gpu:gpu_service",
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -1045,6 +1048,10 @@ cc_test(
|
||||||
":image_to_tensor_calculator",
|
":image_to_tensor_calculator",
|
||||||
":image_to_tensor_converter",
|
":image_to_tensor_converter",
|
||||||
":image_to_tensor_utils",
|
":image_to_tensor_utils",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework/deps:file_path",
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
@ -1061,11 +1068,10 @@ cc_test(
|
||||||
"//mediapipe/framework/port:opencv_imgproc",
|
"//mediapipe/framework/port:opencv_imgproc",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
"//mediapipe/util:image_test_utils",
|
"//mediapipe/util:image_test_utils",
|
||||||
"@com_google_absl//absl/flags:flag",
|
] + select({
|
||||||
"@com_google_absl//absl/memory",
|
"//mediapipe:apple": [],
|
||||||
"@com_google_absl//absl/strings",
|
"//conditions:default": ["//mediapipe/gpu:gl_context"],
|
||||||
"@com_google_absl//absl/strings:str_format",
|
}),
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -45,9 +45,11 @@
|
||||||
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||||
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h"
|
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h"
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
|
#include "mediapipe/gpu/gpu_service.h"
|
||||||
#else
|
#else
|
||||||
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h"
|
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h"
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
|
#include "mediapipe/gpu/gpu_service.h"
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
|
@ -147,7 +149,7 @@ class ImageToTensorCalculator : public Node {
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
#if MEDIAPIPE_METAL_ENABLED
|
||||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||||
#else
|
#else
|
||||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
cc->UseService(kGpuService).Optional();
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
#endif // MEDIAPIPE_DISABLE_GPU
|
#endif // MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,10 @@
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/util/image_test_utils.h"
|
#include "mediapipe/util/image_test_utils.h"
|
||||||
|
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
#include "mediapipe/gpu/gl_context.h"
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -507,5 +511,79 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeAndUseInputImageDims) {
|
||||||
/*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt,
|
/*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt,
|
||||||
/*keep_aspect=*/false, BorderMode::kZero, roi);
|
/*keep_aspect=*/false, BorderMode::kZero, roi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ImageToTensorCalculatorTest, CanBeUsedWithoutGpuServiceSet) {
|
||||||
|
auto graph_config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input_image"
|
||||||
|
node {
|
||||||
|
calculator: "ImageToTensorCalculator"
|
||||||
|
input_stream: "IMAGE:input_image"
|
||||||
|
output_stream: "TENSORS:tensor"
|
||||||
|
options {
|
||||||
|
[mediapipe.ImageToTensorCalculatorOptions.ext] {
|
||||||
|
output_tensor_float_range { min: 0.0f max: 1.0f }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
auto image_frame =
|
||||||
|
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
|
||||||
|
Image image = Image(std::move(image_frame));
|
||||||
|
Packet packet = MakePacket<Image>(std::move(image));
|
||||||
|
MP_ASSERT_OK(
|
||||||
|
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
|
TEST(ImageToTensorCalculatorTest,
|
||||||
|
FailsGracefullyWhenGpuServiceNeededButNotAvailable) {
|
||||||
|
auto graph_config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input_image"
|
||||||
|
node {
|
||||||
|
calculator: "ImageToTensorCalculator"
|
||||||
|
input_stream: "IMAGE:input_image"
|
||||||
|
output_stream: "TENSORS:tensor"
|
||||||
|
options {
|
||||||
|
[mediapipe.ImageToTensorCalculatorOptions.ext] {
|
||||||
|
output_tensor_float_range { min: 0.0f max: 1.0f }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto context,
|
||||||
|
GlContext::Create(nullptr, /*create_thread=*/true));
|
||||||
|
Packet packet;
|
||||||
|
context->Run([&packet]() {
|
||||||
|
auto image_frame =
|
||||||
|
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
|
||||||
|
Image image = Image(std::move(image_frame));
|
||||||
|
// Ensure image is available on GPU to force ImageToTensorCalculator to
|
||||||
|
// run on GPU.
|
||||||
|
ASSERT_TRUE(image.ConvertToGpu());
|
||||||
|
packet = MakePacket<Image>(std::move(image));
|
||||||
|
});
|
||||||
|
MP_ASSERT_OK(
|
||||||
|
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
|
||||||
|
EXPECT_THAT(graph.WaitUntilIdle(),
|
||||||
|
StatusIs(absl::StatusCode::kInternal,
|
||||||
|
HasSubstr("GPU service not available")));
|
||||||
|
}
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
After Width: | Height: | Size: 64 KiB |
|
@ -138,7 +138,23 @@ void TestWithAspectRatio(const double aspect_ratio,
|
||||||
std::string result_image;
|
std::string result_image;
|
||||||
MP_ASSERT_OK(
|
MP_ASSERT_OK(
|
||||||
mediapipe::file::GetContents(result_string_path, &result_image));
|
mediapipe::file::GetContents(result_string_path, &result_image));
|
||||||
EXPECT_EQ(result_image, output_string);
|
if (result_image != output_string) {
|
||||||
|
// There may be slight differences due to the way the JPEG was encoded or
|
||||||
|
// the OpenCV version used to generate the reference files. Compare
|
||||||
|
// pixel-by-pixel using the Peak Signal-to-Noise Ratio instead.
|
||||||
|
cv::Mat result_mat =
|
||||||
|
cv::imdecode(cv::Mat(1, result_image.size(), CV_8UC1,
|
||||||
|
const_cast<char*>(result_image.data())),
|
||||||
|
cv::IMREAD_UNCHANGED);
|
||||||
|
cv::Mat output_mat =
|
||||||
|
cv::imdecode(cv::Mat(1, output_string.size(), CV_8UC1,
|
||||||
|
const_cast<char*>(output_string.data())),
|
||||||
|
cv::IMREAD_UNCHANGED);
|
||||||
|
ASSERT_EQ(result_mat.rows, output_mat.rows);
|
||||||
|
ASSERT_EQ(result_mat.cols, output_mat.cols);
|
||||||
|
ASSERT_EQ(result_mat.type(), output_mat.type());
|
||||||
|
EXPECT_GT(cv::PSNR(result_mat, output_mat), 45.0);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
std::string output_string_path = mediapipe::file::JoinPath(
|
std::string output_string_path = mediapipe::file::JoinPath(
|
||||||
absl::GetFlag(FLAGS_output_folder),
|
absl::GetFlag(FLAGS_output_folder),
|
||||||
|
|
Before Width: | Height: | Size: 3.2 KiB After Width: | Height: | Size: 3.2 KiB |
Before Width: | Height: | Size: 6.1 KiB After Width: | Height: | Size: 6.1 KiB |
Before Width: | Height: | Size: 8.2 KiB After Width: | Height: | Size: 8.2 KiB |
Before Width: | Height: | Size: 7.6 KiB After Width: | Height: | Size: 7.6 KiB |
|
@ -19,9 +19,6 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "face_geometry_from_landmarks_graph",
|
name = "face_geometry_from_landmarks_graph",
|
||||||
srcs = ["face_geometry_from_landmarks_graph.cc"],
|
srcs = ["face_geometry_from_landmarks_graph.cc"],
|
||||||
data = [
|
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/data:geometry_pipeline_metadata_landmarks",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:begin_loop_calculator",
|
"//mediapipe/calculators/core:begin_loop_calculator",
|
||||||
"//mediapipe/calculators/core:end_loop_calculator",
|
"//mediapipe/calculators/core:end_loop_calculator",
|
||||||
|
@ -39,6 +36,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_cc_proto",
|
||||||
"//mediapipe/util:graph_builder_utils",
|
"//mediapipe/util:graph_builder_utils",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
|
|
|
@ -45,6 +45,7 @@ mediapipe_proto_library(
|
||||||
srcs = ["geometry_pipeline_calculator.proto"],
|
srcs = ["geometry_pipeline_calculator.proto"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:external_file_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,6 +60,8 @@ cc_library(
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/port:statusor",
|
"//mediapipe/framework/port:statusor",
|
||||||
|
"//mediapipe/tasks/cc/core:external_file_handler",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline",
|
"//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/libs:validation_utils",
|
"//mediapipe/tasks/cc/vision/face_geometry/libs:validation_utils",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
|
||||||
|
|
|
@ -24,6 +24,8 @@
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/status_macros.h"
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
#include "mediapipe/framework/port/statusor.h"
|
#include "mediapipe/framework/port/statusor.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/external_file_handler.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h"
|
||||||
|
@ -69,8 +71,8 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
|
||||||
// A vector of face geometry data.
|
// A vector of face geometry data.
|
||||||
//
|
//
|
||||||
// Options:
|
// Options:
|
||||||
// metadata_path (`string`, optional):
|
// metadata_file (`ExternalFile`, optional):
|
||||||
// Defines a path for the geometry pipeline metadata file.
|
// Defines an ExternalFile for the geometry pipeline metadata file.
|
||||||
//
|
//
|
||||||
// The geometry pipeline metadata file format must be the binary
|
// The geometry pipeline metadata file format must be the binary
|
||||||
// `GeometryPipelineMetadata` proto.
|
// `GeometryPipelineMetadata` proto.
|
||||||
|
@ -95,7 +97,7 @@ class GeometryPipelineCalculator : public CalculatorBase {
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
GeometryPipelineMetadata metadata,
|
GeometryPipelineMetadata metadata,
|
||||||
ReadMetadataFromFile(options.metadata_path()),
|
ReadMetadataFromFile(options.metadata_file()),
|
||||||
_ << "Failed to read the geometry pipeline metadata from file!");
|
_ << "Failed to read the geometry pipeline metadata from file!");
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(ValidateGeometryPipelineMetadata(metadata))
|
MP_RETURN_IF_ERROR(ValidateGeometryPipelineMetadata(metadata))
|
||||||
|
@ -155,32 +157,19 @@ class GeometryPipelineCalculator : public CalculatorBase {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static absl::StatusOr<GeometryPipelineMetadata> ReadMetadataFromFile(
|
static absl::StatusOr<GeometryPipelineMetadata> ReadMetadataFromFile(
|
||||||
const std::string& metadata_path) {
|
const core::proto::ExternalFile& metadata_file) {
|
||||||
ASSIGN_OR_RETURN(std::string metadata_blob,
|
ASSIGN_OR_RETURN(
|
||||||
ReadContentBlobFromFile(metadata_path),
|
const auto file_handler,
|
||||||
_ << "Failed to read a metadata blob from file!");
|
core::ExternalFileHandler::CreateFromExternalFile(&metadata_file));
|
||||||
|
|
||||||
GeometryPipelineMetadata metadata;
|
GeometryPipelineMetadata metadata;
|
||||||
RET_CHECK(metadata.ParseFromString(metadata_blob))
|
RET_CHECK(
|
||||||
|
metadata.ParseFromString(std::string(file_handler->GetFileContent())))
|
||||||
<< "Failed to parse a metadata proto from a binary blob!";
|
<< "Failed to parse a metadata proto from a binary blob!";
|
||||||
|
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
static absl::StatusOr<std::string> ReadContentBlobFromFile(
|
|
||||||
const std::string& unresolved_path) {
|
|
||||||
ASSIGN_OR_RETURN(std::string resolved_path,
|
|
||||||
mediapipe::PathToResourceAsFile(unresolved_path),
|
|
||||||
_ << "Failed to resolve path! Path = " << unresolved_path);
|
|
||||||
|
|
||||||
std::string content_blob;
|
|
||||||
MP_RETURN_IF_ERROR(
|
|
||||||
mediapipe::GetResourceContents(resolved_path, &content_blob))
|
|
||||||
<< "Failed to read content blob! Resolved path = " << resolved_path;
|
|
||||||
|
|
||||||
return content_blob;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<GeometryPipeline> geometry_pipeline_;
|
std::unique_ptr<GeometryPipeline> geometry_pipeline_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -17,11 +17,12 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.face_geometry;
|
package mediapipe.tasks.vision.face_geometry;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator_options.proto";
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
|
import "mediapipe/tasks/cc/core/proto/external_file.proto";
|
||||||
|
|
||||||
message FaceGeometryPipelineCalculatorOptions {
|
message FaceGeometryPipelineCalculatorOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional FaceGeometryPipelineCalculatorOptions ext = 512499200;
|
optional FaceGeometryPipelineCalculatorOptions ext = 512499200;
|
||||||
}
|
}
|
||||||
|
|
||||||
optional string metadata_path = 1;
|
optional core.proto.ExternalFile metadata_file = 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.pb.h"
|
||||||
#include "mediapipe/util/graph_builder_utils.h"
|
#include "mediapipe/util/graph_builder_utils.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::vision::face_geometry {
|
namespace mediapipe::tasks::vision::face_geometry {
|
||||||
|
@ -49,10 +50,6 @@ constexpr char kIterableTag[] = "ITERABLE";
|
||||||
constexpr char kBatchEndTag[] = "BATCH_END";
|
constexpr char kBatchEndTag[] = "BATCH_END";
|
||||||
constexpr char kItemTag[] = "ITEM";
|
constexpr char kItemTag[] = "ITEM";
|
||||||
|
|
||||||
constexpr char kGeometryPipelineMetadataPath[] =
|
|
||||||
"mediapipe/tasks/cc/vision/face_geometry/data/"
|
|
||||||
"geometry_pipeline_metadata_landmarks.binarypb";
|
|
||||||
|
|
||||||
struct FaceGeometryOuts {
|
struct FaceGeometryOuts {
|
||||||
Stream<std::vector<FaceGeometry>> multi_face_geometry;
|
Stream<std::vector<FaceGeometry>> multi_face_geometry;
|
||||||
};
|
};
|
||||||
|
@ -127,6 +124,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
|
||||||
}
|
}
|
||||||
ASSIGN_OR_RETURN(auto outs,
|
ASSIGN_OR_RETURN(auto outs,
|
||||||
BuildFaceGeometryFromLandmarksGraph(
|
BuildFaceGeometryFromLandmarksGraph(
|
||||||
|
*sc->MutableOptions<proto::FaceGeometryGraphOptions>(),
|
||||||
graph.In(kFaceLandmarksTag)
|
graph.In(kFaceLandmarksTag)
|
||||||
.Cast<std::vector<NormalizedLandmarkList>>(),
|
.Cast<std::vector<NormalizedLandmarkList>>(),
|
||||||
graph.In(kImageSizeTag).Cast<std::pair<int, int>>(),
|
graph.In(kImageSizeTag).Cast<std::pair<int, int>>(),
|
||||||
|
@ -138,6 +136,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::StatusOr<FaceGeometryOuts> BuildFaceGeometryFromLandmarksGraph(
|
absl::StatusOr<FaceGeometryOuts> BuildFaceGeometryFromLandmarksGraph(
|
||||||
|
proto::FaceGeometryGraphOptions& graph_options,
|
||||||
Stream<std::vector<NormalizedLandmarkList>> multi_face_landmarks,
|
Stream<std::vector<NormalizedLandmarkList>> multi_face_landmarks,
|
||||||
Stream<std::pair<int, int>> image_size,
|
Stream<std::pair<int, int>> image_size,
|
||||||
std::optional<SidePacket<Environment>> environment, Graph& graph) {
|
std::optional<SidePacket<Environment>> environment, Graph& graph) {
|
||||||
|
@ -185,7 +184,8 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
|
||||||
"mediapipe.tasks.vision.face_geometry.FaceGeometryPipelineCalculator");
|
"mediapipe.tasks.vision.face_geometry.FaceGeometryPipelineCalculator");
|
||||||
auto& geometry_pipeline_options =
|
auto& geometry_pipeline_options =
|
||||||
geometry_pipeline.GetOptions<FaceGeometryPipelineCalculatorOptions>();
|
geometry_pipeline.GetOptions<FaceGeometryPipelineCalculatorOptions>();
|
||||||
geometry_pipeline_options.set_metadata_path(kGeometryPipelineMetadataPath);
|
geometry_pipeline_options.Swap(
|
||||||
|
graph_options.mutable_geometry_pipeline_options());
|
||||||
image_size >> geometry_pipeline.In(kImageSizeTag);
|
image_size >> geometry_pipeline.In(kImageSizeTag);
|
||||||
multi_face_landmarks_no_iris >>
|
multi_face_landmarks_no_iris >>
|
||||||
geometry_pipeline.In(kMultiFaceLandmarksTag);
|
geometry_pipeline.In(kMultiFaceLandmarksTag);
|
||||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
@ -31,6 +32,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/tool/sink.h"
|
#include "mediapipe/framework/tool/sink.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
|
||||||
|
|
||||||
|
@ -49,6 +51,9 @@ constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||||
constexpr char kFaceLandmarksFileName[] =
|
constexpr char kFaceLandmarksFileName[] =
|
||||||
"face_blendshapes_in_landmarks.prototxt";
|
"face_blendshapes_in_landmarks.prototxt";
|
||||||
constexpr char kFaceGeometryFileName[] = "face_geometry_expected_out.pbtxt";
|
constexpr char kFaceGeometryFileName[] = "face_geometry_expected_out.pbtxt";
|
||||||
|
constexpr char kGeometryPipelineMetadataPath[] =
|
||||||
|
"mediapipe/tasks/cc/vision/face_geometry/data/"
|
||||||
|
"geometry_pipeline_metadata_landmarks.binarypb";
|
||||||
|
|
||||||
std::vector<NormalizedLandmarkList> GetLandmarks(absl::string_view filename) {
|
std::vector<NormalizedLandmarkList> GetLandmarks(absl::string_view filename) {
|
||||||
NormalizedLandmarkList landmarks;
|
NormalizedLandmarkList landmarks;
|
||||||
|
@ -89,17 +94,25 @@ void MakeInputPacketsAndRunGraph(CalculatorGraph& graph) {
|
||||||
|
|
||||||
TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
|
TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
|
||||||
CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
|
CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
|
||||||
CalculatorGraphConfig>(R"pb(
|
CalculatorGraphConfig>(absl::Substitute(
|
||||||
input_stream: "FACE_LANDMARKS:face_landmarks"
|
R"pb(
|
||||||
input_stream: "IMAGE_SIZE:image_size"
|
input_stream: "FACE_LANDMARKS:face_landmarks"
|
||||||
output_stream: "FACE_GEOMETRY:face_geometry"
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
node {
|
output_stream: "FACE_GEOMETRY:face_geometry"
|
||||||
calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph"
|
node {
|
||||||
input_stream: "FACE_LANDMARKS:face_landmarks"
|
calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph"
|
||||||
input_stream: "IMAGE_SIZE:image_size"
|
input_stream: "FACE_LANDMARKS:face_landmarks"
|
||||||
output_stream: "FACE_GEOMETRY:face_geometry"
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
}
|
output_stream: "FACE_GEOMETRY:face_geometry"
|
||||||
)pb");
|
options: {
|
||||||
|
[mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions
|
||||||
|
.ext] {
|
||||||
|
geometry_pipeline_options { metadata_file { file_name: "$0" } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb",
|
||||||
|
kGeometryPipelineMetadataPath));
|
||||||
std::vector<Packet> output_packets;
|
std::vector<Packet> output_packets;
|
||||||
tool::AddVectorSink("face_geometry", &graph_config, &output_packets);
|
tool::AddVectorSink("face_geometry", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
@ -116,19 +129,27 @@ TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
|
||||||
|
|
||||||
TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) {
|
TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) {
|
||||||
CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
|
CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
|
||||||
CalculatorGraphConfig>(R"pb(
|
CalculatorGraphConfig>(absl::Substitute(
|
||||||
input_stream: "FACE_LANDMARKS:face_landmarks"
|
R"pb(
|
||||||
input_stream: "IMAGE_SIZE:image_size"
|
input_stream: "FACE_LANDMARKS:face_landmarks"
|
||||||
input_side_packet: "ENVIRONMENT:environment"
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
output_stream: "FACE_GEOMETRY:face_geometry"
|
input_side_packet: "ENVIRONMENT:environment"
|
||||||
node {
|
output_stream: "FACE_GEOMETRY:face_geometry"
|
||||||
calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph"
|
node {
|
||||||
input_stream: "FACE_LANDMARKS:face_landmarks"
|
calculator: "mediapipe.tasks.vision.face_geometry.FaceGeometryFromLandmarksGraph"
|
||||||
input_stream: "IMAGE_SIZE:image_size"
|
input_stream: "FACE_LANDMARKS:face_landmarks"
|
||||||
input_side_packet: "ENVIRONMENT:environment"
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
output_stream: "FACE_GEOMETRY:face_geometry"
|
input_side_packet: "ENVIRONMENT:environment"
|
||||||
}
|
output_stream: "FACE_GEOMETRY:face_geometry"
|
||||||
)pb");
|
options: {
|
||||||
|
[mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions
|
||||||
|
.ext] {
|
||||||
|
geometry_pipeline_options { metadata_file { file_name: "$0" } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb",
|
||||||
|
kGeometryPipelineMetadataPath));
|
||||||
std::vector<Packet> output_packets;
|
std::vector<Packet> output_packets;
|
||||||
tool::AddVectorSink("face_geometry", &graph_config, &output_packets);
|
tool::AddVectorSink("face_geometry", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
|
|
@ -44,3 +44,12 @@ mediapipe_proto_library(
|
||||||
name = "mesh_3d_proto",
|
name = "mesh_3d_proto",
|
||||||
srcs = ["mesh_3d.proto"],
|
srcs = ["mesh_3d.proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "face_geometry_graph_options_proto",
|
||||||
|
srcs = ["face_geometry_graph_options.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe.tasks.vision.face_geometry.proto;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
|
import "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.proto";
|
||||||
|
|
||||||
|
message FaceGeometryGraphOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional FaceGeometryGraphOptions ext = 515723506;
|
||||||
|
}
|
||||||
|
|
||||||
|
optional FaceGeometryPipelineCalculatorOptions geometry_pipeline_options = 1;
|
||||||
|
}
|
|
@ -210,8 +210,10 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
|
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
|
||||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry:face_geometry_from_landmarks_graph",
|
"//mediapipe/tasks/cc/vision/face_geometry:face_geometry_from_landmarks_graph",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_blendshapes_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarker_graph_options_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_landmarker/proto:face_landmarks_detector_graph_options_cc_proto",
|
||||||
|
|
|
@ -40,8 +40,10 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/core/utils.h"
|
#include "mediapipe/tasks/cc/core/utils.h"
|
||||||
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
#include "mediapipe/tasks/cc/metadata/utils/zip_utils.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_blendshapes_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarker_graph_options.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.pb.h"
|
||||||
|
@ -93,6 +95,8 @@ constexpr char kFaceDetectorTFLiteName[] = "face_detector.tflite";
|
||||||
constexpr char kFaceLandmarksDetectorTFLiteName[] =
|
constexpr char kFaceLandmarksDetectorTFLiteName[] =
|
||||||
"face_landmarks_detector.tflite";
|
"face_landmarks_detector.tflite";
|
||||||
constexpr char kFaceBlendshapeTFLiteName[] = "face_blendshapes.tflite";
|
constexpr char kFaceBlendshapeTFLiteName[] = "face_blendshapes.tflite";
|
||||||
|
constexpr char kFaceGeometryPipelineMetadataName[] =
|
||||||
|
"geometry_pipeline_metadata_landmarks.binarypb";
|
||||||
|
|
||||||
struct FaceLandmarkerOutputs {
|
struct FaceLandmarkerOutputs {
|
||||||
Source<std::vector<NormalizedLandmarkList>> landmark_lists;
|
Source<std::vector<NormalizedLandmarkList>> landmark_lists;
|
||||||
|
@ -305,6 +309,7 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
absl::StatusOr<CalculatorGraphConfig> GetConfig(
|
||||||
SubgraphContext* sc) override {
|
SubgraphContext* sc) override {
|
||||||
Graph graph;
|
Graph graph;
|
||||||
|
bool output_geometry = HasOutput(sc->OriginalNode(), kFaceGeometryTag);
|
||||||
if (sc->Options<FaceLandmarkerGraphOptions>()
|
if (sc->Options<FaceLandmarkerGraphOptions>()
|
||||||
.base_options()
|
.base_options()
|
||||||
.has_model_asset()) {
|
.has_model_asset()) {
|
||||||
|
@ -318,6 +323,18 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
sc->MutableOptions<FaceLandmarkerGraphOptions>(),
|
sc->MutableOptions<FaceLandmarkerGraphOptions>(),
|
||||||
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
!sc->Service(::mediapipe::tasks::core::kModelResourcesCacheService)
|
||||||
.IsAvailable()));
|
.IsAvailable()));
|
||||||
|
if (output_geometry) {
|
||||||
|
// Set the face geometry metdata file for
|
||||||
|
// FaceGeometryFromLandmarksGraph.
|
||||||
|
ASSIGN_OR_RETURN(auto face_geometry_pipeline_metadata_file,
|
||||||
|
model_asset_bundle_resources->GetModelFile(
|
||||||
|
kFaceGeometryPipelineMetadataName));
|
||||||
|
SetExternalFile(face_geometry_pipeline_metadata_file,
|
||||||
|
sc->MutableOptions<FaceLandmarkerGraphOptions>()
|
||||||
|
->mutable_face_geometry_graph_options()
|
||||||
|
->mutable_geometry_pipeline_options()
|
||||||
|
->mutable_metadata_file());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::optional<SidePacket<Environment>> environment;
|
std::optional<SidePacket<Environment>> environment;
|
||||||
if (HasSideInput(sc->OriginalNode(), kEnvironmentTag)) {
|
if (HasSideInput(sc->OriginalNode(), kEnvironmentTag)) {
|
||||||
|
@ -338,7 +355,6 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
.face_landmarks_detector_graph_options()
|
.face_landmarks_detector_graph_options()
|
||||||
.has_face_blendshapes_graph_options()));
|
.has_face_blendshapes_graph_options()));
|
||||||
}
|
}
|
||||||
bool output_geometry = HasOutput(sc->OriginalNode(), kFaceGeometryTag);
|
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
auto outs,
|
auto outs,
|
||||||
BuildFaceLandmarkerGraph(
|
BuildFaceLandmarkerGraph(
|
||||||
|
@ -481,6 +497,9 @@ class FaceLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
auto& face_geometry_from_landmarks = graph.AddNode(
|
auto& face_geometry_from_landmarks = graph.AddNode(
|
||||||
"mediapipe.tasks.vision.face_geometry."
|
"mediapipe.tasks.vision.face_geometry."
|
||||||
"FaceGeometryFromLandmarksGraph");
|
"FaceGeometryFromLandmarksGraph");
|
||||||
|
face_geometry_from_landmarks
|
||||||
|
.GetOptions<face_geometry::proto::FaceGeometryGraphOptions>()
|
||||||
|
.Swap(tasks_options.mutable_face_geometry_graph_options());
|
||||||
if (environment.has_value()) {
|
if (environment.has_value()) {
|
||||||
*environment >> face_geometry_from_landmarks.SideIn(kEnvironmentTag);
|
*environment >> face_geometry_from_landmarks.SideIn(kEnvironmentTag);
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,5 +60,6 @@ mediapipe_proto_library(
|
||||||
"//mediapipe/framework:calculator_proto",
|
"//mediapipe/framework:calculator_proto",
|
||||||
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
"//mediapipe/tasks/cc/core/proto:base_options_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_proto",
|
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ import "mediapipe/framework/calculator.proto";
|
||||||
import "mediapipe/framework/calculator_options.proto";
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
import "mediapipe/tasks/cc/core/proto/base_options.proto";
|
||||||
import "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto";
|
import "mediapipe/tasks/cc/vision/face_detector/proto/face_detector_graph_options.proto";
|
||||||
|
import "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.proto";
|
||||||
import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto";
|
import "mediapipe/tasks/cc/vision/face_landmarker/proto/face_landmarks_detector_graph_options.proto";
|
||||||
|
|
||||||
option java_package = "com.google.mediapipe.tasks.vision.facelandmarker.proto";
|
option java_package = "com.google.mediapipe.tasks.vision.facelandmarker.proto";
|
||||||
|
@ -45,4 +46,8 @@ message FaceLandmarkerGraphOptions {
|
||||||
// Minimum confidence for face landmarks tracking to be considered
|
// Minimum confidence for face landmarks tracking to be considered
|
||||||
// successfully.
|
// successfully.
|
||||||
optional float min_tracking_confidence = 4 [default = 0.5];
|
optional float min_tracking_confidence = 4 [default = 0.5];
|
||||||
|
|
||||||
|
// Options for FaceGeometryGraph to get facial transformation matrix.
|
||||||
|
optional face_geometry.proto.FaceGeometryGraphOptions
|
||||||
|
face_geometry_graph_options = 5;
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,6 +47,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_stylizer/proto:face_stylizer_graph_options_cc_proto",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -294,7 +294,7 @@ absl::Status TensorsToImageCalculator::MetalProcess(CalculatorContext* cc) {
|
||||||
threadsPerThreadgroup:threads_per_group];
|
threadsPerThreadgroup:threads_per_group];
|
||||||
[compute_encoder endEncoding];
|
[compute_encoder endEncoding];
|
||||||
[command_buffer commit];
|
[command_buffer commit];
|
||||||
|
[command_buffer waitUntilCompleted];
|
||||||
kOutputImage(cc).Send(Image(output));
|
kOutputImage(cc).Send(Image(output));
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,6 +36,7 @@ cc_library(
|
||||||
":hand_association_calculator_cc_proto",
|
":hand_association_calculator_cc_proto",
|
||||||
"//mediapipe/calculators/util:association_calculator",
|
"//mediapipe/calculators/util:association_calculator",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:collection_item_id",
|
||||||
"//mediapipe/framework/api2:node",
|
"//mediapipe/framework/api2:node",
|
||||||
"//mediapipe/framework/formats:rect_cc_proto",
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
"//mediapipe/framework/port:rectangle",
|
"//mediapipe/framework/port:rectangle",
|
||||||
|
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/framework/api2/node.h"
|
#include "mediapipe/framework/api2/node.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/collection_item_id.h"
|
||||||
#include "mediapipe/framework/formats/rect.pb.h"
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
#include "mediapipe/framework/port/rectangle.h"
|
#include "mediapipe/framework/port/rectangle.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
|
@ -29,30 +30,55 @@ namespace mediapipe::api2 {
|
||||||
|
|
||||||
using ::mediapipe::NormalizedRect;
|
using ::mediapipe::NormalizedRect;
|
||||||
|
|
||||||
// HandAssociationCalculator accepts multiple inputs of vectors of
|
// Input:
|
||||||
// NormalizedRect. The output is a vector of NormalizedRect that contains
|
// BASE_RECTS - Vector of NormalizedRect.
|
||||||
// rects from the input vectors that don't overlap with each other. When two
|
// RECTS - Vector of NormalizedRect.
|
||||||
// rects overlap, the rect that comes in from an earlier input stream is
|
//
|
||||||
// kept in the output. If a rect has no ID (i.e. from detection stream),
|
// Output:
|
||||||
// then a unique rect ID is assigned for it.
|
// No tag - Vector of NormalizedRect.
|
||||||
|
//
|
||||||
// The rects in multiple input streams are effectively flattened to a single
|
// Example use:
|
||||||
// list. For example:
|
// node {
|
||||||
// Stream1 : rect 1, rect 2
|
// calculator: "HandAssociationCalculator"
|
||||||
// Stream2: rect 3, rect 4
|
// input_stream: "BASE_RECTS:base_rects"
|
||||||
// Stream3: rect 5, rect 6
|
// input_stream: "RECTS:0:rects0"
|
||||||
// (Conceptually) flattened list : rect 1, 2, 3, 4, 5, 6
|
// input_stream: "RECTS:1:rects1"
|
||||||
// In the flattened list, if a rect with a higher index overlaps with a rect a
|
// input_stream: "RECTS:2:rects2"
|
||||||
// lower index, beyond a specified IOU threshold, the rect with the lower
|
// output_stream: "output_rects"
|
||||||
// index will be in the output, and the rect with higher index will be
|
// options {
|
||||||
// discarded.
|
// [mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
// min_similarity_threshold: 0.1
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// IMPORTANT Notes:
|
||||||
|
// - Rects from input streams tagged with "BASE_RECTS" are always preserved.
|
||||||
|
// - This calculator checks for overlap among rects from input streams tagged
|
||||||
|
// with "RECTS". Rects are prioritized based on their index in the vector and
|
||||||
|
// input streams to the calculator. When two rects overlap, the rect that
|
||||||
|
// comes from an input stream with lower tag-index is kept in the output.
|
||||||
|
// - Example of inputs for the node above:
|
||||||
|
// "base_rects": rect 0, rect 1
|
||||||
|
// "rects0": rect 2, rect 3
|
||||||
|
// "rects1": rect 4, rect 5
|
||||||
|
// "rects2": rect 6, rect 7
|
||||||
|
// (Conceptually) flattened list: 0, 1, 2, 3, 4, 5, 6, 7.
|
||||||
|
// Rects 0, 1 will be preserved. Rects 2, 3, 4, 5, 6, 7 will be checked for
|
||||||
|
// overlap. If a rect with a higher index overlaps with a rect with lower
|
||||||
|
// index, beyond a specified IOU threshold, the rect with the lower index
|
||||||
|
// will be in the output, and the rect with higher index will be discarded.
|
||||||
// TODO: Upgrade this to latest API for calculators
|
// TODO: Upgrade this to latest API for calculators
|
||||||
class HandAssociationCalculator : public CalculatorBase {
|
class HandAssociationCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
// Initialize input and output streams.
|
// Initialize input and output streams.
|
||||||
for (auto& input_stream : cc->Inputs()) {
|
for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS");
|
||||||
input_stream.Set<std::vector<NormalizedRect>>();
|
id != cc->Inputs().EndId("BASE_RECTS"); ++id) {
|
||||||
|
cc->Inputs().Get(id).Set<std::vector<NormalizedRect>>();
|
||||||
|
}
|
||||||
|
for (CollectionItemId id = cc->Inputs().BeginId("RECTS");
|
||||||
|
id != cc->Inputs().EndId("RECTS"); ++id) {
|
||||||
|
cc->Inputs().Get(id).Set<std::vector<NormalizedRect>>();
|
||||||
}
|
}
|
||||||
cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>();
|
cc->Outputs().Index(0).Set<std::vector<NormalizedRect>>();
|
||||||
|
|
||||||
|
@ -89,7 +115,24 @@ class HandAssociationCalculator : public CalculatorBase {
|
||||||
CalculatorContext* cc) {
|
CalculatorContext* cc) {
|
||||||
std::vector<NormalizedRect> result;
|
std::vector<NormalizedRect> result;
|
||||||
|
|
||||||
for (const auto& input_stream : cc->Inputs()) {
|
for (CollectionItemId id = cc->Inputs().BeginId("BASE_RECTS");
|
||||||
|
id != cc->Inputs().EndId("BASE_RECTS"); ++id) {
|
||||||
|
const auto& input_stream = cc->Inputs().Get(id);
|
||||||
|
if (input_stream.IsEmpty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto rect : input_stream.Get<std::vector<NormalizedRect>>()) {
|
||||||
|
if (!rect.has_rect_id()) {
|
||||||
|
rect.set_rect_id(GetNextRectId());
|
||||||
|
}
|
||||||
|
result.push_back(rect);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (CollectionItemId id = cc->Inputs().BeginId("RECTS");
|
||||||
|
id != cc->Inputs().EndId("RECTS"); ++id) {
|
||||||
|
const auto& input_stream = cc->Inputs().Get(id);
|
||||||
if (input_stream.IsEmpty()) {
|
if (input_stream.IsEmpty()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,6 +27,8 @@ namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::mediapipe::NormalizedRect;
|
using ::mediapipe::NormalizedRect;
|
||||||
|
using ::testing::ElementsAre;
|
||||||
|
using ::testing::EqualsProto;
|
||||||
|
|
||||||
class HandAssociationCalculatorTest : public testing::Test {
|
class HandAssociationCalculatorTest : public testing::Test {
|
||||||
protected:
|
protected:
|
||||||
|
@ -87,9 +89,9 @@ class HandAssociationCalculatorTest : public testing::Test {
|
||||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
|
TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
calculator: "HandAssociationCalculator"
|
calculator: "HandAssociationCalculator"
|
||||||
input_stream: "input_vec_0"
|
input_stream: "BASE_RECTS:input_vec_0"
|
||||||
input_stream: "input_vec_1"
|
input_stream: "RECTS:0:input_vec_1"
|
||||||
input_stream: "input_vec_2"
|
input_stream: "RECTS:1:input_vec_2"
|
||||||
output_stream: "output_vec"
|
output_stream: "output_vec"
|
||||||
options {
|
options {
|
||||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
@ -103,20 +105,23 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
|
||||||
input_vec_0->push_back(nr_0_);
|
input_vec_0->push_back(nr_0_);
|
||||||
input_vec_0->push_back(nr_1_);
|
input_vec_0->push_back(nr_1_);
|
||||||
input_vec_0->push_back(nr_2_);
|
input_vec_0->push_back(nr_2_);
|
||||||
runner.MutableInputs()->Index(0).packets.push_back(
|
runner.MutableInputs()
|
||||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
->Tag("BASE_RECTS")
|
||||||
|
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
// Input Stream 1: nr_3, nr_4.
|
// Input Stream 1: nr_3, nr_4.
|
||||||
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec_1->push_back(nr_3_);
|
input_vec_1->push_back(nr_3_);
|
||||||
input_vec_1->push_back(nr_4_);
|
input_vec_1->push_back(nr_4_);
|
||||||
runner.MutableInputs()->Index(1).packets.push_back(
|
auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
// Input Stream 2: nr_5.
|
// Input Stream 2: nr_5.
|
||||||
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec_2->push_back(nr_5_);
|
input_vec_2->push_back(nr_5_);
|
||||||
runner.MutableInputs()->Index(2).packets.push_back(
|
index_id = runner.MutableInputs()->GetId("RECTS", 1);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
|
@ -134,25 +139,18 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTest) {
|
||||||
EXPECT_EQ(3, assoc_rects.size());
|
EXPECT_EQ(3, assoc_rects.size());
|
||||||
|
|
||||||
// Check that IDs are filled in and contents match.
|
// Check that IDs are filled in and contents match.
|
||||||
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
|
nr_0_.set_rect_id(1);
|
||||||
assoc_rects[0].clear_rect_id();
|
nr_1_.set_rect_id(2);
|
||||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
|
nr_2_.set_rect_id(3);
|
||||||
|
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_),
|
||||||
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
|
EqualsProto(nr_2_)));
|
||||||
assoc_rects[1].clear_rect_id();
|
|
||||||
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
|
|
||||||
|
|
||||||
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
|
|
||||||
assoc_rects[2].clear_rect_id();
|
|
||||||
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
|
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
calculator: "HandAssociationCalculator"
|
calculator: "HandAssociationCalculator"
|
||||||
input_stream: "input_vec_0"
|
input_stream: "BASE_RECTS:input_vec_0"
|
||||||
input_stream: "input_vec_1"
|
input_stream: "RECTS:0:input_vec_1"
|
||||||
input_stream: "input_vec_2"
|
|
||||||
output_stream: "output_vec"
|
output_stream: "output_vec"
|
||||||
options {
|
options {
|
||||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
@ -169,14 +167,15 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
|
||||||
input_vec_0->push_back(nr_0_);
|
input_vec_0->push_back(nr_0_);
|
||||||
nr_1_.set_rect_id(-1);
|
nr_1_.set_rect_id(-1);
|
||||||
input_vec_0->push_back(nr_1_);
|
input_vec_0->push_back(nr_1_);
|
||||||
runner.MutableInputs()->Index(0).packets.push_back(
|
runner.MutableInputs()
|
||||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
->Tag("BASE_RECTS")
|
||||||
|
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
// Input Stream 1: nr_2, nr_3. Newly detected palms.
|
// Input Stream 1: nr_2, nr_3. Newly detected palms.
|
||||||
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec_1->push_back(nr_2_);
|
input_vec_1->push_back(nr_2_);
|
||||||
input_vec_1->push_back(nr_3_);
|
input_vec_1->push_back(nr_3_);
|
||||||
runner.MutableInputs()->Index(1).packets.push_back(
|
runner.MutableInputs()->Tag("RECTS").packets.push_back(
|
||||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
|
@ -192,23 +191,17 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestWithTrackedHands) {
|
||||||
EXPECT_EQ(3, assoc_rects.size());
|
EXPECT_EQ(3, assoc_rects.size());
|
||||||
|
|
||||||
// Check that IDs are filled in and contents match.
|
// Check that IDs are filled in and contents match.
|
||||||
EXPECT_EQ(assoc_rects[0].rect_id(), -2);
|
nr_2_.set_rect_id(1);
|
||||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_0_));
|
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_0_), EqualsProto(nr_1_),
|
||||||
|
EqualsProto(nr_2_)));
|
||||||
EXPECT_EQ(assoc_rects[1].rect_id(), -1);
|
|
||||||
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_1_));
|
|
||||||
|
|
||||||
EXPECT_EQ(assoc_rects[2].rect_id(), 1);
|
|
||||||
assoc_rects[2].clear_rect_id();
|
|
||||||
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_2_));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
|
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
calculator: "HandAssociationCalculator"
|
calculator: "HandAssociationCalculator"
|
||||||
input_stream: "input_vec_0"
|
input_stream: "BASE_RECTS:input_vec_0"
|
||||||
input_stream: "input_vec_1"
|
input_stream: "RECTS:0:input_vec_1"
|
||||||
input_stream: "input_vec_2"
|
input_stream: "RECTS:1:input_vec_2"
|
||||||
output_stream: "output_vec"
|
output_stream: "output_vec"
|
||||||
options {
|
options {
|
||||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
@ -220,14 +213,16 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
|
||||||
// Input Stream 0: nr_5.
|
// Input Stream 0: nr_5.
|
||||||
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec_0->push_back(nr_5_);
|
input_vec_0->push_back(nr_5_);
|
||||||
runner.MutableInputs()->Index(0).packets.push_back(
|
runner.MutableInputs()
|
||||||
Adopt(input_vec_0.release()).At(Timestamp(1)));
|
->Tag("BASE_RECTS")
|
||||||
|
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
// Input Stream 1: nr_4, nr_3
|
// Input Stream 1: nr_4, nr_3
|
||||||
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec_1->push_back(nr_4_);
|
input_vec_1->push_back(nr_4_);
|
||||||
input_vec_1->push_back(nr_3_);
|
input_vec_1->push_back(nr_3_);
|
||||||
runner.MutableInputs()->Index(1).packets.push_back(
|
auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
// Input Stream 2: nr_2, nr_1, nr_0.
|
// Input Stream 2: nr_2, nr_1, nr_0.
|
||||||
|
@ -235,7 +230,8 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
|
||||||
input_vec_2->push_back(nr_2_);
|
input_vec_2->push_back(nr_2_);
|
||||||
input_vec_2->push_back(nr_1_);
|
input_vec_2->push_back(nr_1_);
|
||||||
input_vec_2->push_back(nr_0_);
|
input_vec_2->push_back(nr_0_);
|
||||||
runner.MutableInputs()->Index(2).packets.push_back(
|
index_id = runner.MutableInputs()->GetId("RECTS", 1);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
|
@ -253,23 +249,78 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReverse) {
|
||||||
EXPECT_EQ(3, assoc_rects.size());
|
EXPECT_EQ(3, assoc_rects.size());
|
||||||
|
|
||||||
// Outputs are in same order as inputs, and IDs are filled in.
|
// Outputs are in same order as inputs, and IDs are filled in.
|
||||||
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
|
nr_5_.set_rect_id(1);
|
||||||
assoc_rects[0].clear_rect_id();
|
nr_4_.set_rect_id(2);
|
||||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_5_));
|
nr_0_.set_rect_id(3);
|
||||||
|
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_4_),
|
||||||
|
EqualsProto(nr_0_)));
|
||||||
|
}
|
||||||
|
|
||||||
EXPECT_EQ(assoc_rects[1].rect_id(), 2);
|
TEST_F(HandAssociationCalculatorTest, NormRectAssocTestReservesBaseRects) {
|
||||||
assoc_rects[1].clear_rect_id();
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
EXPECT_THAT(assoc_rects[1], testing::EqualsProto(nr_4_));
|
calculator: "HandAssociationCalculator"
|
||||||
|
input_stream: "BASE_RECTS:input_vec_0"
|
||||||
|
input_stream: "RECTS:0:input_vec_1"
|
||||||
|
input_stream: "RECTS:1:input_vec_2"
|
||||||
|
output_stream: "output_vec"
|
||||||
|
options {
|
||||||
|
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
min_similarity_threshold: 0.1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
EXPECT_EQ(assoc_rects[2].rect_id(), 3);
|
// Input Stream 0: nr_5, nr_3, nr_1.
|
||||||
assoc_rects[2].clear_rect_id();
|
auto input_vec_0 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
EXPECT_THAT(assoc_rects[2], testing::EqualsProto(nr_0_));
|
input_vec_0->push_back(nr_5_);
|
||||||
|
input_vec_0->push_back(nr_3_);
|
||||||
|
input_vec_0->push_back(nr_1_);
|
||||||
|
runner.MutableInputs()
|
||||||
|
->Tag("BASE_RECTS")
|
||||||
|
.packets.push_back(Adopt(input_vec_0.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
|
// Input Stream 1: nr_4.
|
||||||
|
auto input_vec_1 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
|
input_vec_1->push_back(nr_4_);
|
||||||
|
auto index_id = runner.MutableInputs()->GetId("RECTS", 0);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
|
Adopt(input_vec_1.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
|
// Input Stream 2: nr_2, nr_0.
|
||||||
|
auto input_vec_2 = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
|
input_vec_2->push_back(nr_2_);
|
||||||
|
input_vec_2->push_back(nr_0_);
|
||||||
|
index_id = runner.MutableInputs()->GetId("RECTS", 1);
|
||||||
|
runner.MutableInputs()->Get(index_id).packets.push_back(
|
||||||
|
Adopt(input_vec_2.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
|
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||||
|
EXPECT_EQ(1, output.size());
|
||||||
|
auto assoc_rects = output[0].Get<std::vector<NormalizedRect>>();
|
||||||
|
|
||||||
|
// Rectangles are added in the following sequence:
|
||||||
|
// nr_5 is added because it is in BASE_RECTS input stream.
|
||||||
|
// nr_3 is added because it is in BASE_RECTS input stream.
|
||||||
|
// nr_1 is added because it is in BASE_RECTS input stream.
|
||||||
|
// nr_4 is added because it does not overlap with nr_5.
|
||||||
|
// nr_2 is NOT added because it overlaps with nr_4.
|
||||||
|
// nr_0 is NOT added because it overlaps with nr_3.
|
||||||
|
EXPECT_EQ(4, assoc_rects.size());
|
||||||
|
|
||||||
|
// Outputs are in same order as inputs, and IDs are filled in.
|
||||||
|
nr_5_.set_rect_id(1);
|
||||||
|
nr_3_.set_rect_id(2);
|
||||||
|
nr_1_.set_rect_id(3);
|
||||||
|
nr_4_.set_rect_id(4);
|
||||||
|
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_5_), EqualsProto(nr_3_),
|
||||||
|
EqualsProto(nr_1_), EqualsProto(nr_4_)));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
|
TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
|
||||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
calculator: "HandAssociationCalculator"
|
calculator: "HandAssociationCalculator"
|
||||||
input_stream: "input_vec"
|
input_stream: "BASE_RECTS:input_vec"
|
||||||
output_stream: "output_vec"
|
output_stream: "output_vec"
|
||||||
options {
|
options {
|
||||||
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
[mediapipe.HandAssociationCalculatorOptions.ext] {
|
||||||
|
@ -282,8 +333,9 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
|
||||||
auto input_vec = std::make_unique<std::vector<NormalizedRect>>();
|
auto input_vec = std::make_unique<std::vector<NormalizedRect>>();
|
||||||
input_vec->push_back(nr_3_);
|
input_vec->push_back(nr_3_);
|
||||||
input_vec->push_back(nr_5_);
|
input_vec->push_back(nr_5_);
|
||||||
runner.MutableInputs()->Index(0).packets.push_back(
|
runner.MutableInputs()
|
||||||
Adopt(input_vec.release()).At(Timestamp(1)));
|
->Tag("BASE_RECTS")
|
||||||
|
.packets.push_back(Adopt(input_vec.release()).At(Timestamp(1)));
|
||||||
|
|
||||||
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
MP_ASSERT_OK(runner.Run()) << "Calculator execution failed.";
|
||||||
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
const std::vector<Packet>& output = runner.Outputs().Index(0).packets;
|
||||||
|
@ -292,12 +344,12 @@ TEST_F(HandAssociationCalculatorTest, NormRectAssocSingleInputStream) {
|
||||||
|
|
||||||
// Rectangles are added in the following sequence:
|
// Rectangles are added in the following sequence:
|
||||||
// nr_3 is added 1st.
|
// nr_3 is added 1st.
|
||||||
// nr_5 is NOT added because it overlaps with nr_3.
|
// nr_5 is added 2nd. The calculator assumes it does not overlap with nr_3.
|
||||||
EXPECT_EQ(1, assoc_rects.size());
|
EXPECT_EQ(2, assoc_rects.size());
|
||||||
|
|
||||||
EXPECT_EQ(assoc_rects[0].rect_id(), 1);
|
nr_3_.set_rect_id(1);
|
||||||
assoc_rects[0].clear_rect_id();
|
nr_5_.set_rect_id(2);
|
||||||
EXPECT_THAT(assoc_rects[0], testing::EqualsProto(nr_3_));
|
EXPECT_THAT(assoc_rects, ElementsAre(EqualsProto(nr_3_), EqualsProto(nr_5_)));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -318,9 +318,9 @@ class HandLandmarkerGraph : public core::ModelTaskGraph {
|
||||||
.set_min_similarity_threshold(
|
.set_min_similarity_threshold(
|
||||||
tasks_options.min_tracking_confidence());
|
tasks_options.min_tracking_confidence());
|
||||||
prev_hand_rects_from_landmarks >>
|
prev_hand_rects_from_landmarks >>
|
||||||
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][0];
|
hand_association[Input<std::vector<NormalizedRect>>("BASE_RECTS")];
|
||||||
hand_rects_from_hand_detector >>
|
hand_rects_from_hand_detector >>
|
||||||
hand_association[Input<std::vector<NormalizedRect>>::Multiple("")][1];
|
hand_association[Input<std::vector<NormalizedRect>>("RECTS")];
|
||||||
auto hand_rects = hand_association.Out("");
|
auto hand_rects = hand_association.Out("");
|
||||||
hand_rects >> clip_hand_rects.In("");
|
hand_rects >> clip_hand_rects.In("");
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -34,18 +34,19 @@ _AUDIO_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
_VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
_VISION_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||||
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/face_detector/proto:face_detector_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
|
||||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_classifier_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_embedder_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:gesture_recognizer_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/gesture_recognizer/proto:hand_gesture_recognizer_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/image_classifier/proto:image_classifier_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_embedder/proto:image_embedder_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:image_segmenter_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/hand_detector/proto:hand_detector_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarker_graph_options_java_proto_lite",
|
||||||
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
"//mediapipe/tasks/cc/vision/hand_landmarker/proto:hand_landmarks_detector_graph_options_java_proto_lite",
|
||||||
|
"//mediapipe/tasks/cc/vision/object_detector/proto:object_detector_options_java_proto_lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
_TEXT_TASKS_JAVA_PROTO_LITE_TARGETS = [
|
||||||
|
|
|
@ -36,11 +36,9 @@ export abstract class AudioTaskRunner<T> extends TaskRunner {
|
||||||
|
|
||||||
/** Sends a single audio clip to the graph and awaits results. */
|
/** Sends a single audio clip to the graph and awaits results. */
|
||||||
protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
|
protected processAudioClip(audioData: Float32Array, sampleRate?: number): T {
|
||||||
// Increment the timestamp by 1 millisecond to guarantee that we send
|
|
||||||
// monotonically increasing timestamps to the graph.
|
|
||||||
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
|
|
||||||
return this.process(
|
return this.process(
|
||||||
audioData, sampleRate ?? this.defaultSampleRate, syntheticTimestamp);
|
audioData, sampleRate ?? this.defaultSampleRate,
|
||||||
|
this.getSynctheticTimestamp());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,11 @@ mediapipe_ts_declaration(
|
||||||
deps = [":category"],
|
deps = [":category"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_declaration(
|
||||||
|
name = "keypoint",
|
||||||
|
srcs = ["keypoint.d.ts"],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_ts_declaration(
|
mediapipe_ts_declaration(
|
||||||
name = "landmark",
|
name = "landmark",
|
||||||
srcs = ["landmark.d.ts"],
|
srcs = ["landmark.d.ts"],
|
||||||
|
|
33
mediapipe/tasks/web/components/containers/keypoint.d.ts
vendored
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A keypoint, defined by the coordinates (x, y), normalized by the image
|
||||||
|
* dimensions.
|
||||||
|
*/
|
||||||
|
export declare interface NormalizedKeypoint {
|
||||||
|
/** X in normalized image coordinates. */
|
||||||
|
x: number;
|
||||||
|
|
||||||
|
/** Y in normalized image coordinates. */
|
||||||
|
y: number;
|
||||||
|
|
||||||
|
/** Optional label of the keypoint. */
|
||||||
|
label?: string;
|
||||||
|
|
||||||
|
/** Optional score of the keypoint. */
|
||||||
|
score?: number;
|
||||||
|
}
|
|
@ -175,9 +175,13 @@ export abstract class TaskRunner {
|
||||||
Math.max(this.latestOutputTimestamp, timestamp);
|
Math.max(this.latestOutputTimestamp, timestamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns the latest output timestamp. */
|
/**
|
||||||
protected getLatestOutputTimestamp() {
|
* Gets a syncthethic timestamp in ms that can be used to send data to the
|
||||||
return this.latestOutputTimestamp;
|
* next packet. The timestamp is one millisecond past the last timestamp
|
||||||
|
* received from the graph.
|
||||||
|
*/
|
||||||
|
protected getSynctheticTimestamp(): number {
|
||||||
|
return this.latestOutputTimestamp + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Throws the error from the error listener if an error was raised. */
|
/** Throws the error from the error listener if an error was raised. */
|
||||||
|
|
|
@ -131,11 +131,9 @@ export class TextClassifier extends TaskRunner {
|
||||||
* @return The classification result of the text
|
* @return The classification result of the text
|
||||||
*/
|
*/
|
||||||
classify(text: string): TextClassifierResult {
|
classify(text: string): TextClassifierResult {
|
||||||
// Increment the timestamp by 1 millisecond to guarantee that we send
|
|
||||||
// monotonically increasing timestamps to the graph.
|
|
||||||
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
|
|
||||||
this.classificationResult = {classifications: []};
|
this.classificationResult = {classifications: []};
|
||||||
this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp);
|
this.graphRunner.addStringToStream(
|
||||||
|
text, INPUT_STREAM, this.getSynctheticTimestamp());
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
return this.classificationResult;
|
return this.classificationResult;
|
||||||
}
|
}
|
||||||
|
|
|
@ -135,10 +135,8 @@ export class TextEmbedder extends TaskRunner {
|
||||||
* @return The embedding resuls of the text
|
* @return The embedding resuls of the text
|
||||||
*/
|
*/
|
||||||
embed(text: string): TextEmbedderResult {
|
embed(text: string): TextEmbedderResult {
|
||||||
// Increment the timestamp by 1 millisecond to guarantee that we send
|
this.graphRunner.addStringToStream(
|
||||||
// monotonically increasing timestamps to the graph.
|
text, INPUT_STREAM, this.getSynctheticTimestamp());
|
||||||
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
|
|
||||||
this.graphRunner.addStringToStream(text, INPUT_STREAM, syntheticTimestamp);
|
|
||||||
this.finishProcessing();
|
this.finishProcessing();
|
||||||
return this.embeddingResult;
|
return this.embeddingResult;
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,23 +2,42 @@
|
||||||
|
|
||||||
This package contains the vision tasks for MediaPipe.
|
This package contains the vision tasks for MediaPipe.
|
||||||
|
|
||||||
## Object Detection
|
## Gesture Recognition
|
||||||
|
|
||||||
The MediaPipe Object Detector task lets you detect the presence and location of
|
The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real
|
||||||
multiple classes of objects within images or videos.
|
time, and provides the recognized hand gesture results along with the landmarks
|
||||||
|
of the detected hands. You can use this task to recognize specific hand gestures
|
||||||
|
from a user, and invoke application features that correspond to those gestures.
|
||||||
|
|
||||||
```
|
```
|
||||||
const vision = await FilesetResolver.forVisionTasks(
|
const vision = await FilesetResolver.forVisionTasks(
|
||||||
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
|
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
|
||||||
);
|
);
|
||||||
const objectDetector = await ObjectDetector.createFromModelPath(vision,
|
const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision,
|
||||||
"https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite"
|
"https://storage.googleapis.com/mediapipe-tasks/gesture_recognizer/gesture_recognizer.task"
|
||||||
);
|
);
|
||||||
const image = document.getElementById("image") as HTMLImageElement;
|
const image = document.getElementById("image") as HTMLImageElement;
|
||||||
const detections = objectDetector.detect(image);
|
const recognitions = gestureRecognizer.recognize(image);
|
||||||
```
|
```
|
||||||
|
|
||||||
For more information, refer to the [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector/web_js) documentation.
|
## Hand Landmark Detection
|
||||||
|
|
||||||
|
The MediaPipe Hand Landmarker task lets you detect the landmarks of the hands in
|
||||||
|
an image. You can use this Task to localize key points of the hands and render
|
||||||
|
visual effects over the hands.
|
||||||
|
|
||||||
|
```
|
||||||
|
const vision = await FilesetResolver.forVisionTasks(
|
||||||
|
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
|
||||||
|
);
|
||||||
|
const handLandmarker = await HandLandmarker.createFromModelPath(vision,
|
||||||
|
"https://storage.googleapis.com/mediapipe-tasks/hand_landmarker/hand_landmarker.task"
|
||||||
|
);
|
||||||
|
const image = document.getElementById("image") as HTMLImageElement;
|
||||||
|
const landmarks = handLandmarker.detect(image);
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information, refer to the [Handlandmark Detection](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker/web_js) documentation.
|
||||||
|
|
||||||
## Image Classification
|
## Image Classification
|
||||||
|
|
||||||
|
@ -56,40 +75,21 @@ imageSegmenter.segment(image, (masks, width, height) => {
|
||||||
});
|
});
|
||||||
```
|
```
|
||||||
|
|
||||||
## Gesture Recognition
|
## Object Detection
|
||||||
|
|
||||||
The MediaPipe Gesture Recognizer task lets you recognize hand gestures in real
|
The MediaPipe Object Detector task lets you detect the presence and location of
|
||||||
time, and provides the recognized hand gesture results along with the landmarks
|
multiple classes of objects within images or videos.
|
||||||
of the detected hands. You can use this task to recognize specific hand gestures
|
|
||||||
from a user, and invoke application features that correspond to those gestures.
|
|
||||||
|
|
||||||
```
|
```
|
||||||
const vision = await FilesetResolver.forVisionTasks(
|
const vision = await FilesetResolver.forVisionTasks(
|
||||||
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
|
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
|
||||||
);
|
);
|
||||||
const gestureRecognizer = await GestureRecognizer.createFromModelPath(vision,
|
const objectDetector = await ObjectDetector.createFromModelPath(vision,
|
||||||
"https://storage.googleapis.com/mediapipe-tasks/gesture_recognizer/gesture_recognizer.task"
|
"https://storage.googleapis.com/mediapipe-tasks/object_detector/efficientdet_lite0_uint8.tflite"
|
||||||
);
|
);
|
||||||
const image = document.getElementById("image") as HTMLImageElement;
|
const image = document.getElementById("image") as HTMLImageElement;
|
||||||
const recognitions = gestureRecognizer.recognize(image);
|
const detections = objectDetector.detect(image);
|
||||||
```
|
```
|
||||||
|
|
||||||
## Handlandmark Detection
|
For more information, refer to the [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector/web_js) documentation.
|
||||||
|
|
||||||
The MediaPipe Hand Landmarker task lets you detect the landmarks of the hands in
|
|
||||||
an image. You can use this Task to localize key points of the hands and render
|
|
||||||
visual effects over the hands.
|
|
||||||
|
|
||||||
```
|
|
||||||
const vision = await FilesetResolver.forVisionTasks(
|
|
||||||
"https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@latest/wasm"
|
|
||||||
);
|
|
||||||
const handLandmarker = await HandLandmarker.createFromModelPath(vision,
|
|
||||||
"https://storage.googleapis.com/mediapipe-tasks/hand_landmarker/hand_landmarker.task"
|
|
||||||
);
|
|
||||||
const image = document.getElementById("image") as HTMLImageElement;
|
|
||||||
const landmarks = handLandmarker.detect(image);
|
|
||||||
```
|
|
||||||
|
|
||||||
For more information, refer to the [Handlandmark Detection](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker/web_js) documentation.
|
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,14 @@ mediapipe_ts_declaration(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_declaration(
|
||||||
|
name = "types",
|
||||||
|
srcs = ["types.d.ts"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/web/components/containers:keypoint",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_ts_library(
|
mediapipe_ts_library(
|
||||||
name = "vision_task_runner",
|
name = "vision_task_runner",
|
||||||
srcs = ["vision_task_runner.ts"],
|
srcs = ["vision_task_runner.ts"],
|
||||||
|
@ -51,6 +59,11 @@ mediapipe_ts_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_ts_library(
|
||||||
|
name = "render_utils",
|
||||||
|
srcs = ["render_utils.ts"],
|
||||||
|
)
|
||||||
|
|
||||||
jasmine_node_test(
|
jasmine_node_test(
|
||||||
name = "vision_task_runner_test",
|
name = "vision_task_runner_test",
|
||||||
deps = [":vision_task_runner_test_lib"],
|
deps = [":vision_task_runner_test_lib"],
|
||||||
|
|
78
mediapipe/tasks/web/vision/core/render_utils.ts
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
/** @fileoverview Utility functions used in the vision demos. */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Pre-baked color table for a maximum of 12 classes.
|
||||||
|
const CM_ALPHA = 128;
|
||||||
|
const COLOR_MAP = [
|
||||||
|
[0, 0, 0, CM_ALPHA], // class 0 is BG = transparent
|
||||||
|
[255, 0, 0, CM_ALPHA], // class 1 is red
|
||||||
|
[0, 255, 0, CM_ALPHA], // class 2 is light green
|
||||||
|
[0, 0, 255, CM_ALPHA], // class 3 is blue
|
||||||
|
[255, 255, 0, CM_ALPHA], // class 4 is yellow
|
||||||
|
[255, 0, 255, CM_ALPHA], // class 5 is light purple / magenta
|
||||||
|
[0, 255, 255, CM_ALPHA], // class 6 is light blue / aqua
|
||||||
|
[128, 128, 128, CM_ALPHA], // class 7 is gray
|
||||||
|
[255, 128, 0, CM_ALPHA], // class 8 is orange
|
||||||
|
[128, 0, 255, CM_ALPHA], // class 9 is dark purple
|
||||||
|
[0, 128, 0, CM_ALPHA], // class 10 is dark green
|
||||||
|
[255, 255, 255, CM_ALPHA] // class 11 is white; could do black instead?
|
||||||
|
];
|
||||||
|
|
||||||
|
|
||||||
|
/** Helper function to draw a confidence mask */
|
||||||
|
export function drawConfidenceMask(
|
||||||
|
|
||||||
|
ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array,
|
||||||
|
width: number, height: number): void {
|
||||||
|
const uint8ClampedArray = new Uint8ClampedArray(width * height * 4);
|
||||||
|
for (let i = 0; i < image.length; i++) {
|
||||||
|
uint8ClampedArray[4 * i] = 128;
|
||||||
|
uint8ClampedArray[4 * i + 1] = 0;
|
||||||
|
uint8ClampedArray[4 * i + 2] = 0;
|
||||||
|
uint8ClampedArray[4 * i + 3] = image[i] * 255;
|
||||||
|
}
|
||||||
|
ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper function to draw a category mask. For GPU, we only have F32Arrays
|
||||||
|
* for now.
|
||||||
|
*/
|
||||||
|
export function drawCategoryMask(
|
||||||
|
ctx: CanvasRenderingContext2D, image: Float32Array|Uint8Array,
|
||||||
|
width: number, height: number): void {
|
||||||
|
const uint8ClampedArray = new Uint8ClampedArray(width * height * 4);
|
||||||
|
const isFloatArray = image instanceof Float32Array;
|
||||||
|
for (let i = 0; i < image.length; i++) {
|
||||||
|
const colorIndex = isFloatArray ? Math.round(image[i] * 255) : image[i];
|
||||||
|
const color = COLOR_MAP[colorIndex];
|
||||||
|
|
||||||
|
// When we're given a confidence mask by accident, we just log and return.
|
||||||
|
// TODO: We should fix this.
|
||||||
|
if (!color) {
|
||||||
|
console.warn('No color for ', colorIndex);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8ClampedArray[4 * i] = color[0];
|
||||||
|
uint8ClampedArray[4 * i + 1] = color[1];
|
||||||
|
uint8ClampedArray[4 * i + 2] = color[2];
|
||||||
|
uint8ClampedArray[4 * i + 3] = color[3];
|
||||||
|
}
|
||||||
|
ctx.putImageData(new ImageData(uint8ClampedArray, width, height), 0, 0);
|
||||||
|
}
|
42
mediapipe/tasks/web/vision/core/types.d.ts
vendored
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {NormalizedKeypoint} from '../../../../tasks/web/components/containers/keypoint';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The segmentation tasks return the segmentation result as a Uint8Array
|
||||||
|
* (when the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
|
||||||
|
* output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
|
||||||
|
* for future usage.
|
||||||
|
*/
|
||||||
|
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A callback that receives the computed masks from the segmentation tasks. The
|
||||||
|
* callback either receives a single element array with a category mask (as a
|
||||||
|
* `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
|
||||||
|
* The returned data is only valid for the duration of the callback. If
|
||||||
|
* asynchronous processing is needed, all data needs to be copied before the
|
||||||
|
* callback returns.
|
||||||
|
*/
|
||||||
|
export type SegmentationMaskCallback =
|
||||||
|
(masks: SegmentationMask[], width: number, height: number) => void;
|
||||||
|
|
||||||
|
/** A Region-Of-Interest (ROI) to represent a region within an image. */
|
||||||
|
export declare interface RegionOfInterest {
|
||||||
|
/** The ROI in keypoint format. */
|
||||||
|
keypoint: NormalizedKeypoint;
|
||||||
|
}
|
|
@ -74,11 +74,7 @@ export abstract class VisionTaskRunner extends TaskRunner {
|
||||||
'Task is not initialized with image mode. ' +
|
'Task is not initialized with image mode. ' +
|
||||||
'\'runningMode\' must be set to \'IMAGE\'.');
|
'\'runningMode\' must be set to \'IMAGE\'.');
|
||||||
}
|
}
|
||||||
|
this.process(image, imageProcessingOptions, this.getSynctheticTimestamp());
|
||||||
// Increment the timestamp by 1 millisecond to guarantee that we send
|
|
||||||
// monotonically increasing timestamps to the graph.
|
|
||||||
const syntheticTimestamp = this.getLatestOutputTimestamp() + 1;
|
|
||||||
this.process(image, imageProcessingOptions, syntheticTimestamp);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Sends a single video frame to the graph and awaits results. */
|
/** Sends a single video frame to the graph and awaits results. */
|
||||||
|
|
|
@ -19,8 +19,8 @@ mediapipe_ts_library(
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
|
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_jspb_proto",
|
||||||
"//mediapipe/tasks/web/core",
|
"//mediapipe/tasks/web/core",
|
||||||
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
"//mediapipe/tasks/web/vision/core:image_processing_options",
|
||||||
|
"//mediapipe/tasks/web/vision/core:types",
|
||||||
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
"//mediapipe/tasks/web/vision/core:vision_task_runner",
|
||||||
"//mediapipe/web/graph_runner:graph_runner_image_lib_ts",
|
|
||||||
"//mediapipe/web/graph_runner:graph_runner_ts",
|
"//mediapipe/web/graph_runner:graph_runner_ts",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,6 +21,7 @@ import {ImageSegmenterGraphOptions as ImageSegmenterGraphOptionsProto} from '../
|
||||||
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
import {SegmenterOptions as SegmenterOptionsProto} from '../../../../tasks/cc/vision/image_segmenter/proto/segmenter_options_pb';
|
||||||
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
import {WasmFileset} from '../../../../tasks/web/core/wasm_fileset';
|
||||||
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
import {ImageProcessingOptions} from '../../../../tasks/web/vision/core/image_processing_options';
|
||||||
|
import {SegmentationMask, SegmentationMaskCallback} from '../../../../tasks/web/vision/core/types';
|
||||||
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
import {VisionGraphRunner, VisionTaskRunner} from '../../../../tasks/web/vision/core/vision_task_runner';
|
||||||
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner';
|
||||||
// Placeholder for internal dependency on trusted resource url
|
// Placeholder for internal dependency on trusted resource url
|
||||||
|
@ -28,27 +29,9 @@ import {ImageSource, WasmModule} from '../../../../web/graph_runner/graph_runner
|
||||||
import {ImageSegmenterOptions} from './image_segmenter_options';
|
import {ImageSegmenterOptions} from './image_segmenter_options';
|
||||||
|
|
||||||
export * from './image_segmenter_options';
|
export * from './image_segmenter_options';
|
||||||
|
export {SegmentationMask, SegmentationMaskCallback};
|
||||||
export {ImageSource}; // Used in the public API
|
export {ImageSource}; // Used in the public API
|
||||||
|
|
||||||
/**
|
|
||||||
* The ImageSegmenter returns the segmentation result as a Uint8Array (when
|
|
||||||
* the default mode of `CATEGORY_MASK` is used) or as a Float32Array (for
|
|
||||||
* output type `CONFIDENCE_MASK`). The `WebGLTexture` output type is reserved
|
|
||||||
* for future usage.
|
|
||||||
*/
|
|
||||||
export type SegmentationMask = Uint8Array|Float32Array|WebGLTexture;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A callback that receives the computed masks from the image segmenter. The
|
|
||||||
* callback either receives a single element array with a category mask (as a
|
|
||||||
* `[Uint8Array]`) or multiple confidence masks (as a `Float32Array[]`).
|
|
||||||
* The returned data is only valid for the duration of the callback. If
|
|
||||||
* asynchronous processing is needed, all data needs to be copied before the
|
|
||||||
* callback returns.
|
|
||||||
*/
|
|
||||||
export type SegmentationMaskCallback =
|
|
||||||
(masks: SegmentationMask[], width: number, height: number) => void;
|
|
||||||
|
|
||||||
const IMAGE_STREAM = 'image_in';
|
const IMAGE_STREAM = 'image_in';
|
||||||
const NORM_RECT_STREAM = 'norm_rect';
|
const NORM_RECT_STREAM = 'norm_rect';
|
||||||
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
|
const GROUPED_SEGMENTATIONS_STREAM = 'segmented_masks';
|
||||||
|
|
1
third_party/BUILD
vendored
|
@ -112,6 +112,7 @@ cmake_external(
|
||||||
"WITH_JPEG": "ON",
|
"WITH_JPEG": "ON",
|
||||||
"WITH_PNG": "ON",
|
"WITH_PNG": "ON",
|
||||||
"WITH_TIFF": "ON",
|
"WITH_TIFF": "ON",
|
||||||
|
"WITH_OPENCL": "OFF",
|
||||||
"WITH_WEBP": "OFF",
|
"WITH_WEBP": "OFF",
|
||||||
# Optimization flags
|
# Optimization flags
|
||||||
"CV_ENABLE_INTRINSICS": "ON",
|
"CV_ENABLE_INTRINSICS": "ON",
|
||||||
|
|
10
third_party/external_files.bzl
vendored
|
@ -67,7 +67,7 @@ def external_files():
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_BUILD",
|
name = "com_google_mediapipe_BUILD",
|
||||||
sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3",
|
sha256 = "d2b2a8346202691d7f831887c84e9642e974f64ed67851d9a58cf15c94b1f6b3",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=1661875663693976"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/BUILD?generation=16618756636939761678323576393653"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -318,8 +318,8 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_face_landmarker_with_blendshapes_task",
|
name = "com_google_mediapipe_face_landmarker_with_blendshapes_task",
|
||||||
sha256 = "a75c1ba70e4b8568000af2ad0b355ed559ab5d5793db50fa9ad241f8dc4fad5f",
|
sha256 = "b44e4cae6f5822456d60f33e7c852640d78c7e342aee7eacc22589451a0b9dc2",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmarker_with_blendshapes.task?generation=1678323586260800"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/face_landmarker_with_blendshapes.task?generation=1678504998301299"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
@ -822,8 +822,8 @@ def external_files():
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
name = "com_google_mediapipe_portrait_expected_face_geometry_with_attention_pbtxt",
|
name = "com_google_mediapipe_portrait_expected_face_geometry_with_attention_pbtxt",
|
||||||
sha256 = "5cc57b8da3ad0527dce581fe1309f6b36043e5837e3f4f5af5e24005a99dc52a",
|
sha256 = "7ed1eed98e61e0a10811bb611c895d87c8023f398a36db01b6d9ba2e1ab09e16",
|
||||||
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678323601064393"],
|
urls = ["https://storage.googleapis.com/mediapipe-assets/portrait_expected_face_geometry_with_attention.pbtxt?generation=1678505004840652"],
|
||||||
)
|
)
|
||||||
|
|
||||||
http_file(
|
http_file(
|
||||||
|
|