Merge branch 'master' into face-stylizer-python

This commit is contained in:
Kinar R 2023-03-23 09:27:52 +05:30 committed by GitHub
commit 3afe4cafc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
305 changed files with 21319 additions and 1142 deletions

17
LICENSE
View File

@ -199,3 +199,20 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
===========================================================================
For files under tasks/cc/text/language_detector/custom_ops/utils/utf/
===========================================================================
/*
* The authors of this software are Rob Pike and Ken Thompson.
* Copyright (c) 2002 by Lucent Technologies.
* Permission to use, copy, modify, and distribute this software for any
* purpose without fee is hereby granted, provided that this entire notice
* is included in all copies of any software which is or includes a copy
* or modification of this software and in all copies of the supporting
* documentation for such software.
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
*/

View File

@ -270,7 +270,7 @@ new_local_repository(
# For local MacOS builds, the path should point to an opencv@3 installation. # For local MacOS builds, the path should point to an opencv@3 installation.
# If you edit the path here, you will also need to update the corresponding # If you edit the path here, you will also need to update the corresponding
# prefix in "opencv_macos.BUILD". # prefix in "opencv_macos.BUILD".
path = "/usr/local", path = "/usr/local", # e.g. /usr/local/Cellar for HomeBrew
) )
new_local_repository( new_local_repository(
@ -499,8 +499,8 @@ cc_crosstool(name = "crosstool")
# Node dependencies # Node dependencies
http_archive( http_archive(
name = "build_bazel_rules_nodejs", name = "build_bazel_rules_nodejs",
sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda", sha256 = "94070eff79305be05b7699207fbac5d2608054dd53e6109f7d00d923919ff45a",
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"], urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.8.2/rules_nodejs-5.8.2.tar.gz"],
) )
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies") load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")
@ -543,3 +543,43 @@ external_files()
load("@//third_party:wasm_files.bzl", "wasm_files") load("@//third_party:wasm_files.bzl", "wasm_files")
wasm_files() wasm_files()
# Halide
new_local_repository(
name = "halide",
build_file = "@//third_party/halide:BUILD.bazel",
path = "third_party/halide"
)
http_archive(
name = "linux_halide",
sha256 = "f62b2914823d6e33d18693f5b74484f274523bf5402ce51988e24393d123b375",
strip_prefix = "Halide-15.0.0-x86-64-linux",
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-linux-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"],
build_file = "@//third_party:halide.BUILD",
)
http_archive(
name = "macos_x86_64_halide",
sha256 = "3d832aed942080ea89aa832462c68fbb906f3055c440b7b6d35093d7c52f6aab",
strip_prefix = "Halide-15.0.0-x86-64-osx",
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-osx-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"],
build_file = "@//third_party:halide.BUILD",
)
http_archive(
name = "macos_arm_64_halide",
sha256 = "b1fad3c9810122b187303d7031d9e35fb43761f345d18cc4492c00ed5877f641",
strip_prefix = "Halide-15.0.0-arm-64-osx",
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-arm-64-osx-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"],
build_file = "@//third_party:halide.BUILD",
)
http_archive(
name = "windows_halide",
sha256 = "5acf6fe161dd375856a2b43f4bb0a32815ba958b0585ee312c44e008aa7b0b64",
strip_prefix = "Halide-15.0.0-x86-64-windows",
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-windows-d7651f4b32f9dbd764f243134001f7554378d62d.zip"],
build_file = "@//third_party:halide.BUILD",
)

View File

@ -113,14 +113,14 @@ Warning: On the other hand, it is not guaranteed that an input packet will
always be available for all streams. always be available for all streams.
To explain how it works, we need to introduce the definition of a settled To explain how it works, we need to introduce the definition of a settled
timestamp. We say that a timestamp in a stream is *settled* if it lower than the timestamp. We say that a timestamp in a stream is *settled* if it is lower than
timestamp bound. In other words, a timestamp is settled for a stream once the the timestamp bound. In other words, a timestamp is settled for a stream once
state of the input at that timestamp is irrevocably known: either there is a the state of the input at that timestamp is irrevocably known: either there is a
packet, or there is the certainty that a packet with that timestamp will not packet, or there is the certainty that a packet with that timestamp will not
arrive. arrive.
Note: For this reason, MediaPipe also allows a stream producer to explicitly Note: For this reason, MediaPipe also allows a stream producer to explicitly
advance the timestamp bound farther that what the last packet implies, i.e. to advance the timestamp bound farther than what the last packet implies, i.e. to
provide a tighter bound. This can allow the downstream nodes to settle their provide a tighter bound. This can allow the downstream nodes to settle their
inputs sooner. inputs sooner.

View File

@ -108,6 +108,8 @@ one over the other.
* [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite) * [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite)
* [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite) * [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite)
* [TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip)
* [Model information](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md)
### [Objectron](https://google.github.io/mediapipe/solutions/objectron) ### [Objectron](https://google.github.io/mediapipe/solutions/objectron)

View File

@ -118,9 +118,9 @@ on how to build MediaPipe examples.
* With a TensorFlow Model * With a TensorFlow Model
This uses the This uses the
[TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model) [TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip)
( see also ( see also
[model info](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md)), [model info](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md)),
and the pipeline is implemented in this and the pipeline is implemented in this
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt). [graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt).

View File

@ -0,0 +1,62 @@
## TensorFlow/TFLite Object Detection Model
### TensorFlow model
The model is trained on [MSCOCO 2014](http://cocodataset.org) dataset using [TensorFlow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection). It is a MobileNetV2-based SSD model with 0.5 depth multiplier. Detailed training configuration is in the provided `pipeline.config`. The model is a relatively compact model which has `0.171 mAP` to achieve real-time performance on mobile devices. You can compare it with other models from the [TensorFlow detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md).
### TFLite model
The TFLite model is converted from the TensorFlow above. The steps needed to convert the model are similar to [this tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) with minor modifications. Assuming now we have a trained TensorFlow model which includes the checkpoint files and the training configuration file, for example the files provided in this repo:
* `model.ckpt.index`
* `model.ckpt.meta`
* `model.ckpt.data-00000-of-00001`
* `pipeline.config`
Make sure you have installed these [python libraries](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1.md). Then to get the frozen graph, run the `export_tflite_ssd_graph.py` script from the `models/research` directory with this command:
```bash
$ PATH_TO_MODEL=path/to/the/model
$ bazel run object_detection:export_tflite_ssd_graph -- \
--pipeline_config_path ${PATH_TO_MODEL}/pipeline.config \
--trained_checkpoint_prefix ${PATH_TO_MODEL}/model.ckpt \
--output_directory ${PATH_TO_MODEL} \
--add_postprocessing_op=False
```
The exported model contains two files:
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
The difference between this step and the one in [the tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) is that we set `add_postprocessing_op` to False. In MediaPipe, we have provided all the calculators needed for post-processing such that we can exclude the custom TFLite ops for post-processing in the original graph, e.g., non-maximum suppression. This enables the flexibility to integrate with different post-processing algorithms and implementations.
Optional: You can install and use the [graph tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms) to inspect the input/output of the exported model:
```bash
$ bazel run graph_transforms:summarize_graph -- \
--in_graph=${PATH_TO_MODEL}/tflite_graph.pb
```
You should be able to see the input image size of the model is 320x320 and the outputs of the model are:
* `raw_outputs/box_encodings`
* `raw_outputs/class_predictions`
The last step is to convert the model to TFLite. You can look at [this guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md) for more detail. For this example, you just need to run:
```bash
$ tflite_convert -- \
--graph_def_file=${PATH_TO_MODEL}/tflite_graph.pb \
--output_file=${PATH_TO_MODEL}/model.tflite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--inference_type=FLOAT \
--input_shapes=1,320,320,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays=raw_outputs/box_encodings,raw_outputs/class_predictions
```
Now you have the TFLite model `model.tflite` ready to use with MediaPipe Object Detection graphs. Please see the examples for more detail.

View File

@ -269,6 +269,7 @@ Supported configuration options:
```python ```python
import cv2 import cv2
import mediapipe as mp import mediapipe as mp
import numpy as np
mp_drawing = mp.solutions.drawing_utils mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles mp_drawing_styles = mp.solutions.drawing_styles
mp_pose = mp.solutions.pose mp_pose = mp.solutions.pose

View File

@ -156,6 +156,7 @@ cc_library(
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",
"//mediapipe/framework/port:opencv_imgproc",
] + select({ ] + select({
"//mediapipe/gpu:disable_gpu": [], "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [ "//conditions:default": [
@ -168,6 +169,25 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "set_alpha_calculator_test",
srcs = ["set_alpha_calculator_test.cc"],
deps = [
":set_alpha_calculator",
":set_alpha_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"@com_google_googletest//:gtest_main",
],
)
cc_library( cc_library(
name = "bilateral_filter_calculator", name = "bilateral_filter_calculator",
srcs = ["bilateral_filter_calculator.cc"], srcs = ["bilateral_filter_calculator.cc"],
@ -748,6 +768,7 @@ cc_test(
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png", "//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
], ],
tags = ["desktop_only_test"], tags = ["desktop_only_test"],

View File

@ -29,6 +29,9 @@ class AffineTransformation {
// pixels will be calculated. // pixels will be calculated.
enum class BorderMode { kZero, kReplicate }; enum class BorderMode { kZero, kReplicate };
// Pixel sampling interpolation method.
enum class Interpolation { kLinear, kCubic };
struct Size { struct Size {
int width; int width;
int height; int height;

View File

@ -77,8 +77,11 @@ class GlTextureWarpAffineRunner
std::unique_ptr<GpuBuffer>> { std::unique_ptr<GpuBuffer>> {
public: public:
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper, GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
GpuOrigin::Mode gpu_origin) GpuOrigin::Mode gpu_origin,
: gl_helper_(gl_helper), gpu_origin_(gpu_origin) {} AffineTransformation::Interpolation interpolation)
: gl_helper_(gl_helper),
gpu_origin_(gpu_origin),
interpolation_(interpolation) {}
absl::Status Init() { absl::Status Init() {
return gl_helper_->RunInGlContext([this]() -> absl::Status { return gl_helper_->RunInGlContext([this]() -> absl::Status {
const GLint attr_location[kNumAttributes] = { const GLint attr_location[kNumAttributes] = {
@ -103,28 +106,83 @@ class GlTextureWarpAffineRunner
} }
)"; )";
// TODO Move bicubic code to common shared place.
constexpr GLchar kFragShader[] = R"( constexpr GLchar kFragShader[] = R"(
DEFAULT_PRECISION(highp, float) DEFAULT_PRECISION(highp, float)
in vec2 sample_coordinate;
uniform sampler2D input_texture;
#ifdef GL_ES in vec2 sample_coordinate;
#define fragColor gl_FragColor uniform sampler2D input_texture;
#else uniform vec2 input_size;
out vec4 fragColor;
#endif // defined(GL_ES);
void main() { #ifdef GL_ES
vec4 color = texture2D(input_texture, sample_coordinate); #define fragColor gl_FragColor
#ifdef CUSTOM_ZERO_BORDER_MODE #else
float out_of_bounds = out vec4 fragColor;
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 || #endif // defined(GL_ES);
sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0);
color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds); #ifdef CUBIC_INTERPOLATION
#endif // defined(CUSTOM_ZERO_BORDER_MODE) vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
fragColor = color; const vec2 halve = vec2(0.5,0.5);
} const vec2 one = vec2(1.0,1.0);
)"; const vec2 two = vec2(2.0,2.0);
const vec2 three = vec2(3.0,3.0);
const vec2 six = vec2(6.0,6.0);
// Calculate the fraction and integer.
tex_coord = tex_coord * tex_size - halve;
vec2 frac = fract(tex_coord);
vec2 index = tex_coord - frac + halve;
// Calculate weights for Catmull-Rom filter.
vec2 w0 = frac * (-halve + frac * (one - halve * frac));
vec2 w1 = one + frac * frac * (-(two+halve) + three/two * frac);
vec2 w2 = frac * (halve + frac * (two - three/two * frac));
vec2 w3 = frac * frac * (-halve + halve * frac);
// Calculate weights to take advantage of bilinear texture lookup.
vec2 w12 = w1 + w2;
vec2 offset12 = w2 / (w1 + w2);
vec2 index_tl = index - one;
vec2 index_br = index + two;
vec2 index_eq = index + offset12;
index_tl /= tex_size;
index_br /= tex_size;
index_eq /= tex_size;
// 9 texture lookup and linear blending.
vec4 color = vec4(0.0);
color += texture2D(tex, vec2(index_tl.x, index_tl.y)) * w0.x * w0.y;
color += texture2D(tex, vec2(index_eq.x, index_tl.y)) * w12.x *w0.y;
color += texture2D(tex, vec2(index_br.x, index_tl.y)) * w3.x * w0.y;
color += texture2D(tex, vec2(index_tl.x, index_eq.y)) * w0.x * w12.y;
color += texture2D(tex, vec2(index_eq.x, index_eq.y)) * w12.x *w12.y;
color += texture2D(tex, vec2(index_br.x, index_eq.y)) * w3.x * w12.y;
color += texture2D(tex, vec2(index_tl.x, index_br.y)) * w0.x * w3.y;
color += texture2D(tex, vec2(index_eq.x, index_br.y)) * w12.x *w3.y;
color += texture2D(tex, vec2(index_br.x, index_br.y)) * w3.x * w3.y;
return color;
}
#else
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
return texture2D(tex, tex_coord);
}
#endif // defined(CUBIC_INTERPOLATION)
void main() {
vec4 color = sample(input_texture, sample_coordinate, input_size);
#ifdef CUSTOM_ZERO_BORDER_MODE
float out_of_bounds =
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
sample_coordinate.y < 0.0 || sample_coordinate.y > 1.0);
color = mix(color, vec4(0.0, 0.0, 0.0, 0.0), out_of_bounds);
#endif // defined(CUSTOM_ZERO_BORDER_MODE)
fragColor = color;
}
)";
// Create program and set parameters. // Create program and set parameters.
auto create_fn = [&](const std::string& vs, auto create_fn = [&](const std::string& vs,
@ -137,14 +195,28 @@ class GlTextureWarpAffineRunner
glUseProgram(program); glUseProgram(program);
glUniform1i(glGetUniformLocation(program, "input_texture"), 1); glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
GLint matrix_id = glGetUniformLocation(program, "transform_matrix"); GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
return Program{.id = program, .matrix_id = matrix_id}; GLint size_id = glGetUniformLocation(program, "input_size");
return Program{
.id = program, .matrix_id = matrix_id, .size_id = size_id};
}; };
const std::string vert_src = const std::string vert_src =
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader); absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
const std::string frag_src = absl::StrCat( std::string interpolation_def;
mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader); switch (interpolation_) {
case AffineTransformation::Interpolation::kCubic:
interpolation_def = R"(
#define CUBIC_INTERPOLATION
)";
break;
case AffineTransformation::Interpolation::kLinear:
break;
}
const std::string frag_src =
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
interpolation_def, kFragShader);
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src)); ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
@ -152,9 +224,9 @@ class GlTextureWarpAffineRunner
std::string custom_zero_border_mode_def = R"( std::string custom_zero_border_mode_def = R"(
#define CUSTOM_ZERO_BORDER_MODE #define CUSTOM_ZERO_BORDER_MODE
)"; )";
const std::string frag_custom_zero_src = const std::string frag_custom_zero_src = absl::StrCat(
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble, mediapipe::kMediaPipeFragmentShaderPreamble,
custom_zero_border_mode_def, kFragShader); custom_zero_border_mode_def, interpolation_def, kFragShader);
return create_fn(vert_src, frag_custom_zero_src); return create_fn(vert_src, frag_custom_zero_src);
}; };
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED #if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
@ -256,6 +328,7 @@ class GlTextureWarpAffineRunner
} }
glUseProgram(program->id); glUseProgram(program->id);
// uniforms
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data()); Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) { if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
// @matrix describes affine transformation in terms of TOP LEFT origin, so // @matrix describes affine transformation in terms of TOP LEFT origin, so
@ -275,6 +348,10 @@ class GlTextureWarpAffineRunner
eigen_mat.transposeInPlace(); eigen_mat.transposeInPlace();
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data()); glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
if (interpolation_ == AffineTransformation::Interpolation::kCubic) {
glUniform2f(program->size_id, texture.width(), texture.height());
}
// vao // vao
glBindVertexArray(vao_); glBindVertexArray(vao_);
@ -327,6 +404,7 @@ class GlTextureWarpAffineRunner
struct Program { struct Program {
GLuint id; GLuint id;
GLint matrix_id; GLint matrix_id;
GLint size_id;
}; };
std::shared_ptr<GlCalculatorHelper> gl_helper_; std::shared_ptr<GlCalculatorHelper> gl_helper_;
GpuOrigin::Mode gpu_origin_; GpuOrigin::Mode gpu_origin_;
@ -335,6 +413,8 @@ class GlTextureWarpAffineRunner
Program program_; Program program_;
std::optional<Program> program_custom_zero_; std::optional<Program> program_custom_zero_;
GLuint framebuffer_ = 0; GLuint framebuffer_ = 0;
AffineTransformation::Interpolation interpolation_ =
AffineTransformation::Interpolation::kLinear;
}; };
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED #undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
@ -344,9 +424,10 @@ class GlTextureWarpAffineRunner
absl::StatusOr<std::unique_ptr< absl::StatusOr<std::unique_ptr<
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>> AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
CreateAffineTransformationGlRunner( CreateAffineTransformationGlRunner(
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin) { std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin,
auto runner = AffineTransformation::Interpolation interpolation) {
absl::make_unique<GlTextureWarpAffineRunner>(gl_helper, gpu_origin); auto runner = absl::make_unique<GlTextureWarpAffineRunner>(
gl_helper, gpu_origin, interpolation);
MP_RETURN_IF_ERROR(runner->Init()); MP_RETURN_IF_ERROR(runner->Init());
return runner; return runner;
} }

View File

@ -29,7 +29,8 @@ absl::StatusOr<std::unique_ptr<AffineTransformation::Runner<
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>> mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
CreateAffineTransformationGlRunner( CreateAffineTransformationGlRunner(
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper, std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
mediapipe::GpuOrigin::Mode gpu_origin); mediapipe::GpuOrigin::Mode gpu_origin,
AffineTransformation::Interpolation interpolation);
} // namespace mediapipe } // namespace mediapipe

View File

@ -39,9 +39,22 @@ cv::BorderTypes GetBorderModeForOpenCv(
} }
} }
int GetInterpolationForOpenCv(
AffineTransformation::Interpolation interpolation) {
switch (interpolation) {
case AffineTransformation::Interpolation::kLinear:
return cv::INTER_LINEAR;
case AffineTransformation::Interpolation::kCubic:
return cv::INTER_CUBIC;
}
}
class OpenCvRunner class OpenCvRunner
: public AffineTransformation::Runner<ImageFrame, ImageFrame> { : public AffineTransformation::Runner<ImageFrame, ImageFrame> {
public: public:
OpenCvRunner(AffineTransformation::Interpolation interpolation)
: interpolation_(GetInterpolationForOpenCv(interpolation)) {}
absl::StatusOr<ImageFrame> Run( absl::StatusOr<ImageFrame> Run(
const ImageFrame& input, const std::array<float, 16>& matrix, const ImageFrame& input, const std::array<float, 16>& matrix,
const AffineTransformation::Size& size, const AffineTransformation::Size& size,
@ -142,19 +155,23 @@ class OpenCvRunner
cv::warpAffine(in_mat, out_mat, cv_affine_transform, cv::warpAffine(in_mat, out_mat, cv_affine_transform,
cv::Size(out_mat.cols, out_mat.rows), cv::Size(out_mat.cols, out_mat.rows),
/*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP, /*flags=*/interpolation_ | cv::WARP_INVERSE_MAP,
GetBorderModeForOpenCv(border_mode)); GetBorderModeForOpenCv(border_mode));
return out_image; return out_image;
} }
private:
int interpolation_ = cv::INTER_LINEAR;
}; };
} // namespace } // namespace
absl::StatusOr< absl::StatusOr<
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>> std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
CreateAffineTransformationOpenCvRunner() { CreateAffineTransformationOpenCvRunner(
return absl::make_unique<OpenCvRunner>(); AffineTransformation::Interpolation interpolation) {
return absl::make_unique<OpenCvRunner>(interpolation);
} }
} // namespace mediapipe } // namespace mediapipe

View File

@ -25,7 +25,8 @@ namespace mediapipe {
absl::StatusOr< absl::StatusOr<
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>> std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
CreateAffineTransformationOpenCvRunner(); CreateAffineTransformationOpenCvRunner(
AffineTransformation::Interpolation interpolation);
} // namespace mediapipe } // namespace mediapipe

View File

@ -81,7 +81,8 @@ class ImageCloneCalculator : public Node {
absl::Status Process(CalculatorContext* cc) override { absl::Status Process(CalculatorContext* cc) override {
std::unique_ptr<Image> output; std::unique_ptr<Image> output;
const auto& input = *kIn(cc); const auto& input = *kIn(cc);
if (input.UsesGpu()) { bool input_on_gpu = input.UsesGpu();
if (input_on_gpu) {
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
// Create an output Image that co-owns the underlying texture buffer as // Create an output Image that co-owns the underlying texture buffer as
// the input Image. // the input Image.
@ -97,15 +98,15 @@ class ImageCloneCalculator : public Node {
// Image. This ensures a correct life span of the shared pixel data. // Image. This ensures a correct life span of the shared pixel data.
output = std::make_unique<Image>(std::make_unique<mediapipe::ImageFrame>( output = std::make_unique<Image>(std::make_unique<mediapipe::ImageFrame>(
input.image_format(), input.width(), input.height(), input.step(), input.image_format(), input.width(), input.height(), input.step(),
const_cast<uint8*>(input.GetImageFrameSharedPtr()->PixelData()), const_cast<uint8_t*>(input.GetImageFrameSharedPtr()->PixelData()),
[packet_copy_ptr](uint8*) { delete packet_copy_ptr; })); [packet_copy_ptr](uint8_t*) { delete packet_copy_ptr; }));
} }
if (output_on_gpu_) { if (output_on_gpu_ && !input_on_gpu) {
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
gpu_helper_.RunInGlContext([&output]() { output->ConvertToGpu(); }); gpu_helper_.RunInGlContext([&output]() { output->ConvertToGpu(); });
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else if (!output_on_gpu_ && input_on_gpu) {
output->ConvertToCpu(); output->ConvertToCpu();
} }
kOut(cc).Send(std::move(output)); kOut(cc).Send(std::move(output));

View File

@ -22,6 +22,7 @@
#include "mediapipe/framework/formats/image_frame_opencv.h" #include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h" #include "mediapipe/framework/port/vector.h"
@ -53,24 +54,16 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
// range of [0, 1). Only the first channel of Alpha is used. Input & output Mat // range of [0, 1). Only the first channel of Alpha is used. Input & output Mat
// must be uchar. // must be uchar.
template <typename AlphaType> template <typename AlphaType>
absl::Status MergeRGBA8Image(const cv::Mat input_mat, const cv::Mat& alpha_mat, absl::Status CopyAlphaImage(const cv::Mat& alpha_mat, cv::Mat& output_mat) {
cv::Mat& output_mat) { RET_CHECK_EQ(output_mat.rows, alpha_mat.rows);
RET_CHECK_EQ(input_mat.rows, alpha_mat.rows); RET_CHECK_EQ(output_mat.cols, alpha_mat.cols);
RET_CHECK_EQ(input_mat.cols, alpha_mat.cols);
RET_CHECK_EQ(input_mat.rows, output_mat.rows);
RET_CHECK_EQ(input_mat.cols, output_mat.cols);
for (int i = 0; i < output_mat.rows; ++i) { for (int i = 0; i < output_mat.rows; ++i) {
const uchar* in_ptr = input_mat.ptr<uchar>(i);
const AlphaType* alpha_ptr = alpha_mat.ptr<AlphaType>(i); const AlphaType* alpha_ptr = alpha_mat.ptr<AlphaType>(i);
uchar* out_ptr = output_mat.ptr<uchar>(i); uchar* out_ptr = output_mat.ptr<uchar>(i);
for (int j = 0; j < output_mat.cols; ++j) { for (int j = 0; j < output_mat.cols; ++j) {
const int out_idx = j * kNumChannelsRGBA; const int out_idx = j * kNumChannelsRGBA;
const int in_idx = j * input_mat.channels();
const int alpha_idx = j * alpha_mat.channels(); const int alpha_idx = j * alpha_mat.channels();
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
if constexpr (std::is_same<AlphaType, uchar>::value) { if constexpr (std::is_same<AlphaType, uchar>::value) {
out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask
} else { } else {
@ -273,7 +266,7 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
// Setup source image // Setup source image
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>(); const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
const cv::Mat input_mat = mediapipe::formats::MatView(&input_frame); const cv::Mat input_mat = formats::MatView(&input_frame);
if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) { if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) {
LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported"; LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported";
} }
@ -281,38 +274,38 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
// Setup destination image // Setup destination image
auto output_frame = absl::make_unique<ImageFrame>( auto output_frame = absl::make_unique<ImageFrame>(
ImageFormat::SRGBA, input_mat.cols, input_mat.rows); ImageFormat::SRGBA, input_mat.cols, input_mat.rows);
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get()); cv::Mat output_mat = formats::MatView(output_frame.get());
const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) && const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) &&
!cc->Inputs().Tag(kInputAlphaTag).IsEmpty(); !cc->Inputs().Tag(kInputAlphaTag).IsEmpty();
const bool use_alpha_mask = alpha_value_ < 0 && has_alpha_mask; const bool use_alpha_mask = alpha_value_ < 0 && has_alpha_mask;
// Setup alpha image and Update image in CPU. // Copy rgb part of the image in CPU
if (input_mat.channels() == 3) {
cv::cvtColor(input_mat, output_mat, cv::COLOR_RGB2RGBA);
} else {
input_mat.copyTo(output_mat);
}
// Setup alpha image in CPU.
if (use_alpha_mask) { if (use_alpha_mask) {
const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get<ImageFrame>(); const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get<ImageFrame>();
cv::Mat alpha_mat = mediapipe::formats::MatView(&alpha_mask); cv::Mat alpha_mat = formats::MatView(&alpha_mask);
const bool alpha_is_float = CV_MAT_DEPTH(alpha_mat.type()) == CV_32F; const bool alpha_is_float = CV_MAT_DEPTH(alpha_mat.type()) == CV_32F;
RET_CHECK(alpha_is_float || CV_MAT_DEPTH(alpha_mat.type()) == CV_8U); RET_CHECK(alpha_is_float || CV_MAT_DEPTH(alpha_mat.type()) == CV_8U);
if (alpha_is_float) { if (alpha_is_float) {
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(CopyAlphaImage<float>(alpha_mat, output_mat));
MergeRGBA8Image<float>(input_mat, alpha_mat, output_mat));
} else { } else {
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(CopyAlphaImage<uchar>(alpha_mat, output_mat));
MergeRGBA8Image<uchar>(input_mat, alpha_mat, output_mat));
} }
} else { } else {
const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f); const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f);
for (int i = 0; i < output_mat.rows; ++i) { for (int i = 0; i < output_mat.rows; ++i) {
const uchar* in_ptr = input_mat.ptr<uchar>(i);
uchar* out_ptr = output_mat.ptr<uchar>(i); uchar* out_ptr = output_mat.ptr<uchar>(i);
for (int j = 0; j < output_mat.cols; ++j) { for (int j = 0; j < output_mat.cols; ++j) {
const int out_idx = j * kNumChannelsRGBA; const int out_idx = j * kNumChannelsRGBA;
const int in_idx = j * input_mat.channels();
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
out_ptr[out_idx + 3] = alpha_value; // use value from options out_ptr[out_idx + 3] = alpha_value; // use value from options
} }
} }

View File

@ -0,0 +1,156 @@
#include <cstdint>
#include "mediapipe/calculators/image/set_alpha_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "testing/base/public/benchmark.h"
namespace mediapipe {
namespace {
constexpr int input_width = 100;
constexpr int input_height = 100;
std::unique_ptr<ImageFrame> GetInputFrame(int width, int height, int channel) {
const int total_size = width * height * channel;
ImageFormat::Format image_format;
if (channel == 4) {
image_format = ImageFormat::SRGBA;
} else if (channel == 3) {
image_format = ImageFormat::SRGB;
} else {
image_format = ImageFormat::GRAY8;
}
auto input_frame = std::make_unique<ImageFrame>(image_format, width, height,
/*alignment_boundary =*/1);
for (int i = 0; i < total_size; ++i) {
input_frame->MutablePixelData()[i] = i % 256;
}
return input_frame;
}
// Test SetAlphaCalculator with RGB IMAGE input.
TEST(SetAlphaCalculatorTest, CpuRgb) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 3);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
EXPECT_EQ(outputs.NumEntries(), 1);
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
// Generate ground truth (expected_mat).
const auto image = GetInputFrame(input_width, input_height, 3);
const auto input_mat = formats::MatView(image.get());
const auto mask = GetInputFrame(input_width, input_height, 1);
const auto mask_mat = formats::MatView(mask.get());
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 3, 3});
cv::Mat output_mat = formats::MatView(&output_image);
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
EXPECT_FLOAT_EQ(max_diff, 0);
} // TEST
// Test SetAlphaCalculator with RGBA IMAGE input.
TEST(SetAlphaCalculatorTest, CpuRgba) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 4);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
EXPECT_EQ(outputs.NumEntries(), 1);
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
// Generate ground truth (expected_mat).
const auto image = GetInputFrame(input_width, input_height, 4);
const auto input_mat = formats::MatView(image.get());
const auto mask = GetInputFrame(input_width, input_height, 1);
const auto mask_mat = formats::MatView(mask.get());
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 4, 3});
cv::Mat output_mat = formats::MatView(&output_image);
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
EXPECT_FLOAT_EQ(max_diff, 0);
} // TEST
static void BM_SetAlpha3ChannelImage(benchmark::State& state) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 3);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
ASSERT_EQ(1, outputs.NumEntries());
for (const auto _ : state) {
MP_ASSERT_OK(runner.Run());
}
}
BENCHMARK(BM_SetAlpha3ChannelImage);
} // namespace
} // namespace mediapipe

View File

@ -53,6 +53,17 @@ AffineTransformation::BorderMode GetBorderMode(
} }
} }
AffineTransformation::Interpolation GetInterpolation(
mediapipe::WarpAffineCalculatorOptions::Interpolation interpolation) {
switch (interpolation) {
case mediapipe::WarpAffineCalculatorOptions::INTER_UNSPECIFIED:
case mediapipe::WarpAffineCalculatorOptions::INTER_LINEAR:
return AffineTransformation::Interpolation::kLinear;
case mediapipe::WarpAffineCalculatorOptions::INTER_CUBIC:
return AffineTransformation::Interpolation::kCubic;
}
}
template <typename ImageT> template <typename ImageT>
class WarpAffineRunnerHolder {}; class WarpAffineRunnerHolder {};
@ -61,16 +72,22 @@ template <>
class WarpAffineRunnerHolder<ImageFrame> { class WarpAffineRunnerHolder<ImageFrame> {
public: public:
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>; using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); } absl::Status Open(CalculatorContext* cc) {
interpolation_ = GetInterpolation(
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
return absl::OkStatus();
}
absl::StatusOr<RunnerType*> GetRunner() { absl::StatusOr<RunnerType*> GetRunner() {
if (!runner_) { if (!runner_) {
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner()); ASSIGN_OR_RETURN(runner_,
CreateAffineTransformationOpenCvRunner(interpolation_));
} }
return runner_.get(); return runner_.get();
} }
private: private:
std::unique_ptr<RunnerType> runner_; std::unique_ptr<RunnerType> runner_;
AffineTransformation::Interpolation interpolation_;
}; };
#endif // !MEDIAPIPE_DISABLE_OPENCV #endif // !MEDIAPIPE_DISABLE_OPENCV
@ -85,12 +102,14 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
gpu_origin_ = gpu_origin_ =
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin(); cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>(); gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
interpolation_ = GetInterpolation(
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
return gl_helper_->Open(cc); return gl_helper_->Open(cc);
} }
absl::StatusOr<RunnerType*> GetRunner() { absl::StatusOr<RunnerType*> GetRunner() {
if (!runner_) { if (!runner_) {
ASSIGN_OR_RETURN( ASSIGN_OR_RETURN(runner_, CreateAffineTransformationGlRunner(
runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_)); gl_helper_, gpu_origin_, interpolation_));
} }
return runner_.get(); return runner_.get();
} }
@ -99,6 +118,7 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
mediapipe::GpuOrigin::Mode gpu_origin_; mediapipe::GpuOrigin::Mode gpu_origin_;
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_; std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
std::unique_ptr<RunnerType> runner_; std::unique_ptr<RunnerType> runner_;
AffineTransformation::Interpolation interpolation_;
}; };
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU

View File

@ -31,6 +31,13 @@ message WarpAffineCalculatorOptions {
BORDER_REPLICATE = 2; BORDER_REPLICATE = 2;
} }
// Pixel sampling interpolation methods. See @interpolation.
enum Interpolation {
INTER_UNSPECIFIED = 0;
INTER_LINEAR = 1;
INTER_CUBIC = 2;
}
// Pixel extrapolation method. // Pixel extrapolation method.
// When converting image to tensor it may happen that tensor needs to read // When converting image to tensor it may happen that tensor needs to read
// pixels outside image boundaries. Border mode helps to specify how such // pixels outside image boundaries. Border mode helps to specify how such
@ -43,4 +50,10 @@ message WarpAffineCalculatorOptions {
// to be flipped vertically as tensors are expected to start at top. // to be flipped vertically as tensors are expected to start at top.
// (DEFAULT or unset interpreted as CONVENTIONAL.) // (DEFAULT or unset interpreted as CONVENTIONAL.)
optional GpuOrigin.Mode gpu_origin = 2; optional GpuOrigin.Mode gpu_origin = 2;
// Sampling method for neighboring pixels.
// INTER_LINEAR (bilinear) linearly interpolates from the nearest 4 neighbors.
// INTER_CUBIC (bicubic) interpolates a small neighborhood with cubic weights.
// INTER_UNSPECIFIED or unset interpreted as INTER_LINEAR.
optional Interpolation interpolation = 3;
} }

View File

@ -63,7 +63,8 @@ void RunTest(const std::string& graph_text, const std::string& tag,
const cv::Mat& input, cv::Mat expected_result, const cv::Mat& input, cv::Mat expected_result,
float similarity_threshold, std::array<float, 16> matrix, float similarity_threshold, std::array<float, 16> matrix,
int out_width, int out_height, int out_width, int out_height,
absl::optional<AffineTransformation::BorderMode> border_mode) { std::optional<AffineTransformation::BorderMode> border_mode,
std::optional<AffineTransformation::Interpolation> interpolation) {
std::string border_mode_str; std::string border_mode_str;
if (border_mode) { if (border_mode) {
switch (*border_mode) { switch (*border_mode) {
@ -75,8 +76,20 @@ void RunTest(const std::string& graph_text, const std::string& tag,
break; break;
} }
} }
std::string interpolation_str;
if (interpolation) {
switch (*interpolation) {
case AffineTransformation::Interpolation::kLinear:
interpolation_str = "interpolation: INTER_LINEAR";
break;
case AffineTransformation::Interpolation::kCubic:
interpolation_str = "interpolation: INTER_CUBIC";
break;
}
}
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>( auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(graph_text, /*$0=*/border_mode_str)); absl::Substitute(graph_text, /*$0=*/border_mode_str,
/*$1=*/interpolation_str));
std::vector<Packet> output_packets; std::vector<Packet> output_packets;
tool::AddVectorSink("output_image", &graph_config, &output_packets); tool::AddVectorSink("output_image", &graph_config, &output_packets);
@ -132,7 +145,8 @@ struct SimilarityConfig {
void RunTest(cv::Mat input, cv::Mat expected_result, void RunTest(cv::Mat input, cv::Mat expected_result,
const SimilarityConfig& similarity, std::array<float, 16> matrix, const SimilarityConfig& similarity, std::array<float, 16> matrix,
int out_width, int out_height, int out_width, int out_height,
absl::optional<AffineTransformation::BorderMode> border_mode) { std::optional<AffineTransformation::BorderMode> border_mode,
std::optional<AffineTransformation::Interpolation> interpolation) {
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
input_stream: "output_size" input_stream: "output_size"
@ -146,12 +160,13 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
} }
} }
} }
)", )",
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix, "cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
@ -171,6 +186,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
} }
} }
} }
@ -181,7 +197,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
} }
)", )",
"cpu_image", input, expected_result, similarity.threshold_on_cpu, "cpu_image", input, expected_result, similarity.threshold_on_cpu,
matrix, out_width, out_height, border_mode); matrix, out_width, out_height, border_mode, interpolation);
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
@ -201,6 +217,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
gpu_origin: TOP_LEFT gpu_origin: TOP_LEFT
} }
} }
@ -212,7 +229,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
} }
)", )",
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix, "gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
RunTest(R"( RunTest(R"(
input_stream: "input_image" input_stream: "input_image"
@ -237,6 +254,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options { options {
[mediapipe.WarpAffineCalculatorOptions.ext] { [mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode $0 # border mode
$1 # interpolation
gpu_origin: TOP_LEFT gpu_origin: TOP_LEFT
} }
} }
@ -253,7 +271,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
} }
)", )",
"gpu_image", input, expected_result, similarity.threshold_on_gpu, "gpu_image", input, expected_result, similarity.threshold_on_gpu,
matrix, out_width, out_height, border_mode); matrix, out_width, out_height, border_mode, interpolation);
} }
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi, std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
@ -287,10 +305,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) {
int out_height = 256; int out_height = 256;
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = {}; std::optional<AffineTransformation::BorderMode> border_mode = {};
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) { TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
@ -312,10 +331,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) { TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
@ -337,10 +357,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) { TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
@ -362,10 +383,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) { TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
@ -386,10 +408,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) { TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
@ -411,10 +434,38 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZeroInterpCubic) {
mediapipe::NormalizedRect roi;
roi.set_x_center(0.65f);
roi.set_y_center(0.4f);
roi.set_width(0.5f);
roi.set_height(0.5f);
roi.set_rotation(M_PI * -45.0f / 180.0f);
auto input = GetRgb(
"/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/input.jpg");
auto expected_output = GetRgb(
"/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"medium_sub_rect_with_rotation_border_zero_interp_cubic.png");
int out_width = 256;
int out_height = 256;
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation =
AffineTransformation::Interpolation::kCubic;
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.78},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRect) { TEST(WarpAffineCalculatorTest, LargeSubRect) {
@ -435,10 +486,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRect) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) { TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
@ -459,10 +511,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
bool keep_aspect_ratio = false; bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
@ -483,10 +536,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
@ -508,10 +562,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
@ -532,10 +587,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
int out_height = 128; int out_height = 128;
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = {}; std::optional<AffineTransformation::BorderMode> border_mode = {};
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) { TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
@ -557,10 +613,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, NoOp) { TEST(WarpAffineCalculatorTest, NoOp) {
@ -581,10 +638,11 @@ TEST(WarpAffineCalculatorTest, NoOp) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate; AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
TEST(WarpAffineCalculatorTest, NoOpBorderZero) { TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
@ -605,10 +663,11 @@ TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
bool keep_aspect_ratio = true; bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero; AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output, RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99}, {.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height), GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode); out_width, out_height, border_mode, interpolation);
} }
} // namespace } // namespace

View File

@ -997,17 +997,20 @@ cc_library(
":image_to_tensor_converter_gl_buffer", ":image_to_tensor_converter_gl_buffer",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
], ],
"//mediapipe:apple": [ "//mediapipe:apple": [
":image_to_tensor_converter_metal", ":image_to_tensor_converter_metal",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
], ],
"//conditions:default": [ "//conditions:default": [
":image_to_tensor_converter_gl_buffer", ":image_to_tensor_converter_gl_buffer",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
], ],
}), }),
) )
@ -1045,6 +1048,10 @@ cc_test(
":image_to_tensor_calculator", ":image_to_tensor_calculator",
":image_to_tensor_converter", ":image_to_tensor_converter",
":image_to_tensor_utils", ":image_to_tensor_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner", "//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/deps:file_path",
@ -1061,11 +1068,10 @@ cc_test(
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/util:image_test_utils", "//mediapipe/util:image_test_utils",
"@com_google_absl//absl/flags:flag", ] + select({
"@com_google_absl//absl/memory", "//mediapipe:apple": [],
"@com_google_absl//absl/strings", "//conditions:default": ["//mediapipe/gpu:gl_context"],
"@com_google_absl//absl/strings:str_format", }),
],
) )
cc_library( cc_library(

View File

@ -45,9 +45,11 @@
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h"
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_service.h"
#else #else
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h" #include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h"
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_service.h"
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
@ -147,7 +149,7 @@ class ImageToTensorCalculator : public Node {
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#else #else
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); cc->UseService(kGpuService).Optional();
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#endif // MEDIAPIPE_DISABLE_GPU #endif // MEDIAPIPE_DISABLE_GPU

View File

@ -41,6 +41,10 @@
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/util/image_test_utils.h" #include "mediapipe/util/image_test_utils.h"
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
#include "mediapipe/gpu/gl_context.h"
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
namespace mediapipe { namespace mediapipe {
namespace { namespace {
@ -507,5 +511,79 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeAndUseInputImageDims) {
/*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt, /*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt,
/*keep_aspect=*/false, BorderMode::kZero, roi); /*keep_aspect=*/false, BorderMode::kZero, roi);
} }
TEST(ImageToTensorCalculatorTest, CanBeUsedWithoutGpuServiceSet) {
auto graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_image"
node {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.ImageToTensorCalculatorOptions.ext] {
output_tensor_float_range { min: 0.0f max: 1.0f }
}
}
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
MP_ASSERT_OK(graph.StartRun({}));
auto image_frame =
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
Image image = Image(std::move(image_frame));
Packet packet = MakePacket<Image>(std::move(image));
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
MP_ASSERT_OK(graph.WaitUntilIdle());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
TEST(ImageToTensorCalculatorTest,
FailsGracefullyWhenGpuServiceNeededButNotAvailable) {
auto graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_image"
node {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.ImageToTensorCalculatorOptions.ext] {
output_tensor_float_range { min: 0.0f max: 1.0f }
}
}
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK_AND_ASSIGN(auto context,
GlContext::Create(nullptr, /*create_thread=*/true));
Packet packet;
context->Run([&packet]() {
auto image_frame =
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
Image image = Image(std::move(image_frame));
// Ensure image is available on GPU to force ImageToTensorCalculator to
// run on GPU.
ASSERT_TRUE(image.ConvertToGpu());
packet = MakePacket<Image>(std::move(image));
});
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
EXPECT_THAT(graph.WaitUntilIdle(),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("GPU service not available")));
}
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -141,7 +141,7 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
} }
// Run inference. // Run inference.
{ {
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc); MEDIAPIPE_PROFILING(GPU_TASK_INVOKE_ADVANCED, cc);
return tflite_gpu_runner_->Invoke(); return tflite_gpu_runner_->Invoke();
} }
})); }));

View File

@ -0,0 +1,41 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The option proto for the TensorsReadbackCalculator.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorsReadbackCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorsReadbackCalculatorOptions ext = 514750372;
}
// Expected shapes of the input tensors.
// The calculator uses these shape to build the GPU programs during
// initialization, and check the actual tensor shapes against the expected
// shapes during runtime.
// Batch size of the tensor is set to be 1. `TensorShape` here can be C, WC,
// or HWC.
// For example {dims: 1 dims: 2} represents a tensor with batch_size = 1,
// width = 1, and num_channels = 2.
message TensorShape {
repeated int32 dims = 1 [packed = true];
}
// tensor_shape specifies the shape of each input tensors.
repeated TensorShape tensor_shape = 1;
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

View File

@ -14,6 +14,7 @@
# #
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//lib:selects.bzl", "selects")
licenses(["notice"]) licenses(["notice"])
@ -312,15 +313,19 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library( # TODO: Re-evaluate which of these libraries we can avoid making
# cc_library_with_tflite and can be changed back to cc_library.
cc_library_with_tflite(
name = "tflite_model_calculator", name = "tflite_model_calculator",
srcs = ["tflite_model_calculator.cc"], srcs = ["tflite_model_calculator.cc"],
tflite_deps = [
"@org_tensorflow//tensorflow/lite:framework_stable",
],
deps = [ deps = [
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -66,7 +66,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
} else { } else {
cc->OutputSidePackets() cc->OutputSidePackets()
.Index(0) .Index(0)
.Set<tflite_shims::ops::builtin::BuiltinOpResolver>(); .Set<tflite::ops::builtin::BuiltinOpResolver>();
} }
return absl::OkStatus(); return absl::OkStatus();
} }
@ -77,7 +77,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
const TfLiteCustomOpResolverCalculatorOptions& options = const TfLiteCustomOpResolverCalculatorOptions& options =
cc->Options<TfLiteCustomOpResolverCalculatorOptions>(); cc->Options<TfLiteCustomOpResolverCalculatorOptions>();
std::unique_ptr<tflite_shims::ops::builtin::BuiltinOpResolver> op_resolver; std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> op_resolver;
if (options.use_gpu()) { if (options.use_gpu()) {
op_resolver = absl::make_unique<mediapipe::OpResolver>(); op_resolver = absl::make_unique<mediapipe::OpResolver>();
} else { } else {

View File

@ -21,7 +21,7 @@
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/allocation.h" #include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/core/shims/cc/model.h" #include "tensorflow/lite/model.h"
namespace mediapipe { namespace mediapipe {
@ -82,7 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase {
} }
if (cc->InputSidePackets().HasTag("MODEL_FD")) { if (cc->InputSidePackets().HasTag("MODEL_FD")) {
#ifdef ABSL_HAVE_MMAP #if defined(ABSL_HAVE_MMAP) && !TFLITE_WITH_STABLE_ABI
model_packet = cc->InputSidePackets().Tag("MODEL_FD"); model_packet = cc->InputSidePackets().Tag("MODEL_FD");
const auto& model_fd = const auto& model_fd =
model_packet.Get<std::tuple<int, size_t, size_t>>(); model_packet.Get<std::tuple<int, size_t, size_t>>();

View File

@ -1270,6 +1270,50 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
mediapipe_proto_library(
name = "flat_color_image_calculator_proto",
srcs = ["flat_color_image_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/util:color_proto",
],
)
cc_library(
name = "flat_color_image_calculator",
srcs = ["flat_color_image_calculator.cc"],
deps = [
":flat_color_image_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/util:color_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "flat_color_image_calculator_test",
srcs = ["flat_color_image_calculator_test.cc"],
deps = [
":flat_color_image_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/util:color_cc_proto",
],
)
cc_library( cc_library(
name = "from_image_calculator", name = "from_image_calculator",
srcs = ["from_image_calculator.cc"], srcs = ["from_image_calculator.cc"],

View File

@ -0,0 +1,138 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/util/color.pb.h"
namespace mediapipe {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Node;
using ::mediapipe::api2::Output;
} // namespace
// A calculator for generating an image filled with a single color.
//
// Inputs:
// IMAGE (Image, optional)
// If provided, the output will have the same size
// COLOR (Color proto, optional)
// Color to paint the output with. Takes precedence over the equivalent
// calculator options.
//
// Outputs:
// IMAGE (Image)
// Image filled with the requested color.
//
// Example useage:
// node {
// calculator: "FlatColorImageCalculator"
// input_stream: "IMAGE:image"
// input_stream: "COLOR:color"
// output_stream: "IMAGE:blank_image"
// options {
// [mediapipe.FlatColorImageCalculatorOptions.ext] {
// color: {
// r: 255
// g: 255
// b: 255
// }
// }
// }
// }
class FlatColorImageCalculator : public Node {
public:
static constexpr Input<Image>::Optional kInImage{"IMAGE"};
static constexpr Input<Color>::Optional kInColor{"COLOR"};
static constexpr Output<Image> kOutImage{"IMAGE"};
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage);
static absl::Status UpdateContract(CalculatorContract* cc) {
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
RET_CHECK(kInImage(cc).IsConnected() ^
(options.has_output_height() || options.has_output_width()))
<< "Either set IMAGE input stream, or set through options";
RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color())
<< "Either set COLOR input stream, or set through options";
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
bool use_dimension_from_option_ = false;
bool use_color_from_option_ = false;
};
MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator);
absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) {
use_dimension_from_option_ = !kInImage(cc).IsConnected();
use_color_from_option_ = !kInColor(cc).IsConnected();
return absl::OkStatus();
}
absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
int output_height = -1;
int output_width = -1;
if (use_dimension_from_option_) {
output_height = options.output_height();
output_width = options.output_width();
} else if (!kInImage(cc).IsEmpty()) {
const Image& input_image = kInImage(cc).Get();
output_height = input_image.height();
output_width = input_image.width();
} else {
return absl::OkStatus();
}
Color color;
if (use_color_from_option_) {
color = options.color();
} else if (!kInColor(cc).IsEmpty()) {
color = kInColor(cc).Get();
} else {
return absl::OkStatus();
}
auto output_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
output_width, output_height);
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b()));
kOutImage(cc).Send(Image(output_frame));
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,32 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/util/color.proto";
message FlatColorImageCalculatorOptions {
extend CalculatorOptions {
optional FlatColorImageCalculatorOptions ext = 515548435;
}
// Output dimensions.
optional int32 output_width = 1;
optional int32 output_height = 2;
// The color to fill with in the output image.
optional Color color = 3;
}

View File

@ -0,0 +1,210 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/util/color.pb.h"
namespace mediapipe {
namespace {
using ::testing::HasSubstr;
constexpr char kImageTag[] = "IMAGE";
constexpr char kColorTag[] = "COLOR";
constexpr int kImageWidth = 256;
constexpr int kImageHeight = 256;
TEST(FlatColorImageCalculatorTest, SpecifyColorThroughOptions) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
color: {
r: 100,
g: 200,
b: 255,
}
}
}
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
ASSERT_EQ(outputs.size(), 3);
for (const auto& packet : outputs) {
const auto& image = packet.Get<Image>();
EXPECT_EQ(image.width(), kImageWidth);
EXPECT_EQ(image.height(), kImageHeight);
auto image_frame = image.GetImageFrameSharedPtr();
auto* pixel_data = image_frame->PixelData();
EXPECT_EQ(pixel_data[0], 100);
EXPECT_EQ(pixel_data[1], 200);
EXPECT_EQ(pixel_data[2], 255);
}
}
TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
output_width: 7,
output_height: 13,
}
}
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
ASSERT_EQ(outputs.size(), 3);
for (const auto& packet : outputs) {
const auto& image = packet.Get<Image>();
EXPECT_EQ(image.width(), 7);
EXPECT_EQ(image.height(), 13);
auto image_frame = image.GetImageFrameSharedPtr();
const uint8_t* pixel_data = image_frame->PixelData();
EXPECT_EQ(pixel_data[0], 0);
EXPECT_EQ(pixel_data[1], 5);
EXPECT_EQ(pixel_data[2], 0);
}
}
TEST(FlatColorImageCalculatorTest, FailureMissingDimension) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set IMAGE input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureMissingColor) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
output_stream: "IMAGE:out_image"
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set COLOR input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureDuplicateDimension) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
output_width: 7,
output_height: 13,
}
}
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set IMAGE input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
color: {
r: 100,
g: 200,
b: 255,
}
}
}
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set COLOR input stream"));
}
} // namespace
} // namespace mediapipe

View File

@ -1,5 +1,6 @@
distributionBase=GRADLE_USER_HOME distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip
networkTimeout=10000
zipStoreBase=GRADLE_USER_HOME zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists zipStorePath=wrapper/dists

View File

@ -1,7 +1,7 @@
#!/usr/bin/env sh #!/bin/sh
# #
# Copyright 2015 the original author or authors. # Copyright © 2015-2021 the original authors.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,67 +17,101 @@
# #
############################################################################## ##############################################################################
## #
## Gradle start up script for UN*X # Gradle start up script for POSIX generated by Gradle.
## #
# Important for running:
#
# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is
# noncompliant, but you have some other compliant shell such as ksh or
# bash, then to run this script, type that shell name before the whole
# command line, like:
#
# ksh Gradle
#
# Busybox and similar reduced shells will NOT work, because this script
# requires all of these POSIX shell features:
# * functions;
# * expansions «$var», «${var}», «${var:-default}», «${var+SET}»,
# «${var#prefix}», «${var%suffix}», and «$( cmd )»;
# * compound commands having a testable exit status, especially «case»;
# * various built-in commands including «command», «set», and «ulimit».
#
# Important for patching:
#
# (2) This script targets any POSIX shell, so it avoids extensions provided
# by Bash, Ksh, etc; in particular arrays are avoided.
#
# The "traditional" practice of packing multiple parameters into a
# space-separated string is a well documented source of bugs and security
# problems, so this is (mostly) avoided, by progressively accumulating
# options in "$@", and eventually passing that to Java.
#
# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS,
# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly;
# see the in-line comments for details.
#
# There are tweaks for specific operating systems such as AIX, CygWin,
# Darwin, MinGW, and NonStop.
#
# (3) This script is generated from the Groovy template
# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# within the Gradle project.
#
# You can find Gradle at https://github.com/gradle/gradle/.
#
############################################################################## ##############################################################################
# Attempt to set APP_HOME # Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle" # Resolve links: $0 may be a link
APP_BASE_NAME=`basename "$0"` app_path=$0
# Need this for daisy-chained symlinks.
while
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
[ -h "$app_path" ]
do
ls=$( ls -ld "$app_path" )
link=${ls#*' -> '}
case $link in #(
/*) app_path=$link ;; #(
*) app_path=$APP_HOME$link ;;
esac
done
# This is normally unused
# shellcheck disable=SC2034
APP_BASE_NAME=${0##*/}
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value. # Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum" MAX_FD=maximum
warn () { warn () {
echo "$*" echo "$*"
} } >&2
die () { die () {
echo echo
echo "$*" echo "$*"
echo echo
exit 1 exit 1
} } >&2
# OS specific support (must be 'true' or 'false'). # OS specific support (must be 'true' or 'false').
cygwin=false cygwin=false
msys=false msys=false
darwin=false darwin=false
nonstop=false nonstop=false
case "`uname`" in case "$( uname )" in #(
CYGWIN* ) CYGWIN* ) cygwin=true ;; #(
cygwin=true Darwin* ) darwin=true ;; #(
;; MSYS* | MINGW* ) msys=true ;; #(
Darwin* ) NONSTOP* ) nonstop=true ;;
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
@ -87,9 +121,9 @@ CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
if [ -n "$JAVA_HOME" ] ; then if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables # IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java" JAVACMD=$JAVA_HOME/jre/sh/java
else else
JAVACMD="$JAVA_HOME/bin/java" JAVACMD=$JAVA_HOME/bin/java
fi fi
if [ ! -x "$JAVACMD" ] ; then if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
@ -98,7 +132,7 @@ Please set the JAVA_HOME variable in your environment to match the
location of your Java installation." location of your Java installation."
fi fi
else else
JAVACMD="java" JAVACMD=java
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the Please set the JAVA_HOME variable in your environment to match the
@ -106,80 +140,105 @@ location of your Java installation."
fi fi
# Increase the maximum file descriptors if we can. # Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
MAX_FD_LIMIT=`ulimit -H -n` case $MAX_FD in #(
if [ $? -eq 0 ] ; then max*)
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
MAX_FD="$MAX_FD_LIMIT" # shellcheck disable=SC3045
fi MAX_FD=$( ulimit -H -n ) ||
ulimit -n $MAX_FD warn "Could not query maximum file descriptor limit"
if [ $? -ne 0 ] ; then esac
warn "Could not set maximum file descriptor limit: $MAX_FD" case $MAX_FD in #(
fi '' | soft) :;; #(
else *)
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
fi # shellcheck disable=SC3045
fi ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin or MSYS, switch paths to Windows format before running java
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=`expr $i + 1`
done
case $i in
0) set -- ;;
1) set -- "$args0" ;;
2) set -- "$args0" "$args1" ;;
3) set -- "$args0" "$args1" "$args2" ;;
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac esac
fi fi
# Escape application args # Collect all arguments for the java command, stacking in reverse order:
save () { # * args from the command line
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done # * the main class name
echo " " # * -classpath
} # * -D...appname settings
APP_ARGS=`save "$@"` # * --module-path (only if needed)
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
# Collect all arguments for the java command, following the shell quoting and substitution rules # For Cygwin or MSYS, switch paths to Windows format before running java
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" if "$cygwin" || "$msys" ; then
APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
JAVACMD=$( cygpath --unix "$JAVACMD" )
# Now convert the arguments - kludge to limit ourselves to /bin/sh
for arg do
if
case $arg in #(
-*) false ;; # don't mess with options #(
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
[ -e "$t" ] ;; #(
*) false ;;
esac
then
arg=$( cygpath --path --ignore --mixed "$arg" )
fi
# Roll the args list around exactly as many times as the number of
# args, so each arg winds up back in the position where it started, but
# possibly modified.
#
# NB: a `for` loop captures its iteration list before it begins, so
# changing the positional parameters here affects neither the number of
# iterations, nor the values presented in `arg`.
shift # remove old arg
set -- "$@" "$arg" # push replacement arg
done
fi
# Collect all arguments for the java command;
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
# shell script including quotes and variable substitutions, so put them in
# double quotes to make sure that they get re-expanded; and
# * put everything else in single quotes, so that it's not re-expanded.
set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \
-classpath "$CLASSPATH" \
org.gradle.wrapper.GradleWrapperMain \
"$@"
# Stop when "xargs" is not available.
if ! command -v xargs >/dev/null 2>&1
then
die "xargs is not available"
fi
# Use "xargs" to parse quoted args.
#
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
#
# In Bash we could simply go:
#
# readarray ARGS < <( xargs -n1 <<<"$var" ) &&
# set -- "${ARGS[@]}" "$@"
#
# but POSIX shell has neither arrays nor command substitution, so instead we
# post-process each arg (as a line of input to sed) to backslash-escape any
# character that might be a shell metacharacter, then use eval to reverse
# that process (while maintaining the separation between arguments), and wrap
# the whole thing up as a single "set" statement.
#
# This will of course break if any of these variables contains a newline or
# an unmatched quote.
#
eval "set -- $(
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
xargs -n1 |
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
tr '\n' ' '
)" '"$@"'
exec "$JAVACMD" "$@" exec "$JAVACMD" "$@"

View File

@ -14,7 +14,7 @@
@rem limitations under the License. @rem limitations under the License.
@rem @rem
@if "%DEBUG%" == "" @echo off @if "%DEBUG%"=="" @echo off
@rem ########################################################################## @rem ##########################################################################
@rem @rem
@rem Gradle startup script for Windows @rem Gradle startup script for Windows
@ -25,7 +25,8 @@
if "%OS%"=="Windows_NT" setlocal if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0 set DIRNAME=%~dp0
if "%DIRNAME%" == "" set DIRNAME=. if "%DIRNAME%"=="" set DIRNAME=.
@rem This is normally unused
set APP_BASE_NAME=%~n0 set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME% set APP_HOME=%DIRNAME%
@ -40,7 +41,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1 %JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto execute if %ERRORLEVEL% equ 0 goto execute
echo. echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
@ -75,13 +76,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
:end :end
@rem End local scope for the variables with windows NT shell @rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd if %ERRORLEVEL% equ 0 goto mainEnd
:fail :fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code! rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 set EXIT_CODE=%ERRORLEVEL%
exit /b 1 if %EXIT_CODE% equ 0 set EXIT_CODE=1
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
exit /b %EXIT_CODE%
:mainEnd :mainEnd
if "%OS%"=="Windows_NT" endlocal if "%OS%"=="Windows_NT" endlocal

View File

@ -138,7 +138,23 @@ void TestWithAspectRatio(const double aspect_ratio,
std::string result_image; std::string result_image;
MP_ASSERT_OK( MP_ASSERT_OK(
mediapipe::file::GetContents(result_string_path, &result_image)); mediapipe::file::GetContents(result_string_path, &result_image));
EXPECT_EQ(result_image, output_string); if (result_image != output_string) {
// There may be slight differences due to the way the JPEG was encoded or
// the OpenCV version used to generate the reference files. Compare
// pixel-by-pixel using the Peak Signal-to-Noise Ratio instead.
cv::Mat result_mat =
cv::imdecode(cv::Mat(1, result_image.size(), CV_8UC1,
const_cast<char*>(result_image.data())),
cv::IMREAD_UNCHANGED);
cv::Mat output_mat =
cv::imdecode(cv::Mat(1, output_string.size(), CV_8UC1,
const_cast<char*>(output_string.data())),
cv::IMREAD_UNCHANGED);
ASSERT_EQ(result_mat.rows, output_mat.rows);
ASSERT_EQ(result_mat.cols, output_mat.cols);
ASSERT_EQ(result_mat.type(), output_mat.type());
EXPECT_GT(cv::PSNR(result_mat, output_mat), 45.0);
}
} else { } else {
std::string output_string_path = mediapipe::file::JoinPath( std::string output_string_path = mediapipe::file::JoinPath(
absl::GetFlag(FLAGS_output_folder), absl::GetFlag(FLAGS_output_folder),

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.1 KiB

After

Width:  |  Height:  |  Size: 6.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.2 KiB

After

Width:  |  Height:  |  Size: 8.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.6 KiB

After

Width:  |  Height:  |  Size: 7.6 KiB

View File

@ -56,5 +56,6 @@ objc_library(
deps = [ deps = [
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary", "//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
"//mediapipe/graphs/edge_detection:mobile_calculators", "//mediapipe/graphs/edge_detection:mobile_calculators",
"//third_party/apple_frameworks:Metal",
], ],
) )

View File

@ -631,7 +631,13 @@ absl::Status CalculatorGraph::PrepareServices() {
for (const auto& [key, request] : node->Contract().ServiceRequests()) { for (const auto& [key, request] : node->Contract().ServiceRequests()) {
auto packet = service_manager_.GetServicePacket(request.Service()); auto packet = service_manager_.GetServicePacket(request.Service());
if (!packet.IsEmpty()) continue; if (!packet.IsEmpty()) continue;
auto packet_or = request.Service().CreateDefaultObject(); absl::StatusOr<Packet> packet_or;
if (allow_service_default_initialization_) {
packet_or = request.Service().CreateDefaultObject();
} else {
packet_or = absl::FailedPreconditionError(
"Service default initialization is disallowed.");
}
if (packet_or.ok()) { if (packet_or.ok()) {
MP_RETURN_IF_ERROR(service_manager_.SetServicePacket( MP_RETURN_IF_ERROR(service_manager_.SetServicePacket(
request.Service(), std::move(packet_or).value())); request.Service(), std::move(packet_or).value()));

View File

@ -405,6 +405,34 @@ class CalculatorGraph {
return service_manager_.GetServiceObject(service); return service_manager_.GetServiceObject(service);
} }
// Disallows/disables default initialization of MediaPipe graph services.
//
// IMPORTANT: MediaPipe graph serices, essentially a graph-level singletons,
// are designed in the way, so they may provide default initialization. For
// example, this allows to run OpenGL processing wihtin the graph without
// provinging a praticular OpenGL context as it can be provided by
// default-initializable `kGpuService`. (One caveat here, you may still need
// to initialize it manually to share graph context with external context.)
//
// Even if calculators require some service optionally
// (`calculator_contract->UseService(kSomeService).Optional()`), it will be
// still initialized if it allows default initialization.
//
// So far, in rare cases, this may be unwanted and strict control of what
// services are allowed in the graph can be achieved by calling this method,
// following `SetServiceObject` call for services which are allowed in the
// graph.
//
// Recommendation: do not use unless you have to (for example, default
// initialization has side effects)
//
// NOTE: must be called before `StartRun`/`Run`, where services are checked
// and can be default-initialized.
absl::Status DisallowServiceDefaultInitialization() {
allow_service_default_initialization_ = false;
return absl::OkStatus();
}
// Sets a service object, essentially a graph-level singleton, which can be // Sets a service object, essentially a graph-level singleton, which can be
// accessed by calculators and subgraphs without requiring an explicit // accessed by calculators and subgraphs without requiring an explicit
// connection. // connection.
@ -644,6 +672,9 @@ class CalculatorGraph {
// Object to manage graph services. // Object to manage graph services.
GraphServiceManager service_manager_; GraphServiceManager service_manager_;
// Indicates whether service default initialization is allowed.
bool allow_service_default_initialization_ = true;
// Vector of errors encountered while running graph. Always use RecordError() // Vector of errors encountered while running graph. Always use RecordError()
// to add an error to this vector. // to add an error to this vector.
std::vector<absl::Status> errors_ ABSL_GUARDED_BY(error_mutex_); std::vector<absl::Status> errors_ ABSL_GUARDED_BY(error_mutex_);

View File

@ -136,6 +136,8 @@ message GraphTrace {
GPU_TASK_INVOKE = 16; GPU_TASK_INVOKE = 16;
TPU_TASK_INVOKE = 17; TPU_TASK_INVOKE = 17;
CPU_TASK_INVOKE = 18; CPU_TASK_INVOKE = 18;
GPU_TASK_INVOKE_ADVANCED = 19;
TPU_TASK_INVOKE_ASYNC = 20;
} }
// The timing for one packet set being processed at one caclulator node. // The timing for one packet set being processed at one caclulator node.

View File

@ -315,11 +315,11 @@ cc_library(
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
deps = [ deps = [
"@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:flag",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:statusor",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/deps:file_path", "//mediapipe/framework/port:statusor",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
@ -328,8 +328,8 @@ cc_library(
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
], ],
"//mediapipe:android": [ "//mediapipe:android": [
"//mediapipe/java/com/google/mediapipe/framework/jni:jni_util",
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
"//mediapipe/java/com/google/mediapipe/framework/jni:jni_util",
], ],
"//mediapipe:apple": [ "//mediapipe:apple": [
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",

View File

@ -112,6 +112,10 @@ struct TraceEvent {
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE; static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE; static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE; static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE;
static constexpr EventType GPU_TASK_INVOKE_ADVANCED =
GraphTrace::GPU_TASK_INVOKE_ADVANCED;
static constexpr EventType TPU_TASK_INVOKE_ASYNC =
GraphTrace::TPU_TASK_INVOKE_ASYNC;
}; };
// Packet trace log buffer. // Packet trace log buffer.

View File

@ -57,7 +57,6 @@ struct hash<mediapipe::TaskId> {
namespace mediapipe { namespace mediapipe {
namespace { namespace {
void BasicTraceEventTypes(TraceEventRegistry* result) { void BasicTraceEventTypes(TraceEventRegistry* result) {
// The initializer arguments below are: event_type, description, // The initializer arguments below are: event_type, description,
// is_packet_event, is_stream_event, id_event_data. // is_packet_event, is_stream_event, id_event_data.
@ -84,6 +83,15 @@ void BasicTraceEventTypes(TraceEventRegistry* result) {
"A time measured by GPU clock and by CPU clock.", true, false}, "A time measured by GPU clock and by CPU clock.", true, false},
{TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.", {TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.",
true, true, false}, true, true, false},
{TraceEvent::GPU_TASK_INVOKE, "CPU timing for initiating a GPU task."},
{TraceEvent::TPU_TASK_INVOKE, "CPU timing for initiating a TPU task."},
{TraceEvent::CPU_TASK_INVOKE, "CPU timing for initiating a CPU task."},
{TraceEvent::GPU_TASK_INVOKE_ADVANCED,
"CPU timing for initiating a GPU task bypassing the TFLite "
"interpreter."},
{TraceEvent::TPU_TASK_INVOKE_ASYNC,
"CPU timing for async initiation of a TPU task."},
}; };
for (const TraceEventType& t : basic_types) { for (const TraceEventType& t : basic_types) {
(*result)[t.event_type()] = t; (*result)[t.event_type()] = t;

View File

@ -77,7 +77,6 @@ mediapipe_proto_library(
name = "calculator_graph_template_proto", name = "calculator_graph_template_proto",
srcs = ["calculator_graph_template.proto"], srcs = ["calculator_graph_template.proto"],
def_options_lib = False, def_options_lib = False,
def_py_proto = False,
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",

View File

@ -204,7 +204,7 @@ def rewrite_mediapipe_proto(name, rewrite_proto, source_proto, **kwargs):
'import public "' + join_path + '";', 'import public "' + join_path + '";',
) )
rewrite_ref = SubsituteCommand( rewrite_ref = SubsituteCommand(
r"mediapipe\\.(" + rewrite_message_regex + ")", r"mediapipe\.(" + rewrite_message_regex + ")",
r"mediapipe.\\1", r"mediapipe.\\1",
) )
rewrite_objc = SubsituteCommand( rewrite_objc = SubsituteCommand(
@ -284,7 +284,7 @@ def mediapipe_proto_library(
def_jspb_proto: define the jspb_proto_library target def_jspb_proto: define the jspb_proto_library target
def_go_proto: define the go_proto_library target def_go_proto: define the go_proto_library target
def_options_lib: define the mediapipe_options_library target def_options_lib: define the mediapipe_options_library target
def_rewrite: define a sibbling mediapipe_proto_library with package "mediapipe" def_rewrite: define a sibling mediapipe_proto_library with package "mediapipe"
""" """
mediapipe_proto_library_impl( mediapipe_proto_library_impl(

View File

@ -183,12 +183,13 @@ absl::Status FindCorrespondingStreams(
// name, calculator, input_stream, output_stream, input_side_packet, // name, calculator, input_stream, output_stream, input_side_packet,
// output_side_packet, options. // output_side_packet, options.
// All other fields are only applicable to calculators. // All other fields are only applicable to calculators.
// TODO: Check whether executor is not set in the subgraph node
// after this issues is properly solved.
absl::Status ValidateSubgraphFields( absl::Status ValidateSubgraphFields(
const CalculatorGraphConfig::Node& subgraph_node) { const CalculatorGraphConfig::Node& subgraph_node) {
if (subgraph_node.source_layer() || subgraph_node.buffer_size_hint() || if (subgraph_node.source_layer() || subgraph_node.buffer_size_hint() ||
subgraph_node.has_output_stream_handler() || subgraph_node.has_output_stream_handler() ||
subgraph_node.input_stream_info_size() != 0 || subgraph_node.input_stream_info_size() != 0) {
!subgraph_node.executor().empty()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Subgraph \"" << subgraph_node.name() << "Subgraph \"" << subgraph_node.name()
<< "\" has a field that is only applicable to calculators."; << "\" has a field that is only applicable to calculators.";

View File

@ -272,14 +272,6 @@ selects.config_setting_group(
], ],
) )
selects.config_setting_group(
name = "platform_ios_without_gpu",
match_all = [
":disable_gpu",
"//mediapipe:ios",
],
)
selects.config_setting_group( selects.config_setting_group(
name = "platform_macos_with_gpu", name = "platform_macos_with_gpu",
match_all = [ match_all = [
@ -296,32 +288,33 @@ cc_library(
deps = [ deps = [
":gpu_buffer_format", ":gpu_buffer_format",
":gpu_buffer_storage", ":gpu_buffer_storage",
":gpu_buffer_storage_image_frame",
"@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
":gpu_buffer_storage_image_frame",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
":gl_texture_view",
":gl_texture_buffer", ":gl_texture_buffer",
":gl_texture_view",
], ],
":platform_ios_with_gpu": [ ":platform_ios_with_gpu": [
":gl_texture_view", ":gl_texture_view",
":gpu_buffer_storage_cv_pixel_buffer", ":gpu_buffer_storage_cv_pixel_buffer",
"//mediapipe/objc:util",
"//mediapipe/objc:CFHolder", "//mediapipe/objc:CFHolder",
], ],
":platform_macos_with_gpu": [ ":platform_macos_with_gpu": [
"//mediapipe/objc:CFHolder",
":gl_texture_view",
":gl_texture_buffer", ":gl_texture_buffer",
], ":gl_texture_view",
":platform_ios_without_gpu": [ "//mediapipe/objc:CFHolder",
"//mediapipe/objc:util",
], ],
":disable_gpu": [], ":disable_gpu": [],
}) + select({
"//conditions:default": [],
"//mediapipe:ios": [
"//mediapipe/objc:util",
],
}), }),
) )
@ -331,9 +324,9 @@ cc_library(
hdrs = ["gpu_buffer_format.h"], hdrs = ["gpu_buffer_format.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework/formats:image_format_cc_proto",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"//mediapipe/framework/deps:no_destructor", "//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
@ -474,6 +467,7 @@ cc_library(
"//mediapipe/framework/formats:frame_buffer", "//mediapipe/framework/formats:frame_buffer",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:yuv_image", "//mediapipe/framework/formats:yuv_image",
"//mediapipe/util/frame_buffer:frame_buffer_util",
"//third_party/libyuv", "//third_party/libyuv",
"@com_google_absl//absl/log", "@com_google_absl//absl/log",
"@com_google_absl//absl/log:check", "@com_google_absl//absl/log:check",
@ -619,22 +613,22 @@ cc_library(
}), }),
visibility = ["//visibility:private"], visibility = ["//visibility:private"],
deps = [ deps = [
":gl_context_options_cc_proto",
":graph_support",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:executor",
"//mediapipe/framework:calculator_node",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/deps:no_destructor",
":gl_base", ":gl_base",
":gl_context", ":gl_context",
":gl_context_options_cc_proto",
":gpu_buffer_multi_pool", ":gpu_buffer_multi_pool",
":gpu_shared_data_header", ":gpu_shared_data_header",
":graph_support",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_node",
"//mediapipe/framework:executor",
"//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/port:ret_check",
] + select({ ] + select({
"//conditions:default": [], "//conditions:default": [],
"//mediapipe:apple": [ "//mediapipe:apple": [
":metal_shared_resources",
":cv_texture_cache_manager", ":cv_texture_cache_manager",
":metal_shared_resources",
], ],
}), }),
) )
@ -703,13 +697,13 @@ cc_library(
":gpu_buffer", ":gpu_buffer",
":gpu_shared_data_header", ":gpu_shared_data_header",
":multi_pool", ":multi_pool",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_node", "//mediapipe/framework:calculator_node",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/util:resource_cache", "//mediapipe/util:resource_cache",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
":gl_texture_buffer", ":gl_texture_buffer",
@ -725,9 +719,9 @@ cc_library(
"//mediapipe:macos": [ "//mediapipe:macos": [
":cv_pixel_buffer_pool_wrapper", ":cv_pixel_buffer_pool_wrapper",
":cv_texture_cache_manager", ":cv_texture_cache_manager",
":pixel_buffer_pool_util",
":gl_texture_buffer", ":gl_texture_buffer",
":gl_texture_buffer_pool", ":gl_texture_buffer_pool",
":pixel_buffer_pool_util",
], ],
}), }),
) )
@ -795,31 +789,31 @@ cc_library(
":gpu_buffer", ":gpu_buffer",
":gpu_buffer_format", ":gpu_buffer_format",
":gpu_buffer_multi_pool", ":gpu_buffer_multi_pool",
":gpu_shared_data_internal",
":gpu_service", ":gpu_service",
":gpu_shared_data_internal",
":graph_support", ":graph_support",
":image_frame_view", ":image_frame_view",
":shader_util", ":shader_util",
"//mediapipe/framework:calculator_framework",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_node",
"//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_node",
"//mediapipe/framework:demangle", "//mediapipe/framework:demangle",
"//mediapipe/framework:legacy_calculator_support", "//mediapipe/framework:legacy_calculator_support",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework:packet_set", "//mediapipe/framework:packet_set",
"//mediapipe/framework:packet_type", "//mediapipe/framework:packet_type",
"//mediapipe/framework:timestamp", "//mediapipe/framework:timestamp",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:map_util",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
], ],
@ -918,8 +912,6 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":gl_calculator_helper", ":gl_calculator_helper",
":gpu_buffer_storage_image_frame",
"//mediapipe/framework/api2:node",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -941,7 +933,7 @@ mediapipe_proto_library(
], ],
) )
proto_library( mediapipe_proto_library(
name = "gl_scaler_calculator_proto", name = "gl_scaler_calculator_proto",
srcs = ["gl_scaler_calculator.proto"], srcs = ["gl_scaler_calculator.proto"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
@ -951,17 +943,6 @@ proto_library(
], ],
) )
mediapipe_cc_proto_library(
name = "gl_scaler_calculator_cc_proto",
srcs = ["gl_scaler_calculator.proto"],
cc_deps = [
":scale_mode_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//visibility:public"],
deps = [":gl_scaler_calculator_proto"],
)
cc_library( cc_library(
name = "gl_scaler_calculator", name = "gl_scaler_calculator",
srcs = ["gl_scaler_calculator.cc"], srcs = ["gl_scaler_calculator.cc"],

View File

@ -12,63 +12,73 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#ifdef __APPLE__
#include "mediapipe/objc/util.h"
#endif
namespace mediapipe { namespace mediapipe {
namespace api2 {
class ImageFrameToGpuBufferCalculator // Convert ImageFrame to GpuBuffer.
: public RegisteredNode<ImageFrameToGpuBufferCalculator> { class ImageFrameToGpuBufferCalculator : public CalculatorBase {
public: public:
static constexpr Input<ImageFrame> kIn{""}; ImageFrameToGpuBufferCalculator() {}
static constexpr Output<GpuBuffer> kOut{""};
MEDIAPIPE_NODE_INTERFACE(ImageFrameToGpuBufferCalculator, kIn, kOut); static absl::Status GetContract(CalculatorContract* cc);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override; absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override;
private: private:
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
GlCalculatorHelper helper_; GlCalculatorHelper helper_;
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
}; };
REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator);
// static // static
absl::Status ImageFrameToGpuBufferCalculator::UpdateContract( absl::Status ImageFrameToGpuBufferCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
cc->Inputs().Index(0).Set<ImageFrame>();
cc->Outputs().Index(0).Set<GpuBuffer>();
// Note: we call this method even on platforms where we don't use the helper, // Note: we call this method even on platforms where we don't use the helper,
// to ensure the calculator's contract is the same. In particular, the helper // to ensure the calculator's contract is the same. In particular, the helper
// enables support for the legacy side packet, which several graphs still use. // enables support for the legacy side packet, which several graphs still use.
return GlCalculatorHelper::UpdateContract(cc); MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
return absl::OkStatus();
} }
absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) { absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) {
// Inform the framework that we always output at the same timestamp
// as we receive a packet at.
cc->SetOffset(TimestampDiff(0));
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
MP_RETURN_IF_ERROR(helper_.Open(cc)); MP_RETURN_IF_ERROR(helper_.Open(cc));
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) { absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) {
auto image_frame = std::const_pointer_cast<ImageFrame>( #if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
mediapipe::SharedPtrWithPacket<ImageFrame>(kIn(cc).packet())); CFHolder<CVPixelBufferRef> buffer;
auto gpu_buffer = api2::MakePacket<GpuBuffer>( MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket(
std::make_shared<mediapipe::GpuBufferStorageImageFrame>( cc->Inputs().Index(0).Value(), &buffer));
std::move(image_frame))) cc->Outputs().Index(0).Add(new GpuBuffer(buffer), cc->InputTimestamp());
.At(cc->InputTimestamp()); #else
// This calculator's behavior has been to do the texture upload eagerly, and const auto& input = cc->Inputs().Index(0).Get<ImageFrame>();
// some graphs may rely on running this on a separate GL context to avoid helper_.RunInGlContext([this, &input, &cc]() {
// blocking another context with the read operation. So let's request GPU auto src = helper_.CreateSourceTexture(input);
// access here to ensure that the behavior stays the same. auto output = src.GetFrame<GpuBuffer>();
// TODO: have a better way to do this, or defer until later. glFlush();
helper_.RunInGlContext( cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
[&gpu_buffer] { auto view = gpu_buffer->GetReadView<GlTextureView>(0); }); src.Release();
kOut(cc).Send(std::move(gpu_buffer)); });
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
return absl::OkStatus(); return absl::OkStatus();
} }
} // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -66,7 +66,8 @@ public class GraphTextureFrame implements TextureFrame {
if (nativeBufferHandle == 0) { if (nativeBufferHandle == 0) {
return 0; return 0;
} }
if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) { long contextHandle = nativeGetCurrentExternalContextHandle();
if (contextHandle != 0 && activeConsumerContextHandleSet.add(contextHandle)) {
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using // Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
// PacketGetter.getTextureFrameDeferredSync(). // PacketGetter.getTextureFrameDeferredSync().
if (deferredSync) { if (deferredSync) {
@ -116,7 +117,14 @@ public class GraphTextureFrame implements TextureFrame {
GlSyncToken consumerToken = null; GlSyncToken consumerToken = null;
// Note that this remove should be moved to the other overload of release when b/68808951 is // Note that this remove should be moved to the other overload of release when b/68808951 is
// addressed. // addressed.
if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) { final long contextHandle = nativeGetCurrentExternalContextHandle();
if (contextHandle == 0 && !activeConsumerContextHandleSet.isEmpty()) {
logger.atWarning().log(
"GraphTextureFrame is being released on non GL thread while having active consumers,"
+ " which may lead to external / internal GL contexts synchronization issues.");
}
if (contextHandle != 0 && activeConsumerContextHandleSet.remove(contextHandle)) {
consumerToken = consumerToken =
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle)); new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
} }
@ -169,7 +177,9 @@ public class GraphTextureFrame implements TextureFrame {
private native void nativeReleaseBuffer(long nativeHandle); private native void nativeReleaseBuffer(long nativeHandle);
private native int nativeGetTextureName(long nativeHandle); private native int nativeGetTextureName(long nativeHandle);
private native int nativeGetWidth(long nativeHandle); private native int nativeGetWidth(long nativeHandle);
private native int nativeGetHeight(long nativeHandle); private native int nativeGetHeight(long nativeHandle);
private native void nativeGpuWait(long nativeHandle); private native void nativeGpuWait(long nativeHandle);

View File

@ -30,11 +30,11 @@ cc_library(
"compat_jni.cc", "compat_jni.cc",
"graph.cc", "graph.cc",
"graph_jni.cc", "graph_jni.cc",
"graph_profiler_jni.cc",
"graph_service_jni.cc", "graph_service_jni.cc",
"packet_context_jni.cc", "packet_context_jni.cc",
"packet_creator_jni.cc", "packet_creator_jni.cc",
"packet_getter_jni.cc", "packet_getter_jni.cc",
"graph_profiler_jni.cc",
] + select({ ] + select({
"//conditions:default": [], "//conditions:default": [],
"//mediapipe:android": [ "//mediapipe:android": [
@ -54,11 +54,11 @@ cc_library(
"compat_jni.h", "compat_jni.h",
"graph.h", "graph.h",
"graph_jni.h", "graph_jni.h",
"graph_profiler_jni.h",
"graph_service_jni.h", "graph_service_jni.h",
"packet_context_jni.h", "packet_context_jni.h",
"packet_creator_jni.h", "packet_creator_jni.h",
"packet_getter_jni.h", "packet_getter_jni.h",
"graph_profiler_jni.h",
] + select({ ] + select({
"//conditions:default": [], "//conditions:default": [],
"//mediapipe:android": [ "//mediapipe:android": [
@ -84,40 +84,40 @@ cc_library(
deps = [ deps = [
":class_registry", ":class_registry",
":jni_util", ":jni_util",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework:calculator_framework",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
"@eigen_archive//:eigen3", "@eigen_archive//:eigen3",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework:camera_intrinsics", "//mediapipe/framework:camera_intrinsics",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/formats:video_stream_header", "//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/tool:name_util",
"//mediapipe/framework/tool:executor_util",
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:threadpool",
"//mediapipe/framework/port:singleton", "//mediapipe/framework/port:singleton",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:threadpool",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/tool:executor_util",
"//mediapipe/framework/tool:name_util",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
"//mediapipe/framework/port:file_helpers", "//mediapipe/framework/port:file_helpers",
], ],
"//mediapipe:android": [ "//mediapipe:android": [
"//mediapipe/util/android/file/base",
"//mediapipe/util/android:asset_manager_util", "//mediapipe/util/android:asset_manager_util",
"//mediapipe/util/android/file/base",
], ],
}) + select({ }) + select({
"//conditions:default": [ "//conditions:default": [
"//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:gl_surface_sink_calculator", "//mediapipe/gpu:gl_surface_sink_calculator",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_shared_data_internal", "//mediapipe/gpu:gpu_shared_data_internal",
@ -153,9 +153,9 @@ cc_library(
srcs = ["class_registry.cc"], srcs = ["class_registry.cc"],
hdrs = ["class_registry.h"], hdrs = ["class_registry.h"],
deps = [ deps = [
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/container:node_hash_map",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
], ],
@ -172,9 +172,9 @@ cc_library(
":class_registry", ":class_registry",
":loose_headers", ":loose_headers",
":mediapipe_framework_jni", ":mediapipe_framework_jni",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/container:node_hash_map",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [

View File

@ -357,6 +357,22 @@ def mediapipe_java_proto_srcs(name = ""):
target = "//mediapipe/framework/formats:rect_java_proto_lite", target = "//mediapipe/framework/formats:rect_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/RectProto.java", src_out = "com/google/mediapipe/formats/proto/RectProto.java",
)) ))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:color_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/ColorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:label_map_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/LabelMapProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:render_data_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/RenderDataProto.java",
))
return proto_src_list return proto_src_list
def mediapipe_logging_java_proto_srcs(name = ""): def mediapipe_logging_java_proto_srcs(name = ""):

View File

@ -0,0 +1,24 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/model_maker/python/vision/face_stylizer:__subpackages__"])
filegroup(
name = "models",
srcs = glob([
"**",
]),
)

View File

@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools import functools
from typing import Callable, Optional, Tuple, TypeVar from typing import Any, Callable, Optional, Tuple, TypeVar
# Dependency imports # Dependency imports
import tensorflow as tf import tensorflow as tf
@ -66,12 +66,14 @@ class Dataset(object):
""" """
return self._size return self._size
def gen_tf_dataset(self, def gen_tf_dataset(
batch_size: int = 1, self,
is_training: bool = False, batch_size: int = 1,
shuffle: bool = False, is_training: bool = False,
preprocess: Optional[Callable[..., bool]] = None, shuffle: bool = False,
drop_remainder: bool = False) -> tf.data.Dataset: preprocess: Optional[Callable[..., Any]] = None,
drop_remainder: bool = False,
) -> tf.data.Dataset:
"""Generates a batched tf.data.Dataset for training/evaluation. """Generates a batched tf.data.Dataset for training/evaluation.
Args: Args:

View File

@ -48,11 +48,13 @@ class Classifier(custom_model.CustomModel):
self._hparams: hp.BaseHParams = None self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None self._history: tf.keras.callbacks.History = None
def _train_model(self, def _train_model(
train_data: classification_ds.ClassificationDataset, self,
validation_data: classification_ds.ClassificationDataset, train_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None, validation_data: classification_ds.ClassificationDataset,
checkpoint_path: Optional[str] = None): preprocessor: Optional[Callable[..., Any]] = None,
checkpoint_path: Optional[str] = None,
):
"""Trains the classifier model. """Trains the classifier model.
Compiles and fits the tf.keras `_model` and records the `_history`. Compiles and fits the tf.keras `_model` and records the `_history`.

View File

@ -115,9 +115,11 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
def convert_to_tflite( def convert_to_tflite(
model: tf.keras.Model, model: tf.keras.Model,
quantization_config: Optional[quantization.QuantizationConfig] = None, quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet, supported_ops: Tuple[tf.lite.OpsSet, ...] = (
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,), tf.lite.OpsSet.TFLITE_BUILTINS,
preprocess: Optional[Callable[..., bool]] = None) -> bytearray: ),
preprocess: Optional[Callable[..., Any]] = None,
) -> bytearray:
"""Converts the input Keras model to TFLite format. """Converts the input Keras model to TFLite format.
Args: Args:

View File

@ -0,0 +1,48 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(default_visibility = ["//mediapipe:__subpackages__"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/vision/core:image_utils",
],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
data = [
":testdata",
],
deps = [
":dataset",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -0,0 +1,14 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Model Maker Python Public API For Face Stylization."""

View File

@ -0,0 +1,98 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Face stylizer dataset library."""
import logging
import os
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils
# TODO: Change to a unlabeled dataset if it makes sense.
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for face stylizer fine tuning."""
@classmethod
def from_folder(
cls, dirname: str
) -> classification_dataset.ClassificationDataset:
"""Loads images from the given directory.
The style image dataset directory is expected to contain one subdirectory
whose name represents the label of the style. There can be one or multiple
images of the same style in that subdirectory. Supported input image formats
include 'jpg', 'jpeg', 'png'.
Args:
dirname: Name of the directory containing the image files.
Returns:
Dataset containing images and labels and other related info.
Raises:
ValueError: if the input data directory is empty.
"""
data_root = os.path.abspath(dirname)
# Assumes the image data of the same label are in the same subdirectory,
# gets image path and label names.
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
all_image_size = len(all_image_paths)
if all_image_size == 0:
raise ValueError('Invalid input data directory')
if not any(
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
):
raise ValueError('No images found under given directory')
label_names = sorted(
name
for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name))
)
all_label_size = len(label_names)
index_by_label = dict(
(name, index) for index, name in enumerate(label_names)
)
# Get the style label from the subdirectory name.
all_image_labels = [
index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
)
# Load label
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(all_image_labels, tf.int64)
)
# Create a dataset of (image, label) pairs
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
logging.info(
'Load images dataset with size: %d, num_label: %d, labels: %s.',
all_image_size,
all_label_size,
', '.join(label_names),
)
return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names
)

View File

@ -0,0 +1,48 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from mediapipe.model_maker.python.vision.face_stylizer import dataset
from mediapipe.tasks.python.test import test_utils
class DatasetTest(tf.test.TestCase):
def setUp(self):
super().setUp()
# TODO: Replace the stylize image dataset with licensed images.
self._test_data_dirname = 'testdata'
def test_from_folder(self):
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
data = dataset.Dataset.from_folder(dirname=input_data_dir)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
self.assertLen(data, 2)
def test_from_folder_raise_value_error_for_invalid_path(self):
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
dataset.Dataset.from_folder(dirname='invalid')
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
input_data_dir = test_utils.get_test_data_path('face_stylizer')
with self.assertRaisesRegex(
ValueError, 'No images found under given directory'
):
dataset.Dataset.from_folder(dirname=input_data_dir)
if __name__ == '__main__':
tf.test.main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 347 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 336 KiB

View File

@ -15,6 +15,7 @@
import io import io
import os import os
import tempfile import tempfile
import unittest
from unittest import mock as unittest_mock from unittest import mock as unittest_mock
import zipfile import zipfile
@ -31,6 +32,7 @@ _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdat
tf.keras.backend.experimental.enable_tf_random_generator() tf.keras.backend.experimental.enable_tf_random_generator()
@unittest.skip('b/273818271')
class GestureRecognizerTest(tf.test.TestCase): class GestureRecognizerTest(tf.test.TestCase):
def _load_data(self): def _load_data(self):
@ -72,8 +74,10 @@ class GestureRecognizerTest(tf.test.TestCase):
self._test_accuracy(model) self._test_accuracy(model)
@unittest.skip('b/273818271')
@unittest_mock.patch.object( @unittest_mock.patch.object(
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense) tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense
)
def test_gesture_recognizer_model_layer_widths(self, mock_dense): def test_gesture_recognizer_model_layer_widths(self, mock_dense):
layer_widths = [64, 32] layer_widths = [64, 32]
mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths) mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
@ -143,12 +147,14 @@ class GestureRecognizerTest(tf.test.TestCase):
hyperparameters, hyperparameters,
'HParams', 'HParams',
autospec=True, autospec=True,
return_value=gesture_recognizer.HParams(epochs=1)) return_value=gesture_recognizer.HParams(epochs=1),
)
@unittest_mock.patch.object( @unittest_mock.patch.object(
model_options, model_options,
'GestureRecognizerModelOptions', 'GestureRecognizerModelOptions',
autospec=True, autospec=True,
return_value=gesture_recognizer.ModelOptions()) return_value=gesture_recognizer.ModelOptions(),
)
def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options( def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options(
self, mock_hparams, mock_model_options): self, mock_hparams, mock_model_options):
options = gesture_recognizer.GestureRecognizerOptions() options = gesture_recognizer.GestureRecognizerOptions()

View File

@ -28,7 +28,7 @@ class ModelSpec(object):
uri: str, uri: str,
input_image_shape: Optional[List[int]] = None, input_image_shape: Optional[List[int]] = None,
name: str = ''): name: str = ''):
"""Initializes a new instance of the `ImageModelSpec` class. """Initializes a new instance of the image classifier `ModelSpec` class.
Args: Args:
uri: str, URI to the pretrained model. uri: str, URI to the pretrained model.

View File

@ -5,4 +5,4 @@ opencv-python
tensorflow>=2.10 tensorflow>=2.10
tensorflow-datasets tensorflow-datasets
tensorflow-hub tensorflow-hub
tf-models-official>=2.10.1 tf-models-official>=2.11.4

View File

@ -37,6 +37,7 @@ constexpr char kDetectionTag[] = "DETECTION";
constexpr char kDetectionsTag[] = "DETECTIONS"; constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kLabelsTag[] = "LABELS"; constexpr char kLabelsTag[] = "LABELS";
constexpr char kLabelsCsvTag[] = "LABELS_CSV"; constexpr char kLabelsCsvTag[] = "LABELS_CSV";
constexpr char kLabelMapTag[] = "LABEL_MAP";
using mediapipe::RE2; using mediapipe::RE2;
using Detections = std::vector<Detection>; using Detections = std::vector<Detection>;
@ -151,6 +152,11 @@ absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) {
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) { if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>(); cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
} }
if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
cc->InputSidePackets()
.Tag(kLabelMapTag)
.Set<std::unique_ptr<std::map<int, std::string>>>();
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -158,7 +164,8 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<FilterDetectionCalculatorOptions>(); options_ = cc->Options<FilterDetectionCalculatorOptions>();
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) || limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
cc->InputSidePackets().HasTag(kLabelsCsvTag); cc->InputSidePackets().HasTag(kLabelsCsvTag) ||
cc->InputSidePackets().HasTag(kLabelMapTag);
if (limit_labels_) { if (limit_labels_) {
Strings allowlist_labels; Strings allowlist_labels;
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) { if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
@ -168,8 +175,16 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
for (auto& e : allowlist_labels) { for (auto& e : allowlist_labels) {
absl::StripAsciiWhitespace(&e); absl::StripAsciiWhitespace(&e);
} }
} else { } else if (cc->InputSidePackets().HasTag(kLabelsTag)) {
allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>(); allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
} else if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
auto label_map = cc->InputSidePackets()
.Tag(kLabelMapTag)
.Get<std::unique_ptr<std::map<int, std::string>>>()
.get();
for (const auto& [_, v] : *label_map) {
allowlist_labels.push_back(v);
}
} }
allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end()); allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
} }

View File

@ -67,5 +67,68 @@ TEST(FilterDetectionCalculatorTest, DetectionFilterTest) {
)); ));
} }
TEST(FilterDetectionCalculatorTest, DetectionFilterLabelMapTest) {
auto runner = std::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "FilterDetectionCalculator"
input_stream: "DETECTION:input"
input_side_packet: "LABEL_MAP:input_map"
output_stream: "DETECTION:output"
options {
[mediapipe.FilterDetectionCalculatorOptions.ext]: { min_score: 0.6 }
}
)pb"));
runner->MutableInputs()->Tag("DETECTION").packets = {
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
label: "a"
label: "b"
label: "c"
label: "d"
score: 1
score: 0.8
score: 0.3
score: 0.9
)pb"))
.At(Timestamp(20)),
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
label: "a"
label: "b"
label: "c"
label: "e"
score: 0.6
score: 0.4
score: 0.2
score: 0.7
)pb"))
.At(Timestamp(40)),
};
auto label_map = std::make_unique<std::map<int, std::string>>();
(*label_map)[0] = "a";
(*label_map)[1] = "b";
(*label_map)[2] = "c";
runner->MutableSidePackets()->Tag("LABEL_MAP") =
AdoptAsUniquePtr(label_map.release());
// Run graph.
MP_ASSERT_OK(runner->Run());
// Check output.
EXPECT_THAT(
runner->Outputs().Tag("DETECTION").packets,
ElementsAre(PacketContainsTimestampAndPayload<Detection>(
Eq(Timestamp(20)),
EqualsProto(R"pb(
label: "a" label: "b" score: 1 score: 0.8
)pb")), // Packet 1 at timestamp 20.
PacketContainsTimestampAndPayload<Detection>(
Eq(Timestamp(40)),
EqualsProto(R"pb(
label: "a" score: 0.6
)pb")) // Packet 2 at timestamp 40.
));
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -57,6 +57,7 @@ pybind_extension(
"//mediapipe/framework/formats:landmark_registration", "//mediapipe/framework/formats:landmark_registration",
"//mediapipe/framework/formats:rect_registration", "//mediapipe/framework/formats:rect_registration",
"//mediapipe/modules/objectron/calculators:annotation_registration", "//mediapipe/modules/objectron/calculators:annotation_registration",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_registration",
], ],
) )
@ -95,6 +96,8 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph", "//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph", "//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph", "//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph",
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
] + select({ ] + select({
# TODO: Build text_classifier_graph and text_embedder_graph on Windows. # TODO: Build text_classifier_graph and text_embedder_graph on Windows.
"//mediapipe:windows": [], "//mediapipe:windows": [],

View File

@ -30,7 +30,7 @@ constexpr absl::string_view kMediaPipeTasksPayload = "MediaPipeTasksStatus";
// //
// At runtime, such codes are meant to be attached (where applicable) to a // At runtime, such codes are meant to be attached (where applicable) to a
// `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and // `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and
// stringifed error code as value (aka payload). This logic is encapsulated in // stringified error code as value (aka payload). This logic is encapsulated in
// the `CreateStatusWithPayload` helper below for convenience. // the `CreateStatusWithPayload` helper below for convenience.
// //
// The returned status includes: // The returned status includes:

View File

@ -51,12 +51,11 @@ ModelAssetBundleResources::Create(
auto model_bundle_resources = absl::WrapUnique( auto model_bundle_resources = absl::WrapUnique(
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file))); new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
model_bundle_resources->ExtractModelFilesFromExternalFileProto()); model_bundle_resources->ExtractFilesFromExternalFileProto());
return model_bundle_resources; return model_bundle_resources;
} }
absl::Status absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() {
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
if (model_asset_bundle_file_->has_file_name()) { if (model_asset_bundle_file_->has_file_name()) {
// If the model asset bundle file name is a relative path, searches the file // If the model asset bundle file name is a relative path, searches the file
// in a platform-specific location and returns the absolute path on success. // in a platform-specific location and returns the absolute path on success.
@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
model_asset_bundle_file_handler_->GetFileContent().data(); model_asset_bundle_file_handler_->GetFileContent().data();
size_t buffer_size = size_t buffer_size =
model_asset_bundle_file_handler_->GetFileContent().size(); model_asset_bundle_file_handler_->GetFileContent().size();
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_);
&model_files_);
} }
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile( absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetFile(
const std::string& filename) const { const std::string& filename) const {
auto it = model_files_.find(filename); auto it = files_.find(filename);
if (it == model_files_.end()) { if (it == files_.end()) {
auto model_files = ListModelFiles(); auto files = ListFiles();
std::string all_model_files = std::string all_files = absl::StrJoin(files.begin(), files.end(), ", ");
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
return CreateStatusWithPayload( return CreateStatusWithPayload(
StatusCode::kNotFound, StatusCode::kNotFound,
absl::StrFormat("No model file with name: %s. All model files in the " absl::StrFormat("No file with name: %s. All files in the model asset "
"model asset bundle are: %s.", "bundle are: %s.",
filename, all_model_files), filename, all_files),
MediaPipeTasksStatus::kFileNotFoundError); MediaPipeTasksStatus::kFileNotFoundError);
} }
return it->second; return it->second;
} }
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const { std::vector<std::string> ModelAssetBundleResources::ListFiles() const {
std::vector<std::string> model_names; std::vector<std::string> file_names;
for (const auto& [model_name, _] : model_files_) { for (const auto& [file_name, _] : files_) {
model_names.push_back(model_name); file_names.push_back(file_name);
} }
return model_names; return file_names;
} }
} // namespace core } // namespace core

View File

@ -28,8 +28,8 @@ namespace core {
// The mediapipe task model asset bundle resources class. // The mediapipe task model asset bundle resources class.
// A ModelAssetBundleResources object, created from an external file proto, // A ModelAssetBundleResources object, created from an external file proto,
// contains model asset bundle related resources and the method to extract the // contains model asset bundle related resources and the method to extract the
// tflite models or model asset bundles for the mediapipe sub-tasks. As the // tflite models, resource files or model asset bundles for the mediapipe
// resources are owned by the ModelAssetBundleResources object // sub-tasks. As the resources are owned by the ModelAssetBundleResources object
// callers must keep ModelAssetBundleResources alive while using any of the // callers must keep ModelAssetBundleResources alive while using any of the
// resources. // resources.
class ModelAssetBundleResources { class ModelAssetBundleResources {
@ -50,14 +50,13 @@ class ModelAssetBundleResources {
// Returns the model asset bundle resources tag. // Returns the model asset bundle resources tag.
std::string GetTag() const { return tag_; } std::string GetTag() const { return tag_; }
// Gets the contents of the model file (either tflite model file or model // Gets the contents of the model file (either tflite model file, resource
// bundle file) with the provided name. An error is returned if there is no // file or model bundle file) with the provided name. An error is returned if
// such model file. // there is no such model file.
absl::StatusOr<absl::string_view> GetModelFile( absl::StatusOr<absl::string_view> GetFile(const std::string& filename) const;
const std::string& filename) const;
// Lists all the model file names in the model asset model. // Lists all the file names in the model asset model.
std::vector<std::string> ListModelFiles() const; std::vector<std::string> ListFiles() const;
private: private:
// Constructor. // Constructor.
@ -65,9 +64,9 @@ class ModelAssetBundleResources {
const std::string& tag, const std::string& tag,
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file); std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
// Extracts the model files (either tflite model file or model bundle file) // Extracts the model files (either tflite model file, resource file or model
// from the external file proto. // bundle file) from the external file proto.
absl::Status ExtractModelFilesFromExternalFileProto(); absl::Status ExtractFilesFromExternalFileProto();
// The model asset bundle resources tag. // The model asset bundle resources tag.
const std::string tag_; const std::string tag_;
@ -78,11 +77,11 @@ class ModelAssetBundleResources {
// The ExternalFileHandler for the model asset bundle. // The ExternalFileHandler for the model asset bundle.
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_; std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
// The model files bundled in model asset bundle, as a map with the filename // The files bundled in model asset bundle, as a map with the filename
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and // (corresponding to a basename, e.g. "hand_detector.tflite") as key and
// a pointer to the file contents as value. Each model file can be either // a pointer to the file contents as value. Each file can be either a TFLite
// a TFLite model file or a model bundle file for sub-task. // model file, resource file or a model bundle file for sub-task.
absl::flat_hash_map<std::string, absl::string_view> model_files_; absl::flat_hash_map<std::string, absl::string_view> files_;
}; };
} // namespace core } // namespace core

View File

@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
#endif // _WIN32 #endif // _WIN32
@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task") model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
.status());
MP_EXPECT_OK( MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite") model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status()); .status());
} }
@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status_or_model_bundle_file = auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task"); model_bundle_resources->GetFile("dummy_hand_landmarker.task");
MP_EXPECT_OK(status_or_model_bundle_file.status()); MP_EXPECT_OK(status_or_model_bundle_file.status());
// Creates sub-task model asset bundle resources. // Creates sub-task model asset bundle resources.
@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(hand_landmaker_model_file))); std::move(hand_landmaker_model_file)));
MP_EXPECT_OK(hand_landmaker_model_bundle_resources MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_detector.tflite") ->GetFile("dummy_hand_detector.tflite")
.status()); .status());
MP_EXPECT_OK(hand_landmaker_model_bundle_resources MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_landmarker.tflite") ->GetFile("dummy_hand_landmarker.tflite")
.status()); .status());
} }
@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status_or_model_bundle_file = auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite"); model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite");
MP_EXPECT_OK(status_or_model_bundle_file.status()); MP_EXPECT_OK(status_or_model_bundle_file.status());
// Verify tflite model works. // Verify tflite model works.
@ -200,12 +196,12 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
auto model_bundle_resources, auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto status = model_bundle_resources->GetModelFile("not_found.task").status(); auto status = model_bundle_resources->GetFile("not_found.task").status();
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound); EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
EXPECT_THAT(status.message(), EXPECT_THAT(
testing::HasSubstr( status.message(),
"No model file with name: not_found.task. All model files in " testing::HasSubstr("No file with name: not_found.task. All files in "
"the model asset bundle are: ")); "the model asset bundle are: "));
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload), EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
testing::Optional(absl::Cord( testing::Optional(absl::Cord(
absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError)))); absl::StrCat(MediaPipeTasksStatus::kFileNotFoundError))));
@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
auto model_bundle_resources, auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag, ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file))); std::move(model_file)));
auto model_files = model_bundle_resources->ListModelFiles(); auto model_files = model_bundle_resources->ListFiles();
std::vector<std::string> expected_model_files = { std::vector<std::string> expected_model_files = {
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"}; "dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
std::sort(model_files.begin(), model_files.end()); std::sort(model_files.begin(), model_files.end());

View File

@ -77,9 +77,11 @@ class ModelResourcesCalculator : public api2::Node {
if (options.has_model_file()) { if (options.has_model_file()) {
RET_CHECK(options.model_file().has_file_content() || RET_CHECK(options.model_file().has_file_content() ||
options.model_file().has_file_descriptor_meta() || options.model_file().has_file_descriptor_meta() ||
options.model_file().has_file_name()) options.model_file().has_file_name() ||
options.model_file().has_file_pointer_meta())
<< "'model_file' must specify at least one of " << "'model_file' must specify at least one of "
"'file_content', 'file_descriptor_meta', or 'file_name'"; "'file_content', 'file_descriptor_meta', 'file_name', or "
"'file_pointer_meta'";
} }
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -179,9 +179,9 @@ TEST_F(ModelResourcesCalculatorTest, EmptyExternalFileProto) {
auto status = graph.Initialize(graph_config); auto status = graph.Initialize(graph_config);
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(), EXPECT_THAT(status.message(),
testing::HasSubstr( testing::HasSubstr("'model_file' must specify at least one of "
"'model_file' must specify at least one of " "'file_content', 'file_descriptor_meta', "
"'file_content', 'file_descriptor_meta', or 'file_name'")); "'file_name', or 'file_pointer_meta'"));
} }
TEST_F(ModelResourcesCalculatorTest, GraphServiceNotAvailable) { TEST_F(ModelResourcesCalculatorTest, GraphServiceNotAvailable) {

View File

@ -138,7 +138,7 @@ class InferenceSubgraph : public Subgraph {
delegate.mutable_tflite()->CopyFrom(acceleration.tflite()); delegate.mutable_tflite()->CopyFrom(acceleration.tflite());
break; break;
case Acceleration::DELEGATE_NOT_SET: case Acceleration::DELEGATE_NOT_SET:
// Deafult inference calculator setting. // Default inference calculator setting.
break; break;
} }
return delegate; return delegate;

View File

@ -124,10 +124,10 @@ class ModelTaskGraph : public Subgraph {
// Inserts a mediapipe task inference subgraph into the provided // Inserts a mediapipe task inference subgraph into the provided
// GraphBuilder. The returned node provides the following interfaces to the // GraphBuilder. The returned node provides the following interfaces to the
// the rest of the graph: // the rest of the graph:
// - a tensor vector (std::vector<MeidaPipe::Tensor>) input stream with tag // - a tensor vector (std::vector<mediapipe::Tensor>) input stream with tag
// "TENSORS", representing the input tensors to be consumed by the // "TENSORS", representing the input tensors to be consumed by the
// inference engine. // inference engine.
// - a tensor vector (std::vector<MeidaPipe::Tensor>) output stream with tag // - a tensor vector (std::vector<mediapipe::Tensor>) output stream with tag
// "TENSORS", representing the output tensors generated by the inference // "TENSORS", representing the output tensors generated by the inference
// engine. // engine.
// - a MetadataExtractor output side packet with tag "METADATA_EXTRACTOR". // - a MetadataExtractor output side packet with tag "METADATA_EXTRACTOR".

View File

@ -301,7 +301,7 @@ absl::Status TaskRunner::Close() {
} }
is_running_ = false; is_running_ = false;
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
AddPayload(graph_.CloseAllInputStreams(), "Fail to close intput streams", AddPayload(graph_.CloseAllInputStreams(), "Fail to close input streams",
MediaPipeTasksStatus::kRunnerFailsToCloseError)); MediaPipeTasksStatus::kRunnerFailsToCloseError));
MP_RETURN_IF_ERROR(AddPayload( MP_RETURN_IF_ERROR(AddPayload(
graph_.WaitUntilDone(), "Fail to shutdown the MediaPipe graph.", graph_.WaitUntilDone(), "Fail to shutdown the MediaPipe graph.",

View File

@ -65,7 +65,7 @@ class TaskRunner {
// Creates the task runner with a CalculatorGraphConfig proto. // Creates the task runner with a CalculatorGraphConfig proto.
// If a tflite op resolver object is provided, the task runner will take // If a tflite op resolver object is provided, the task runner will take
// it as the global op resolver for all models running within this task. // it as the global op resolver for all models running within this task.
// The op resolver's owernship will be transferred into the pipeleine runner. // The op resolver's ownership will be transferred into the pipeleine runner.
// When a user-defined PacketsCallback is provided, clients must use the // When a user-defined PacketsCallback is provided, clients must use the
// asynchronous method, Send(), to provide the input packets. If the packets // asynchronous method, Send(), to provide the input packets. If the packets
// callback is absent, clients must use the synchronous method, Process(), to // callback is absent, clients must use the synchronous method, Process(), to
@ -84,7 +84,7 @@ class TaskRunner {
// frames from a video file and an audio file. The call blocks the current // frames from a video file and an audio file. The call blocks the current
// thread until a failure status or a successful result is returned. // thread until a failure status or a successful result is returned.
// If the input packets have no timestamp, an internal timestamp will be // If the input packets have no timestamp, an internal timestamp will be
// assigend per invocation. Otherwise, when the timestamp is set in the // assigned per invocation. Otherwise, when the timestamp is set in the
// input packets, the caller must ensure that the input packet timestamps are // input packets, the caller must ensure that the input packet timestamps are
// greater than the timestamps of the previous invocation. This method is // greater than the timestamps of the previous invocation. This method is
// thread-unsafe and it is the caller's responsibility to synchronize access // thread-unsafe and it is the caller's responsibility to synchronize access

View File

@ -64,7 +64,7 @@ class ModelMetadataPopulator {
// Loads associated files into the TFLite FlatBuffer model. The input is a map // Loads associated files into the TFLite FlatBuffer model. The input is a map
// of {filename, file contents}. // of {filename, file contents}.
// //
// Warning: this method removes any previoulsy present associated files. // Warning: this method removes any previously present associated files.
// Calling this method multiple time removes any associated files from // Calling this method multiple time removes any associated files from
// previous calls, so this method should usually be called only once. // previous calls, so this method should usually be called only once.
void LoadAssociatedFiles( void LoadAssociatedFiles(

View File

@ -213,7 +213,7 @@ void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
Version* min_version) { Version* min_version) {
if (table == nullptr) return; if (table == nullptr) return;
// Checks the ContenProperties field. // Checks the ContentProperties field.
if (table->content_properties_type() == ContentProperties_AudioProperties) { if (table->content_properties_type() == ContentProperties_AudioProperties) {
UpdateMinimumVersion( UpdateMinimumVersion(
GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties), GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties),

View File

@ -31,8 +31,8 @@ PYBIND11_MODULE(_pywrap_metadata_version, m) {
// Using pybind11 type conversions to convert between Python and native // Using pybind11 type conversions to convert between Python and native
// C++ types. There are other options to provide access to native Python types // C++ types. There are other options to provide access to native Python types
// in C++ and vice versa. See the pybind 11 instrcution [1] for more details. // in C++ and vice versa. See the pybind 11 instruction [1] for more details.
// Type converstions is recommended by pybind11, though the main downside // Type conversions is recommended by pybind11, though the main downside
// is that a copy of the data must be made on every Python to C++ transition: // is that a copy of the data must be made on every Python to C++ transition:
// this is needed since the C++ and Python versions of the same type generally // this is needed since the C++ and Python versions of the same type generally
// wont have the same memory layout. // wont have the same memory layout.

View File

@ -79,7 +79,7 @@ TEST(MetadataVersionTest,
auto metadata = metadata_builder.Finish(); auto metadata = metadata_builder.Finish();
FinishModelMetadataBuffer(builder, metadata); FinishModelMetadataBuffer(builder, metadata);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -100,7 +100,7 @@ TEST(MetadataVersionTest,
auto metadata = metadata_builder.Finish(); auto metadata = metadata_builder.Finish();
builder.Finish(metadata); builder.Finish(metadata);
// Gets the mimimum metadata parser version and triggers error. // Gets the minimum metadata parser version and triggers error.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -121,7 +121,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files); metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -147,7 +147,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -172,7 +172,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -203,7 +203,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -234,7 +234,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -265,7 +265,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -294,7 +294,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -323,7 +323,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -348,7 +348,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -373,7 +373,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -404,7 +404,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -431,7 +431,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()}); std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder); CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -453,7 +453,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files); metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -476,7 +476,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files); metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),
@ -504,7 +504,7 @@ TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) {
metadata_builder.add_subgraph_metadata(subgraphs); metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish()); FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version. // Gets the minimum metadata parser version.
std::string min_version; std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(), EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version), builder.GetSize(), &min_version),

View File

@ -42,3 +42,36 @@ cc_test(
"@org_tensorflow//tensorflow/lite/kernels:test_util", "@org_tensorflow//tensorflow/lite/kernels:test_util",
], ],
) )
cc_library(
name = "ngram_hash",
srcs = ["ngram_hash.cc"],
hdrs = ["ngram_hash.h"],
copts = tflite_copts(),
deps = [
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils",
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
],
alwayslink = 1,
)
cc_test(
name = "ngram_hash_test",
srcs = ["ngram_hash_test.cc"],
deps = [
":ngram_hash",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
"@com_google_absl//absl/types:optional",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:test_util",
],
)

View File

@ -0,0 +1,264 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include <cstdint>
#include <string>
#include <vector>
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
namespace tflite::ops::custom {
namespace ngram_op {
namespace {
using ::flexbuffers::GetRoot;
using ::flexbuffers::Map;
using ::flexbuffers::TypedVector;
using ::mediapipe::tasks::text::language_detector::custom_ops::
LowercaseUnicodeStr;
using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize;
using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput;
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
MurmurHash64WithSeed;
using ::tflite::GetString;
using ::tflite::StringRef;
constexpr int kInputMessage = 0;
constexpr int kOutputLabel = 0;
constexpr int kDefaultMaxSplits = 128;
// This op takes in a string, finds the character ngrams for it and then
// maps each of these ngrams to an index using the specified vocabulary sizes.
// Input(s):
// - input: Input string.
// - seeds: Seed for the random number generator.
// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would
// be interpreted as generating unigrams, bigrams, and trigrams.
// - vocab_sizes: Size of the vocabulary for each of the ngram features
// respectively. The op would generate vocab ids to be less than or equal to
// the vocab size. The index 0 implies an invalid ngram.
// - max_splits: Maximum number of tokens in the output. If this is unset, the
// limit is `kDefaultMaxSplits`.
// - lower_case_input: If this is set to true, the input string would be
// lower-cased before any processing.
// Output(s):
// - output: A tensor of size [number of ngrams, number of tokens + 2],
// where 2 tokens are reserved for the padding. If `max_splits` is set, this
// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`.
// Helper class used for pre-processing the input.
class NGramHashParams {
public:
NGramHashParams(const uint64_t seed, const std::vector<int>& ngram_lengths,
const std::vector<int>& vocab_sizes, int max_splits,
bool lower_case_input)
: seed_(seed),
ngram_lengths_(ngram_lengths),
vocab_sizes_(vocab_sizes),
max_splits_(max_splits),
lower_case_input_(lower_case_input) {}
TfLiteStatus PreprocessInput(const TfLiteTensor* input_t,
TfLiteContext* context) {
if (input_t->bytes == 0) {
context->ReportError(context, "Empty input not supported.");
return kTfLiteError;
}
// Do sanity checks on the input.
if (ngram_lengths_.empty()) {
context->ReportError(context, "`ngram_lengths` must be non-empty.");
return kTfLiteError;
}
if (vocab_sizes_.empty()) {
context->ReportError(context, "`vocab_sizes` must be non-empty.");
return kTfLiteError;
}
if (ngram_lengths_.size() != vocab_sizes_.size()) {
context->ReportError(
context,
"Sizes of `ngram_lengths` and `vocab_sizes` must be the same.");
return kTfLiteError;
}
if (max_splits_ <= 0) {
context->ReportError(context, "`max_splits` must be > 0.");
return kTfLiteError;
}
// Obtain and tokenize the input.
StringRef inputref = GetString(input_t, /*string_index=*/0);
if (lower_case_input_) {
std::string lower_cased_str;
LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str);
tokenized_output_ =
Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_,
/*exclude_nonalphaspace_tokens=*/true);
} else {
tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_,
/*exclude_nonalphaspace_tokens=*/true);
}
return kTfLiteOk;
}
uint64_t GetSeed() const { return seed_; }
int GetNumTokens() const { return tokenized_output_.tokens.size(); }
int GetNumNGrams() const { return ngram_lengths_.size(); }
std::vector<int> GetNGramLengths() const { return ngram_lengths_; }
std::vector<int> GetVocabSizes() const { return vocab_sizes_; }
const TokenizedOutput& GetTokenizedOutput() const {
return tokenized_output_;
}
TokenizedOutput tokenized_output_;
private:
const uint64_t seed_;
std::vector<int> ngram_lengths_;
std::vector<int> vocab_sizes_;
const int max_splits_;
const bool lower_case_input_;
};
// Convert the TypedVector into a regular std::vector.
std::vector<int> GetIntVector(TypedVector typed_vec) {
std::vector<int> vec(typed_vec.size());
for (int j = 0; j < typed_vec.size(); j++) {
vec[j] = typed_vec[j].AsInt32();
}
return vec;
}
void GetNGramHashIndices(NGramHashParams* params, int32_t* data) {
const int max_unicode_length = params->GetNumTokens();
const auto ngram_lengths = params->GetNGramLengths();
const auto vocab_sizes = params->GetVocabSizes();
const auto& tokenized_output = params->GetTokenizedOutput();
const auto seed = params->GetSeed();
// Compute for each ngram.
for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) {
const int vocab_size = vocab_sizes[ngram];
const int ngram_length = ngram_lengths[ngram];
// Compute for each token within the input.
for (int start = 0; start < tokenized_output.tokens.size(); start++) {
// Compute the number of bytes for the ngram starting at the given
// token.
int num_bytes = 0;
for (int i = start;
i < tokenized_output.tokens.size() && i < (start + ngram_length);
i++) {
num_bytes += tokenized_output.tokens[i].second;
}
// Compute the hash for the ngram starting at the token.
const auto str_hash = MurmurHash64WithSeed(
tokenized_output.str.c_str() + tokenized_output.tokens[start].first,
num_bytes, seed);
// Map the hash to an index in the vocab.
data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1;
}
}
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const Map& m = GetRoot(buffer_t, length).AsMap();
const uint64_t seed = m["seed"].AsUInt64();
const std::vector<int> ngram_lengths =
GetIntVector(m["ngram_lengths"].AsTypedVector());
const std::vector<int> vocab_sizes =
GetIntVector(m["vocab_sizes"].AsTypedVector());
const int max_splits =
m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32();
const bool lowercase_input =
m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool();
return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits,
lowercase_input);
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<NGramHashParams*>(buffer);
}
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
SetTensorToDynamic(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
TF_LITE_ENSURE_OK(
context,
params->PreprocessInput(GetInput(context, node, kInputMessage), context));
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
if (IsDynamicTensor(output)) {
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = 1;
output_size->data[1] = params->GetNumNGrams();
output_size->data[2] = params->GetNumTokens();
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
} else {
context->ReportError(context, "Output must by dynamic.");
return kTfLiteError;
}
if (output->type == kTfLiteInt32) {
GetNGramHashIndices(params, output->data.i32);
} else {
context->ReportError(context, "Output type must be Int32.");
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace ngram_op
TfLiteRegistration* Register_NGRAM_HASH() {
static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free,
ngram_op::Resize, ngram_op::Eval};
return &r;
}
} // namespace tflite::ops::custom

View File

@ -0,0 +1,27 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
#include "tensorflow/lite/kernels/register.h"
namespace tflite::ops::custom {
TfLiteRegistration* Register_NGRAM_HASH();
} // namespace tflite::ops::custom
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_

View File

@ -0,0 +1,313 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
namespace tflite::ops::custom {
namespace {
using ::flexbuffers::Builder;
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
MurmurHash64WithSeed;
using ::testing::ElementsAreArray;
using ::testing::Message;
// Helper class for testing the op.
class NGramHashModel : public SingleOpModel {
public:
explicit NGramHashModel(const uint64_t seed,
const std::vector<int>& ngram_lengths,
const std::vector<int>& vocab_sizes,
const absl::optional<int> max_splits = std::nullopt) {
// Setup the model inputs.
Builder fbb;
size_t start = fbb.StartMap();
fbb.UInt("seed", seed);
{
size_t start = fbb.StartVector("ngram_lengths");
for (const int& ngram_len : ngram_lengths) {
fbb.Int(ngram_len);
}
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
}
{
size_t start = fbb.StartVector("vocab_sizes");
for (const int& vocab_size : vocab_sizes) {
fbb.Int(vocab_size);
}
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
}
if (max_splits) {
fbb.Int("max_splits", *max_splits);
}
fbb.EndMap(start);
fbb.Finish();
output_ = AddOutput({TensorType_INT32, {}});
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
BuildInterpreter({GetShape(input_)});
}
void SetupInputTensor(const std::string& input) {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
}
void Invoke(const std::string& input) {
SetupInputTensor(input);
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
}
TfLiteStatus InvokeUnchecked(const std::string& input) {
SetupInputTensor(input);
return SingleOpModel::Invoke();
}
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_ = AddInput(TensorType_STRING);
int output_;
};
TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) {
// Checks that the op returns the expected value when the input is sane.
// Also checks that when `max_splits` is not specified, the entire string is
// tokenized.
const uint64_t kSeed = 123;
const std::vector<int> vocab_sizes({100, 200});
std::vector<int> ngram_lengths({1, 2});
const std::vector<std::string> testcase_inputs({
"hi",
"wow",
"!",
"HI",
});
// A hash function that maps the given string to an index in the embedding
// table denoted by `vocab_idx`.
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
const auto hash_value =
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
};
const std::vector<std::vector<int>> expected_testcase_outputs(
{{
// Unigram & Bigram output for "hi".
hash("^", 0),
hash("h", 0),
hash("i", 0),
hash("$", 0),
hash("^h", 1),
hash("hi", 1),
hash("i$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow".
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "!" (which will get replaced by " ").
hash("^", 0),
hash(" ", 0),
hash("$", 0),
hash("^ ", 1),
hash(" $", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "HI" (which will get lower-cased).
hash("^", 0),
hash("h", 0),
hash("i", 0),
hash("$", 0),
hash("^h", 1),
hash("hi", 1),
hash("i$", 1),
hash("$", 1),
}});
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
const string& testcase_input = testcase_inputs[test_idx];
m.Invoke(testcase_input);
SCOPED_TRACE(Message() << "Where the testcases' input is: "
<< testcase_input);
EXPECT_THAT(m.GetOutput<int>(),
ElementsAreArray(expected_testcase_outputs[test_idx]));
EXPECT_THAT(m.GetOutputShape(),
ElementsAreArray(
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
static_cast<int>(testcase_input.size()) + /*padding*/ 2}));
}
}
TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) {
// Checks that the op returns the expected value when the input is correct
// when `max_splits` is specified.
const uint64_t kSeed = 123;
const std::vector<int> vocab_sizes({100, 200});
std::vector<int> ngram_lengths({1, 2});
const std::string testcase_input = "wow";
const std::vector<int> max_splits({2, 3, 4, 5, 6});
// A hash function that maps the given string to an index in the embedding
// table denoted by `vocab_idx`.
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
const auto hash_value =
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
};
const std::vector<std::vector<int>> expected_testcase_outputs(
{{
// Unigram & Bigram output for "wow", when `max_splits` == 2.
// We cannot include any of the actual tokens, since `max_splits`
// only allows enough space for the delimiters.
hash("^", 0),
hash("$", 0),
hash("^$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 3.
// We can start to include some tokens from the input string.
hash("^", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 4.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("o$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 5.
// We can include the full input string.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 6.
// `max_splits` is more than the full input string.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
}});
for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) {
const int testcase_max_splits = max_splits[test_idx];
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits);
m.Invoke(testcase_input);
SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits);
EXPECT_THAT(m.GetOutput<int>(),
ElementsAreArray(expected_testcase_outputs[test_idx]));
EXPECT_THAT(
m.GetOutputShape(),
ElementsAreArray(
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
std::min(
// Longest possible tokenization when using the entire
// input.
static_cast<int>(testcase_input.size()) + /*padding*/ 2,
// Longest possible string when the `max_splits` value
// is < testcase_input.size() + 2 for padding.
testcase_max_splits)}));
}
}
TEST(NGramHashTest, InvalidMaxSplitsValue) {
// Check that the op errors out when given an invalid max splits value.
const std::vector<int> invalid_max_splits({0, -1, -5, -100});
for (const int max_splits : invalid_max_splits) {
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
/*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits);
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
}
TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) {
// Check that the op errors out when ngram lengths and vocab sizes mistmatch.
{
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300},
/*vocab_sizes=*/{1, 2});
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
{
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
/*vocab_sizes=*/{1, 2, 3});
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
}
} // namespace
} // namespace tflite::ops::custom

View File

@ -0,0 +1,42 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "ngram_hash_ops_utils",
srcs = [
"ngram_hash_ops_utils.cc",
],
hdrs = [
"ngram_hash_ops_utils.h",
],
deps = [
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf",
],
)
cc_test(
name = "ngram_hash_ops_utils_test",
size = "small",
srcs = [
"ngram_hash_ops_utils_test.cc",
],
deps = [
":ngram_hash_ops_utils",
"//mediapipe/framework/port:gtest_main",
],
)

View File

@ -0,0 +1,38 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "murmur",
srcs = ["murmur.cc"],
hdrs = ["murmur.h"],
deps = [
"//mediapipe/framework/port:integral_types",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:endian",
],
)
cc_test(
name = "murmur_test",
srcs = ["murmur_test.cc"],
deps = [
":murmur",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
],
)

View File

@ -0,0 +1,95 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Austin Appelby and Jyrki Alakuijala.
// Original copyright message below.
// Copyright 2009 Google Inc. All Rights Reserved.
// Author: aappleby@google.com (Austin Appleby)
// jyrki@google.com (Jyrki Alakuijala)
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include <cstdint>
#include "absl/base/internal/endian.h"
#include "absl/base/optimization.h"
#include "mediapipe/framework/port/integral_types.h"
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
namespace {
using ::absl::little_endian::Load64;
// Murmur 2.0 multiplication constant.
static const uint64_t kMul = 0xc6a4a7935bd1e995ULL;
// We need to mix some of the bits that get propagated and mixed into the
// high bits by multiplication back into the low bits. 17 last bits get
// a more efficiently mixed with this.
inline uint64_t ShiftMix(uint64_t val) { return val ^ (val >> 47); }
// Accumulate 8 bytes into 64-bit Murmur hash
inline uint64_t MurmurStep(uint64_t hash, uint64_t data) {
hash ^= ShiftMix(data * kMul) * kMul;
hash *= kMul;
return hash;
}
// Build a uint64 from 1-8 bytes.
// 8 * len least significant bits are loaded from the memory with
// LittleEndian order. The 64 - 8 * len most significant bits are
// set all to 0.
// In latex-friendly words, this function returns:
// $\sum_{i=0}^{len-1} p[i] 256^{i}$, where p[i] is unsigned.
//
// This function is equivalent to:
// uint64 val = 0;
// memcpy(&val, p, len);
// return ToHost64(val);
//
// The caller needs to guarantee that 0 <= len <= 8.
uint64_t Load64VariableLength(const void* const p, int len) {
ABSL_ASSUME(len >= 0 && len <= 8);
uint64_t val = 0;
const uint8_t* const src = static_cast<const uint8_t*>(p);
for (int i = 0; i < len; ++i) {
val |= static_cast<uint64_t>(src[i]) << (8 * i);
}
return val;
}
} // namespace
unsigned long long MurmurHash64WithSeed(const char* buf, // NOLINT
const size_t len, const uint64_t seed) {
// Let's remove the bytes not divisible by the sizeof(uint64).
// This allows the inner loop to process the data as 64 bit integers.
const size_t len_aligned = len & ~0x7;
const char* const end = buf + len_aligned;
uint64_t hash = seed ^ (len * kMul);
for (const char* p = buf; p != end; p += 8) {
hash = MurmurStep(hash, Load64(p));
}
if ((len & 0x7) != 0) {
const uint64_t data = Load64VariableLength(end, len & 0x7);
hash ^= data;
hash *= kMul;
}
hash = ShiftMix(hash) * kMul;
hash = ShiftMix(hash);
return hash;
}
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash

View File

@ -0,0 +1,43 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Austin Appelby and Jyrki Alakuijala.
// Original copyright message below.
// Copyright 2009 Google Inc. All Rights Reserved.
// Author: aappleby@google.com (Austin Appelby)
// jyrki@google.com (Jyrki Alakuijala)
//
// MurmurHash is a fast multiplication and shifting based algorithm,
// based on Austin Appleby's MurmurHash 2.0 algorithm.
#ifndef UTIL_HASH_MURMUR_H_
#define UTIL_HASH_MURMUR_H_
#include <stddef.h>
#include <stdlib.h> // for size_t.
#include <cstdint>
#include "mediapipe/framework/port/integral_types.h"
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
// Hash function for a byte array. Has a seed which allows this hash function to
// be used in algorithms that need a family of parameterized hash functions.
// e.g. Minhash.
unsigned long long MurmurHash64WithSeed(const char* buf, size_t len, // NOLINT
uint64_t seed);
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash
#endif // UTIL_HASH_MURMUR_H_

View File

@ -0,0 +1,66 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a test library written by Jyrki Alakuijala.
// Original copyright message below.
// Copyright 2009 Google Inc. All Rights Reserved.
// Author: jyrki@google.com (Jyrki Alakuijala)
//
// Tests for the fast hashing algorithm based on Austin Appleby's
// MurmurHash 2.0 algorithm. See http://murmurhash.googlepages.com/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include <string.h>
#include <cstdint>
#include <string>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
TEST(Murmur, EmptyData64) {
EXPECT_EQ(uint64_t{0}, MurmurHash64WithSeed(nullptr, uint64_t{0}, 0));
}
TEST(Murmur, VaryWithDifferentSeeds) {
// While in theory different seeds could return the same
// hash for the same data this is unlikely.
char data1 = 'x';
EXPECT_NE(MurmurHash64WithSeed(&data1, 1, 100),
MurmurHash64WithSeed(&data1, 1, 101));
}
// Hashes don't change.
TEST(Murmur, Idempotence) {
const char data[] = "deadbeef";
const size_t dlen = strlen(data);
for (int i = 0; i < 10; i++) {
EXPECT_EQ(MurmurHash64WithSeed(data, dlen, i),
MurmurHash64WithSeed(data, dlen, i));
}
const char next_data[] = "deadbeef000---";
const size_t next_dlen = strlen(next_data);
for (int i = 0; i < 10; i++) {
EXPECT_EQ(MurmurHash64WithSeed(next_data, next_dlen, i),
MurmurHash64WithSeed(next_data, next_dlen, i));
}
}
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash

View File

@ -0,0 +1,96 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include <string>
#include <utility>
#include <vector>
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
namespace mediapipe::tasks::text::language_detector::custom_ops {
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
bool exclude_nonalphaspace_tokens) {
const std::string kPrefix = "^";
const std::string kSuffix = "$";
const std::string kReplacementToken = " ";
TokenizedOutput output;
size_t token_start = 0;
output.str.reserve(len + 2);
output.tokens.reserve(len + 2);
output.str.append(kPrefix);
output.tokens.push_back(std::make_pair(token_start, kPrefix.size()));
token_start += kPrefix.size();
Rune token;
for (int i = 0; i < len && output.tokens.size() + 1 < max_tokens;) {
// Use the standard UTF-8 library to find the next token.
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
// Stop processing, if we can't read any more tokens, or we have reached
// maximum allowed tokens, allocating one token for the suffix.
if (bytes_read == 0) {
break;
}
// If `exclude_nonalphaspace_tokens` is set to true, and the token is not
// alphanumeric, replace it with a replacement token.
if (exclude_nonalphaspace_tokens && !utf_isalpharune(token)) {
output.str.append(kReplacementToken);
output.tokens.push_back(
std::make_pair(token_start, kReplacementToken.size()));
token_start += kReplacementToken.size();
i += bytes_read;
continue;
}
// Append the token in the output string, and note its position and the
// number of bytes that token consumed.
output.str.append(input_str + i, bytes_read);
output.tokens.push_back(std::make_pair(token_start, bytes_read));
token_start += bytes_read;
i += bytes_read;
}
output.str.append(kSuffix);
output.tokens.push_back(std::make_pair(token_start, kSuffix.size()));
token_start += kSuffix.size();
return output;
}
void LowercaseUnicodeStr(const char* input_str, int len,
std::string* output_str) {
for (int i = 0; i < len;) {
Rune token;
// Tokenize the given string, and get the appropriate lowercase token.
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
token = utf_isalpharune(token) ? utf_tolowerrune(token) : token;
// Write back the token to the output string.
char token_buf[UTFmax];
size_t bytes_to_write = utf_runetochar(token_buf, &token);
output_str->append(token_buf, bytes_to_write);
i += bytes_read;
}
}
} // namespace mediapipe::tasks::text::language_detector::custom_ops

View File

@ -0,0 +1,56 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
#include <string>
#include <utility>
#include <vector>
namespace mediapipe::tasks::text::language_detector::custom_ops {
struct TokenizedOutput {
// The processed string (with necessary prefix, suffix, skipped tokens, etc.).
std::string str;
// This vector contains pairs, where each pair has two members. The first
// denoting the starting index of the token in the `str` string, and the
// second denoting the length of that token in bytes.
std::vector<std::pair<const size_t, const size_t>> tokens;
};
// Tokenizes the given input string on Unicode token boundaries, with a maximum
// of `max_tokens` tokens.
//
// If `exclude_nonalphaspace_tokens` is enabled, the tokenization ignores
// non-alphanumeric tokens, and replaces them with a replacement token (" ").
//
// The method returns the output in the `TokenizedOutput` struct, which stores
// both, the processed input string, and the indices and sizes of each token
// within that string.
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
bool exclude_nonalphaspace_tokens);
// Converts the given unicode string (`input_str`) with the specified length
// (`len`) to a lowercase string.
//
// The method populates the lowercased string in `output_str`.
void LowercaseUnicodeStr(const char* input_str, int len,
std::string* output_str);
} // namespace mediapipe::tasks::text::language_detector::custom_ops
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_

View File

@ -0,0 +1,135 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include <string>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe::tasks::text::language_detector::custom_ops {
namespace {
using ::testing::Values;
std::string ReconstructStringFromTokens(TokenizedOutput output) {
std::string reconstructed_str;
for (int i = 0; i < output.tokens.size(); i++) {
reconstructed_str.append(
output.str.c_str() + output.tokens[i].first,
output.str.c_str() + output.tokens[i].first + output.tokens[i].second);
}
return reconstructed_str;
}
struct TokenizeTestParams {
std::string input_str;
size_t max_tokens;
bool exclude_nonalphaspace_tokens;
std::string expected_output_str;
};
class TokenizeParameterizedTest
: public ::testing::Test,
public testing::WithParamInterface<TokenizeTestParams> {};
TEST_P(TokenizeParameterizedTest, Tokenize) {
// Checks that the Tokenize method returns the expected value.
const TokenizeTestParams params = TokenizeParameterizedTest::GetParam();
const TokenizedOutput output = Tokenize(
/*input_str=*/params.input_str.c_str(),
/*len=*/params.input_str.size(),
/*max_tokens=*/params.max_tokens,
/*exclude_nonalphaspace_tokens=*/params.exclude_nonalphaspace_tokens);
// The output string should have the necessary prefixes, and the "!" token
// should have been replaced with a " ".
EXPECT_EQ(output.str, params.expected_output_str);
EXPECT_EQ(ReconstructStringFromTokens(output), params.expected_output_str);
}
INSTANTIATE_TEST_SUITE_P(
TokenizeParameterizedTests, TokenizeParameterizedTest,
Values(
// Test including non-alphanumeric characters.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
/*exclude_alphanonspace=*/false,
/*expected_output_str=*/"^hi!$"}),
// Test not including non-alphanumeric characters.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^hi $"}),
// Test with a maximum of 3 tokens.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/3,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^h$"}),
// Test with non-latin characters.
TokenizeTestParams({/*input_str=*/"ありがと", /*max_tokens=*/100,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^ありがと$"})));
TEST(LowercaseUnicodeTest, TestLowercaseUnicode) {
{
// Check that the method is a no-op when the string is lowercase.
std::string input_str = "hello";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "hello");
}
{
// Check that the method has uppercase characters.
std::string input_str = "hElLo";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "hello");
}
{
// Check that the method works with non-latin scripts.
// Cyrillic has the concept of cases, so it should change the input.
std::string input_str = "БЙп";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "бйп");
}
{
// Check that the method works with non-latin scripts.
// Japanese doesn't have the concept of cases, so it should not change.
std::string input_str = "ありがと";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "ありがと");
}
}
} // namespace
} // namespace mediapipe::tasks::text::language_detector::custom_ops

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "utf",
srcs = [
"rune.c",
"runetype.c",
"runetypebody.h",
],
hdrs = ["utf.h"],
)

View File

@ -0,0 +1,233 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Rob Pike and Ken Thompson. Original
// copyright message below.
/*
* The authors of this software are Rob Pike and Ken Thompson.
* Copyright (c) 2002 by Lucent Technologies.
* Permission to use, copy, modify, and distribute this software for any
* purpose without fee is hereby granted, provided that this entire notice
* is included in all copies of any software which is or includes a copy
* or modification of this software and in all copies of the supporting
* documentation for such software.
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
*/
#include <stdarg.h>
#include <string.h>
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
enum
{
Bit1 = 7,
Bitx = 6,
Bit2 = 5,
Bit3 = 4,
Bit4 = 3,
Bit5 = 2,
T1 = ((1<<(Bit1+1))-1) ^ 0xFF, /* 0000 0000 */
Tx = ((1<<(Bitx+1))-1) ^ 0xFF, /* 1000 0000 */
T2 = ((1<<(Bit2+1))-1) ^ 0xFF, /* 1100 0000 */
T3 = ((1<<(Bit3+1))-1) ^ 0xFF, /* 1110 0000 */
T4 = ((1<<(Bit4+1))-1) ^ 0xFF, /* 1111 0000 */
T5 = ((1<<(Bit5+1))-1) ^ 0xFF, /* 1111 1000 */
Rune1 = (1<<(Bit1+0*Bitx))-1, /* 0000 0000 0111 1111 */
Rune2 = (1<<(Bit2+1*Bitx))-1, /* 0000 0111 1111 1111 */
Rune3 = (1<<(Bit3+2*Bitx))-1, /* 1111 1111 1111 1111 */
Rune4 = (1<<(Bit4+3*Bitx))-1,
/* 0001 1111 1111 1111 1111 1111 */
Maskx = (1<<Bitx)-1, /* 0011 1111 */
Testx = Maskx ^ 0xFF, /* 1100 0000 */
Bad = Runeerror,
};
/*
* Modified by Wei-Hwa Huang, Google Inc., on 2004-09-24
* This is a slower but "safe" version of the old chartorune
* that works on strings that are not necessarily null-terminated.
*
* If you know for sure that your string is null-terminated,
* chartorune will be a bit faster.
*
* It is guaranteed not to attempt to access "length"
* past the incoming pointer. This is to avoid
* possible access violations. If the string appears to be
* well-formed but incomplete (i.e., to get the whole Rune
* we'd need to read past str+length) then we'll set the Rune
* to Bad and return 0.
*
* Note that if we have decoding problems for other
* reasons, we return 1 instead of 0.
*/
int
utf_charntorune(Rune *rune, const char *str, int length)
{
int c, c1, c2, c3;
long l;
/* When we're not allowed to read anything */
if(length <= 0) {
goto badlen;
}
/*
* one character sequence (7-bit value)
* 00000-0007F => T1
*/
c = *(uchar*)str;
if(c < Tx) {
*rune = c;
return 1;
}
// If we can't read more than one character we must stop
if(length <= 1) {
goto badlen;
}
/*
* two character sequence (11-bit value)
* 0080-07FF => T2 Tx
*/
c1 = *(uchar*)(str+1) ^ Tx;
if(c1 & Testx)
goto bad;
if(c < T3) {
if(c < T2)
goto bad;
l = ((c << Bitx) | c1) & Rune2;
if(l <= Rune1)
goto bad;
*rune = l;
return 2;
}
// If we can't read more than two characters we must stop
if(length <= 2) {
goto badlen;
}
/*
* three character sequence (16-bit value)
* 0800-FFFF => T3 Tx Tx
*/
c2 = *(uchar*)(str+2) ^ Tx;
if(c2 & Testx)
goto bad;
if(c < T4) {
l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3;
if(l <= Rune2)
goto bad;
*rune = l;
return 3;
}
if (length <= 3)
goto badlen;
/*
* four character sequence (21-bit value)
* 10000-1FFFFF => T4 Tx Tx Tx
*/
c3 = *(uchar*)(str+3) ^ Tx;
if (c3 & Testx)
goto bad;
if (c < T5) {
l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4;
if (l <= Rune3)
goto bad;
if (l > Runemax)
goto bad;
*rune = l;
return 4;
}
// Support for 5-byte or longer UTF-8 would go here, but
// since we don't have that, we'll just fall through to bad.
/*
* bad decoding
*/
bad:
*rune = Bad;
return 1;
badlen:
*rune = Bad;
return 0;
}
int
utf_runetochar(char *str, const Rune *rune)
{
/* Runes are signed, so convert to unsigned for range check. */
unsigned long c;
/*
* one character sequence
* 00000-0007F => 00-7F
*/
c = *rune;
if(c <= Rune1) {
str[0] = c;
return 1;
}
/*
* two character sequence
* 0080-07FF => T2 Tx
*/
if(c <= Rune2) {
str[0] = T2 | (c >> 1*Bitx);
str[1] = Tx | (c & Maskx);
return 2;
}
/*
* If the Rune is out of range, convert it to the error rune.
* Do this test here because the error rune encodes to three bytes.
* Doing it earlier would duplicate work, since an out of range
* Rune wouldn't have fit in one or two bytes.
*/
if (c > Runemax)
c = Runeerror;
/*
* three character sequence
* 0800-FFFF => T3 Tx Tx
*/
if (c <= Rune3) {
str[0] = T3 | (c >> 2*Bitx);
str[1] = Tx | ((c >> 1*Bitx) & Maskx);
str[2] = Tx | (c & Maskx);
return 3;
}
/*
* four character sequence (21-bit value)
* 10000-1FFFFF => T4 Tx Tx Tx
*/
str[0] = T4 | (c >> 3*Bitx);
str[1] = Tx | ((c >> 2*Bitx) & Maskx);
str[2] = Tx | ((c >> 1*Bitx) & Maskx);
str[3] = Tx | (c & Maskx);
return 4;
}

Some files were not shown because too many files have changed in this diff Show More