Merge branch 'google:master' into audio-record-api-python

This commit is contained in:
Kinar R 2023-03-22 11:46:28 +05:30 committed by GitHub
commit 3c39aca52b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
253 changed files with 14296 additions and 859 deletions

17
LICENSE
View File

@ -199,3 +199,20 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
===========================================================================
For files under tasks/cc/text/language_detector/custom_ops/utils/utf/
===========================================================================
/*
* The authors of this software are Rob Pike and Ken Thompson.
* Copyright (c) 2002 by Lucent Technologies.
* Permission to use, copy, modify, and distribute this software for any
* purpose without fee is hereby granted, provided that this entire notice
* is included in all copies of any software which is or includes a copy
* or modification of this software and in all copies of the supporting
* documentation for such software.
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
*/

View File

@ -270,7 +270,7 @@ new_local_repository(
# For local MacOS builds, the path should point to an opencv@3 installation. # For local MacOS builds, the path should point to an opencv@3 installation.
# If you edit the path here, you will also need to update the corresponding # If you edit the path here, you will also need to update the corresponding
# prefix in "opencv_macos.BUILD". # prefix in "opencv_macos.BUILD".
path = "/usr/local", path = "/usr/local", # e.g. /usr/local/Cellar for HomeBrew
) )
new_local_repository( new_local_repository(
@ -499,8 +499,8 @@ cc_crosstool(name = "crosstool")
# Node dependencies # Node dependencies
http_archive( http_archive(
name = "build_bazel_rules_nodejs", name = "build_bazel_rules_nodejs",
sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda", sha256 = "94070eff79305be05b7699207fbac5d2608054dd53e6109f7d00d923919ff45a",
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"], urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.8.2/rules_nodejs-5.8.2.tar.gz"],
) )
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies") load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")

View File

@ -108,6 +108,8 @@ one over the other.
* [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite) * [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite)
* [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite) * [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite)
* [TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip)
* [Model information](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md)
### [Objectron](https://google.github.io/mediapipe/solutions/objectron) ### [Objectron](https://google.github.io/mediapipe/solutions/objectron)

View File

@ -118,9 +118,9 @@ on how to build MediaPipe examples.
* With a TensorFlow Model * With a TensorFlow Model
This uses the This uses the
[TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model) [TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip)
( see also ( see also
[model info](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md)), [model info](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md)),
and the pipeline is implemented in this and the pipeline is implemented in this
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt). [graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt).

View File

@ -0,0 +1,62 @@
## TensorFlow/TFLite Object Detection Model
### TensorFlow model
The model is trained on [MSCOCO 2014](http://cocodataset.org) dataset using [TensorFlow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection). It is a MobileNetV2-based SSD model with 0.5 depth multiplier. Detailed training configuration is in the provided `pipeline.config`. The model is a relatively compact model which has `0.171 mAP` to achieve real-time performance on mobile devices. You can compare it with other models from the [TensorFlow detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md).
### TFLite model
The TFLite model is converted from the TensorFlow above. The steps needed to convert the model are similar to [this tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) with minor modifications. Assuming now we have a trained TensorFlow model which includes the checkpoint files and the training configuration file, for example the files provided in this repo:
* `model.ckpt.index`
* `model.ckpt.meta`
* `model.ckpt.data-00000-of-00001`
* `pipeline.config`
Make sure you have installed these [python libraries](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1.md). Then to get the frozen graph, run the `export_tflite_ssd_graph.py` script from the `models/research` directory with this command:
```bash
$ PATH_TO_MODEL=path/to/the/model
$ bazel run object_detection:export_tflite_ssd_graph -- \
--pipeline_config_path ${PATH_TO_MODEL}/pipeline.config \
--trained_checkpoint_prefix ${PATH_TO_MODEL}/model.ckpt \
--output_directory ${PATH_TO_MODEL} \
--add_postprocessing_op=False
```
The exported model contains two files:
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
The difference between this step and the one in [the tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) is that we set `add_postprocessing_op` to False. In MediaPipe, we have provided all the calculators needed for post-processing such that we can exclude the custom TFLite ops for post-processing in the original graph, e.g., non-maximum suppression. This enables the flexibility to integrate with different post-processing algorithms and implementations.
Optional: You can install and use the [graph tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms) to inspect the input/output of the exported model:
```bash
$ bazel run graph_transforms:summarize_graph -- \
--in_graph=${PATH_TO_MODEL}/tflite_graph.pb
```
You should be able to see the input image size of the model is 320x320 and the outputs of the model are:
* `raw_outputs/box_encodings`
* `raw_outputs/class_predictions`
The last step is to convert the model to TFLite. You can look at [this guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md) for more detail. For this example, you just need to run:
```bash
$ tflite_convert -- \
--graph_def_file=${PATH_TO_MODEL}/tflite_graph.pb \
--output_file=${PATH_TO_MODEL}/model.tflite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--inference_type=FLOAT \
--input_shapes=1,320,320,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays=raw_outputs/box_encodings,raw_outputs/class_predictions
```
Now you have the TFLite model `model.tflite` ready to use with MediaPipe Object Detection graphs. Please see the examples for more detail.

View File

@ -269,6 +269,7 @@ Supported configuration options:
```python ```python
import cv2 import cv2
import mediapipe as mp import mediapipe as mp
import numpy as np
mp_drawing = mp.solutions.drawing_utils mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles mp_drawing_styles = mp.solutions.drawing_styles
mp_pose = mp.solutions.pose mp_pose = mp.solutions.pose

View File

@ -156,6 +156,7 @@ cc_library(
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",
"//mediapipe/framework/port:opencv_imgproc",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [ "//conditions:default": [
@ -168,6 +169,25 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "set_alpha_calculator_test",
srcs = ["set_alpha_calculator_test.cc"],
deps = [
":set_alpha_calculator",
":set_alpha_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"@com_google_googletest//:gtest_main",
],
)
cc_library( cc_library(
name = "bilateral_filter_calculator", name = "bilateral_filter_calculator",
srcs = ["bilateral_filter_calculator.cc"], srcs = ["bilateral_filter_calculator.cc"],
@ -748,6 +768,7 @@ cc_test(
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
], ],
tags = ["desktop_only_test"], tags = ["desktop_only_test"],

View File

@ -29,6 +29,9 @@ class AffineTransformation {
// pixels will be calculated. // pixels will be calculated.
enum class BorderMode { kZero, kReplicate }; enum class BorderMode { kZero, kReplicate };
// Pixel sampling interpolation method.
enum class Interpolation { kLinear, kCubic };
struct Size { struct Size {
int width; int width;
int height; int height;

View File

@ -77,8 +77,11 @@ class GlTextureWarpAffineRunner
std::unique_ptr<GpuBuffer>> { std::unique_ptr<GpuBuffer>> {
public: public:
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper, GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
GpuOrigin::Mode gpu_origin) GpuOrigin::Mode gpu_origin,
: gl_helper_(gl_helper), gpu_origin_(gpu_origin) {} AffineTransformation::Interpolation interpolation)
: gl_helper_(gl_helper),
gpu_origin_(gpu_origin),
interpolation_(interpolation) {}
absl::Status Init() { absl::Status Init() {
return gl_helper_->RunInGlContext([this]() -> absl::Status { return gl_helper_->RunInGlContext([this]() -> absl::Status {
const GLint attr_location[kNumAttributes] = { const GLint attr_location[kNumAttributes] = {
@ -103,10 +106,13 @@ class GlTextureWarpAffineRunner
} }
)"; )";
// TODO Move bicubic code to common shared place.
constexpr GLchar kFragShader[] = R"( constexpr GLchar kFragShader[] = R"(
DEFAULT_PRECISION(highp, float) DEFAULT_PRECISION(highp, float)
in vec2 sample_coordinate; in vec2 sample_coordinate;
uniform sampler2D input_texture; uniform sampler2D input_texture;
uniform vec2 input_size;
#ifdef GL_ES #ifdef GL_ES
#define fragColor gl_FragColor #define fragColor gl_FragColor
@ -114,8 +120,60 @@ class GlTextureWarpAffineRunner
out vec4 fragColor; out vec4 fragColor;
#endif // defined(GL_ES); #endif // defined(GL_ES);
#ifdef CUBIC_INTERPOLATION
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
const vec2 halve = vec2(0.5,0.5);
const vec2 one = vec2(1.0,1.0);
const vec2 two = vec2(2.0,2.0);
const vec2 three = vec2(3.0,3.0);
const vec2 six = vec2(6.0,6.0);
// Calculate the fraction and integer.
tex_coord = tex_coord * tex_size - halve;
vec2 frac = fract(tex_coord);
vec2 index = tex_coord - frac + halve;
// Calculate weights for Catmull-Rom filter.
vec2 w0 = frac * (-halve + frac * (one - halve * frac));
vec2 w1 = one + frac * frac * (-(two+halve) + three/two * frac);
vec2 w2 = frac * (halve + frac * (two - three/two * frac));
vec2 w3 = frac * frac * (-halve + halve * frac);
// Calculate weights to take advantage of bilinear texture lookup.
vec2 w12 = w1 + w2;
vec2 offset12 = w2 / (w1 + w2);
vec2 index_tl = index - one;
vec2 index_br = index + two;
vec2 index_eq = index + offset12;
index_tl /= tex_size;
index_br /= tex_size;
index_eq /= tex_size;
// 9 texture lookup and linear blending.
vec4 color = vec4(0.0);
color += texture2D(tex, vec2(index_tl.x, index_tl.y)) * w0.x * w0.y;
color += texture2D(tex, vec2(index_eq.x, index_tl.y)) * w12.x *w0.y;
color += texture2D(tex, vec2(index_br.x, index_tl.y)) * w3.x * w0.y;
color += texture2D(tex, vec2(index_tl.x, index_eq.y)) * w0.x * w12.y;
color += texture2D(tex, vec2(index_eq.x, index_eq.y)) * w12.x *w12.y;
color += texture2D(tex, vec2(index_br.x, index_eq.y)) * w3.x * w12.y;
color += texture2D(tex, vec2(index_tl.x, index_br.y)) * w0.x * w3.y;
color += texture2D(tex, vec2(index_eq.x, index_br.y)) * w12.x *w3.y;
color += texture2D(tex, vec2(index_br.x, index_br.y)) * w3.x * w3.y;
return color;
}
#else
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
return texture2D(tex, tex_coord);
}
#endif // defined(CUBIC_INTERPOLATION)
void main() { void main() {
vec4 color = texture2D(input_texture, sample_coordinate); vec4 color = sample(input_texture, sample_coordinate, input_size);
#ifdef CUSTOM_ZERO_BORDER_MODE #ifdef CUSTOM_ZERO_BORDER_MODE
float out_of_bounds = float out_of_bounds =
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 || float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
@ -137,14 +195,28 @@ class GlTextureWarpAffineRunner
glUseProgram(program); glUseProgram(program);
glUniform1i(glGetUniformLocation(program, "input_texture"), 1); glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
GLint matrix_id = glGetUniformLocation(program, "transform_matrix"); GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
return Program{.id = program, .matrix_id = matrix_id}; GLint size_id = glGetUniformLocation(program, "input_size");
return Program{
.id = program, .matrix_id = matrix_id, .size_id = size_id};
}; };
const std::string vert_src = const std::string vert_src =
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader); absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
const std::string frag_src = absl::StrCat( std::string interpolation_def;
mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader); switch (interpolation_) {
case AffineTransformation::Interpolation::kCubic:
interpolation_def = R"(
#define CUBIC_INTERPOLATION
)";
break;
case AffineTransformation::Interpolation::kLinear:
break;
}
const std::string frag_src =
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
interpolation_def, kFragShader);
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src)); ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
@ -152,9 +224,9 @@ class GlTextureWarpAffineRunner
std::string custom_zero_border_mode_def = R"( std::string custom_zero_border_mode_def = R"(
#define CUSTOM_ZERO_BORDER_MODE #define CUSTOM_ZERO_BORDER_MODE
)"; )";
const std::string frag_custom_zero_src = const std::string frag_custom_zero_src = absl::StrCat(
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble, mediapipe::kMediaPipeFragmentShaderPreamble,
custom_zero_border_mode_def, kFragShader); custom_zero_border_mode_def, interpolation_def, kFragShader);
return create_fn(vert_src, frag_custom_zero_src); return create_fn(vert_src, frag_custom_zero_src);
}; };
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED #if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
@ -256,6 +328,7 @@ class GlTextureWarpAffineRunner
} }
glUseProgram(program->id); glUseProgram(program->id);
// uniforms
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data()); Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) { if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
// @matrix describes affine transformation in terms of TOP LEFT origin, so // @matrix describes affine transformation in terms of TOP LEFT origin, so
@ -275,6 +348,10 @@ class GlTextureWarpAffineRunner
eigen_mat.transposeInPlace(); eigen_mat.transposeInPlace();
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data()); glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
if (interpolation_ == AffineTransformation::Interpolation::kCubic) {
glUniform2f(program->size_id, texture.width(), texture.height());
}
// vao // vao
glBindVertexArray(vao_); glBindVertexArray(vao_);
@ -327,6 +404,7 @@ class GlTextureWarpAffineRunner
struct Program { struct Program {
GLuint id; GLuint id;
GLint matrix_id; GLint matrix_id;
GLint size_id;
}; };
std::shared_ptr<GlCalculatorHelper> gl_helper_; std::shared_ptr<GlCalculatorHelper> gl_helper_;
GpuOrigin::Mode gpu_origin_; GpuOrigin::Mode gpu_origin_;
@ -335,6 +413,8 @@ class GlTextureWarpAffineRunner
Program program_; Program program_;
std::optional<Program> program_custom_zero_; std::optional<Program> program_custom_zero_;
GLuint framebuffer_ = 0; GLuint framebuffer_ = 0;
AffineTransformation::Interpolation interpolation_ =
AffineTransformation::Interpolation::kLinear;
}; };
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED #undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
@ -344,9 +424,10 @@ class GlTextureWarpAffineRunner
absl::StatusOr<std::unique_ptr< absl::StatusOr<std::unique_ptr<
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>> AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
CreateAffineTransformationGlRunner( CreateAffineTransformationGlRunner(
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin) { std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin,
auto runner = AffineTransformation::Interpolation interpolation) {
absl::make_unique<GlTextureWarpAffineRunner>(gl_helper, gpu_origin); auto runner = absl::make_unique<GlTextureWarpAffineRunner>(
gl_helper, gpu_origin, interpolation);
MP_RETURN_IF_ERROR(runner->Init()); MP_RETURN_IF_ERROR(runner->Init());
return runner; return runner;
} }

View File

@ -29,7 +29,8 @@ absl::StatusOr<std::unique_ptr<AffineTransformation::Runner<
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>> mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
CreateAffineTransformationGlRunner( CreateAffineTransformationGlRunner(
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper, std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
mediapipe::GpuOrigin::Mode gpu_origin); mediapipe::GpuOrigin::Mode gpu_origin,
AffineTransformation::Interpolation interpolation);
} // namespace mediapipe } // namespace mediapipe

View File

@ -39,9 +39,22 @@ cv::BorderTypes GetBorderModeForOpenCv(
} }
} }
int GetInterpolationForOpenCv(
AffineTransformation::Interpolation interpolation) {
switch (interpolation) {
case AffineTransformation::Interpolation::kLinear:
return cv::INTER_LINEAR;
case AffineTransformation::Interpolation::kCubic:
return cv::INTER_CUBIC;
}
}
class OpenCvRunner class OpenCvRunner
: public AffineTransformation::Runner<ImageFrame, ImageFrame> { : public AffineTransformation::Runner<ImageFrame, ImageFrame> {
public: public:
OpenCvRunner(AffineTransformation::Interpolation interpolation)
: interpolation_(GetInterpolationForOpenCv(interpolation)) {}
absl::StatusOr<ImageFrame> Run( absl::StatusOr<ImageFrame> Run(
const ImageFrame& input, const std::array<float, 16>& matrix, const ImageFrame& input, const std::array<float, 16>& matrix,
const AffineTransformation::Size& size, const AffineTransformation::Size& size,
@ -142,19 +155,23 @@ class OpenCvRunner
cv::warpAffine(in_mat, out_mat, cv_affine_transform, cv::warpAffine(in_mat, out_mat, cv_affine_transform,
cv::Size(out_mat.cols, out_mat.rows), cv::Size(out_mat.cols, out_mat.rows),
/*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP, /*flags=*/interpolation_ | cv::WARP_INVERSE_MAP,
GetBorderModeForOpenCv(border_mode)); GetBorderModeForOpenCv(border_mode));
return out_image; return out_image;
} }
private:
int interpolation_ = cv::INTER_LINEAR;
}; };
} // namespace } // namespace
absl::StatusOr< absl::StatusOr<
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>> std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
CreateAffineTransformationOpenCvRunner() { CreateAffineTransformationOpenCvRunner(
return absl::make_unique<OpenCvRunner>(); AffineTransformation::Interpolation interpolation) {
return absl::make_unique<OpenCvRunner>(interpolation);
} }
} // namespace mediapipe } // namespace mediapipe

View File

@ -25,7 +25,8 @@ namespace mediapipe {
absl::StatusOr< absl::StatusOr<
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>> std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
CreateAffineTransformationOpenCvRunner(); CreateAffineTransformationOpenCvRunner(
AffineTransformation::Interpolation interpolation);
} // namespace mediapipe } // namespace mediapipe

View File

@ -22,6 +22,7 @@
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h" #include "mediapipe/framework/port/vector.h"
@ -53,24 +54,16 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
// range of [0, 1). Only the first channel of Alpha is used. Input & output Mat // range of [0, 1). Only the first channel of Alpha is used. Input & output Mat
// must be uchar. // must be uchar.
template <typename AlphaType> template <typename AlphaType>
absl::Status MergeRGBA8Image(const cv::Mat input_mat, const cv::Mat& alpha_mat, absl::Status CopyAlphaImage(const cv::Mat& alpha_mat, cv::Mat& output_mat) {
cv::Mat& output_mat) { RET_CHECK_EQ(output_mat.rows, alpha_mat.rows);
RET_CHECK_EQ(input_mat.rows, alpha_mat.rows); RET_CHECK_EQ(output_mat.cols, alpha_mat.cols);
RET_CHECK_EQ(input_mat.cols, alpha_mat.cols);
RET_CHECK_EQ(input_mat.rows, output_mat.rows);
RET_CHECK_EQ(input_mat.cols, output_mat.cols);
for (int i = 0; i < output_mat.rows; ++i) { for (int i = 0; i < output_mat.rows; ++i) {
const uchar* in_ptr = input_mat.ptr<uchar>(i);
const AlphaType* alpha_ptr = alpha_mat.ptr<AlphaType>(i); const AlphaType* alpha_ptr = alpha_mat.ptr<AlphaType>(i);
uchar* out_ptr = output_mat.ptr<uchar>(i); uchar* out_ptr = output_mat.ptr<uchar>(i);
for (int j = 0; j < output_mat.cols; ++j) { for (int j = 0; j < output_mat.cols; ++j) {
const int out_idx = j * kNumChannelsRGBA; const int out_idx = j * kNumChannelsRGBA;
const int in_idx = j * input_mat.channels();
const int alpha_idx = j * alpha_mat.channels(); const int alpha_idx = j * alpha_mat.channels();
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
if constexpr (std::is_same<AlphaType, uchar>::value) { if constexpr (std::is_same<AlphaType, uchar>::value) {
out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask
} else { } else {
@ -273,7 +266,7 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
// Setup source image // Setup source image
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>(); const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
const cv::Mat input_mat = mediapipe::formats::MatView(&input_frame); const cv::Mat input_mat = formats::MatView(&input_frame);
if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) { if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) {
LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported"; LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported";
} }
@ -281,38 +274,38 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
// Setup destination image // Setup destination image
auto output_frame = absl::make_unique<ImageFrame>( auto output_frame = absl::make_unique<ImageFrame>(
ImageFormat::SRGBA, input_mat.cols, input_mat.rows); ImageFormat::SRGBA, input_mat.cols, input_mat.rows);
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get()); cv::Mat output_mat = formats::MatView(output_frame.get());
const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) && const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) &&
!cc->Inputs().Tag(kInputAlphaTag).IsEmpty(); !cc->Inputs().Tag(kInputAlphaTag).IsEmpty();
const bool use_alpha_mask = alpha_value_ < 0 && has_alpha_mask; const bool use_alpha_mask = alpha_value_ < 0 && has_alpha_mask;
// Setup alpha image and Update image in CPU. // Copy rgb part of the image in CPU
if (input_mat.channels() == 3) {
cv::cvtColor(input_mat, output_mat, cv::COLOR_RGB2RGBA);
} else {
input_mat.copyTo(output_mat);
}
// Setup alpha image in CPU.
if (use_alpha_mask) { if (use_alpha_mask) {
const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get<ImageFrame>(); const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get<ImageFrame>();
cv::Mat alpha_mat = mediapipe::formats::MatView(&alpha_mask); cv::Mat alpha_mat = formats::MatView(&alpha_mask);
const bool alpha_is_float = CV_MAT_DEPTH(alpha_mat.type()) == CV_32F; const bool alpha_is_float = CV_MAT_DEPTH(alpha_mat.type()) == CV_32F;
RET_CHECK(alpha_is_float || CV_MAT_DEPTH(alpha_mat.type()) == CV_8U); RET_CHECK(alpha_is_float || CV_MAT_DEPTH(alpha_mat.type()) == CV_8U);
if (alpha_is_float) { if (alpha_is_float) {
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(CopyAlphaImage<float>(alpha_mat, output_mat));
MergeRGBA8Image<float>(input_mat, alpha_mat, output_mat));
} else { } else {
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(CopyAlphaImage<uchar>(alpha_mat, output_mat));
MergeRGBA8Image<uchar>(input_mat, alpha_mat, output_mat));
} }
} else { } else {
const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f); const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f);
for (int i = 0; i < output_mat.rows; ++i) { for (int i = 0; i < output_mat.rows; ++i) {
const uchar* in_ptr = input_mat.ptr<uchar>(i);
uchar* out_ptr = output_mat.ptr<uchar>(i); uchar* out_ptr = output_mat.ptr<uchar>(i);
for (int j = 0; j < output_mat.cols; ++j) { for (int j = 0; j < output_mat.cols; ++j) {
const int out_idx = j * kNumChannelsRGBA; const int out_idx = j * kNumChannelsRGBA;
const int in_idx = j * input_mat.channels();
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
out_ptr[out_idx + 3] = alpha_value; // use value from options out_ptr[out_idx + 3] = alpha_value; // use value from options
} }
} }

View File

@ -0,0 +1,156 @@
#include <cstdint>
#include "mediapipe/calculators/image/set_alpha_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "testing/base/public/benchmark.h"
namespace mediapipe {
namespace {
constexpr int input_width = 100;
constexpr int input_height = 100;
std::unique_ptr<ImageFrame> GetInputFrame(int width, int height, int channel) {
const int total_size = width * height * channel;
ImageFormat::Format image_format;
if (channel == 4) {
image_format = ImageFormat::SRGBA;
} else if (channel == 3) {
image_format = ImageFormat::SRGB;
} else {
image_format = ImageFormat::GRAY8;
}
auto input_frame = std::make_unique<ImageFrame>(image_format, width, height,
/*alignment_boundary =*/1);
for (int i = 0; i < total_size; ++i) {
input_frame->MutablePixelData()[i] = i % 256;
}
return input_frame;
}
// Test SetAlphaCalculator with RGB IMAGE input.
TEST(SetAlphaCalculatorTest, CpuRgb) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 3);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
EXPECT_EQ(outputs.NumEntries(), 1);
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
// Generate ground truth (expected_mat).
const auto image = GetInputFrame(input_width, input_height, 3);
const auto input_mat = formats::MatView(image.get());
const auto mask = GetInputFrame(input_width, input_height, 1);
const auto mask_mat = formats::MatView(mask.get());
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 3, 3});
cv::Mat output_mat = formats::MatView(&output_image);
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
EXPECT_FLOAT_EQ(max_diff, 0);
} // TEST
// Test SetAlphaCalculator with RGBA IMAGE input.
TEST(SetAlphaCalculatorTest, CpuRgba) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 4);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
EXPECT_EQ(outputs.NumEntries(), 1);
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
// Generate ground truth (expected_mat).
const auto image = GetInputFrame(input_width, input_height, 4);
const auto input_mat = formats::MatView(image.get());
const auto mask = GetInputFrame(input_width, input_height, 1);
const auto mask_mat = formats::MatView(mask.get());
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 4, 3});
cv::Mat output_mat = formats::MatView(&output_image);
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
EXPECT_FLOAT_EQ(max_diff, 0);
} // TEST
static void BM_SetAlpha3ChannelImage(benchmark::State& state) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 3);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
ASSERT_EQ(1, outputs.NumEntries());
for (const auto _ : state) {
MP_ASSERT_OK(runner.Run());
}
}
BENCHMARK(BM_SetAlpha3ChannelImage);
} // namespace
} // namespace mediapipe

View File

@ -53,6 +53,17 @@ AffineTransformation::BorderMode GetBorderMode(
} }
} }
AffineTransformation::Interpolation GetInterpolation(
mediapipe::WarpAffineCalculatorOptions::Interpolation interpolation) {
switch (interpolation) {
case mediapipe::WarpAffineCalculatorOptions::INTER_UNSPECIFIED:
case mediapipe::WarpAffineCalculatorOptions::INTER_LINEAR:
return AffineTransformation::Interpolation::kLinear;
case mediapipe::WarpAffineCalculatorOptions::INTER_CUBIC:
return AffineTransformation::Interpolation::kCubic;
}
}
template <typename ImageT> template <typename ImageT>
class WarpAffineRunnerHolder {}; class WarpAffineRunnerHolder {};
@ -61,16 +72,22 @@ template <>
class WarpAffineRunnerHolder<ImageFrame> { class WarpAffineRunnerHolder<ImageFrame> {
public: public:
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>; using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) {
interpolation_ = GetInterpolation(
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
return absl::OkStatus();
}
absl::StatusOr<RunnerType*> GetRunner() { absl::StatusOr<RunnerType*> GetRunner() {
if (!runner_) { if (!runner_) {
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner()); ASSIGN_OR_RETURN(runner_,
CreateAffineTransformationOpenCvRunner(interpolation_));
} }
return runner_.get(); return runner_.get();
} }
private: private:
std::unique_ptr<RunnerType> runner_; std::unique_ptr<RunnerType> runner_;
AffineTransformation::Interpolation interpolation_;
}; };
#endif // !MEDIAPIPE_DISABLE_OPENCV #endif // !MEDIAPIPE_DISABLE_OPENCV
@ -85,12 +102,14 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
gpu_origin_ = gpu_origin_ =
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin(); cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>(); gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
interpolation_ = GetInterpolation(
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
return gl_helper_->Open(cc); return gl_helper_->Open(cc);
} }
absl::StatusOr<RunnerType*> GetRunner() { absl::StatusOr<RunnerType*> GetRunner() {
if (!runner_) { if (!runner_) {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(runner_, CreateAffineTransformationGlRunner(
runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_)); gl_helper_, gpu_origin_, interpolation_));
} }
return runner_.get(); return runner_.get();
} }
@ -99,6 +118,7 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
mediapipe::GpuOrigin::Mode gpu_origin_; mediapipe::GpuOrigin::Mode gpu_origin_;
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_; std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
std::unique_ptr<RunnerType> runner_; std::unique_ptr<RunnerType> runner_;
AffineTransformation::Interpolation interpolation_;
}; };
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU

View File

@ -31,6 +31,13 @@ message WarpAffineCalculatorOptions {
BORDER_REPLICATE = 2; BORDER_REPLICATE = 2;
} }
// Pixel sampling interpolation methods. See @interpolation.
enum Interpolation {
INTER_UNSPECIFIED = 0;
INTER_LINEAR = 1;
INTER_CUBIC = 2;
}
// Pixel extrapolation method. // Pixel extrapolation method.
// When converting image to tensor it may happen that tensor needs to read // When converting image to tensor it may happen that tensor needs to read
// pixels outside image boundaries. Border mode helps to specify how such // pixels outside image boundaries. Border mode helps to specify how such
@ -43,4 +50,10 @@ message WarpAffineCalculatorOptions {
// to be flipped vertically as tensors are expected to start at top. // to be flipped vertically as tensors are expected to start at top.
// (DEFAULT or unset interpreted as CONVENTIONAL.) // (DEFAULT or unset interpreted as CONVENTIONAL.)
optional GpuOrigin.Mode gpu_origin = 2; optional GpuOrigin.Mode gpu_origin = 2;
// Sampling method for neighboring pixels.
// INTER_LINEAR (bilinear) linearly interpolates from the nearest 4 neighbors.
// INTER_CUBIC (bicubic) interpolates a small neighborhood with cubic weights.
// INTER_UNSPECIFIED or unset interpreted as INTER_LINEAR.
optional Interpolation interpolation = 3;
} }

View File

@ -63,7 +63,8 @@ void RunTest(const std::string& graph_text, const std::string& tag,
const cv::Mat& input, cv::Mat expected_result, const cv::Mat& input, cv::Mat expected_result,
float similarity_threshold, std::array<float, 16> matrix, float similarity_threshold, std::array<float, 16> matrix,
int out_width, int out_height, int out_width, int out_height,
absl::optional<AffineTransformation::BorderMode> border_mode) { std::optional<AffineTransformation::BorderMode> border_mode,
std::optional<AffineTransformation::Interpolation> interpolation) {
std::string border_mode_str; std::string border_mode_str;
if (border_mode) { if (border_mode) {
switch (*border_mode) { switch (*border_mode) {
@ -75,8 +76,20 @@ void RunTest(const std::string& graph_text, const std::string& tag,
break; break;
} }
} }
std::string interpolation_str;
if (interpolation) {
switch (*interpolation) {
case AffineTransformation::Interpolation::kLinear:
interpolation_str = "interpolation: INTER_LINEAR";
break;
case AffineTransformation::Interpolation::kCubic:
interpolation_str = "interpolation: INTER_CUBIC";
break;
}
}
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>( auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(graph_text, /*$0=*/border_mode_str)); absl::Substitute(graph_text, /*$0=*/border_mode_str,
/*$1=*/interpolation_str));
std::vector<Packet> output_packets; std::vector<Packet> output_packets;
tool::AddVectorSink("output_image", &graph_config, &output_packets); tool::AddVectorSink("output_image", &graph_config, &output_packets);
@ -132,7 +145,8 @@ struct SimilarityConfig {
void RunTest(cv::Mat input, cv::Mat expected_result, void RunTest(cv::Mat input, cv::Mat expected_result,
const SimilarityConfig& similarity, std::array<float, 16> matrix, const SimilarityConfig& similarity, std::array<float, 16> matrix,
int out_width, int out_height, int out_width, int out_height,
absl::optional<AffineTransformation::BorderMode> border_mode) { std::optional<AffineTransformation::BorderMode> border_mode,
std::optional<AffineTransformation::Interpolation> interpolation) {
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
input_stream: "output_size" input_stream: "output_size"
@ -146,12 +160,13 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
} }
} }
} }
)", )",
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix, "cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
@ -171,6 +186,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
} }
} }
} }
@ -181,7 +197,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
} }
)", )",
"cpu_image", input, expected_result, similarity.threshold_on_cpu, "cpu_image", input, expected_result, similarity.threshold_on_cpu,
matrix, out_width, out_height, border_mode); matrix, out_width, out_height, border_mode, interpolation);
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
@ -201,6 +217,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
gpu_origin: TOP_LEFT gpu_origin: TOP_LEFT
} }
} }
@ -212,7 +229,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
} }
)", )",
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix, "gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
@ -237,6 +254,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
gpu_origin: TOP_LEFT gpu_origin: TOP_LEFT
} }
} }
@ -253,7 +271,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
} }
)", )",
"gpu_image", input, expected_result, similarity.threshold_on_gpu, "gpu_image", input, expected_result, similarity.threshold_on_gpu,
matrix, out_width, out_height, border_mode); matrix, out_width, out_height, border_mode, interpolation);
} }
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi, std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
@ -287,10 +305,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) {
int out_height = 256; int out_height = 256;
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = {}; std::optional<AffineTransformation::BorderMode> border_mode = {};
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) { TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
@ -312,10 +331,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) { TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
@ -337,10 +357,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) { TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
@ -362,10 +383,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) { TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
@ -386,10 +408,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) { TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
@ -411,10 +434,38 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZeroInterpCubic) {
mediapipe::NormalizedRect roi;
roi.set_x_center(0.65f);
roi.set_y_center(0.4f);
roi.set_width(0.5f);
roi.set_height(0.5f);
roi.set_rotation(M_PI * -45.0f / 180.0f);
auto input = GetRgb(
"/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/input.jpg");
auto expected_output = GetRgb(
"/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"medium_sub_rect_with_rotation_border_zero_interp_cubic.png");
int out_width = 256;
int out_height = 256;
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation =
AffineTransformation::Interpolation::kCubic;
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.78},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRect) { TEST(WarpAffineCalculatorTest, LargeSubRect) {
@ -435,10 +486,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRect) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) { TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
@ -459,10 +511,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
@ -483,10 +536,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
@ -508,10 +562,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
@ -532,10 +587,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
int out_height = 128; int out_height = 128;
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = {}; std::optional<AffineTransformation::BorderMode> border_mode = {};
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
@ -557,10 +613,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, NoOp) { TEST(WarpAffineCalculatorTest, NoOp) {
@ -581,10 +638,11 @@ TEST(WarpAffineCalculatorTest, NoOp) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, NoOpBorderZero) { TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
@ -605,10 +663,11 @@ TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
} // namespace } // namespace

View File

@ -997,17 +997,20 @@ cc_library(
":image_to_tensor_converter_gl_buffer", ":image_to_tensor_converter_gl_buffer",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
], ],
"//mediapipe:apple": [ "//mediapipe:apple": [
":image_to_tensor_converter_metal", ":image_to_tensor_converter_metal",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
], ],
"//conditions:default": [ "//conditions:default": [
":image_to_tensor_converter_gl_buffer", ":image_to_tensor_converter_gl_buffer",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
], ],
}), }),
) )
@ -1045,6 +1048,10 @@ cc_test(
":image_to_tensor_calculator", ":image_to_tensor_calculator",
":image_to_tensor_converter", ":image_to_tensor_converter",
":image_to_tensor_utils", ":image_to_tensor_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
@ -1061,11 +1068,10 @@ cc_test(
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/util:image_test_utils", "//mediapipe/util:image_test_utils",
"@com_google_absl//absl/flags:flag", ] + select({
"@com_google_absl//absl/memory", "//mediapipe:apple": [],
"@com_google_absl//absl/strings", "//conditions:default": ["//mediapipe/gpu:gl_context"],
"@com_google_absl//absl/strings:str_format", }),
],
) )
cc_library( cc_library(

View File

@ -45,9 +45,11 @@
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h"
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_service.h"
#else #else
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h"
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_service.h"
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
@ -147,7 +149,7 @@ class ImageToTensorCalculator : public Node {
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#else #else
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); cc->UseService(kGpuService).Optional();
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#endif // MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_DISABLE_GPU

View File

@ -41,6 +41,10 @@
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/util/image_test_utils.h" #include "mediapipe/util/image_test_utils.h"
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
#include "mediapipe/gpu/gl_context.h"
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
namespace mediapipe { namespace mediapipe {
namespace { namespace {
@ -507,5 +511,79 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeAndUseInputImageDims) {
/*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt, /*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt,
/*keep_aspect=*/false, BorderMode::kZero, roi); /*keep_aspect=*/false, BorderMode::kZero, roi);
} }
TEST(ImageToTensorCalculatorTest, CanBeUsedWithoutGpuServiceSet) {
auto graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_image"
node {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.ImageToTensorCalculatorOptions.ext] {
output_tensor_float_range { min: 0.0f max: 1.0f }
}
}
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
MP_ASSERT_OK(graph.StartRun({}));
auto image_frame =
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
Image image = Image(std::move(image_frame));
Packet packet = MakePacket<Image>(std::move(image));
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
MP_ASSERT_OK(graph.WaitUntilIdle());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
TEST(ImageToTensorCalculatorTest,
FailsGracefullyWhenGpuServiceNeededButNotAvailable) {
auto graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_image"
node {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.ImageToTensorCalculatorOptions.ext] {
output_tensor_float_range { min: 0.0f max: 1.0f }
}
}
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK_AND_ASSIGN(auto context,
GlContext::Create(nullptr, /*create_thread=*/true));
Packet packet;
context->Run([&packet]() {
auto image_frame =
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
Image image = Image(std::move(image_frame));
// Ensure image is available on GPU to force ImageToTensorCalculator to
// run on GPU.
ASSERT_TRUE(image.ConvertToGpu());
packet = MakePacket<Image>(std::move(image));
});
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
EXPECT_THAT(graph.WaitUntilIdle(),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("GPU service not available")));
}
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -141,7 +141,7 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
} }
// Run inference. // Run inference.
{ {
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc); MEDIAPIPE_PROFILING(GPU_TASK_INVOKE_ADVANCED, cc);
return tflite_gpu_runner_->Invoke(); return tflite_gpu_runner_->Invoke();
} }
})); }));

View File

@ -0,0 +1,41 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The option proto for the TensorsReadbackCalculator.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorsReadbackCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorsReadbackCalculatorOptions ext = 514750372;
}
// Expected shapes of the input tensors.
// The calculator uses these shape to build the GPU programs during
// initialization, and check the actual tensor shapes against the expected
// shapes during runtime.
// Batch size of the tensor is set to be 1. `TensorShape` here can be C, WC,
// or HWC.
// For example {dims: 1 dims: 2} represents a tensor with batch_size = 1,
// width = 1, and num_channels = 2.
message TensorShape {
repeated int32 dims = 1 [packed = true];
}
// tensor_shape specifies the shape of each input tensors.
repeated TensorShape tensor_shape = 1;
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

View File

@ -14,6 +14,7 @@
# #
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//lib:selects.bzl", "selects")
licenses(["notice"]) licenses(["notice"])
@ -312,15 +313,19 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library( # TODO: Re-evaluate which of these libraries we can avoid making
# cc_library_with_tflite and can be changed back to cc_library.
cc_library_with_tflite(
name = "tflite_model_calculator", name = "tflite_model_calculator",
srcs = ["tflite_model_calculator.cc"], srcs = ["tflite_model_calculator.cc"],
tflite_deps = [
"@org_tensorflow//tensorflow/lite:framework_stable",
],
deps = [ deps = [
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -66,7 +66,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
} else { } else {
cc->OutputSidePackets() cc->OutputSidePackets()
.Index(0) .Index(0)
.Set<tflite_shims::ops::builtin::BuiltinOpResolver>(); .Set<tflite::ops::builtin::BuiltinOpResolver>();
} }
return absl::OkStatus(); return absl::OkStatus();
} }
@ -77,7 +77,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
const TfLiteCustomOpResolverCalculatorOptions& options = const TfLiteCustomOpResolverCalculatorOptions& options =
cc->Options<TfLiteCustomOpResolverCalculatorOptions>(); cc->Options<TfLiteCustomOpResolverCalculatorOptions>();
std::unique_ptr<tflite_shims::ops::builtin::BuiltinOpResolver> op_resolver; std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> op_resolver;
if (options.use_gpu()) { if (options.use_gpu()) {
op_resolver = absl::make_unique<mediapipe::OpResolver>(); op_resolver = absl::make_unique<mediapipe::OpResolver>();
} else { } else {

View File

@ -21,7 +21,7 @@
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/allocation.h" #include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/core/shims/cc/model.h" #include "tensorflow/lite/model.h"
namespace mediapipe { namespace mediapipe {
@ -82,7 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase {
} }
if (cc->InputSidePackets().HasTag("MODEL_FD")) { if (cc->InputSidePackets().HasTag("MODEL_FD")) {
#ifdef ABSL_HAVE_MMAP #if defined(ABSL_HAVE_MMAP) && !TFLITE_WITH_STABLE_ABI
model_packet = cc->InputSidePackets().Tag("MODEL_FD"); model_packet = cc->InputSidePackets().Tag("MODEL_FD");
const auto& model_fd = const auto& model_fd =
model_packet.Get<std::tuple<int, size_t, size_t>>(); model_packet.Get<std::tuple<int, size_t, size_t>>();

View File

@ -1270,6 +1270,50 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
mediapipe_proto_library(
name = "flat_color_image_calculator_proto",
srcs = ["flat_color_image_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/util:color_proto",
],
)
cc_library(
name = "flat_color_image_calculator",
srcs = ["flat_color_image_calculator.cc"],
deps = [
":flat_color_image_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/util:color_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "flat_color_image_calculator_test",
srcs = ["flat_color_image_calculator_test.cc"],
deps = [
":flat_color_image_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/util:color_cc_proto",
],
)
cc_library( cc_library(
name = "from_image_calculator", name = "from_image_calculator",
srcs = ["from_image_calculator.cc"], srcs = ["from_image_calculator.cc"],

View File

@ -0,0 +1,138 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/util/color.pb.h"
namespace mediapipe {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Node;
using ::mediapipe::api2::Output;
} // namespace
// A calculator for generating an image filled with a single color.
//
// Inputs:
// IMAGE (Image, optional)
// If provided, the output will have the same size
// COLOR (Color proto, optional)
// Color to paint the output with. Takes precedence over the equivalent
// calculator options.
//
// Outputs:
// IMAGE (Image)
// Image filled with the requested color.
//
// Example useage:
// node {
// calculator: "FlatColorImageCalculator"
// input_stream: "IMAGE:image"
// input_stream: "COLOR:color"
// output_stream: "IMAGE:blank_image"
// options {
// [mediapipe.FlatColorImageCalculatorOptions.ext] {
// color: {
// r: 255
// g: 255
// b: 255
// }
// }
// }
// }
class FlatColorImageCalculator : public Node {
public:
static constexpr Input<Image>::Optional kInImage{"IMAGE"};
static constexpr Input<Color>::Optional kInColor{"COLOR"};
static constexpr Output<Image> kOutImage{"IMAGE"};
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage);
static absl::Status UpdateContract(CalculatorContract* cc) {
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
RET_CHECK(kInImage(cc).IsConnected() ^
(options.has_output_height() || options.has_output_width()))
<< "Either set IMAGE input stream, or set through options";
RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color())
<< "Either set COLOR input stream, or set through options";
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
bool use_dimension_from_option_ = false;
bool use_color_from_option_ = false;
};
MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator);
absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) {
use_dimension_from_option_ = !kInImage(cc).IsConnected();
use_color_from_option_ = !kInColor(cc).IsConnected();
return absl::OkStatus();
}
absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
int output_height = -1;
int output_width = -1;
if (use_dimension_from_option_) {
output_height = options.output_height();
output_width = options.output_width();
} else if (!kInImage(cc).IsEmpty()) {
const Image& input_image = kInImage(cc).Get();
output_height = input_image.height();
output_width = input_image.width();
} else {
return absl::OkStatus();
}
Color color;
if (use_color_from_option_) {
color = options.color();
} else if (!kInColor(cc).IsEmpty()) {
color = kInColor(cc).Get();
} else {
return absl::OkStatus();
}
auto output_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
output_width, output_height);
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b()));
kOutImage(cc).Send(Image(output_frame));
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,32 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/util/color.proto";
message FlatColorImageCalculatorOptions {
extend CalculatorOptions {
optional FlatColorImageCalculatorOptions ext = 515548435;
}
// Output dimensions.
optional int32 output_width = 1;
optional int32 output_height = 2;
// The color to fill with in the output image.
optional Color color = 3;
}

View File

@ -0,0 +1,210 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/util/color.pb.h"
namespace mediapipe {
namespace {
using ::testing::HasSubstr;
constexpr char kImageTag[] = "IMAGE";
constexpr char kColorTag[] = "COLOR";
constexpr int kImageWidth = 256;
constexpr int kImageHeight = 256;
TEST(FlatColorImageCalculatorTest, SpecifyColorThroughOptions) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
color: {
r: 100,
g: 200,
b: 255,
}
}
}
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
ASSERT_EQ(outputs.size(), 3);
for (const auto& packet : outputs) {
const auto& image = packet.Get<Image>();
EXPECT_EQ(image.width(), kImageWidth);
EXPECT_EQ(image.height(), kImageHeight);
auto image_frame = image.GetImageFrameSharedPtr();
auto* pixel_data = image_frame->PixelData();
EXPECT_EQ(pixel_data[0], 100);
EXPECT_EQ(pixel_data[1], 200);
EXPECT_EQ(pixel_data[2], 255);
}
}
TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
output_width: 7,
output_height: 13,
}
}
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
ASSERT_EQ(outputs.size(), 3);
for (const auto& packet : outputs) {
const auto& image = packet.Get<Image>();
EXPECT_EQ(image.width(), 7);
EXPECT_EQ(image.height(), 13);
auto image_frame = image.GetImageFrameSharedPtr();
const uint8_t* pixel_data = image_frame->PixelData();
EXPECT_EQ(pixel_data[0], 0);
EXPECT_EQ(pixel_data[1], 5);
EXPECT_EQ(pixel_data[2], 0);
}
}
TEST(FlatColorImageCalculatorTest, FailureMissingDimension) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set IMAGE input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureMissingColor) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
output_stream: "IMAGE:out_image"
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set COLOR input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureDuplicateDimension) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
output_width: 7,
output_height: 13,
}
}
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set IMAGE input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
color: {
r: 100,
g: 200,
b: 255,
}
}
}
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set COLOR input stream"));
}
} // namespace
} // namespace mediapipe

View File

@ -138,7 +138,23 @@ void TestWithAspectRatio(const double aspect_ratio,
std::string result_image; std::string result_image;
MP_ASSERT_OK( MP_ASSERT_OK(
mediapipe::file::GetContents(result_string_path, &result_image)); mediapipe::file::GetContents(result_string_path, &result_image));
EXPECT_EQ(result_image, output_string); if (result_image != output_string) {
// There may be slight differences due to the way the JPEG was encoded or
// the OpenCV version used to generate the reference files. Compare
// pixel-by-pixel using the Peak Signal-to-Noise Ratio instead.
cv::Mat result_mat =
cv::imdecode(cv::Mat(1, result_image.size(), CV_8UC1,
const_cast<char*>(result_image.data())),
cv::IMREAD_UNCHANGED);
cv::Mat output_mat =
cv::imdecode(cv::Mat(1, output_string.size(), CV_8UC1,
const_cast<char*>(output_string.data())),
cv::IMREAD_UNCHANGED);
ASSERT_EQ(result_mat.rows, output_mat.rows);
ASSERT_EQ(result_mat.cols, output_mat.cols);
ASSERT_EQ(result_mat.type(), output_mat.type());
EXPECT_GT(cv::PSNR(result_mat, output_mat), 45.0);
}
} else { } else {
std::string output_string_path = mediapipe::file::JoinPath( std::string output_string_path = mediapipe::file::JoinPath(
absl::GetFlag(FLAGS_output_folder), absl::GetFlag(FLAGS_output_folder),

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.1 KiB

After

Width:  |  Height:  |  Size: 6.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.2 KiB

After

Width:  |  Height:  |  Size: 8.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.6 KiB

After

Width:  |  Height:  |  Size: 7.6 KiB

View File

@ -136,6 +136,8 @@ message GraphTrace {
GPU_TASK_INVOKE = 16; GPU_TASK_INVOKE = 16;
TPU_TASK_INVOKE = 17; TPU_TASK_INVOKE = 17;
CPU_TASK_INVOKE = 18; CPU_TASK_INVOKE = 18;
GPU_TASK_INVOKE_ADVANCED = 19;
TPU_TASK_INVOKE_ASYNC = 20;
} }
// The timing for one packet set being processed at one caclulator node. // The timing for one packet set being processed at one caclulator node.

View File

@ -112,6 +112,10 @@ struct TraceEvent {
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE; static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE; static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE; static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE;
static constexpr EventType GPU_TASK_INVOKE_ADVANCED =
GraphTrace::GPU_TASK_INVOKE_ADVANCED;
static constexpr EventType TPU_TASK_INVOKE_ASYNC =
GraphTrace::TPU_TASK_INVOKE_ASYNC;
}; };
// Packet trace log buffer. // Packet trace log buffer.

View File

@ -57,7 +57,6 @@ struct hash<mediapipe::TaskId> {
namespace mediapipe { namespace mediapipe {
namespace { namespace {
void BasicTraceEventTypes(TraceEventRegistry* result) { void BasicTraceEventTypes(TraceEventRegistry* result) {
// The initializer arguments below are: event_type, description, // The initializer arguments below are: event_type, description,
// is_packet_event, is_stream_event, id_event_data. // is_packet_event, is_stream_event, id_event_data.
@ -84,6 +83,15 @@ void BasicTraceEventTypes(TraceEventRegistry* result) {
"A time measured by GPU clock and by CPU clock.", true, false}, "A time measured by GPU clock and by CPU clock.", true, false},
{TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.", {TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.",
true, true, false}, true, true, false},
{TraceEvent::GPU_TASK_INVOKE, "CPU timing for initiating a GPU task."},
{TraceEvent::TPU_TASK_INVOKE, "CPU timing for initiating a TPU task."},
{TraceEvent::CPU_TASK_INVOKE, "CPU timing for initiating a CPU task."},
{TraceEvent::GPU_TASK_INVOKE_ADVANCED,
"CPU timing for initiating a GPU task bypassing the TFLite "
"interpreter."},
{TraceEvent::TPU_TASK_INVOKE_ASYNC,
"CPU timing for async initiation of a TPU task."},
}; };
for (const TraceEventType& t : basic_types) { for (const TraceEventType& t : basic_types) {
(*result)[t.event_type()] = t; (*result)[t.event_type()] = t;

View File

@ -204,7 +204,7 @@ def rewrite_mediapipe_proto(name, rewrite_proto, source_proto, **kwargs):
'import public "' + join_path + '";', 'import public "' + join_path + '";',
) )
rewrite_ref = SubsituteCommand( rewrite_ref = SubsituteCommand(
r"mediapipe\\.(" + rewrite_message_regex + ")", r"mediapipe\.(" + rewrite_message_regex + ")",
r"mediapipe.\\1", r"mediapipe.\\1",
) )
rewrite_objc = SubsituteCommand( rewrite_objc = SubsituteCommand(
@ -284,7 +284,7 @@ def mediapipe_proto_library(
def_jspb_proto: define the jspb_proto_library target def_jspb_proto: define the jspb_proto_library target
def_go_proto: define the go_proto_library target def_go_proto: define the go_proto_library target
def_options_lib: define the mediapipe_options_library target def_options_lib: define the mediapipe_options_library target
def_rewrite: define a sibbling mediapipe_proto_library with package "mediapipe" def_rewrite: define a sibling mediapipe_proto_library with package "mediapipe"
""" """
mediapipe_proto_library_impl( mediapipe_proto_library_impl(

View File

@ -183,12 +183,13 @@ absl::Status FindCorrespondingStreams(
// name, calculator, input_stream, output_stream, input_side_packet, // name, calculator, input_stream, output_stream, input_side_packet,
// output_side_packet, options. // output_side_packet, options.
// All other fields are only applicable to calculators. // All other fields are only applicable to calculators.
// TODO: Check whether executor is not set in the subgraph node
// after this issues is properly solved.
absl::Status ValidateSubgraphFields( absl::Status ValidateSubgraphFields(
const CalculatorGraphConfig::Node& subgraph_node) { const CalculatorGraphConfig::Node& subgraph_node) {
if (subgraph_node.source_layer() || subgraph_node.buffer_size_hint() || if (subgraph_node.source_layer() || subgraph_node.buffer_size_hint() ||
subgraph_node.has_output_stream_handler() || subgraph_node.has_output_stream_handler() ||
subgraph_node.input_stream_info_size() != 0 || subgraph_node.input_stream_info_size() != 0) {
!subgraph_node.executor().empty()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Subgraph \"" << subgraph_node.name() << "Subgraph \"" << subgraph_node.name()
<< "\" has a field that is only applicable to calculators."; << "\" has a field that is only applicable to calculators.";

View File

@ -467,6 +467,7 @@ cc_library(
"//mediapipe/framework/formats:frame_buffer", "//mediapipe/framework/formats:frame_buffer",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:yuv_image", "//mediapipe/framework/formats:yuv_image",
"//mediapipe/util/frame_buffer:frame_buffer_util",
"//third_party/libyuv", "//third_party/libyuv",
"@com_google_absl//absl/log", "@com_google_absl//absl/log",
"@com_google_absl//absl/log:check", "@com_google_absl//absl/log:check",
@ -932,7 +933,7 @@ mediapipe_proto_library(
], ],
) )
proto_library( mediapipe_proto_library(
name = "gl_scaler_calculator_proto", name = "gl_scaler_calculator_proto",
srcs = ["gl_scaler_calculator.proto"], srcs = ["gl_scaler_calculator.proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
@ -942,17 +943,6 @@ proto_library(
], ],
) )
mediapipe_cc_proto_library(
name = "gl_scaler_calculator_cc_proto",
srcs = ["gl_scaler_calculator.proto"],
cc_deps = [
":scale_mode_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//visibility:public"],
deps = [":gl_scaler_calculator_proto"],
)
cc_library( cc_library(
name = "gl_scaler_calculator", name = "gl_scaler_calculator",
srcs = ["gl_scaler_calculator.cc"], srcs = ["gl_scaler_calculator.cc"],

View File

@ -66,7 +66,8 @@ public class GraphTextureFrame implements TextureFrame {
if (nativeBufferHandle == 0) { if (nativeBufferHandle == 0) {
return 0; return 0;
} }
if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) { long contextHandle = nativeGetCurrentExternalContextHandle();
if (contextHandle != 0 && activeConsumerContextHandleSet.add(contextHandle)) {
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
// PacketGetter.getTextureFrameDeferredSync(). // PacketGetter.getTextureFrameDeferredSync().
if (deferredSync) { if (deferredSync) {
@ -116,7 +117,14 @@ public class GraphTextureFrame implements TextureFrame {
GlSyncToken consumerToken = null; GlSyncToken consumerToken = null;
// Note that this remove should be moved to the other overload of release when b/68808951 is // Note that this remove should be moved to the other overload of release when b/68808951 is
// addressed. // addressed.
if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) { final long contextHandle = nativeGetCurrentExternalContextHandle();
if (contextHandle == 0 && !activeConsumerContextHandleSet.isEmpty()) {
logger.atWarning().log(
"GraphTextureFrame is being released on non GL thread while having active consumers,"
+ " which may lead to external / internal GL contexts synchronization issues.");
}
if (contextHandle != 0 && activeConsumerContextHandleSet.remove(contextHandle)) {
consumerToken = consumerToken =
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
} }
@ -169,7 +177,9 @@ public class GraphTextureFrame implements TextureFrame {
private native void nativeReleaseBuffer(long nativeHandle); private native void nativeReleaseBuffer(long nativeHandle);
private native int nativeGetTextureName(long nativeHandle); private native int nativeGetTextureName(long nativeHandle);
private native int nativeGetWidth(long nativeHandle); private native int nativeGetWidth(long nativeHandle);
private native int nativeGetHeight(long nativeHandle); private native int nativeGetHeight(long nativeHandle);
private native void nativeGpuWait(long nativeHandle); private native void nativeGpuWait(long nativeHandle);

View File

@ -357,6 +357,22 @@ def mediapipe_java_proto_srcs(name = ""):
target = "//mediapipe/framework/formats:rect_java_proto_lite", target = "//mediapipe/framework/formats:rect_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/RectProto.java", src_out = "com/google/mediapipe/formats/proto/RectProto.java",
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:color_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/ColorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:label_map_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/LabelMapProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:render_data_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/RenderDataProto.java",
))
return proto_src_list return proto_src_list
def mediapipe_logging_java_proto_srcs(name = ""): def mediapipe_logging_java_proto_srcs(name = ""):

View File

@ -0,0 +1,24 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/model_maker/python/vision/face_stylizer:__subpackages__"])
filegroup(
name = "models",
srcs = glob([
"**",
]),
)

View File

@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools import functools
from typing import Callable, Optional, Tuple, TypeVar from typing import Any, Callable, Optional, Tuple, TypeVar
# Dependency imports # Dependency imports
import tensorflow as tf import tensorflow as tf
@ -66,12 +66,14 @@ class Dataset(object):
""" """
return self._size return self._size
def gen_tf_dataset(self, def gen_tf_dataset(
self,
batch_size: int = 1, batch_size: int = 1,
is_training: bool = False, is_training: bool = False,
shuffle: bool = False, shuffle: bool = False,
preprocess: Optional[Callable[..., bool]] = None, preprocess: Optional[Callable[..., Any]] = None,
drop_remainder: bool = False) -> tf.data.Dataset: drop_remainder: bool = False,
) -> tf.data.Dataset:
"""Generates a batched tf.data.Dataset for training/evaluation. """Generates a batched tf.data.Dataset for training/evaluation.
Args: Args:

View File

@ -48,11 +48,13 @@ class Classifier(custom_model.CustomModel):
self._hparams: hp.BaseHParams = None self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None self._history: tf.keras.callbacks.History = None
def _train_model(self, def _train_model(
self,
train_data: classification_ds.ClassificationDataset, train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset, validation_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None, preprocessor: Optional[Callable[..., Any]] = None,
checkpoint_path: Optional[str] = None): checkpoint_path: Optional[str] = None,
):
"""Trains the classifier model. """Trains the classifier model.
Compiles and fits the tf.keras `_model` and records the `_history`. Compiles and fits the tf.keras `_model` and records the `_history`.

View File

@ -115,9 +115,11 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
def convert_to_tflite( def convert_to_tflite(
model: tf.keras.Model, model: tf.keras.Model,
quantization_config: Optional[quantization.QuantizationConfig] = None, quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet, supported_ops: Tuple[tf.lite.OpsSet, ...] = (
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,), tf.lite.OpsSet.TFLITE_BUILTINS,
preprocess: Optional[Callable[..., bool]] = None) -> bytearray: ),
preprocess: Optional[Callable[..., Any]] = None,
) -> bytearray:
"""Converts the input Keras model to TFLite format. """Converts the input Keras model to TFLite format.
Args: Args:

View File

@ -0,0 +1,48 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(default_visibility = ["//mediapipe:__subpackages__"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/vision/core:image_utils",
],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
data = [
":testdata",
],
deps = [
":dataset",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -0,0 +1,14 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Model Maker Python Public API For Face Stylization."""

View File

@ -0,0 +1,98 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Face stylizer dataset library."""
import logging
import os
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils
# TODO: Change to a unlabeled dataset if it makes sense.
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for face stylizer fine tuning."""
@classmethod
def from_folder(
cls, dirname: str
) -> classification_dataset.ClassificationDataset:
"""Loads images from the given directory.
The style image dataset directory is expected to contain one subdirectory
whose name represents the label of the style. There can be one or multiple
images of the same style in that subdirectory. Supported input image formats
include 'jpg', 'jpeg', 'png'.
Args:
dirname: Name of the directory containing the image files.
Returns:
Dataset containing images and labels and other related info.
Raises:
ValueError: if the input data directory is empty.
"""
data_root = os.path.abspath(dirname)
# Assumes the image data of the same label are in the same subdirectory,
# gets image path and label names.
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
all_image_size = len(all_image_paths)
if all_image_size == 0:
raise ValueError('Invalid input data directory')
if not any(
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
):
raise ValueError('No images found under given directory')
label_names = sorted(
name
for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name))
)
all_label_size = len(label_names)
index_by_label = dict(
(name, index) for index, name in enumerate(label_names)
)
# Get the style label from the subdirectory name.
all_image_labels = [
index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
)
# Load label
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(all_image_labels, tf.int64)
)
# Create a dataset of (image, label) pairs
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
logging.info(
'Load images dataset with size: %d, num_label: %d, labels: %s.',
all_image_size,
all_label_size,
', '.join(label_names),
)
return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names
)

View File

@ -0,0 +1,48 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from mediapipe.model_maker.python.vision.face_stylizer import dataset
from mediapipe.tasks.python.test import test_utils
class DatasetTest(tf.test.TestCase):
def setUp(self):
super().setUp()
# TODO: Replace the stylize image dataset with licensed images.
self._test_data_dirname = 'testdata'
def test_from_folder(self):
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
data = dataset.Dataset.from_folder(dirname=input_data_dir)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
self.assertLen(data, 2)
def test_from_folder_raise_value_error_for_invalid_path(self):
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
dataset.Dataset.from_folder(dirname='invalid')
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
input_data_dir = test_utils.get_test_data_path('face_stylizer')
with self.assertRaisesRegex(
ValueError, 'No images found under given directory'
):
dataset.Dataset.from_folder(dirname=input_data_dir)
if __name__ == '__main__':
tf.test.main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 347 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 336 KiB

View File

@ -15,6 +15,7 @@
import io import io
import os import os
import tempfile import tempfile
import unittest
from unittest import mock as unittest_mock from unittest import mock as unittest_mock
import zipfile import zipfile
@ -31,6 +32,7 @@ _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdat
tf.keras.backend.experimental.enable_tf_random_generator() tf.keras.backend.experimental.enable_tf_random_generator()
@unittest.skip('b/273818271')
class GestureRecognizerTest(tf.test.TestCase): class GestureRecognizerTest(tf.test.TestCase):
def _load_data(self): def _load_data(self):
@ -72,8 +74,10 @@ class GestureRecognizerTest(tf.test.TestCase):
self._test_accuracy(model) self._test_accuracy(model)
@unittest.skip('b/273818271')
@unittest_mock.patch.object( @unittest_mock.patch.object(
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense) tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense
)
def test_gesture_recognizer_model_layer_widths(self, mock_dense): def test_gesture_recognizer_model_layer_widths(self, mock_dense):
layer_widths = [64, 32] layer_widths = [64, 32]
mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths) mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
@ -143,12 +147,14 @@ class GestureRecognizerTest(tf.test.TestCase):
hyperparameters, hyperparameters,
'HParams', 'HParams',
autospec=True, autospec=True,
return_value=gesture_recognizer.HParams(epochs=1)) return_value=gesture_recognizer.HParams(epochs=1),
)
@unittest_mock.patch.object( @unittest_mock.patch.object(
model_options, model_options,
'GestureRecognizerModelOptions', 'GestureRecognizerModelOptions',
autospec=True, autospec=True,
return_value=gesture_recognizer.ModelOptions()) return_value=gesture_recognizer.ModelOptions(),
)
def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options( def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options(
self, mock_hparams, mock_model_options): self, mock_hparams, mock_model_options):
options = gesture_recognizer.GestureRecognizerOptions() options = gesture_recognizer.GestureRecognizerOptions()

View File

@ -37,6 +37,7 @@ constexpr char kDetectionTag[] = "DETECTION";
constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kLabelsTag[] = "LABELS"; constexpr char kLabelsTag[] = "LABELS";
constexpr char kLabelsCsvTag[] = "LABELS_CSV"; constexpr char kLabelsCsvTag[] = "LABELS_CSV";
constexpr char kLabelMapTag[] = "LABEL_MAP";
using mediapipe::RE2; using mediapipe::RE2;
using Detections = std::vector<Detection>; using Detections = std::vector<Detection>;
@ -151,6 +152,11 @@ absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) {
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) { if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>(); cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
} }
if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
cc->InputSidePackets()
.Tag(kLabelMapTag)
.Set<std::unique_ptr<std::map<int, std::string>>>();
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -158,7 +164,8 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<FilterDetectionCalculatorOptions>(); options_ = cc->Options<FilterDetectionCalculatorOptions>();
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) || limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
cc->InputSidePackets().HasTag(kLabelsCsvTag); cc->InputSidePackets().HasTag(kLabelsCsvTag) ||
cc->InputSidePackets().HasTag(kLabelMapTag);
if (limit_labels_) { if (limit_labels_) {
Strings allowlist_labels; Strings allowlist_labels;
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) { if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
@ -168,8 +175,16 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
for (auto& e : allowlist_labels) { for (auto& e : allowlist_labels) {
absl::StripAsciiWhitespace(&e); absl::StripAsciiWhitespace(&e);
} }
} else { } else if (cc->InputSidePackets().HasTag(kLabelsTag)) {
allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>(); allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
} else if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
auto label_map = cc->InputSidePackets()
.Tag(kLabelMapTag)
.Get<std::unique_ptr<std::map<int, std::string>>>()
.get();
for (const auto& [_, v] : *label_map) {
allowlist_labels.push_back(v);
}
} }
allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end()); allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
} }

View File

@ -67,5 +67,68 @@ TEST(FilterDetectionCalculatorTest, DetectionFilterTest) {
)); ));
} }
TEST(FilterDetectionCalculatorTest, DetectionFilterLabelMapTest) {
auto runner = std::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "FilterDetectionCalculator"
input_stream: "DETECTION:input"
input_side_packet: "LABEL_MAP:input_map"
output_stream: "DETECTION:output"
options {
[mediapipe.FilterDetectionCalculatorOptions.ext]: { min_score: 0.6 }
}
)pb"));
runner->MutableInputs()->Tag("DETECTION").packets = {
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
label: "a"
label: "b"
label: "c"
label: "d"
score: 1
score: 0.8
score: 0.3
score: 0.9
)pb"))
.At(Timestamp(20)),
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
label: "a"
label: "b"
label: "c"
label: "e"
score: 0.6
score: 0.4
score: 0.2
score: 0.7
)pb"))
.At(Timestamp(40)),
};
auto label_map = std::make_unique<std::map<int, std::string>>();
(*label_map)[0] = "a";
(*label_map)[1] = "b";
(*label_map)[2] = "c";
runner->MutableSidePackets()->Tag("LABEL_MAP") =
AdoptAsUniquePtr(label_map.release());
// Run graph.
MP_ASSERT_OK(runner->Run());
// Check output.
EXPECT_THAT(
runner->Outputs().Tag("DETECTION").packets,
ElementsAre(PacketContainsTimestampAndPayload<Detection>(
Eq(Timestamp(20)),
EqualsProto(R"pb(
label: "a" label: "b" score: 1 score: 0.8
)pb")), // Packet 1 at timestamp 20.
PacketContainsTimestampAndPayload<Detection>(
Eq(Timestamp(40)),
EqualsProto(R"pb(
label: "a" score: 0.6
)pb")) // Packet 2 at timestamp 40.
));
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -57,6 +57,7 @@ pybind_extension(
"//mediapipe/framework/formats:landmark_registration", "//mediapipe/framework/formats:landmark_registration",
"//mediapipe/framework/formats:rect_registration", "//mediapipe/framework/formats:rect_registration",
"//mediapipe/modules/objectron/calculators:annotation_registration", "//mediapipe/modules/objectron/calculators:annotation_registration",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_registration",
], ],
) )
@ -94,6 +95,8 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph", "//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
] + select({ ] + select({
# TODO: Build text_classifier_graph and text_embedder_graph on Windows. # TODO: Build text_classifier_graph and text_embedder_graph on Windows.
"//mediapipe:windows": [], "//mediapipe:windows": [],

View File

@ -30,7 +30,7 @@ constexpr absl::string_view kMediaPipeTasksPayload = "MediaPipeTasksStatus";
// //
// At runtime, such codes are meant to be attached (where applicable) to a // At runtime, such codes are meant to be attached (where applicable) to a
// `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and // `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and
// stringifed error code as value (aka payload). This logic is encapsulated in // stringified error code as value (aka payload). This logic is encapsulated in
// the `CreateStatusWithPayload` helper below for convenience. // the `CreateStatusWithPayload` helper below for convenience.
// //
// The returned status includes: // The returned status includes:

View File

@ -51,12 +51,11 @@ ModelAssetBundleResources::Create(
auto model_bundle_resources = absl::WrapUnique( auto model_bundle_resources = absl::WrapUnique(
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file))); new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
model_bundle_resources->ExtractModelFilesFromExternalFileProto()); model_bundle_resources->ExtractFilesFromExternalFileProto());
return model_bundle_resources; return model_bundle_resources;
} }
absl::Status absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() {
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
if (model_asset_bundle_file_->has_file_name()) { if (model_asset_bundle_file_->has_file_name()) {
// If the model asset bundle file name is a relative path, searches the file // If the model asset bundle file name is a relative path, searches the file
// in a platform-specific location and returns the absolute path on success. // in a platform-specific location and returns the absolute path on success.
@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
model_asset_bundle_file_handler_->GetFileContent().data(); model_asset_bundle_file_handler_->GetFileContent().data();
size_t buffer_size = size_t buffer_size =
model_asset_bundle_file_handler_->GetFileContent().size(); model_asset_bundle_file_handler_->GetFileContent().size();
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_);
&model_files_);
} }
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile( absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetFile(
const std::string& filename) const { const std::string& filename) const {
auto it = model_files_.find(filename); auto it = files_.find(filename);
if (it == model_files_.end()) { if (it == files_.end()) {
auto model_files = ListModelFiles(); auto files = ListFiles();
std::string all_model_files = std::string all_files = absl::StrJoin(files.begin(), files.end(), ", ");
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
return CreateStatusWithPayload( return CreateStatusWithPayload(
StatusCode::kNotFound, StatusCode::kNotFound,
absl::StrFormat("No model file with name: %s. All model files in the " absl::StrFormat("No file with name: %s. All files in the model asset "
"model asset bundle are: %s.", "bundle are: %s.",
filename, all_model_files), filename, all_files),
MediaPipeTasksStatus::kFileNotFoundError); MediaPipeTasksStatus::kFileNotFoundError);
} }
return it->second; return it->second;
} }
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const { std::vector<std::string> ModelAssetBundleResources::ListFiles() const {
std::vector<std::string> model_names; std::vector<std::string> file_names;
for (const auto& [model_name, _] : model_files_) { for (const auto& [file_name, _] : files_) {
model_names.push_back(model_name); file_names.push_back(file_name);
} }
return model_names; return file_names;
} }
} // namespace core } // namespace core

View File

@ -28,8 +28,8 @@ namespace core {
// The mediapipe task model asset bundle resources class. // The mediapipe task model asset bundle resources class.
// A ModelAssetBundleResources object, created from an external file proto, // A ModelAssetBundleResources object, created from an external file proto,
// contains model asset bundle related resources and the method to extract the // contains model asset bundle related resources and the method to extract the
// tflite models or model asset bundles for the mediapipe sub-tasks. As the // tflite models, resource files or model asset bundles for the mediapipe
// resources are owned by the ModelAssetBundleResources object // sub-tasks. As the resources are owned by the ModelAssetBundleResources object
// callers must keep ModelAssetBundleResources alive while using any of the // callers must keep ModelAssetBundleResources alive while using any of the
// resources. // resources.
class ModelAssetBundleResources { class ModelAssetBundleResources {
@ -50,14 +50,13 @@ class ModelAssetBundleResources {
// Returns the model asset bundle resources tag. // Returns the model asset bundle resources tag.
std::string GetTag() const { return tag_; } std::string GetTag() const { return tag_; }
// Gets the contents of the model file (either tflite model file or model // Gets the contents of the model file (either tflite model file, resource
// bundle file) with the provided name. An error is returned if there is no // file or model bundle file) with the provided name. An error is returned if
// such model file. // there is no such model file.
absl::StatusOr<absl::string_view> GetModelFile( absl::StatusOr<absl::string_view> GetFile(const std::string& filename) const;
const std::string& filename) const;
// Lists all the model file names in the model asset model. // Lists all the file names in the model asset model.
std::vector<std::string> ListModelFiles() const; std::vector<std::string> ListFiles() const;
private: private:
// Constructor. // Constructor.
@ -65,9 +64,9 @@ class ModelAssetBundleResources {
const std::string& tag, const std::string& tag,
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file); std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
// Extracts the model files (either tflite model file or model bundle file) // Extracts the model files (either tflite model file, resource file or model
// from the external file proto. // bundle file) from the external file proto.
absl::Status ExtractModelFilesFromExternalFileProto(); absl::Status ExtractFilesFromExternalFileProto();
// The model asset bundle resources tag. // The model asset bundle resources tag.
const std::string tag_; const std::string tag_;
@ -78,11 +77,11 @@ class ModelAssetBundleResources {
// The ExternalFileHandler for the model asset bundle. // The ExternalFileHandler for the model asset bundle.
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_; std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
// The model files bundled in model asset bundle, as a map with the filename // The files bundled in model asset bundle, as a map with the filename
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and // (corresponding to a basename, e.g. "hand_detector.tflite") as key and
// a pointer to the file contents as value. Each model file can be either // a pointer to the file contents as value. Each file can be either a TFLite
// a TFLite model file or a model bundle file for sub-task. // model file, resource file or a model bundle file for sub-task.
absl::flat_hash_map<std::string, absl::string_view> model_files_; absl::flat_hash_map<std::string, absl::string_view> files_;
}; };
} // namespace core } // namespace core

View File

@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
#endif // _WIN32 #endif // _WIN32
@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status_or_model_bundle_file = auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task"); model_bundle_resources->GetFile("dummy_hand_landmarker.task");
MP_EXPECT_OK(status_or_model_bundle_file.status()); MP_EXPECT_OK(status_or_model_bundle_file.status());
// Creates sub-task model asset bundle resources. // Creates sub-task model asset bundle resources.
@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(hand_landmaker_model_file))); std::move(hand_landmaker_model_file)));
MP_EXPECT_OK(hand_landmaker_model_bundle_resources MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_detector.tflite") ->GetFile("dummy_hand_detector.tflite")
.status()); .status());
MP_EXPECT_OK(hand_landmaker_model_bundle_resources MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_landmarker.tflite") ->GetFile("dummy_hand_landmarker.tflite")
.status()); .status());
} }
@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status_or_model_bundle_file = auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite"); model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite");
MP_EXPECT_OK(status_or_model_bundle_file.status()); MP_EXPECT_OK(status_or_model_bundle_file.status());
// Verify tflite model works. // Verify tflite model works.
@ -200,11 +196,11 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
auto model_bundle_resources, auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status = model_bundle_resources->GetModelFile("not_found.task").status(); auto status = model_bundle_resources->GetFile("not_found.task").status();
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
EXPECT_THAT(status.message(), EXPECT_THAT(
testing::HasSubstr( status.message(),
"No model file with name: not_found.task. All model files in " testing::HasSubstr("No file with name: not_found.task. All files in "
"the model asset bundle are: ")); "the model asset bundle are: "));
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload), EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
testing::Optional(absl::Cord( testing::Optional(absl::Cord(
@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
auto model_bundle_resources, auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto model_files = model_bundle_resources->ListModelFiles(); auto model_files = model_bundle_resources->ListFiles();
std::vector<std::string> expected_model_files = { std::vector<std::string> expected_model_files = {
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"}; "dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
std::sort(model_files.begin(), model_files.end()); std::sort(model_files.begin(), model_files.end());

View File

@ -77,9 +77,11 @@ class ModelResourcesCalculator : public api2::Node {
if (options.has_model_file()) { if (options.has_model_file()) {
RET_CHECK(options.model_file().has_file_content() || RET_CHECK(options.model_file().has_file_content() ||
options.model_file().has_file_descriptor_meta() || options.model_file().has_file_descriptor_meta() ||
options.model_file().has_file_name()) options.model_file().has_file_name() ||
options.model_file().has_file_pointer_meta())
<< "'model_file' must specify at least one of " << "'model_file' must specify at least one of "
"'file_content', 'file_descriptor_meta', or 'file_name'"; "'file_content', 'file_descriptor_meta', 'file_name', or "
"'file_pointer_meta'";
} }
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -179,9 +179,9 @@ TEST_F(ModelResourcesCalculatorTest, EmptyExternalFileProto) {
auto status = graph.Initialize(graph_config); auto status = graph.Initialize(graph_config);
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(), EXPECT_THAT(status.message(),
testing::HasSubstr( testing::HasSubstr("'model_file' must specify at least one of "
"'model_file' must specify at least one of " "'file_content', 'file_descriptor_meta', "
"'file_content', 'file_descriptor_meta', or 'file_name'")); "'file_name', or 'file_pointer_meta'"));
} }
TEST_F(ModelResourcesCalculatorTest, GraphServiceNotAvailable) { TEST_F(ModelResourcesCalculatorTest, GraphServiceNotAvailable) {

View File

@ -138,7 +138,7 @@ class InferenceSubgraph : public Subgraph {
delegate.mutable_tflite()->CopyFrom(acceleration.tflite()); delegate.mutable_tflite()->CopyFrom(acceleration.tflite());
break; break;
case Acceleration::DELEGATE_NOT_SET: case Acceleration::DELEGATE_NOT_SET:
// Deafult inference calculator setting. // Default inference calculator setting.
break; break;
} }
return delegate; return delegate;

View File

@ -124,10 +124,10 @@ class ModelTaskGraph : public Subgraph {
// Inserts a mediapipe task inference subgraph into the provided // Inserts a mediapipe task inference subgraph into the provided
// GraphBuilder. The returned node provides the following interfaces to the // GraphBuilder. The returned node provides the following interfaces to the
// the rest of the graph: // the rest of the graph:
// - a tensor vector (std::vector<MeidaPipe::Tensor>) input stream with tag // - a tensor vector (std::vector<mediapipe::Tensor>) input stream with tag
// "TENSORS", representing the input tensors to be consumed by the // "TENSORS", representing the input tensors to be consumed by the
// inference engine. // inference engine.
// - a tensor vector (std::vector<MeidaPipe::Tensor>) output stream with tag // - a tensor vector (std::vector<mediapipe::Tensor>) output stream with tag
// "TENSORS", representing the output tensors generated by the inference // "TENSORS", representing the output tensors generated by the inference
// engine. // engine.
// - a MetadataExtractor output side packet with tag "METADATA_EXTRACTOR". // - a MetadataExtractor output side packet with tag "METADATA_EXTRACTOR".

View File

@ -301,7 +301,7 @@ absl::Status TaskRunner::Close() {
} }
is_running_ = false; is_running_ = false;
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
AddPayload(graph_.CloseAllInputStreams(), "Fail to close intput streams", AddPayload(graph_.CloseAllInputStreams(), "Fail to close input streams",
MediaPipeTasksStatus::kRunnerFailsToCloseError)); MediaPipeTasksStatus::kRunnerFailsToCloseError));
MP_RETURN_IF_ERROR(AddPayload( MP_RETURN_IF_ERROR(AddPayload(
graph_.WaitUntilDone(), "Fail to shutdown the MediaPipe graph.", graph_.WaitUntilDone(), "Fail to shutdown the MediaPipe graph.",

View File

@ -65,7 +65,7 @@ class TaskRunner {
// Creates the task runner with a CalculatorGraphConfig proto. // Creates the task runner with a CalculatorGraphConfig proto.
// If a tflite op resolver object is provided, the task runner will take // If a tflite op resolver object is provided, the task runner will take
// it as the global op resolver for all models running within this task. // it as the global op resolver for all models running within this task.
// The op resolver's owernship will be transferred into the pipeleine runner. // The op resolver's ownership will be transferred into the pipeleine runner.
// When a user-defined PacketsCallback is provided, clients must use the // When a user-defined PacketsCallback is provided, clients must use the
// asynchronous method, Send(), to provide the input packets. If the packets // asynchronous method, Send(), to provide the input packets. If the packets
// callback is absent, clients must use the synchronous method, Process(), to // callback is absent, clients must use the synchronous method, Process(), to
@ -84,7 +84,7 @@ class TaskRunner {
// frames from a video file and an audio file. The call blocks the current // frames from a video file and an audio file. The call blocks the current
// thread until a failure status or a successful result is returned. // thread until a failure status or a successful result is returned.
// If the input packets have no timestamp, an internal timestamp will be // If the input packets have no timestamp, an internal timestamp will be
// assigend per invocation. Otherwise, when the timestamp is set in the // assigned per invocation. Otherwise, when the timestamp is set in the
// input packets, the caller must ensure that the input packet timestamps are // input packets, the caller must ensure that the input packet timestamps are
// greater than the timestamps of the previous invocation. This method is // greater than the timestamps of the previous invocation. This method is
// thread-unsafe and it is the caller's responsibility to synchronize access // thread-unsafe and it is the caller's responsibility to synchronize access

View File

@ -64,7 +64,7 @@ class ModelMetadataPopulator {
// Loads associated files into the TFLite FlatBuffer model. The input is a map // Loads associated files into the TFLite FlatBuffer model. The input is a map
// of {filename, file contents}. // of {filename, file contents}.
// //
// Warning: this method removes any previoulsy present associated files. // Warning: this method removes any previously present associated files.
// Calling this method multiple time removes any associated files from // Calling this method multiple time removes any associated files from
// previous calls, so this method should usually be called only once. // previous calls, so this method should usually be called only once.
void LoadAssociatedFiles( void LoadAssociatedFiles(

View File

@ -213,7 +213,7 @@ void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
Version* min_version) { Version* min_version) {
if (table == nullptr) return; if (table == nullptr) return;
// Checks the ContenProperties field. // Checks the ContentProperties field.
if (table->content_properties_type() == ContentProperties_AudioProperties) { if (table->content_properties_type() == ContentProperties_AudioProperties) {
UpdateMinimumVersion( UpdateMinimumVersion(
GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties), GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties),

View File

@ -31,8 +31,8 @@ PYBIND11_MODULE(_pywrap_metadata_version, m) {
// Using pybind11 type conversions to convert between Python and native // Using pybind11 type conversions to convert between Python and native
// C++ types. There are other options to provide access to native Python types // C++ types. There are other options to provide access to native Python types
// in C++ and vice versa. See the pybind 11 instrcution [1] for more details. // in C++ and vice versa. See the pybind 11 instruction [1] for more details.
// Type converstions is recommended by pybind11, though the main downside // Type conversions is recommended by pybind11, though the main downside
// is that a copy of the data must be made on every Python to C++ transition: // is that a copy of the data must be made on every Python to C++ transition:
// this is needed since the C++ and Python versions of the same type generally // this is needed since the C++ and Python versions of the same type generally
// wont have the same memory layout. // wont have the same memory layout.

View File

@ -79,7 +79,7 @@ TEST(MetadataVersionTest,
auto metadata = metadata_builder.Finish(); auto metadata = metadata_builder.Finish();
FinishModelMetadataBuffer(builder, metadata); FinishModelMetadataBuffer(builder, metadata);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -100,7 +100,7 @@ TEST(MetadataVersionTest,
auto metadata = metadata_builder.Finish(); auto metadata = metadata_builder.Finish();
builder.Finish(metadata); builder.Finish(metadata);
// Gets the mimimum metadata parser version and triggers error. // Gets the minimum metadata parser version and triggers error.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -121,7 +121,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files); metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -147,7 +147,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -172,7 +172,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -203,7 +203,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -234,7 +234,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -265,7 +265,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -294,7 +294,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -323,7 +323,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -348,7 +348,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -373,7 +373,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -404,7 +404,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -431,7 +431,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -453,7 +453,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files); metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -476,7 +476,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files); metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -504,7 +504,7 @@ TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) {
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),

View File

@ -42,3 +42,36 @@ cc_test(
"@org_tensorflow//tensorflow/lite/kernels:test_util", "@org_tensorflow//tensorflow/lite/kernels:test_util",
], ],
) )
cc_library(
name = "ngram_hash",
srcs = ["ngram_hash.cc"],
hdrs = ["ngram_hash.h"],
copts = tflite_copts(),
deps = [
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils",
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
],
alwayslink = 1,
)
cc_test(
name = "ngram_hash_test",
srcs = ["ngram_hash_test.cc"],
deps = [
":ngram_hash",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
"@com_google_absl//absl/types:optional",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:test_util",
],
)

View File

@ -0,0 +1,264 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include <cstdint>
#include <string>
#include <vector>
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
namespace tflite::ops::custom {
namespace ngram_op {
namespace {
using ::flexbuffers::GetRoot;
using ::flexbuffers::Map;
using ::flexbuffers::TypedVector;
using ::mediapipe::tasks::text::language_detector::custom_ops::
LowercaseUnicodeStr;
using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize;
using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput;
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
MurmurHash64WithSeed;
using ::tflite::GetString;
using ::tflite::StringRef;
constexpr int kInputMessage = 0;
constexpr int kOutputLabel = 0;
constexpr int kDefaultMaxSplits = 128;
// This op takes in a string, finds the character ngrams for it and then
// maps each of these ngrams to an index using the specified vocabulary sizes.
// Input(s):
// - input: Input string.
// - seeds: Seed for the random number generator.
// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would
// be interpreted as generating unigrams, bigrams, and trigrams.
// - vocab_sizes: Size of the vocabulary for each of the ngram features
// respectively. The op would generate vocab ids to be less than or equal to
// the vocab size. The index 0 implies an invalid ngram.
// - max_splits: Maximum number of tokens in the output. If this is unset, the
// limit is `kDefaultMaxSplits`.
// - lower_case_input: If this is set to true, the input string would be
// lower-cased before any processing.
// Output(s):
// - output: A tensor of size [number of ngrams, number of tokens + 2],
// where 2 tokens are reserved for the padding. If `max_splits` is set, this
// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`.
// Helper class used for pre-processing the input.
class NGramHashParams {
public:
NGramHashParams(const uint64_t seed, const std::vector<int>& ngram_lengths,
const std::vector<int>& vocab_sizes, int max_splits,
bool lower_case_input)
: seed_(seed),
ngram_lengths_(ngram_lengths),
vocab_sizes_(vocab_sizes),
max_splits_(max_splits),
lower_case_input_(lower_case_input) {}
TfLiteStatus PreprocessInput(const TfLiteTensor* input_t,
TfLiteContext* context) {
if (input_t->bytes == 0) {
context->ReportError(context, "Empty input not supported.");
return kTfLiteError;
}
// Do sanity checks on the input.
if (ngram_lengths_.empty()) {
context->ReportError(context, "`ngram_lengths` must be non-empty.");
return kTfLiteError;
}
if (vocab_sizes_.empty()) {
context->ReportError(context, "`vocab_sizes` must be non-empty.");
return kTfLiteError;
}
if (ngram_lengths_.size() != vocab_sizes_.size()) {
context->ReportError(
context,
"Sizes of `ngram_lengths` and `vocab_sizes` must be the same.");
return kTfLiteError;
}
if (max_splits_ <= 0) {
context->ReportError(context, "`max_splits` must be > 0.");
return kTfLiteError;
}
// Obtain and tokenize the input.
StringRef inputref = GetString(input_t, /*string_index=*/0);
if (lower_case_input_) {
std::string lower_cased_str;
LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str);
tokenized_output_ =
Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_,
/*exclude_nonalphaspace_tokens=*/true);
} else {
tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_,
/*exclude_nonalphaspace_tokens=*/true);
}
return kTfLiteOk;
}
uint64_t GetSeed() const { return seed_; }
int GetNumTokens() const { return tokenized_output_.tokens.size(); }
int GetNumNGrams() const { return ngram_lengths_.size(); }
std::vector<int> GetNGramLengths() const { return ngram_lengths_; }
std::vector<int> GetVocabSizes() const { return vocab_sizes_; }
const TokenizedOutput& GetTokenizedOutput() const {
return tokenized_output_;
}
TokenizedOutput tokenized_output_;
private:
const uint64_t seed_;
std::vector<int> ngram_lengths_;
std::vector<int> vocab_sizes_;
const int max_splits_;
const bool lower_case_input_;
};
// Convert the TypedVector into a regular std::vector.
std::vector<int> GetIntVector(TypedVector typed_vec) {
std::vector<int> vec(typed_vec.size());
for (int j = 0; j < typed_vec.size(); j++) {
vec[j] = typed_vec[j].AsInt32();
}
return vec;
}
void GetNGramHashIndices(NGramHashParams* params, int32_t* data) {
const int max_unicode_length = params->GetNumTokens();
const auto ngram_lengths = params->GetNGramLengths();
const auto vocab_sizes = params->GetVocabSizes();
const auto& tokenized_output = params->GetTokenizedOutput();
const auto seed = params->GetSeed();
// Compute for each ngram.
for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) {
const int vocab_size = vocab_sizes[ngram];
const int ngram_length = ngram_lengths[ngram];
// Compute for each token within the input.
for (int start = 0; start < tokenized_output.tokens.size(); start++) {
// Compute the number of bytes for the ngram starting at the given
// token.
int num_bytes = 0;
for (int i = start;
i < tokenized_output.tokens.size() && i < (start + ngram_length);
i++) {
num_bytes += tokenized_output.tokens[i].second;
}
// Compute the hash for the ngram starting at the token.
const auto str_hash = MurmurHash64WithSeed(
tokenized_output.str.c_str() + tokenized_output.tokens[start].first,
num_bytes, seed);
// Map the hash to an index in the vocab.
data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1;
}
}
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const Map& m = GetRoot(buffer_t, length).AsMap();
const uint64_t seed = m["seed"].AsUInt64();
const std::vector<int> ngram_lengths =
GetIntVector(m["ngram_lengths"].AsTypedVector());
const std::vector<int> vocab_sizes =
GetIntVector(m["vocab_sizes"].AsTypedVector());
const int max_splits =
m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32();
const bool lowercase_input =
m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool();
return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits,
lowercase_input);
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<NGramHashParams*>(buffer);
}
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
SetTensorToDynamic(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
TF_LITE_ENSURE_OK(
context,
params->PreprocessInput(GetInput(context, node, kInputMessage), context));
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
if (IsDynamicTensor(output)) {
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = 1;
output_size->data[1] = params->GetNumNGrams();
output_size->data[2] = params->GetNumTokens();
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
} else {
context->ReportError(context, "Output must by dynamic.");
return kTfLiteError;
}
if (output->type == kTfLiteInt32) {
GetNGramHashIndices(params, output->data.i32);
} else {
context->ReportError(context, "Output type must be Int32.");
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace ngram_op
TfLiteRegistration* Register_NGRAM_HASH() {
static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free,
ngram_op::Resize, ngram_op::Eval};
return &r;
}
} // namespace tflite::ops::custom

View File

@ -0,0 +1,27 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
#include "tensorflow/lite/kernels/register.h"
namespace tflite::ops::custom {
TfLiteRegistration* Register_NGRAM_HASH();
} // namespace tflite::ops::custom
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_

View File

@ -0,0 +1,313 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
namespace tflite::ops::custom {
namespace {
using ::flexbuffers::Builder;
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
MurmurHash64WithSeed;
using ::testing::ElementsAreArray;
using ::testing::Message;
// Helper class for testing the op.
class NGramHashModel : public SingleOpModel {
public:
explicit NGramHashModel(const uint64_t seed,
const std::vector<int>& ngram_lengths,
const std::vector<int>& vocab_sizes,
const absl::optional<int> max_splits = std::nullopt) {
// Setup the model inputs.
Builder fbb;
size_t start = fbb.StartMap();
fbb.UInt("seed", seed);
{
size_t start = fbb.StartVector("ngram_lengths");
for (const int& ngram_len : ngram_lengths) {
fbb.Int(ngram_len);
}
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
}
{
size_t start = fbb.StartVector("vocab_sizes");
for (const int& vocab_size : vocab_sizes) {
fbb.Int(vocab_size);
}
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
}
if (max_splits) {
fbb.Int("max_splits", *max_splits);
}
fbb.EndMap(start);
fbb.Finish();
output_ = AddOutput({TensorType_INT32, {}});
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
BuildInterpreter({GetShape(input_)});
}
void SetupInputTensor(const std::string& input) {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
}
void Invoke(const std::string& input) {
SetupInputTensor(input);
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
}
TfLiteStatus InvokeUnchecked(const std::string& input) {
SetupInputTensor(input);
return SingleOpModel::Invoke();
}
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_ = AddInput(TensorType_STRING);
int output_;
};
TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) {
// Checks that the op returns the expected value when the input is sane.
// Also checks that when `max_splits` is not specified, the entire string is
// tokenized.
const uint64_t kSeed = 123;
const std::vector<int> vocab_sizes({100, 200});
std::vector<int> ngram_lengths({1, 2});
const std::vector<std::string> testcase_inputs({
"hi",
"wow",
"!",
"HI",
});
// A hash function that maps the given string to an index in the embedding
// table denoted by `vocab_idx`.
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
const auto hash_value =
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
};
const std::vector<std::vector<int>> expected_testcase_outputs(
{{
// Unigram & Bigram output for "hi".
hash("^", 0),
hash("h", 0),
hash("i", 0),
hash("$", 0),
hash("^h", 1),
hash("hi", 1),
hash("i$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow".
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "!" (which will get replaced by " ").
hash("^", 0),
hash(" ", 0),
hash("$", 0),
hash("^ ", 1),
hash(" $", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "HI" (which will get lower-cased).
hash("^", 0),
hash("h", 0),
hash("i", 0),
hash("$", 0),
hash("^h", 1),
hash("hi", 1),
hash("i$", 1),
hash("$", 1),
}});
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
const string& testcase_input = testcase_inputs[test_idx];
m.Invoke(testcase_input);
SCOPED_TRACE(Message() << "Where the testcases' input is: "
<< testcase_input);
EXPECT_THAT(m.GetOutput<int>(),
ElementsAreArray(expected_testcase_outputs[test_idx]));
EXPECT_THAT(m.GetOutputShape(),
ElementsAreArray(
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
static_cast<int>(testcase_input.size()) + /*padding*/ 2}));
}
}
TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) {
// Checks that the op returns the expected value when the input is correct
// when `max_splits` is specified.
const uint64_t kSeed = 123;
const std::vector<int> vocab_sizes({100, 200});
std::vector<int> ngram_lengths({1, 2});
const std::string testcase_input = "wow";
const std::vector<int> max_splits({2, 3, 4, 5, 6});
// A hash function that maps the given string to an index in the embedding
// table denoted by `vocab_idx`.
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
const auto hash_value =
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
};
const std::vector<std::vector<int>> expected_testcase_outputs(
{{
// Unigram & Bigram output for "wow", when `max_splits` == 2.
// We cannot include any of the actual tokens, since `max_splits`
// only allows enough space for the delimiters.
hash("^", 0),
hash("$", 0),
hash("^$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 3.
// We can start to include some tokens from the input string.
hash("^", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 4.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("o$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 5.
// We can include the full input string.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 6.
// `max_splits` is more than the full input string.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
}});
for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) {
const int testcase_max_splits = max_splits[test_idx];
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits);
m.Invoke(testcase_input);
SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits);
EXPECT_THAT(m.GetOutput<int>(),
ElementsAreArray(expected_testcase_outputs[test_idx]));
EXPECT_THAT(
m.GetOutputShape(),
ElementsAreArray(
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
std::min(
// Longest possible tokenization when using the entire
// input.
static_cast<int>(testcase_input.size()) + /*padding*/ 2,
// Longest possible string when the `max_splits` value
// is < testcase_input.size() + 2 for padding.
testcase_max_splits)}));
}
}
TEST(NGramHashTest, InvalidMaxSplitsValue) {
// Check that the op errors out when given an invalid max splits value.
const std::vector<int> invalid_max_splits({0, -1, -5, -100});
for (const int max_splits : invalid_max_splits) {
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
/*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits);
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
}
TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) {
// Check that the op errors out when ngram lengths and vocab sizes mistmatch.
{
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300},
/*vocab_sizes=*/{1, 2});
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
{
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
/*vocab_sizes=*/{1, 2, 3});
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
}
} // namespace
} // namespace tflite::ops::custom

View File

@ -0,0 +1,42 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "ngram_hash_ops_utils",
srcs = [
"ngram_hash_ops_utils.cc",
],
hdrs = [
"ngram_hash_ops_utils.h",
],
deps = [
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf",
],
)
cc_test(
name = "ngram_hash_ops_utils_test",
size = "small",
srcs = [
"ngram_hash_ops_utils_test.cc",
],
deps = [
":ngram_hash_ops_utils",
"//mediapipe/framework/port:gtest_main",
],
)

View File

@ -0,0 +1,38 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "murmur",
srcs = ["murmur.cc"],
hdrs = ["murmur.h"],
deps = [
"//mediapipe/framework/port:integral_types",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:endian",
],
)
cc_test(
name = "murmur_test",
srcs = ["murmur_test.cc"],
deps = [
":murmur",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
],
)

View File

@ -0,0 +1,95 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Austin Appelby and Jyrki Alakuijala.
// Original copyright message below.
// Copyright 2009 Google Inc. All Rights Reserved.
// Author: aappleby@google.com (Austin Appleby)
// jyrki@google.com (Jyrki Alakuijala)
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include <cstdint>
#include "absl/base/internal/endian.h"
#include "absl/base/optimization.h"
#include "mediapipe/framework/port/integral_types.h"
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
namespace {
using ::absl::little_endian::Load64;
// Murmur 2.0 multiplication constant.
static const uint64_t kMul = 0xc6a4a7935bd1e995ULL;
// We need to mix some of the bits that get propagated and mixed into the
// high bits by multiplication back into the low bits. 17 last bits get
// a more efficiently mixed with this.
inline uint64_t ShiftMix(uint64_t val) { return val ^ (val >> 47); }
// Accumulate 8 bytes into 64-bit Murmur hash
inline uint64_t MurmurStep(uint64_t hash, uint64_t data) {
hash ^= ShiftMix(data * kMul) * kMul;
hash *= kMul;
return hash;
}
// Build a uint64 from 1-8 bytes.
// 8 * len least significant bits are loaded from the memory with
// LittleEndian order. The 64 - 8 * len most significant bits are
// set all to 0.
// In latex-friendly words, this function returns:
// $\sum_{i=0}^{len-1} p[i] 256^{i}$, where p[i] is unsigned.
//
// This function is equivalent to:
// uint64 val = 0;
// memcpy(&val, p, len);
// return ToHost64(val);
//
// The caller needs to guarantee that 0 <= len <= 8.
uint64_t Load64VariableLength(const void* const p, int len) {
ABSL_ASSUME(len >= 0 && len <= 8);
uint64_t val = 0;
const uint8_t* const src = static_cast<const uint8_t*>(p);
for (int i = 0; i < len; ++i) {
val |= static_cast<uint64_t>(src[i]) << (8 * i);
}
return val;
}
} // namespace
unsigned long long MurmurHash64WithSeed(const char* buf, // NOLINT
const size_t len, const uint64_t seed) {
// Let's remove the bytes not divisible by the sizeof(uint64).
// This allows the inner loop to process the data as 64 bit integers.
const size_t len_aligned = len & ~0x7;
const char* const end = buf + len_aligned;
uint64_t hash = seed ^ (len * kMul);
for (const char* p = buf; p != end; p += 8) {
hash = MurmurStep(hash, Load64(p));
}
if ((len & 0x7) != 0) {
const uint64_t data = Load64VariableLength(end, len & 0x7);
hash ^= data;
hash *= kMul;
}
hash = ShiftMix(hash) * kMul;
hash = ShiftMix(hash);
return hash;
}
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash

View File

@ -0,0 +1,43 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Austin Appelby and Jyrki Alakuijala.
// Original copyright message below.
// Copyright 2009 Google Inc. All Rights Reserved.
// Author: aappleby@google.com (Austin Appelby)
// jyrki@google.com (Jyrki Alakuijala)
//
// MurmurHash is a fast multiplication and shifting based algorithm,
// based on Austin Appleby's MurmurHash 2.0 algorithm.
#ifndef UTIL_HASH_MURMUR_H_
#define UTIL_HASH_MURMUR_H_
#include <stddef.h>
#include <stdlib.h> // for size_t.
#include <cstdint>
#include "mediapipe/framework/port/integral_types.h"
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
// Hash function for a byte array. Has a seed which allows this hash function to
// be used in algorithms that need a family of parameterized hash functions.
// e.g. Minhash.
unsigned long long MurmurHash64WithSeed(const char* buf, size_t len, // NOLINT
uint64_t seed);
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash
#endif // UTIL_HASH_MURMUR_H_

View File

@ -0,0 +1,66 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a test library written by Jyrki Alakuijala.
// Original copyright message below.
// Copyright 2009 Google Inc. All Rights Reserved.
// Author: jyrki@google.com (Jyrki Alakuijala)
//
// Tests for the fast hashing algorithm based on Austin Appleby's
// MurmurHash 2.0 algorithm. See http://murmurhash.googlepages.com/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include <string.h>
#include <cstdint>
#include <string>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
TEST(Murmur, EmptyData64) {
EXPECT_EQ(uint64_t{0}, MurmurHash64WithSeed(nullptr, uint64_t{0}, 0));
}
TEST(Murmur, VaryWithDifferentSeeds) {
// While in theory different seeds could return the same
// hash for the same data this is unlikely.
char data1 = 'x';
EXPECT_NE(MurmurHash64WithSeed(&data1, 1, 100),
MurmurHash64WithSeed(&data1, 1, 101));
}
// Hashes don't change.
TEST(Murmur, Idempotence) {
const char data[] = "deadbeef";
const size_t dlen = strlen(data);
for (int i = 0; i < 10; i++) {
EXPECT_EQ(MurmurHash64WithSeed(data, dlen, i),
MurmurHash64WithSeed(data, dlen, i));
}
const char next_data[] = "deadbeef000---";
const size_t next_dlen = strlen(next_data);
for (int i = 0; i < 10; i++) {
EXPECT_EQ(MurmurHash64WithSeed(next_data, next_dlen, i),
MurmurHash64WithSeed(next_data, next_dlen, i));
}
}
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash

View File

@ -0,0 +1,96 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include <string>
#include <utility>
#include <vector>
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
namespace mediapipe::tasks::text::language_detector::custom_ops {
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
bool exclude_nonalphaspace_tokens) {
const std::string kPrefix = "^";
const std::string kSuffix = "$";
const std::string kReplacementToken = " ";
TokenizedOutput output;
size_t token_start = 0;
output.str.reserve(len + 2);
output.tokens.reserve(len + 2);
output.str.append(kPrefix);
output.tokens.push_back(std::make_pair(token_start, kPrefix.size()));
token_start += kPrefix.size();
Rune token;
for (int i = 0; i < len && output.tokens.size() + 1 < max_tokens;) {
// Use the standard UTF-8 library to find the next token.
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
// Stop processing, if we can't read any more tokens, or we have reached
// maximum allowed tokens, allocating one token for the suffix.
if (bytes_read == 0) {
break;
}
// If `exclude_nonalphaspace_tokens` is set to true, and the token is not
// alphanumeric, replace it with a replacement token.
if (exclude_nonalphaspace_tokens && !utf_isalpharune(token)) {
output.str.append(kReplacementToken);
output.tokens.push_back(
std::make_pair(token_start, kReplacementToken.size()));
token_start += kReplacementToken.size();
i += bytes_read;
continue;
}
// Append the token in the output string, and note its position and the
// number of bytes that token consumed.
output.str.append(input_str + i, bytes_read);
output.tokens.push_back(std::make_pair(token_start, bytes_read));
token_start += bytes_read;
i += bytes_read;
}
output.str.append(kSuffix);
output.tokens.push_back(std::make_pair(token_start, kSuffix.size()));
token_start += kSuffix.size();
return output;
}
void LowercaseUnicodeStr(const char* input_str, int len,
std::string* output_str) {
for (int i = 0; i < len;) {
Rune token;
// Tokenize the given string, and get the appropriate lowercase token.
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
token = utf_isalpharune(token) ? utf_tolowerrune(token) : token;
// Write back the token to the output string.
char token_buf[UTFmax];
size_t bytes_to_write = utf_runetochar(token_buf, &token);
output_str->append(token_buf, bytes_to_write);
i += bytes_read;
}
}
} // namespace mediapipe::tasks::text::language_detector::custom_ops

View File

@ -0,0 +1,56 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
#include <string>
#include <utility>
#include <vector>
namespace mediapipe::tasks::text::language_detector::custom_ops {
struct TokenizedOutput {
// The processed string (with necessary prefix, suffix, skipped tokens, etc.).
std::string str;
// This vector contains pairs, where each pair has two members. The first
// denoting the starting index of the token in the `str` string, and the
// second denoting the length of that token in bytes.
std::vector<std::pair<const size_t, const size_t>> tokens;
};
// Tokenizes the given input string on Unicode token boundaries, with a maximum
// of `max_tokens` tokens.
//
// If `exclude_nonalphaspace_tokens` is enabled, the tokenization ignores
// non-alphanumeric tokens, and replaces them with a replacement token (" ").
//
// The method returns the output in the `TokenizedOutput` struct, which stores
// both, the processed input string, and the indices and sizes of each token
// within that string.
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
bool exclude_nonalphaspace_tokens);
// Converts the given unicode string (`input_str`) with the specified length
// (`len`) to a lowercase string.
//
// The method populates the lowercased string in `output_str`.
void LowercaseUnicodeStr(const char* input_str, int len,
std::string* output_str);
} // namespace mediapipe::tasks::text::language_detector::custom_ops
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_

View File

@ -0,0 +1,135 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include <string>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe::tasks::text::language_detector::custom_ops {
namespace {
using ::testing::Values;
std::string ReconstructStringFromTokens(TokenizedOutput output) {
std::string reconstructed_str;
for (int i = 0; i < output.tokens.size(); i++) {
reconstructed_str.append(
output.str.c_str() + output.tokens[i].first,
output.str.c_str() + output.tokens[i].first + output.tokens[i].second);
}
return reconstructed_str;
}
struct TokenizeTestParams {
std::string input_str;
size_t max_tokens;
bool exclude_nonalphaspace_tokens;
std::string expected_output_str;
};
class TokenizeParameterizedTest
: public ::testing::Test,
public testing::WithParamInterface<TokenizeTestParams> {};
TEST_P(TokenizeParameterizedTest, Tokenize) {
// Checks that the Tokenize method returns the expected value.
const TokenizeTestParams params = TokenizeParameterizedTest::GetParam();
const TokenizedOutput output = Tokenize(
/*input_str=*/params.input_str.c_str(),
/*len=*/params.input_str.size(),
/*max_tokens=*/params.max_tokens,
/*exclude_nonalphaspace_tokens=*/params.exclude_nonalphaspace_tokens);
// The output string should have the necessary prefixes, and the "!" token
// should have been replaced with a " ".
EXPECT_EQ(output.str, params.expected_output_str);
EXPECT_EQ(ReconstructStringFromTokens(output), params.expected_output_str);
}
INSTANTIATE_TEST_SUITE_P(
TokenizeParameterizedTests, TokenizeParameterizedTest,
Values(
// Test including non-alphanumeric characters.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
/*exclude_alphanonspace=*/false,
/*expected_output_str=*/"^hi!$"}),
// Test not including non-alphanumeric characters.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^hi $"}),
// Test with a maximum of 3 tokens.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/3,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^h$"}),
// Test with non-latin characters.
TokenizeTestParams({/*input_str=*/"ありがと", /*max_tokens=*/100,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^ありがと$"})));
TEST(LowercaseUnicodeTest, TestLowercaseUnicode) {
{
// Check that the method is a no-op when the string is lowercase.
std::string input_str = "hello";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "hello");
}
{
// Check that the method has uppercase characters.
std::string input_str = "hElLo";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "hello");
}
{
// Check that the method works with non-latin scripts.
// Cyrillic has the concept of cases, so it should change the input.
std::string input_str = "БЙп";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "бйп");
}
{
// Check that the method works with non-latin scripts.
// Japanese doesn't have the concept of cases, so it should not change.
std::string input_str = "ありがと";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "ありがと");
}
}
} // namespace
} // namespace mediapipe::tasks::text::language_detector::custom_ops

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "utf",
srcs = [
"rune.c",
"runetype.c",
"runetypebody.h",
],
hdrs = ["utf.h"],
)

View File

@ -0,0 +1,233 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Rob Pike and Ken Thompson. Original
// copyright message below.
/*
* The authors of this software are Rob Pike and Ken Thompson.
* Copyright (c) 2002 by Lucent Technologies.
* Permission to use, copy, modify, and distribute this software for any
* purpose without fee is hereby granted, provided that this entire notice
* is included in all copies of any software which is or includes a copy
* or modification of this software and in all copies of the supporting
* documentation for such software.
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
*/
#include <stdarg.h>
#include <string.h>
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
enum
{
Bit1 = 7,
Bitx = 6,
Bit2 = 5,
Bit3 = 4,
Bit4 = 3,
Bit5 = 2,
T1 = ((1<<(Bit1+1))-1) ^ 0xFF, /* 0000 0000 */
Tx = ((1<<(Bitx+1))-1) ^ 0xFF, /* 1000 0000 */
T2 = ((1<<(Bit2+1))-1) ^ 0xFF, /* 1100 0000 */
T3 = ((1<<(Bit3+1))-1) ^ 0xFF, /* 1110 0000 */
T4 = ((1<<(Bit4+1))-1) ^ 0xFF, /* 1111 0000 */
T5 = ((1<<(Bit5+1))-1) ^ 0xFF, /* 1111 1000 */
Rune1 = (1<<(Bit1+0*Bitx))-1, /* 0000 0000 0111 1111 */
Rune2 = (1<<(Bit2+1*Bitx))-1, /* 0000 0111 1111 1111 */
Rune3 = (1<<(Bit3+2*Bitx))-1, /* 1111 1111 1111 1111 */
Rune4 = (1<<(Bit4+3*Bitx))-1,
/* 0001 1111 1111 1111 1111 1111 */
Maskx = (1<<Bitx)-1, /* 0011 1111 */
Testx = Maskx ^ 0xFF, /* 1100 0000 */
Bad = Runeerror,
};
/*
* Modified by Wei-Hwa Huang, Google Inc., on 2004-09-24
* This is a slower but "safe" version of the old chartorune
* that works on strings that are not necessarily null-terminated.
*
* If you know for sure that your string is null-terminated,
* chartorune will be a bit faster.
*
* It is guaranteed not to attempt to access "length"
* past the incoming pointer. This is to avoid
* possible access violations. If the string appears to be
* well-formed but incomplete (i.e., to get the whole Rune
* we'd need to read past str+length) then we'll set the Rune
* to Bad and return 0.
*
* Note that if we have decoding problems for other
* reasons, we return 1 instead of 0.
*/
int
utf_charntorune(Rune *rune, const char *str, int length)
{
int c, c1, c2, c3;
long l;
/* When we're not allowed to read anything */
if(length <= 0) {
goto badlen;
}
/*
* one character sequence (7-bit value)
* 00000-0007F => T1
*/
c = *(uchar*)str;
if(c < Tx) {
*rune = c;
return 1;
}
// If we can't read more than one character we must stop
if(length <= 1) {
goto badlen;
}
/*
* two character sequence (11-bit value)
* 0080-07FF => T2 Tx
*/
c1 = *(uchar*)(str+1) ^ Tx;
if(c1 & Testx)
goto bad;
if(c < T3) {
if(c < T2)
goto bad;
l = ((c << Bitx) | c1) & Rune2;
if(l <= Rune1)
goto bad;
*rune = l;
return 2;
}
// If we can't read more than two characters we must stop
if(length <= 2) {
goto badlen;
}
/*
* three character sequence (16-bit value)
* 0800-FFFF => T3 Tx Tx
*/
c2 = *(uchar*)(str+2) ^ Tx;
if(c2 & Testx)
goto bad;
if(c < T4) {
l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3;
if(l <= Rune2)
goto bad;
*rune = l;
return 3;
}
if (length <= 3)
goto badlen;
/*
* four character sequence (21-bit value)
* 10000-1FFFFF => T4 Tx Tx Tx
*/
c3 = *(uchar*)(str+3) ^ Tx;
if (c3 & Testx)
goto bad;
if (c < T5) {
l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4;
if (l <= Rune3)
goto bad;
if (l > Runemax)
goto bad;
*rune = l;
return 4;
}
// Support for 5-byte or longer UTF-8 would go here, but
// since we don't have that, we'll just fall through to bad.
/*
* bad decoding
*/
bad:
*rune = Bad;
return 1;
badlen:
*rune = Bad;
return 0;
}
int
utf_runetochar(char *str, const Rune *rune)
{
/* Runes are signed, so convert to unsigned for range check. */
unsigned long c;
/*
* one character sequence
* 00000-0007F => 00-7F
*/
c = *rune;
if(c <= Rune1) {
str[0] = c;
return 1;
}
/*
* two character sequence
* 0080-07FF => T2 Tx
*/
if(c <= Rune2) {
str[0] = T2 | (c >> 1*Bitx);
str[1] = Tx | (c & Maskx);
return 2;
}
/*
* If the Rune is out of range, convert it to the error rune.
* Do this test here because the error rune encodes to three bytes.
* Doing it earlier would duplicate work, since an out of range
* Rune wouldn't have fit in one or two bytes.
*/
if (c > Runemax)
c = Runeerror;
/*
* three character sequence
* 0800-FFFF => T3 Tx Tx
*/
if (c <= Rune3) {
str[0] = T3 | (c >> 2*Bitx);
str[1] = Tx | ((c >> 1*Bitx) & Maskx);
str[2] = Tx | (c & Maskx);
return 3;
}
/*
* four character sequence (21-bit value)
* 10000-1FFFFF => T4 Tx Tx Tx
*/
str[0] = T4 | (c >> 3*Bitx);
str[1] = Tx | ((c >> 2*Bitx) & Maskx);
str[2] = Tx | ((c >> 1*Bitx) & Maskx);
str[3] = Tx | (c & Maskx);
return 4;
}

View File

@ -0,0 +1,54 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Rob Pike and Ken Thompson. Original
// copyright message below.
/*
* The authors of this software are Rob Pike and Ken Thompson.
* Copyright (c) 2002 by Lucent Technologies.
* Permission to use, copy, modify, and distribute this software for any
* purpose without fee is hereby granted, provided that this entire notice
* is included in all copies of any software which is or includes a copy
* or modification of this software and in all copies of the supporting
* documentation for such software.
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
static
Rune*
rbsearch(Rune c, Rune *t, int n, int ne)
{
Rune *p;
int m;
while(n > 1) {
m = n >> 1;
p = t + m*ne;
if(c >= p[0]) {
t = p;
n = n-m;
} else
n = m;
}
if(n && c >= t[0])
return t;
return 0;
}
#define RUNETYPEBODY
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h"

View File

@ -0,0 +1,212 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifdef RUNETYPEBODY
static Rune __isalphar[] = {
0x0041, 0x005a, 0x0061, 0x007a, 0x00c0, 0x00d6, 0x00d8, 0x00f6,
0x00f8, 0x02c1, 0x02c6, 0x02d1, 0x02e0, 0x02e4, 0x0370, 0x0374,
0x0376, 0x0377, 0x037a, 0x037d, 0x0388, 0x038a, 0x038e, 0x03a1,
0x03a3, 0x03f5, 0x03f7, 0x0481, 0x048a, 0x0527, 0x0531, 0x0556,
0x0561, 0x0587, 0x05d0, 0x05ea, 0x05f0, 0x05f2, 0x0620, 0x064a,
0x066e, 0x066f, 0x0671, 0x06d3, 0x06e5, 0x06e6, 0x06ee, 0x06ef,
0x06fa, 0x06fc, 0x0712, 0x072f, 0x074d, 0x07a5, 0x07ca, 0x07ea,
0x07f4, 0x07f5, 0x0800, 0x0815, 0x0840, 0x0858, 0x08a2, 0x08ac,
0x0904, 0x0939, 0x0958, 0x0961, 0x0971, 0x0977, 0x0979, 0x097f,
0x0985, 0x098c, 0x098f, 0x0990, 0x0993, 0x09a8, 0x09aa, 0x09b0,
0x09b6, 0x09b9, 0x09dc, 0x09dd, 0x09df, 0x09e1, 0x09f0, 0x09f1,
0x0a05, 0x0a0a, 0x0a0f, 0x0a10, 0x0a13, 0x0a28, 0x0a2a, 0x0a30,
0x0a32, 0x0a33, 0x0a35, 0x0a36, 0x0a38, 0x0a39, 0x0a59, 0x0a5c,
0x0a72, 0x0a74, 0x0a85, 0x0a8d, 0x0a8f, 0x0a91, 0x0a93, 0x0aa8,
0x0aaa, 0x0ab0, 0x0ab2, 0x0ab3, 0x0ab5, 0x0ab9, 0x0ae0, 0x0ae1,
0x0b05, 0x0b0c, 0x0b0f, 0x0b10, 0x0b13, 0x0b28, 0x0b2a, 0x0b30,
0x0b32, 0x0b33, 0x0b35, 0x0b39, 0x0b5c, 0x0b5d, 0x0b5f, 0x0b61,
0x0b85, 0x0b8a, 0x0b8e, 0x0b90, 0x0b92, 0x0b95, 0x0b99, 0x0b9a,
0x0b9e, 0x0b9f, 0x0ba3, 0x0ba4, 0x0ba8, 0x0baa, 0x0bae, 0x0bb9,
0x0c05, 0x0c0c, 0x0c0e, 0x0c10, 0x0c12, 0x0c28, 0x0c2a, 0x0c33,
0x0c35, 0x0c39, 0x0c58, 0x0c59, 0x0c60, 0x0c61, 0x0c85, 0x0c8c,
0x0c8e, 0x0c90, 0x0c92, 0x0ca8, 0x0caa, 0x0cb3, 0x0cb5, 0x0cb9,
0x0ce0, 0x0ce1, 0x0cf1, 0x0cf2, 0x0d05, 0x0d0c, 0x0d0e, 0x0d10,
0x0d12, 0x0d3a, 0x0d60, 0x0d61, 0x0d7a, 0x0d7f, 0x0d85, 0x0d96,
0x0d9a, 0x0db1, 0x0db3, 0x0dbb, 0x0dc0, 0x0dc6, 0x0e01, 0x0e30,
0x0e32, 0x0e33, 0x0e40, 0x0e46, 0x0e81, 0x0e82, 0x0e87, 0x0e88,
0x0e94, 0x0e97, 0x0e99, 0x0e9f, 0x0ea1, 0x0ea3, 0x0eaa, 0x0eab,
0x0ead, 0x0eb0, 0x0eb2, 0x0eb3, 0x0ec0, 0x0ec4, 0x0edc, 0x0edf,
0x0f40, 0x0f47, 0x0f49, 0x0f6c, 0x0f88, 0x0f8c, 0x1000, 0x102a,
0x1050, 0x1055, 0x105a, 0x105d, 0x1065, 0x1066, 0x106e, 0x1070,
0x1075, 0x1081, 0x10a0, 0x10c5, 0x10d0, 0x10fa, 0x10fc, 0x1248,
0x124a, 0x124d, 0x1250, 0x1256, 0x125a, 0x125d, 0x1260, 0x1288,
0x128a, 0x128d, 0x1290, 0x12b0, 0x12b2, 0x12b5, 0x12b8, 0x12be,
0x12c2, 0x12c5, 0x12c8, 0x12d6, 0x12d8, 0x1310, 0x1312, 0x1315,
0x1318, 0x135a, 0x1380, 0x138f, 0x13a0, 0x13f4, 0x1401, 0x166c,
0x166f, 0x167f, 0x1681, 0x169a, 0x16a0, 0x16ea, 0x1700, 0x170c,
0x170e, 0x1711, 0x1720, 0x1731, 0x1740, 0x1751, 0x1760, 0x176c,
0x176e, 0x1770, 0x1780, 0x17b3, 0x1820, 0x1877, 0x1880, 0x18a8,
0x18b0, 0x18f5, 0x1900, 0x191c, 0x1950, 0x196d, 0x1970, 0x1974,
0x1980, 0x19ab, 0x19c1, 0x19c7, 0x1a00, 0x1a16, 0x1a20, 0x1a54,
0x1b05, 0x1b33, 0x1b45, 0x1b4b, 0x1b83, 0x1ba0, 0x1bae, 0x1baf,
0x1bba, 0x1be5, 0x1c00, 0x1c23, 0x1c4d, 0x1c4f, 0x1c5a, 0x1c7d,
0x1ce9, 0x1cec, 0x1cee, 0x1cf1, 0x1cf5, 0x1cf6, 0x1d00, 0x1dbf,
0x1e00, 0x1f15, 0x1f18, 0x1f1d, 0x1f20, 0x1f45, 0x1f48, 0x1f4d,
0x1f50, 0x1f57, 0x1f5f, 0x1f7d, 0x1f80, 0x1fb4, 0x1fb6, 0x1fbc,
0x1fc2, 0x1fc4, 0x1fc6, 0x1fcc, 0x1fd0, 0x1fd3, 0x1fd6, 0x1fdb,
0x1fe0, 0x1fec, 0x1ff2, 0x1ff4, 0x1ff6, 0x1ffc, 0x2090, 0x209c,
0x210a, 0x2113, 0x2119, 0x211d, 0x212a, 0x212d, 0x212f, 0x2139,
0x213c, 0x213f, 0x2145, 0x2149, 0x2183, 0x2184, 0x2c00, 0x2c2e,
0x2c30, 0x2c5e, 0x2c60, 0x2ce4, 0x2ceb, 0x2cee, 0x2cf2, 0x2cf3,
0x2d00, 0x2d25, 0x2d30, 0x2d67, 0x2d80, 0x2d96, 0x2da0, 0x2da6,
0x2da8, 0x2dae, 0x2db0, 0x2db6, 0x2db8, 0x2dbe, 0x2dc0, 0x2dc6,
0x2dc8, 0x2dce, 0x2dd0, 0x2dd6, 0x2dd8, 0x2dde, 0x3005, 0x3006,
0x3031, 0x3035, 0x303b, 0x303c, 0x3041, 0x3096, 0x309d, 0x309f,
0x30a1, 0x30fa, 0x30fc, 0x30ff, 0x3105, 0x312d, 0x3131, 0x318e,
0x31a0, 0x31ba, 0x31f0, 0x31ff, 0x3400, 0x4db5, 0x4e00, 0x9fcc,
0xa000, 0xa48c, 0xa4d0, 0xa4fd, 0xa500, 0xa60c, 0xa610, 0xa61f,
0xa62a, 0xa62b, 0xa640, 0xa66e, 0xa67f, 0xa697, 0xa6a0, 0xa6e5,
0xa717, 0xa71f, 0xa722, 0xa788, 0xa78b, 0xa78e, 0xa790, 0xa793,
0xa7a0, 0xa7aa, 0xa7f8, 0xa801, 0xa803, 0xa805, 0xa807, 0xa80a,
0xa80c, 0xa822, 0xa840, 0xa873, 0xa882, 0xa8b3, 0xa8f2, 0xa8f7,
0xa90a, 0xa925, 0xa930, 0xa946, 0xa960, 0xa97c, 0xa984, 0xa9b2,
0xaa00, 0xaa28, 0xaa40, 0xaa42, 0xaa44, 0xaa4b, 0xaa60, 0xaa76,
0xaa80, 0xaaaf, 0xaab5, 0xaab6, 0xaab9, 0xaabd, 0xaadb, 0xaadd,
0xaae0, 0xaaea, 0xaaf2, 0xaaf4, 0xab01, 0xab06, 0xab09, 0xab0e,
0xab11, 0xab16, 0xab20, 0xab26, 0xab28, 0xab2e, 0xabc0, 0xabe2,
0xac00, 0xd7a3, 0xd7b0, 0xd7c6, 0xd7cb, 0xd7fb, 0xf900, 0xfa6d,
0xfa70, 0xfad9, 0xfb00, 0xfb06, 0xfb13, 0xfb17, 0xfb1f, 0xfb28,
0xfb2a, 0xfb36, 0xfb38, 0xfb3c, 0xfb40, 0xfb41, 0xfb43, 0xfb44,
0xfb46, 0xfbb1, 0xfbd3, 0xfd3d, 0xfd50, 0xfd8f, 0xfd92, 0xfdc7,
0xfdf0, 0xfdfb, 0xfe70, 0xfe74, 0xfe76, 0xfefc, 0xff21, 0xff3a,
0xff41, 0xff5a, 0xff66, 0xffbe, 0xffc2, 0xffc7, 0xffca, 0xffcf,
0xffd2, 0xffd7, 0xffda, 0xffdc, 0x10000, 0x1000b, 0x1000d, 0x10026,
0x10028, 0x1003a, 0x1003c, 0x1003d, 0x1003f, 0x1004d, 0x10050, 0x1005d,
0x10080, 0x100fa, 0x10280, 0x1029c, 0x102a0, 0x102d0, 0x10300, 0x1031e,
0x10330, 0x10340, 0x10342, 0x10349, 0x10380, 0x1039d, 0x103a0, 0x103c3,
0x103c8, 0x103cf, 0x10400, 0x1049d, 0x10800, 0x10805, 0x1080a, 0x10835,
0x10837, 0x10838, 0x1083f, 0x10855, 0x10900, 0x10915, 0x10920, 0x10939,
0x10980, 0x109b7, 0x109be, 0x109bf, 0x10a10, 0x10a13, 0x10a15, 0x10a17,
0x10a19, 0x10a33, 0x10a60, 0x10a7c, 0x10b00, 0x10b35, 0x10b40, 0x10b55,
0x10b60, 0x10b72, 0x10c00, 0x10c48, 0x11003, 0x11037, 0x11083, 0x110af,
0x110d0, 0x110e8, 0x11103, 0x11126, 0x11183, 0x111b2, 0x111c1, 0x111c4,
0x11680, 0x116aa, 0x12000, 0x1236e, 0x13000, 0x1342e, 0x16800, 0x16a38,
0x16f00, 0x16f44, 0x16f93, 0x16f9f, 0x1b000, 0x1b001, 0x1d400, 0x1d454,
0x1d456, 0x1d49c, 0x1d49e, 0x1d49f, 0x1d4a5, 0x1d4a6, 0x1d4a9, 0x1d4ac,
0x1d4ae, 0x1d4b9, 0x1d4bd, 0x1d4c3, 0x1d4c5, 0x1d505, 0x1d507, 0x1d50a,
0x1d50d, 0x1d514, 0x1d516, 0x1d51c, 0x1d51e, 0x1d539, 0x1d53b, 0x1d53e,
0x1d540, 0x1d544, 0x1d54a, 0x1d550, 0x1d552, 0x1d6a5, 0x1d6a8, 0x1d6c0,
0x1d6c2, 0x1d6da, 0x1d6dc, 0x1d6fa, 0x1d6fc, 0x1d714, 0x1d716, 0x1d734,
0x1d736, 0x1d74e, 0x1d750, 0x1d76e, 0x1d770, 0x1d788, 0x1d78a, 0x1d7a8,
0x1d7aa, 0x1d7c2, 0x1d7c4, 0x1d7cb, 0x1ee00, 0x1ee03, 0x1ee05, 0x1ee1f,
0x1ee21, 0x1ee22, 0x1ee29, 0x1ee32, 0x1ee34, 0x1ee37, 0x1ee4d, 0x1ee4f,
0x1ee51, 0x1ee52, 0x1ee61, 0x1ee62, 0x1ee67, 0x1ee6a, 0x1ee6c, 0x1ee72,
0x1ee74, 0x1ee77, 0x1ee79, 0x1ee7c, 0x1ee80, 0x1ee89, 0x1ee8b, 0x1ee9b,
0x1eea1, 0x1eea3, 0x1eea5, 0x1eea9, 0x1eeab, 0x1eebb, 0x20000, 0x2a6d6,
0x2a700, 0x2b734, 0x2b740, 0x2b81d, 0x2f800, 0x2fa1d,
};
static Rune __isalphas[] = {
0x00aa, 0x00b5, 0x00ba, 0x02ec, 0x02ee, 0x0386, 0x038c, 0x0559,
0x06d5, 0x06ff, 0x0710, 0x07b1, 0x07fa, 0x081a, 0x0824, 0x0828,
0x08a0, 0x093d, 0x0950, 0x09b2, 0x09bd, 0x09ce, 0x0a5e, 0x0abd,
0x0ad0, 0x0b3d, 0x0b71, 0x0b83, 0x0b9c, 0x0bd0, 0x0c3d, 0x0cbd,
0x0cde, 0x0d3d, 0x0d4e, 0x0dbd, 0x0e84, 0x0e8a, 0x0e8d, 0x0ea5,
0x0ea7, 0x0ebd, 0x0ec6, 0x0f00, 0x103f, 0x1061, 0x108e, 0x10c7,
0x10cd, 0x1258, 0x12c0, 0x17d7, 0x17dc, 0x18aa, 0x1aa7, 0x1f59,
0x1f5b, 0x1f5d, 0x1fbe, 0x2071, 0x207f, 0x2102, 0x2107, 0x2115,
0x2124, 0x2126, 0x2128, 0x214e, 0x2d27, 0x2d2d, 0x2d6f, 0x2e2f,
0xa8fb, 0xa9cf, 0xaa7a, 0xaab1, 0xaac0, 0xaac2, 0xfb1d, 0xfb3e,
0x10808, 0x1083c, 0x10a00, 0x16f50, 0x1d4a2, 0x1d4bb, 0x1d546, 0x1ee24,
0x1ee27, 0x1ee39, 0x1ee3b, 0x1ee42, 0x1ee47, 0x1ee49, 0x1ee4b, 0x1ee54,
0x1ee57, 0x1ee59, 0x1ee5b, 0x1ee5d, 0x1ee5f, 0x1ee64, 0x1ee7e,
};
int utf_isalpharune(Rune c) {
Rune *p;
p = rbsearch(c, __isalphar, nelem(__isalphar) / 2, 2);
if (p && c >= p[0] && c <= p[1]) return 1;
p = rbsearch(c, __isalphas, nelem(__isalphas), 1);
if (p && c == p[0]) return 1;
return 0;
}
static Rune __tolowerr[] = {
0x0041, 0x005a, 1048608, 0x00c0, 0x00d6, 1048608, 0x00d8, 0x00de, 1048608,
0x0189, 0x018a, 1048781, 0x01b1, 0x01b2, 1048793, 0x0388, 0x038a, 1048613,
0x038e, 0x038f, 1048639, 0x0391, 0x03a1, 1048608, 0x03a3, 0x03ab, 1048608,
0x03fd, 0x03ff, 1048446, 0x0400, 0x040f, 1048656, 0x0410, 0x042f, 1048608,
0x0531, 0x0556, 1048624, 0x10a0, 0x10c5, 1055840, 0x1f08, 0x1f0f, 1048568,
0x1f18, 0x1f1d, 1048568, 0x1f28, 0x1f2f, 1048568, 0x1f38, 0x1f3f, 1048568,
0x1f48, 0x1f4d, 1048568, 0x1f68, 0x1f6f, 1048568, 0x1f88, 0x1f8f, 1048568,
0x1f98, 0x1f9f, 1048568, 0x1fa8, 0x1faf, 1048568, 0x1fb8, 0x1fb9, 1048568,
0x1fba, 0x1fbb, 1048502, 0x1fc8, 0x1fcb, 1048490, 0x1fd8, 0x1fd9, 1048568,
0x1fda, 0x1fdb, 1048476, 0x1fe8, 0x1fe9, 1048568, 0x1fea, 0x1feb, 1048464,
0x1ff8, 0x1ff9, 1048448, 0x1ffa, 0x1ffb, 1048450, 0x2160, 0x216f, 1048592,
0x24b6, 0x24cf, 1048602, 0x2c00, 0x2c2e, 1048624, 0x2c7e, 0x2c7f, 1037761,
0xff21, 0xff3a, 1048608, 0x10400, 0x10427, 1048616,
};
static Rune __tolowerp[] = {
0x0100, 0x012e, 1048577, 0x0132, 0x0136, 1048577, 0x0139, 0x0147, 1048577,
0x014a, 0x0176, 1048577, 0x017b, 0x017d, 1048577, 0x01a2, 0x01a4, 1048577,
0x01b3, 0x01b5, 1048577, 0x01cd, 0x01db, 1048577, 0x01de, 0x01ee, 1048577,
0x01f8, 0x021e, 1048577, 0x0222, 0x0232, 1048577, 0x0248, 0x024e, 1048577,
0x0370, 0x0372, 1048577, 0x03d8, 0x03ee, 1048577, 0x0460, 0x0480, 1048577,
0x048a, 0x04be, 1048577, 0x04c3, 0x04cd, 1048577, 0x04d0, 0x0526, 1048577,
0x1e00, 0x1e94, 1048577, 0x1ea0, 0x1efe, 1048577, 0x1f59, 0x1f5f, 1048568,
0x2c67, 0x2c6b, 1048577, 0x2c80, 0x2ce2, 1048577, 0x2ceb, 0x2ced, 1048577,
0xa640, 0xa66c, 1048577, 0xa680, 0xa696, 1048577, 0xa722, 0xa72e, 1048577,
0xa732, 0xa76e, 1048577, 0xa779, 0xa77b, 1048577, 0xa780, 0xa786, 1048577,
0xa790, 0xa792, 1048577, 0xa7a0, 0xa7a8, 1048577,
};
static Rune __tolowers[] = {
0x0130, 1048377, 0x0178, 1048455, 0x0179, 1048577, 0x0181, 1048786,
0x0182, 1048577, 0x0184, 1048577, 0x0186, 1048782, 0x0187, 1048577,
0x018b, 1048577, 0x018e, 1048655, 0x018f, 1048778, 0x0190, 1048779,
0x0191, 1048577, 0x0193, 1048781, 0x0194, 1048783, 0x0196, 1048787,
0x0197, 1048785, 0x0198, 1048577, 0x019c, 1048787, 0x019d, 1048789,
0x019f, 1048790, 0x01a0, 1048577, 0x01a6, 1048794, 0x01a7, 1048577,
0x01a9, 1048794, 0x01ac, 1048577, 0x01ae, 1048794, 0x01af, 1048577,
0x01b7, 1048795, 0x01b8, 1048577, 0x01bc, 1048577, 0x01c4, 1048578,
0x01c5, 1048577, 0x01c7, 1048578, 0x01c8, 1048577, 0x01ca, 1048578,
0x01cb, 1048577, 0x01f1, 1048578, 0x01f2, 1048577, 0x01f4, 1048577,
0x01f6, 1048479, 0x01f7, 1048520, 0x0220, 1048446, 0x023a, 1059371,
0x023b, 1048577, 0x023d, 1048413, 0x023e, 1059368, 0x0241, 1048577,
0x0243, 1048381, 0x0244, 1048645, 0x0245, 1048647, 0x0246, 1048577,
0x0376, 1048577, 0x0386, 1048614, 0x038c, 1048640, 0x03cf, 1048584,
0x03f4, 1048516, 0x03f7, 1048577, 0x03f9, 1048569, 0x03fa, 1048577,
0x04c0, 1048591, 0x04c1, 1048577, 0x10c7, 1055840, 0x10cd, 1055840,
0x1e9e, 1040961, 0x1fbc, 1048567, 0x1fcc, 1048567, 0x1fec, 1048569,
0x1ffc, 1048567, 0x2126, 1041059, 0x212a, 1040193, 0x212b, 1040314,
0x2132, 1048604, 0x2183, 1048577, 0x2c60, 1048577, 0x2c62, 1037833,
0x2c63, 1044762, 0x2c64, 1037849, 0x2c6d, 1037796, 0x2c6e, 1037827,
0x2c6f, 1037793, 0x2c70, 1037794, 0x2c72, 1048577, 0x2c75, 1048577,
0x2cf2, 1048577, 0xa77d, 1013244, 0xa77e, 1048577, 0xa78b, 1048577,
0xa78d, 1006296, 0xa7aa, 1006268,
};
Rune utf_tolowerrune(Rune c) {
Rune *p;
p = rbsearch(c, __tolowerr, nelem(__tolowerr) / 3, 3);
if (p && c >= p[0] && c <= p[1]) return c + p[2] - 1048576;
p = rbsearch(c, __tolowerp, nelem(__tolowerp) / 3, 3);
if (p && c >= p[0] && c <= p[1] && !((c - p[0]) & 1))
return c + p[2] - 1048576;
p = rbsearch(c, __tolowers, nelem(__tolowers) / 2, 2);
if (p && c == p[0]) return c + p[1] - 1048576;
return c;
}
#endif

View File

@ -0,0 +1,98 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Fork of several UTF utils originally written by Rob Pike and Ken Thompson.
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_ 1
#include <stdint.h>
// Code-point values in Unicode 4.0 are 21 bits wide.
typedef signed int Rune;
#define uchar _utfuchar
typedef unsigned char uchar;
#define nelem(x) (sizeof(x) / sizeof((x)[0]))
enum {
UTFmax = 4, // maximum bytes per rune
Runeerror = 0xFFFD, // decoding error in UTF
Runemax = 0x10FFFF, // maximum rune value
};
#ifdef __cplusplus
extern "C" {
#endif
/*
* rune routines
*/
/*
* These routines were written by Rob Pike and Ken Thompson
* and first appeared in Plan 9.
* SEE ALSO
* utf (7)
* tcs (1)
*/
// utf_runetochar copies (encodes) one rune, pointed to by r, to at most
// UTFmax bytes starting at s and returns the number of bytes generated.
int utf_runetochar(char* s, const Rune* r);
// utf_charntorune copies (decodes) at most UTFmax bytes starting at `str` to
// one rune, pointed to by `rune`, access at most `length` bytes of `str`, and
// returns the number of bytes consumed.
// If the UTF sequence is incomplete within n bytes,
// utf_charntorune will set *r to Runeerror and return 0. If it is complete
// but not in UTF format, it will set *r to Runeerror and return 1.
//
// Added 2004-09-24 by Wei-Hwa Huang
int utf_charntorune(Rune* rune, const char* str, int length);
// Unicode defines some characters as letters and
// specifies three cases: upper, lower, and title. Mappings among the
// cases are also defined, although they are not exhaustive: some
// upper case letters have no lower case mapping, and so on. Unicode
// also defines several character properties, a subset of which are
// checked by these routines. These routines are based on Unicode
// version 3.0.0.
//
// NOTE: The routines are implemented in C, so isalpharrune returns 0 for false
// and 1 for true.
//
// utf_tolowerrune is the Unicode case mapping. It returns the character
// unchanged if it has no defined mapping.
Rune utf_tolowerrune(Rune r);
// utf_isalpharune tests for Unicode letters; this includes ideographs in
// addition to alphabetic characters.
int utf_isalpharune(Rune r);
// (The comments in this file were copied from the manpage files rune.3,
// isalpharune.3, and runestrcat.3. Some formatting changes were also made
// to conform to Google style. /JRM 11/11/05)
#ifdef __cplusplus
}
#endif
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_

View File

@ -34,7 +34,7 @@ constexpr char kTestSPModelPath[] =
std::unique_ptr<SentencePieceTokenizer> CreateSentencePieceTokenizer( std::unique_ptr<SentencePieceTokenizer> CreateSentencePieceTokenizer(
absl::string_view model_path) { absl::string_view model_path) {
// We are using `LoadBinaryContent()` instead of loading the model direclty // We are using `LoadBinaryContent()` instead of loading the model directly
// via `SentencePieceTokenizer` so that the file can be located on Windows // via `SentencePieceTokenizer` so that the file can be located on Windows
std::string buffer = LoadBinaryContent(kTestSPModelPath); std::string buffer = LoadBinaryContent(kTestSPModelPath);
return absl::make_unique<SentencePieceTokenizer>(buffer.data(), return absl::make_unique<SentencePieceTokenizer>(buffer.data(),

View File

@ -74,7 +74,7 @@ class FaceDetector : core::BaseVisionTaskApi {
// three running modes: // three running modes:
// 1) Image mode for detecting faces on single image inputs. Users // 1) Image mode for detecting faces on single image inputs. Users
// provide mediapipe::Image to the `Detect` method, and will receive the // provide mediapipe::Image to the `Detect` method, and will receive the
// deteced face detection results as the return value. // detected face detection results as the return value.
// 2) Video mode for detecting faces on the decoded frames of a // 2) Video mode for detecting faces on the decoded frames of a
// video. Users call `DetectForVideo` method, and will receive the detected // video. Users call `DetectForVideo` method, and will receive the detected
// face detection results as the return value. // face detection results as the return value.

View File

@ -19,9 +19,6 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
cc_library( cc_library(
name = "face_geometry_from_landmarks_graph", name = "face_geometry_from_landmarks_graph",
srcs = ["face_geometry_from_landmarks_graph.cc"], srcs = ["face_geometry_from_landmarks_graph.cc"],
data = [
"//mediapipe/tasks/cc/vision/face_geometry/data:geometry_pipeline_metadata_landmarks",
],
deps = [ deps = [
"//mediapipe/calculators/core:begin_loop_calculator", "//mediapipe/calculators/core:begin_loop_calculator",
"//mediapipe/calculators/core:end_loop_calculator", "//mediapipe/calculators/core:end_loop_calculator",
@ -39,6 +36,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_cc_proto",
"//mediapipe/util:graph_builder_utils", "//mediapipe/util:graph_builder_utils",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],

View File

@ -45,6 +45,7 @@ mediapipe_proto_library(
srcs = ["geometry_pipeline_calculator.proto"], srcs = ["geometry_pipeline_calculator.proto"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/tasks/cc/core/proto:external_file_proto",
], ],
) )
@ -59,6 +60,9 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"//mediapipe/tasks/cc:common",
"//mediapipe/tasks/cc/core:external_file_handler",
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
"//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline", "//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline",
"//mediapipe/tasks/cc/vision/face_geometry/libs:validation_utils", "//mediapipe/tasks/cc/vision/face_geometry/libs:validation_utils",
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
@ -66,6 +70,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto", "//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:str_format",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -18,12 +18,16 @@
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_macros.h" #include "mediapipe/framework/port/status_macros.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "mediapipe/tasks/cc/common.h"
#include "mediapipe/tasks/cc/core/external_file_handler.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.h" #include "mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.h"
#include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h" #include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h"
@ -39,13 +43,50 @@ static constexpr char kEnvironmentTag[] = "ENVIRONMENT";
static constexpr char kImageSizeTag[] = "IMAGE_SIZE"; static constexpr char kImageSizeTag[] = "IMAGE_SIZE";
static constexpr char kMultiFaceGeometryTag[] = "MULTI_FACE_GEOMETRY"; static constexpr char kMultiFaceGeometryTag[] = "MULTI_FACE_GEOMETRY";
static constexpr char kMultiFaceLandmarksTag[] = "MULTI_FACE_LANDMARKS"; static constexpr char kMultiFaceLandmarksTag[] = "MULTI_FACE_LANDMARKS";
static constexpr char kFaceGeometryTag[] = "FACE_GEOMETRY";
static constexpr char kFaceLandmarksTag[] = "FACE_LANDMARKS";
using ::mediapipe::tasks::vision::face_geometry::proto::Environment; using ::mediapipe::tasks::vision::face_geometry::proto::Environment;
using ::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry; using ::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry;
using ::mediapipe::tasks::vision::face_geometry::proto:: using ::mediapipe::tasks::vision::face_geometry::proto::
GeometryPipelineMetadata; GeometryPipelineMetadata;
// A calculator that renders a visual effect for multiple faces. absl::Status SanityCheck(CalculatorContract* cc) {
if (!(cc->Inputs().HasTag(kFaceLandmarksTag) ^
cc->Inputs().HasTag(kMultiFaceLandmarksTag))) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Only one of %s and %s can be set at a time.",
kFaceLandmarksTag, kMultiFaceLandmarksTag));
}
if (!(cc->Outputs().HasTag(kFaceGeometryTag) ^
cc->Outputs().HasTag(kMultiFaceGeometryTag))) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat("Only one of %s and %s can be set at a time.",
kFaceGeometryTag, kMultiFaceGeometryTag));
}
if (cc->Inputs().HasTag(kFaceLandmarksTag) !=
cc->Outputs().HasTag(kFaceGeometryTag)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"%s and %s must both be set or neither be set and a time.",
kFaceLandmarksTag, kFaceGeometryTag));
}
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag) !=
cc->Outputs().HasTag(kMultiFaceGeometryTag)) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
absl::StrFormat(
"%s and %s must both be set or neither be set and a time.",
kMultiFaceLandmarksTag, kMultiFaceGeometryTag));
}
return absl::OkStatus();
}
// A calculator that renders a visual effect for multiple faces. Support single
// face landmarks or multiple face landmarks.
// //
// Inputs: // Inputs:
// IMAGE_SIZE (`std::pair<int, int>`, required): // IMAGE_SIZE (`std::pair<int, int>`, required):
@ -56,8 +97,12 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
// ratio. If used as-is, the resulting face geometry visualization should be // ratio. If used as-is, the resulting face geometry visualization should be
// happening on a frame with the same ratio as well. // happening on a frame with the same ratio as well.
// //
// MULTI_FACE_LANDMARKS (`std::vector<NormalizedLandmarkList>`, required): // MULTI_FACE_LANDMARKS (`std::vector<NormalizedLandmarkList>`, optional):
// A vector of face landmark lists. // A vector of face landmark lists. If connected, the output stream
// MULTI_FACE_GEOMETRY must be connected.
// FACE_LANDMARKS (NormalizedLandmarkList, optional):
// A NormalizedLandmarkList of single face landmark lists. If connected, the
// output stream FACE_GEOMETRY must be connected.
// //
// Input side packets: // Input side packets:
// ENVIRONMENT (`proto::Environment`, required) // ENVIRONMENT (`proto::Environment`, required)
@ -65,12 +110,14 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
// as well as virtual camera parameters. // as well as virtual camera parameters.
// //
// Output: // Output:
// MULTI_FACE_GEOMETRY (`std::vector<FaceGeometry>`, required): // MULTI_FACE_GEOMETRY (`std::vector<FaceGeometry>`, optional):
// A vector of face geometry data. // A vector of face geometry data if MULTI_FACE_LANDMARKS is connected .
// FACE_GEOMETRY (FaceGeometry, optional):
// A FaceGeometry of the face landmarks if FACE_LANDMARKS is connected.
// //
// Options: // Options:
// metadata_path (`string`, optional): // metadata_file (`ExternalFile`, optional):
// Defines a path for the geometry pipeline metadata file. // Defines an ExternalFile for the geometry pipeline metadata file.
// //
// The geometry pipeline metadata file format must be the binary // The geometry pipeline metadata file format must be the binary
// `GeometryPipelineMetadata` proto. // `GeometryPipelineMetadata` proto.
@ -79,13 +126,21 @@ class GeometryPipelineCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Tag(kEnvironmentTag).Set<Environment>(); cc->InputSidePackets().Tag(kEnvironmentTag).Set<Environment>();
MP_RETURN_IF_ERROR(SanityCheck(cc));
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>(); cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) {
cc->Inputs() cc->Inputs()
.Tag(kMultiFaceLandmarksTag) .Tag(kMultiFaceLandmarksTag)
.Set<std::vector<mediapipe::NormalizedLandmarkList>>(); .Set<std::vector<mediapipe::NormalizedLandmarkList>>();
cc->Outputs().Tag(kMultiFaceGeometryTag).Set<std::vector<FaceGeometry>>(); cc->Outputs().Tag(kMultiFaceGeometryTag).Set<std::vector<FaceGeometry>>();
return absl::OkStatus(); return absl::OkStatus();
} else {
cc->Inputs()
.Tag(kFaceLandmarksTag)
.Set<mediapipe::NormalizedLandmarkList>();
cc->Outputs().Tag(kFaceGeometryTag).Set<FaceGeometry>();
return absl::OkStatus();
}
} }
absl::Status Open(CalculatorContext* cc) override { absl::Status Open(CalculatorContext* cc) override {
@ -95,7 +150,7 @@ class GeometryPipelineCalculator : public CalculatorBase {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(
GeometryPipelineMetadata metadata, GeometryPipelineMetadata metadata,
ReadMetadataFromFile(options.metadata_path()), ReadMetadataFromFile(options.metadata_file()),
_ << "Failed to read the geometry pipeline metadata from file!"); _ << "Failed to read the geometry pipeline metadata from file!");
MP_RETURN_IF_ERROR(ValidateGeometryPipelineMetadata(metadata)) MP_RETURN_IF_ERROR(ValidateGeometryPipelineMetadata(metadata))
@ -110,21 +165,26 @@ class GeometryPipelineCalculator : public CalculatorBase {
ASSIGN_OR_RETURN(geometry_pipeline_, ASSIGN_OR_RETURN(geometry_pipeline_,
CreateGeometryPipeline(environment, metadata), CreateGeometryPipeline(environment, metadata),
_ << "Failed to create a geometry pipeline!"); _ << "Failed to create a geometry pipeline!");
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status Process(CalculatorContext* cc) override { absl::Status Process(CalculatorContext* cc) override {
// Both the `IMAGE_SIZE` and the `MULTI_FACE_LANDMARKS` streams are required // Both the `IMAGE_SIZE` and either the `FACE_LANDMARKS` or
// to have a non-empty packet. In case this requirement is not met, there's // `MULTI_FACE_LANDMARKS` streams are required to have a non-empty packet.
// nothing to be processed at the current timestamp. // In case this requirement is not met, there's nothing to be processed at
if (cc->Inputs().Tag(kImageSizeTag).IsEmpty() || // the current timestamp and we return early (checked here and below).
cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty()) { if (cc->Inputs().Tag(kImageSizeTag).IsEmpty()) {
return absl::OkStatus(); return absl::OkStatus();
} }
const auto& image_size = const auto& image_size =
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>(); cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) {
if (cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty()) {
return absl::OkStatus();
}
const auto& multi_face_landmarks = const auto& multi_face_landmarks =
cc->Inputs() cc->Inputs()
.Tag(kMultiFaceLandmarksTag) .Tag(kMultiFaceLandmarksTag)
@ -145,6 +205,29 @@ class GeometryPipelineCalculator : public CalculatorBase {
.AddPacket(mediapipe::Adopt<std::vector<FaceGeometry>>( .AddPacket(mediapipe::Adopt<std::vector<FaceGeometry>>(
multi_face_geometry.release()) multi_face_geometry.release())
.At(cc->InputTimestamp())); .At(cc->InputTimestamp()));
} else if (cc->Inputs().HasTag(kFaceLandmarksTag)) {
if (cc->Inputs().Tag(kFaceLandmarksTag).IsEmpty()) {
return absl::OkStatus();
}
const auto& face_landmarks =
cc->Inputs()
.Tag(kFaceLandmarksTag)
.Get<mediapipe::NormalizedLandmarkList>();
ASSIGN_OR_RETURN(
std::vector<FaceGeometry> multi_face_geometry,
geometry_pipeline_->EstimateFaceGeometry(
{face_landmarks}, //
/*frame_width*/ image_size.first,
/*frame_height*/ image_size.second),
_ << "Failed to estimate face geometry for multiple faces!");
cc->Outputs()
.Tag(kFaceGeometryTag)
.AddPacket(mediapipe::MakePacket<FaceGeometry>(multi_face_geometry[0])
.At(cc->InputTimestamp()));
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -155,32 +238,19 @@ class GeometryPipelineCalculator : public CalculatorBase {
private: private:
static absl::StatusOr<GeometryPipelineMetadata> ReadMetadataFromFile( static absl::StatusOr<GeometryPipelineMetadata> ReadMetadataFromFile(
const std::string& metadata_path) { const core::proto::ExternalFile& metadata_file) {
ASSIGN_OR_RETURN(std::string metadata_blob, ASSIGN_OR_RETURN(
ReadContentBlobFromFile(metadata_path), const auto file_handler,
_ << "Failed to read a metadata blob from file!"); core::ExternalFileHandler::CreateFromExternalFile(&metadata_file));
GeometryPipelineMetadata metadata; GeometryPipelineMetadata metadata;
RET_CHECK(metadata.ParseFromString(metadata_blob)) RET_CHECK(
metadata.ParseFromString(std::string(file_handler->GetFileContent())))
<< "Failed to parse a metadata proto from a binary blob!"; << "Failed to parse a metadata proto from a binary blob!";
return metadata; return metadata;
} }
static absl::StatusOr<std::string> ReadContentBlobFromFile(
const std::string& unresolved_path) {
ASSIGN_OR_RETURN(std::string resolved_path,
mediapipe::PathToResourceAsFile(unresolved_path),
_ << "Failed to resolve path! Path = " << unresolved_path);
std::string content_blob;
MP_RETURN_IF_ERROR(
mediapipe::GetResourceContents(resolved_path, &content_blob))
<< "Failed to read content blob! Resolved path = " << resolved_path;
return content_blob;
}
std::unique_ptr<GeometryPipeline> geometry_pipeline_; std::unique_ptr<GeometryPipeline> geometry_pipeline_;
}; };

View File

@ -17,11 +17,12 @@ syntax = "proto2";
package mediapipe.tasks.vision.face_geometry; package mediapipe.tasks.vision.face_geometry;
import "mediapipe/framework/calculator_options.proto"; import "mediapipe/framework/calculator_options.proto";
import "mediapipe/tasks/cc/core/proto/external_file.proto";
message FaceGeometryPipelineCalculatorOptions { message FaceGeometryPipelineCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional FaceGeometryPipelineCalculatorOptions ext = 512499200; optional FaceGeometryPipelineCalculatorOptions ext = 512499200;
} }
optional string metadata_path = 1; optional core.proto.ExternalFile metadata_file = 1;
} }

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.pb.h"
#include "mediapipe/util/graph_builder_utils.h" #include "mediapipe/util/graph_builder_utils.h"
namespace mediapipe::tasks::vision::face_geometry { namespace mediapipe::tasks::vision::face_geometry {
@ -49,10 +50,6 @@ constexpr char kIterableTag[] = "ITERABLE";
constexpr char kBatchEndTag[] = "BATCH_END"; constexpr char kBatchEndTag[] = "BATCH_END";
constexpr char kItemTag[] = "ITEM"; constexpr char kItemTag[] = "ITEM";
constexpr char kGeometryPipelineMetadataPath[] =
"mediapipe/tasks/cc/vision/face_geometry/data/"
"geometry_pipeline_metadata_landmarks.binarypb";
struct FaceGeometryOuts { struct FaceGeometryOuts {
Stream<std::vector<FaceGeometry>> multi_face_geometry; Stream<std::vector<FaceGeometry>> multi_face_geometry;
}; };
@ -127,6 +124,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
} }
ASSIGN_OR_RETURN(auto outs, ASSIGN_OR_RETURN(auto outs,
BuildFaceGeometryFromLandmarksGraph( BuildFaceGeometryFromLandmarksGraph(
*sc->MutableOptions<proto::FaceGeometryGraphOptions>(),
graph.In(kFaceLandmarksTag) graph.In(kFaceLandmarksTag)
.Cast<std::vector<NormalizedLandmarkList>>(), .Cast<std::vector<NormalizedLandmarkList>>(),
graph.In(kImageSizeTag).Cast<std::pair<int, int>>(), graph.In(kImageSizeTag).Cast<std::pair<int, int>>(),
@ -138,6 +136,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
private: private:
absl::StatusOr<FaceGeometryOuts> BuildFaceGeometryFromLandmarksGraph( absl::StatusOr<FaceGeometryOuts> BuildFaceGeometryFromLandmarksGraph(
proto::FaceGeometryGraphOptions& graph_options,
Stream<std::vector<NormalizedLandmarkList>> multi_face_landmarks, Stream<std::vector<NormalizedLandmarkList>> multi_face_landmarks,
Stream<std::pair<int, int>> image_size, Stream<std::pair<int, int>> image_size,
std::optional<SidePacket<Environment>> environment, Graph& graph) { std::optional<SidePacket<Environment>> environment, Graph& graph) {
@ -185,7 +184,8 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
"mediapipe.tasks.vision.face_geometry.FaceGeometryPipelineCalculator"); "mediapipe.tasks.vision.face_geometry.FaceGeometryPipelineCalculator");
auto& geometry_pipeline_options = auto& geometry_pipeline_options =
geometry_pipeline.GetOptions<FaceGeometryPipelineCalculatorOptions>(); geometry_pipeline.GetOptions<FaceGeometryPipelineCalculatorOptions>();
geometry_pipeline_options.set_metadata_path(kGeometryPipelineMetadataPath); geometry_pipeline_options.Swap(
graph_options.mutable_geometry_pipeline_options());
image_size >> geometry_pipeline.In(kImageSizeTag); image_size >> geometry_pipeline.In(kImageSizeTag);
multi_face_landmarks_no_iris >> multi_face_landmarks_no_iris >>
geometry_pipeline.In(kMultiFaceLandmarksTag); geometry_pipeline.In(kMultiFaceLandmarksTag);

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "mediapipe/framework/api2/port.h" #include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h" #include "mediapipe/framework/calculator_runner.h"
@ -31,6 +32,7 @@ limitations under the License.
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/tool/sink.h" #include "mediapipe/framework/tool/sink.h"
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h" #include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
@ -49,6 +51,9 @@ constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
constexpr char kFaceLandmarksFileName[] = constexpr char kFaceLandmarksFileName[] =
"face_blendshapes_in_landmarks.prototxt"; "face_blendshapes_in_landmarks.prototxt";
constexpr char kFaceGeometryFileName[] = "face_geometry_expected_out.pbtxt"; constexpr char kFaceGeometryFileName[] = "face_geometry_expected_out.pbtxt";
constexpr char kGeometryPipelineMetadataPath[] =
"mediapipe/tasks/cc/vision/face_geometry/data/"
"geometry_pipeline_metadata_landmarks.binarypb";
std::vector<NormalizedLandmarkList> GetLandmarks(absl::string_view filename) { std::vector<NormalizedLandmarkList> GetLandmarks(absl::string_view filename) {
NormalizedLandmarkList landmarks; NormalizedLandmarkList landmarks;
@ -89,7 +94,8 @@ void MakeInputPacketsAndRunGraph(CalculatorGraph& graph) {
TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) { TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
CalculatorGraphConfig graph_config = ParseTextProtoOrDie< CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
CalculatorGraphConfig>(R"pb( CalculatorGraphConfig>(absl::Substitute(
R"pb(
input_stream: "FACE_LANDMARKS:face_landmarks" input_stream: "FACE_LANDMARKS:face_landmarks"
input_stream: "IMAGE_SIZE:image_size" input_stream: "IMAGE_SIZE:image_size"
output_stream: "FACE_GEOMETRY:face_geometry" output_stream: "FACE_GEOMETRY:face_geometry"
@ -98,8 +104,15 @@ TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
input_stream: "FACE_LANDMARKS:face_landmarks" input_stream: "FACE_LANDMARKS:face_landmarks"
input_stream: "IMAGE_SIZE:image_size" input_stream: "IMAGE_SIZE:image_size"
output_stream: "FACE_GEOMETRY:face_geometry" output_stream: "FACE_GEOMETRY:face_geometry"
options: {
[mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions
.ext] {
geometry_pipeline_options { metadata_file { file_name: "$0" } }
} }
)pb"); }
}
)pb",
kGeometryPipelineMetadataPath));
std::vector<Packet> output_packets; std::vector<Packet> output_packets;
tool::AddVectorSink("face_geometry", &graph_config, &output_packets); tool::AddVectorSink("face_geometry", &graph_config, &output_packets);
@ -116,7 +129,8 @@ TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) { TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) {
CalculatorGraphConfig graph_config = ParseTextProtoOrDie< CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
CalculatorGraphConfig>(R"pb( CalculatorGraphConfig>(absl::Substitute(
R"pb(
input_stream: "FACE_LANDMARKS:face_landmarks" input_stream: "FACE_LANDMARKS:face_landmarks"
input_stream: "IMAGE_SIZE:image_size" input_stream: "IMAGE_SIZE:image_size"
input_side_packet: "ENVIRONMENT:environment" input_side_packet: "ENVIRONMENT:environment"
@ -127,8 +141,15 @@ TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) {
input_stream: "IMAGE_SIZE:image_size" input_stream: "IMAGE_SIZE:image_size"
input_side_packet: "ENVIRONMENT:environment" input_side_packet: "ENVIRONMENT:environment"
output_stream: "FACE_GEOMETRY:face_geometry" output_stream: "FACE_GEOMETRY:face_geometry"
options: {
[mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions
.ext] {
geometry_pipeline_options { metadata_file { file_name: "$0" } }
} }
)pb"); }
}
)pb",
kGeometryPipelineMetadataPath));
std::vector<Packet> output_packets; std::vector<Packet> output_packets;
tool::AddVectorSink("face_geometry", &graph_config, &output_packets); tool::AddVectorSink("face_geometry", &graph_config, &output_packets);

View File

@ -99,7 +99,7 @@ class ScreenToMetricSpaceConverter {
// //
// (3) Use the canonical-to-runtime scale from (2) to unproject the screen // (3) Use the canonical-to-runtime scale from (2) to unproject the screen
// landmarks. The result is referenced as "intermediate landmarks" because // landmarks. The result is referenced as "intermediate landmarks" because
// they are the first estimation of the resuling metric landmarks, but are // they are the first estimation of the resulting metric landmarks,but are
// not quite there yet. // not quite there yet.
// //
// (4) Estimate a canonical-to-runtime landmark set scale by running the // (4) Estimate a canonical-to-runtime landmark set scale by running the
@ -347,7 +347,7 @@ class GeometryPipelineImpl : public GeometryPipeline {
proto::Mesh3d* mutable_mesh = face_geometry.mutable_mesh(); proto::Mesh3d* mutable_mesh = face_geometry.mutable_mesh();
// Copy the canonical face mesh as the face geometry mesh. // Copy the canonical face mesh as the face geometry mesh.
mutable_mesh->CopyFrom(canonical_mesh_); mutable_mesh->CopyFrom(canonical_mesh_);
// Replace XYZ vertex mesh coodinates with the metric landmark positions. // Replace XYZ vertex mesh coordinates with the metric landmark positions.
for (int i = 0; i < canonical_mesh_num_vertices_; ++i) { for (int i = 0; i < canonical_mesh_num_vertices_; ++i) {
uint32_t vertex_buffer_offset = canonical_mesh_vertex_size_ * i + uint32_t vertex_buffer_offset = canonical_mesh_vertex_size_ * i +
canonical_mesh_vertex_position_offset_; canonical_mesh_vertex_position_offset_;

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type")
licenses(["notice"]) licenses(["notice"])
@ -23,6 +24,16 @@ mediapipe_proto_library(
srcs = ["environment.proto"], srcs = ["environment.proto"],
) )
mediapipe_register_type(
base_name = "face_geometry",
include_headers = ["mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"],
types = [
"::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry",
"::std::vector<::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry>",
],
deps = [":face_geometry_cc_proto"],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "face_geometry_proto", name = "face_geometry_proto",
srcs = ["face_geometry.proto"], srcs = ["face_geometry.proto"],
@ -44,3 +55,12 @@ mediapipe_proto_library(
name = "mesh_3d_proto", name = "mesh_3d_proto",
srcs = ["mesh_3d.proto"], srcs = ["mesh_3d.proto"],
) )
mediapipe_proto_library(
name = "face_geometry_graph_options_proto",
srcs = ["face_geometry_graph_options.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_proto",
],
)

View File

@ -16,7 +16,7 @@ syntax = "proto2";
package mediapipe.tasks.vision.face_geometry.proto; package mediapipe.tasks.vision.face_geometry.proto;
option java_package = "mediapipe.tasks.vision.facegeometry.proto"; option java_package = "com.google.mediapipe.tasks.vision.facegeometry.proto";
option java_outer_classname = "EnvironmentProto"; option java_outer_classname = "EnvironmentProto";
// Defines the (0, 0) origin point location of the environment. // Defines the (0, 0) origin point location of the environment.

View File

@ -19,7 +19,7 @@ package mediapipe.tasks.vision.face_geometry.proto;
import "mediapipe/framework/formats/matrix_data.proto"; import "mediapipe/framework/formats/matrix_data.proto";
import "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto"; import "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto";
option java_package = "mediapipe.tasks.vision.facegeometry.proto"; option java_package = "com.google.mediapipe.tasks.vision.facegeometry.proto";
option java_outer_classname = "FaceGeometryProto"; option java_outer_classname = "FaceGeometryProto";
// Defines the face geometry pipeline estimation result format. // Defines the face geometry pipeline estimation result format.
@ -28,7 +28,7 @@ message FaceGeometry {
// the face landmark IDs. // the face landmark IDs.
// //
// XYZ coordinates exist in the right-handed Metric 3D space configured by an // XYZ coordinates exist in the right-handed Metric 3D space configured by an
// environment. UV coodinates are taken from the canonical face mesh model. // environment. UV coordinates are taken from the canonical face mesh model.
// //
// XY coordinates are guaranteed to match the screen positions of // XY coordinates are guaranteed to match the screen positions of
// the input face landmarks after (1) being multiplied by the face pose // the input face landmarks after (1) being multiplied by the face pose

Some files were not shown because too many files have changed in this diff Show More