Merge branch 'google:master' into audio-record-api-python
17
LICENSE
|
@ -199,3 +199,20 @@
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
|
|
||||||
|
===========================================================================
|
||||||
|
For files under tasks/cc/text/language_detector/custom_ops/utils/utf/
|
||||||
|
===========================================================================
|
||||||
|
/*
|
||||||
|
* The authors of this software are Rob Pike and Ken Thompson.
|
||||||
|
* Copyright (c) 2002 by Lucent Technologies.
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose without fee is hereby granted, provided that this entire notice
|
||||||
|
* is included in all copies of any software which is or includes a copy
|
||||||
|
* or modification of this software and in all copies of the supporting
|
||||||
|
* documentation for such software.
|
||||||
|
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
|
||||||
|
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
|
||||||
|
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
|
||||||
|
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
|
||||||
|
*/
|
||||||
|
|
|
@ -270,7 +270,7 @@ new_local_repository(
|
||||||
# For local MacOS builds, the path should point to an opencv@3 installation.
|
# For local MacOS builds, the path should point to an opencv@3 installation.
|
||||||
# If you edit the path here, you will also need to update the corresponding
|
# If you edit the path here, you will also need to update the corresponding
|
||||||
# prefix in "opencv_macos.BUILD".
|
# prefix in "opencv_macos.BUILD".
|
||||||
path = "/usr/local",
|
path = "/usr/local", # e.g. /usr/local/Cellar for HomeBrew
|
||||||
)
|
)
|
||||||
|
|
||||||
new_local_repository(
|
new_local_repository(
|
||||||
|
@ -499,8 +499,8 @@ cc_crosstool(name = "crosstool")
|
||||||
# Node dependencies
|
# Node dependencies
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "build_bazel_rules_nodejs",
|
name = "build_bazel_rules_nodejs",
|
||||||
sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda",
|
sha256 = "94070eff79305be05b7699207fbac5d2608054dd53e6109f7d00d923919ff45a",
|
||||||
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"],
|
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.8.2/rules_nodejs-5.8.2.tar.gz"],
|
||||||
)
|
)
|
||||||
|
|
||||||
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")
|
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")
|
||||||
|
|
|
@ -108,6 +108,8 @@ one over the other.
|
||||||
|
|
||||||
* [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite)
|
* [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite)
|
||||||
* [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite)
|
* [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite)
|
||||||
|
* [TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip)
|
||||||
|
* [Model information](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md)
|
||||||
|
|
||||||
### [Objectron](https://google.github.io/mediapipe/solutions/objectron)
|
### [Objectron](https://google.github.io/mediapipe/solutions/objectron)
|
||||||
|
|
||||||
|
|
|
@ -118,9 +118,9 @@ on how to build MediaPipe examples.
|
||||||
* With a TensorFlow Model
|
* With a TensorFlow Model
|
||||||
|
|
||||||
This uses the
|
This uses the
|
||||||
[TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model)
|
[TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip)
|
||||||
( see also
|
( see also
|
||||||
[model info](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md)),
|
[model info](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md)),
|
||||||
and the pipeline is implemented in this
|
and the pipeline is implemented in this
|
||||||
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt).
|
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt).
|
||||||
|
|
||||||
|
|
62
docs/solutions/object_detection_saved_model.md
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
## TensorFlow/TFLite Object Detection Model
|
||||||
|
|
||||||
|
### TensorFlow model
|
||||||
|
|
||||||
|
The model is trained on [MSCOCO 2014](http://cocodataset.org) dataset using [TensorFlow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection). It is a MobileNetV2-based SSD model with 0.5 depth multiplier. Detailed training configuration is in the provided `pipeline.config`. The model is a relatively compact model which has `0.171 mAP` to achieve real-time performance on mobile devices. You can compare it with other models from the [TensorFlow detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md).
|
||||||
|
|
||||||
|
|
||||||
|
### TFLite model
|
||||||
|
|
||||||
|
The TFLite model is converted from the TensorFlow above. The steps needed to convert the model are similar to [this tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) with minor modifications. Assuming now we have a trained TensorFlow model which includes the checkpoint files and the training configuration file, for example the files provided in this repo:
|
||||||
|
|
||||||
|
* `model.ckpt.index`
|
||||||
|
* `model.ckpt.meta`
|
||||||
|
* `model.ckpt.data-00000-of-00001`
|
||||||
|
* `pipeline.config`
|
||||||
|
|
||||||
|
Make sure you have installed these [python libraries](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1.md). Then to get the frozen graph, run the `export_tflite_ssd_graph.py` script from the `models/research` directory with this command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ PATH_TO_MODEL=path/to/the/model
|
||||||
|
$ bazel run object_detection:export_tflite_ssd_graph -- \
|
||||||
|
--pipeline_config_path ${PATH_TO_MODEL}/pipeline.config \
|
||||||
|
--trained_checkpoint_prefix ${PATH_TO_MODEL}/model.ckpt \
|
||||||
|
--output_directory ${PATH_TO_MODEL} \
|
||||||
|
--add_postprocessing_op=False
|
||||||
|
```
|
||||||
|
|
||||||
|
The exported model contains two files:
|
||||||
|
|
||||||
|
* `tflite_graph.pb`
|
||||||
|
* `tflite_graph.pbtxt`
|
||||||
|
|
||||||
|
The difference between this step and the one in [the tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) is that we set `add_postprocessing_op` to False. In MediaPipe, we have provided all the calculators needed for post-processing such that we can exclude the custom TFLite ops for post-processing in the original graph, e.g., non-maximum suppression. This enables the flexibility to integrate with different post-processing algorithms and implementations.
|
||||||
|
|
||||||
|
Optional: You can install and use the [graph tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms) to inspect the input/output of the exported model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ bazel run graph_transforms:summarize_graph -- \
|
||||||
|
--in_graph=${PATH_TO_MODEL}/tflite_graph.pb
|
||||||
|
```
|
||||||
|
|
||||||
|
You should be able to see the input image size of the model is 320x320 and the outputs of the model are:
|
||||||
|
|
||||||
|
* `raw_outputs/box_encodings`
|
||||||
|
* `raw_outputs/class_predictions`
|
||||||
|
|
||||||
|
The last step is to convert the model to TFLite. You can look at [this guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md) for more detail. For this example, you just need to run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ tflite_convert -- \
|
||||||
|
--graph_def_file=${PATH_TO_MODEL}/tflite_graph.pb \
|
||||||
|
--output_file=${PATH_TO_MODEL}/model.tflite \
|
||||||
|
--input_format=TENSORFLOW_GRAPHDEF \
|
||||||
|
--output_format=TFLITE \
|
||||||
|
--inference_type=FLOAT \
|
||||||
|
--input_shapes=1,320,320,3 \
|
||||||
|
--input_arrays=normalized_input_image_tensor \
|
||||||
|
--output_arrays=raw_outputs/box_encodings,raw_outputs/class_predictions
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you have the TFLite model `model.tflite` ready to use with MediaPipe Object Detection graphs. Please see the examples for more detail.
|
|
@ -269,6 +269,7 @@ Supported configuration options:
|
||||||
```python
|
```python
|
||||||
import cv2
|
import cv2
|
||||||
import mediapipe as mp
|
import mediapipe as mp
|
||||||
|
import numpy as np
|
||||||
mp_drawing = mp.solutions.drawing_utils
|
mp_drawing = mp.solutions.drawing_utils
|
||||||
mp_drawing_styles = mp.solutions.drawing_styles
|
mp_drawing_styles = mp.solutions.drawing_styles
|
||||||
mp_pose = mp.solutions.pose
|
mp_pose = mp.solutions.pose
|
||||||
|
|
|
@ -156,6 +156,7 @@ cc_library(
|
||||||
"//mediapipe/framework/port:opencv_core",
|
"//mediapipe/framework/port:opencv_core",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/port:vector",
|
"//mediapipe/framework/port:vector",
|
||||||
|
"//mediapipe/framework/port:opencv_imgproc",
|
||||||
] + select({
|
] + select({
|
||||||
"//mediapipe/gpu:disable_gpu": [],
|
"//mediapipe/gpu:disable_gpu": [],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
|
@ -168,6 +169,25 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "set_alpha_calculator_test",
|
||||||
|
srcs = ["set_alpha_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":set_alpha_calculator",
|
||||||
|
":set_alpha_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/formats:image_frame_opencv",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:opencv_core",
|
||||||
|
"//mediapipe/framework/port:opencv_imgproc",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "bilateral_filter_calculator",
|
name = "bilateral_filter_calculator",
|
||||||
srcs = ["bilateral_filter_calculator.cc"],
|
srcs = ["bilateral_filter_calculator.cc"],
|
||||||
|
@ -748,6 +768,7 @@ cc_test(
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
|
||||||
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png",
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
|
||||||
],
|
],
|
||||||
tags = ["desktop_only_test"],
|
tags = ["desktop_only_test"],
|
||||||
|
|
|
@ -29,6 +29,9 @@ class AffineTransformation {
|
||||||
// pixels will be calculated.
|
// pixels will be calculated.
|
||||||
enum class BorderMode { kZero, kReplicate };
|
enum class BorderMode { kZero, kReplicate };
|
||||||
|
|
||||||
|
// Pixel sampling interpolation method.
|
||||||
|
enum class Interpolation { kLinear, kCubic };
|
||||||
|
|
||||||
struct Size {
|
struct Size {
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
|
|
|
@ -77,8 +77,11 @@ class GlTextureWarpAffineRunner
|
||||||
std::unique_ptr<GpuBuffer>> {
|
std::unique_ptr<GpuBuffer>> {
|
||||||
public:
|
public:
|
||||||
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
|
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
|
||||||
GpuOrigin::Mode gpu_origin)
|
GpuOrigin::Mode gpu_origin,
|
||||||
: gl_helper_(gl_helper), gpu_origin_(gpu_origin) {}
|
AffineTransformation::Interpolation interpolation)
|
||||||
|
: gl_helper_(gl_helper),
|
||||||
|
gpu_origin_(gpu_origin),
|
||||||
|
interpolation_(interpolation) {}
|
||||||
absl::Status Init() {
|
absl::Status Init() {
|
||||||
return gl_helper_->RunInGlContext([this]() -> absl::Status {
|
return gl_helper_->RunInGlContext([this]() -> absl::Status {
|
||||||
const GLint attr_location[kNumAttributes] = {
|
const GLint attr_location[kNumAttributes] = {
|
||||||
|
@ -103,10 +106,13 @@ class GlTextureWarpAffineRunner
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
// TODO Move bicubic code to common shared place.
|
||||||
constexpr GLchar kFragShader[] = R"(
|
constexpr GLchar kFragShader[] = R"(
|
||||||
DEFAULT_PRECISION(highp, float)
|
DEFAULT_PRECISION(highp, float)
|
||||||
|
|
||||||
in vec2 sample_coordinate;
|
in vec2 sample_coordinate;
|
||||||
uniform sampler2D input_texture;
|
uniform sampler2D input_texture;
|
||||||
|
uniform vec2 input_size;
|
||||||
|
|
||||||
#ifdef GL_ES
|
#ifdef GL_ES
|
||||||
#define fragColor gl_FragColor
|
#define fragColor gl_FragColor
|
||||||
|
@ -114,8 +120,60 @@ class GlTextureWarpAffineRunner
|
||||||
out vec4 fragColor;
|
out vec4 fragColor;
|
||||||
#endif // defined(GL_ES);
|
#endif // defined(GL_ES);
|
||||||
|
|
||||||
|
#ifdef CUBIC_INTERPOLATION
|
||||||
|
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
|
||||||
|
const vec2 halve = vec2(0.5,0.5);
|
||||||
|
const vec2 one = vec2(1.0,1.0);
|
||||||
|
const vec2 two = vec2(2.0,2.0);
|
||||||
|
const vec2 three = vec2(3.0,3.0);
|
||||||
|
const vec2 six = vec2(6.0,6.0);
|
||||||
|
|
||||||
|
// Calculate the fraction and integer.
|
||||||
|
tex_coord = tex_coord * tex_size - halve;
|
||||||
|
vec2 frac = fract(tex_coord);
|
||||||
|
vec2 index = tex_coord - frac + halve;
|
||||||
|
|
||||||
|
// Calculate weights for Catmull-Rom filter.
|
||||||
|
vec2 w0 = frac * (-halve + frac * (one - halve * frac));
|
||||||
|
vec2 w1 = one + frac * frac * (-(two+halve) + three/two * frac);
|
||||||
|
vec2 w2 = frac * (halve + frac * (two - three/two * frac));
|
||||||
|
vec2 w3 = frac * frac * (-halve + halve * frac);
|
||||||
|
|
||||||
|
// Calculate weights to take advantage of bilinear texture lookup.
|
||||||
|
vec2 w12 = w1 + w2;
|
||||||
|
vec2 offset12 = w2 / (w1 + w2);
|
||||||
|
|
||||||
|
vec2 index_tl = index - one;
|
||||||
|
vec2 index_br = index + two;
|
||||||
|
vec2 index_eq = index + offset12;
|
||||||
|
|
||||||
|
index_tl /= tex_size;
|
||||||
|
index_br /= tex_size;
|
||||||
|
index_eq /= tex_size;
|
||||||
|
|
||||||
|
// 9 texture lookup and linear blending.
|
||||||
|
vec4 color = vec4(0.0);
|
||||||
|
color += texture2D(tex, vec2(index_tl.x, index_tl.y)) * w0.x * w0.y;
|
||||||
|
color += texture2D(tex, vec2(index_eq.x, index_tl.y)) * w12.x *w0.y;
|
||||||
|
color += texture2D(tex, vec2(index_br.x, index_tl.y)) * w3.x * w0.y;
|
||||||
|
|
||||||
|
color += texture2D(tex, vec2(index_tl.x, index_eq.y)) * w0.x * w12.y;
|
||||||
|
color += texture2D(tex, vec2(index_eq.x, index_eq.y)) * w12.x *w12.y;
|
||||||
|
color += texture2D(tex, vec2(index_br.x, index_eq.y)) * w3.x * w12.y;
|
||||||
|
|
||||||
|
color += texture2D(tex, vec2(index_tl.x, index_br.y)) * w0.x * w3.y;
|
||||||
|
color += texture2D(tex, vec2(index_eq.x, index_br.y)) * w12.x *w3.y;
|
||||||
|
color += texture2D(tex, vec2(index_br.x, index_br.y)) * w3.x * w3.y;
|
||||||
|
return color;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
|
||||||
|
return texture2D(tex, tex_coord);
|
||||||
|
}
|
||||||
|
#endif // defined(CUBIC_INTERPOLATION)
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
vec4 color = texture2D(input_texture, sample_coordinate);
|
vec4 color = sample(input_texture, sample_coordinate, input_size);
|
||||||
#ifdef CUSTOM_ZERO_BORDER_MODE
|
#ifdef CUSTOM_ZERO_BORDER_MODE
|
||||||
float out_of_bounds =
|
float out_of_bounds =
|
||||||
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
|
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
|
||||||
|
@ -137,14 +195,28 @@ class GlTextureWarpAffineRunner
|
||||||
glUseProgram(program);
|
glUseProgram(program);
|
||||||
glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
|
glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
|
||||||
GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
|
GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
|
||||||
return Program{.id = program, .matrix_id = matrix_id};
|
GLint size_id = glGetUniformLocation(program, "input_size");
|
||||||
|
return Program{
|
||||||
|
.id = program, .matrix_id = matrix_id, .size_id = size_id};
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::string vert_src =
|
const std::string vert_src =
|
||||||
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
|
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
|
||||||
|
|
||||||
const std::string frag_src = absl::StrCat(
|
std::string interpolation_def;
|
||||||
mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader);
|
switch (interpolation_) {
|
||||||
|
case AffineTransformation::Interpolation::kCubic:
|
||||||
|
interpolation_def = R"(
|
||||||
|
#define CUBIC_INTERPOLATION
|
||||||
|
)";
|
||||||
|
break;
|
||||||
|
case AffineTransformation::Interpolation::kLinear:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string frag_src =
|
||||||
|
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
|
||||||
|
interpolation_def, kFragShader);
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
|
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
|
||||||
|
|
||||||
|
@ -152,9 +224,9 @@ class GlTextureWarpAffineRunner
|
||||||
std::string custom_zero_border_mode_def = R"(
|
std::string custom_zero_border_mode_def = R"(
|
||||||
#define CUSTOM_ZERO_BORDER_MODE
|
#define CUSTOM_ZERO_BORDER_MODE
|
||||||
)";
|
)";
|
||||||
const std::string frag_custom_zero_src =
|
const std::string frag_custom_zero_src = absl::StrCat(
|
||||||
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
|
mediapipe::kMediaPipeFragmentShaderPreamble,
|
||||||
custom_zero_border_mode_def, kFragShader);
|
custom_zero_border_mode_def, interpolation_def, kFragShader);
|
||||||
return create_fn(vert_src, frag_custom_zero_src);
|
return create_fn(vert_src, frag_custom_zero_src);
|
||||||
};
|
};
|
||||||
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
||||||
|
@ -256,6 +328,7 @@ class GlTextureWarpAffineRunner
|
||||||
}
|
}
|
||||||
glUseProgram(program->id);
|
glUseProgram(program->id);
|
||||||
|
|
||||||
|
// uniforms
|
||||||
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
|
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
|
||||||
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
|
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
|
||||||
// @matrix describes affine transformation in terms of TOP LEFT origin, so
|
// @matrix describes affine transformation in terms of TOP LEFT origin, so
|
||||||
|
@ -275,6 +348,10 @@ class GlTextureWarpAffineRunner
|
||||||
eigen_mat.transposeInPlace();
|
eigen_mat.transposeInPlace();
|
||||||
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
|
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
|
||||||
|
|
||||||
|
if (interpolation_ == AffineTransformation::Interpolation::kCubic) {
|
||||||
|
glUniform2f(program->size_id, texture.width(), texture.height());
|
||||||
|
}
|
||||||
|
|
||||||
// vao
|
// vao
|
||||||
glBindVertexArray(vao_);
|
glBindVertexArray(vao_);
|
||||||
|
|
||||||
|
@ -327,6 +404,7 @@ class GlTextureWarpAffineRunner
|
||||||
struct Program {
|
struct Program {
|
||||||
GLuint id;
|
GLuint id;
|
||||||
GLint matrix_id;
|
GLint matrix_id;
|
||||||
|
GLint size_id;
|
||||||
};
|
};
|
||||||
std::shared_ptr<GlCalculatorHelper> gl_helper_;
|
std::shared_ptr<GlCalculatorHelper> gl_helper_;
|
||||||
GpuOrigin::Mode gpu_origin_;
|
GpuOrigin::Mode gpu_origin_;
|
||||||
|
@ -335,6 +413,8 @@ class GlTextureWarpAffineRunner
|
||||||
Program program_;
|
Program program_;
|
||||||
std::optional<Program> program_custom_zero_;
|
std::optional<Program> program_custom_zero_;
|
||||||
GLuint framebuffer_ = 0;
|
GLuint framebuffer_ = 0;
|
||||||
|
AffineTransformation::Interpolation interpolation_ =
|
||||||
|
AffineTransformation::Interpolation::kLinear;
|
||||||
};
|
};
|
||||||
|
|
||||||
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
||||||
|
@ -344,9 +424,10 @@ class GlTextureWarpAffineRunner
|
||||||
absl::StatusOr<std::unique_ptr<
|
absl::StatusOr<std::unique_ptr<
|
||||||
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
|
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
|
||||||
CreateAffineTransformationGlRunner(
|
CreateAffineTransformationGlRunner(
|
||||||
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin) {
|
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin,
|
||||||
auto runner =
|
AffineTransformation::Interpolation interpolation) {
|
||||||
absl::make_unique<GlTextureWarpAffineRunner>(gl_helper, gpu_origin);
|
auto runner = absl::make_unique<GlTextureWarpAffineRunner>(
|
||||||
|
gl_helper, gpu_origin, interpolation);
|
||||||
MP_RETURN_IF_ERROR(runner->Init());
|
MP_RETURN_IF_ERROR(runner->Init());
|
||||||
return runner;
|
return runner;
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,8 @@ absl::StatusOr<std::unique_ptr<AffineTransformation::Runner<
|
||||||
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
|
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
|
||||||
CreateAffineTransformationGlRunner(
|
CreateAffineTransformationGlRunner(
|
||||||
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
|
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
|
||||||
mediapipe::GpuOrigin::Mode gpu_origin);
|
mediapipe::GpuOrigin::Mode gpu_origin,
|
||||||
|
AffineTransformation::Interpolation interpolation);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -39,9 +39,22 @@ cv::BorderTypes GetBorderModeForOpenCv(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int GetInterpolationForOpenCv(
|
||||||
|
AffineTransformation::Interpolation interpolation) {
|
||||||
|
switch (interpolation) {
|
||||||
|
case AffineTransformation::Interpolation::kLinear:
|
||||||
|
return cv::INTER_LINEAR;
|
||||||
|
case AffineTransformation::Interpolation::kCubic:
|
||||||
|
return cv::INTER_CUBIC;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class OpenCvRunner
|
class OpenCvRunner
|
||||||
: public AffineTransformation::Runner<ImageFrame, ImageFrame> {
|
: public AffineTransformation::Runner<ImageFrame, ImageFrame> {
|
||||||
public:
|
public:
|
||||||
|
OpenCvRunner(AffineTransformation::Interpolation interpolation)
|
||||||
|
: interpolation_(GetInterpolationForOpenCv(interpolation)) {}
|
||||||
|
|
||||||
absl::StatusOr<ImageFrame> Run(
|
absl::StatusOr<ImageFrame> Run(
|
||||||
const ImageFrame& input, const std::array<float, 16>& matrix,
|
const ImageFrame& input, const std::array<float, 16>& matrix,
|
||||||
const AffineTransformation::Size& size,
|
const AffineTransformation::Size& size,
|
||||||
|
@ -142,19 +155,23 @@ class OpenCvRunner
|
||||||
|
|
||||||
cv::warpAffine(in_mat, out_mat, cv_affine_transform,
|
cv::warpAffine(in_mat, out_mat, cv_affine_transform,
|
||||||
cv::Size(out_mat.cols, out_mat.rows),
|
cv::Size(out_mat.cols, out_mat.rows),
|
||||||
/*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP,
|
/*flags=*/interpolation_ | cv::WARP_INVERSE_MAP,
|
||||||
GetBorderModeForOpenCv(border_mode));
|
GetBorderModeForOpenCv(border_mode));
|
||||||
|
|
||||||
return out_image;
|
return out_image;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int interpolation_ = cv::INTER_LINEAR;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::StatusOr<
|
absl::StatusOr<
|
||||||
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
||||||
CreateAffineTransformationOpenCvRunner() {
|
CreateAffineTransformationOpenCvRunner(
|
||||||
return absl::make_unique<OpenCvRunner>();
|
AffineTransformation::Interpolation interpolation) {
|
||||||
|
return absl::make_unique<OpenCvRunner>(interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -25,7 +25,8 @@ namespace mediapipe {
|
||||||
|
|
||||||
absl::StatusOr<
|
absl::StatusOr<
|
||||||
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
||||||
CreateAffineTransformationOpenCvRunner();
|
CreateAffineTransformationOpenCvRunner(
|
||||||
|
AffineTransformation::Interpolation interpolation);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||||
#include "mediapipe/framework/port/logging.h"
|
#include "mediapipe/framework/port/logging.h"
|
||||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/vector.h"
|
#include "mediapipe/framework/port/vector.h"
|
||||||
|
|
||||||
|
@ -53,24 +54,16 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
|
||||||
// range of [0, 1). Only the first channel of Alpha is used. Input & output Mat
|
// range of [0, 1). Only the first channel of Alpha is used. Input & output Mat
|
||||||
// must be uchar.
|
// must be uchar.
|
||||||
template <typename AlphaType>
|
template <typename AlphaType>
|
||||||
absl::Status MergeRGBA8Image(const cv::Mat input_mat, const cv::Mat& alpha_mat,
|
absl::Status CopyAlphaImage(const cv::Mat& alpha_mat, cv::Mat& output_mat) {
|
||||||
cv::Mat& output_mat) {
|
RET_CHECK_EQ(output_mat.rows, alpha_mat.rows);
|
||||||
RET_CHECK_EQ(input_mat.rows, alpha_mat.rows);
|
RET_CHECK_EQ(output_mat.cols, alpha_mat.cols);
|
||||||
RET_CHECK_EQ(input_mat.cols, alpha_mat.cols);
|
|
||||||
RET_CHECK_EQ(input_mat.rows, output_mat.rows);
|
|
||||||
RET_CHECK_EQ(input_mat.cols, output_mat.cols);
|
|
||||||
|
|
||||||
for (int i = 0; i < output_mat.rows; ++i) {
|
for (int i = 0; i < output_mat.rows; ++i) {
|
||||||
const uchar* in_ptr = input_mat.ptr<uchar>(i);
|
|
||||||
const AlphaType* alpha_ptr = alpha_mat.ptr<AlphaType>(i);
|
const AlphaType* alpha_ptr = alpha_mat.ptr<AlphaType>(i);
|
||||||
uchar* out_ptr = output_mat.ptr<uchar>(i);
|
uchar* out_ptr = output_mat.ptr<uchar>(i);
|
||||||
for (int j = 0; j < output_mat.cols; ++j) {
|
for (int j = 0; j < output_mat.cols; ++j) {
|
||||||
const int out_idx = j * kNumChannelsRGBA;
|
const int out_idx = j * kNumChannelsRGBA;
|
||||||
const int in_idx = j * input_mat.channels();
|
|
||||||
const int alpha_idx = j * alpha_mat.channels();
|
const int alpha_idx = j * alpha_mat.channels();
|
||||||
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
|
|
||||||
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
|
|
||||||
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
|
|
||||||
if constexpr (std::is_same<AlphaType, uchar>::value) {
|
if constexpr (std::is_same<AlphaType, uchar>::value) {
|
||||||
out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask
|
out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask
|
||||||
} else {
|
} else {
|
||||||
|
@ -273,7 +266,7 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
|
||||||
|
|
||||||
// Setup source image
|
// Setup source image
|
||||||
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
|
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
|
||||||
const cv::Mat input_mat = mediapipe::formats::MatView(&input_frame);
|
const cv::Mat input_mat = formats::MatView(&input_frame);
|
||||||
if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) {
|
if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) {
|
||||||
LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported";
|
LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported";
|
||||||
}
|
}
|
||||||
|
@ -281,38 +274,38 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
|
||||||
// Setup destination image
|
// Setup destination image
|
||||||
auto output_frame = absl::make_unique<ImageFrame>(
|
auto output_frame = absl::make_unique<ImageFrame>(
|
||||||
ImageFormat::SRGBA, input_mat.cols, input_mat.rows);
|
ImageFormat::SRGBA, input_mat.cols, input_mat.rows);
|
||||||
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
|
cv::Mat output_mat = formats::MatView(output_frame.get());
|
||||||
|
|
||||||
const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) &&
|
const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) &&
|
||||||
!cc->Inputs().Tag(kInputAlphaTag).IsEmpty();
|
!cc->Inputs().Tag(kInputAlphaTag).IsEmpty();
|
||||||
const bool use_alpha_mask = alpha_value_ < 0 && has_alpha_mask;
|
const bool use_alpha_mask = alpha_value_ < 0 && has_alpha_mask;
|
||||||
|
|
||||||
// Setup alpha image and Update image in CPU.
|
// Copy rgb part of the image in CPU
|
||||||
|
if (input_mat.channels() == 3) {
|
||||||
|
cv::cvtColor(input_mat, output_mat, cv::COLOR_RGB2RGBA);
|
||||||
|
} else {
|
||||||
|
input_mat.copyTo(output_mat);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup alpha image in CPU.
|
||||||
if (use_alpha_mask) {
|
if (use_alpha_mask) {
|
||||||
const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get<ImageFrame>();
|
const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get<ImageFrame>();
|
||||||
cv::Mat alpha_mat = mediapipe::formats::MatView(&alpha_mask);
|
cv::Mat alpha_mat = formats::MatView(&alpha_mask);
|
||||||
|
|
||||||
const bool alpha_is_float = CV_MAT_DEPTH(alpha_mat.type()) == CV_32F;
|
const bool alpha_is_float = CV_MAT_DEPTH(alpha_mat.type()) == CV_32F;
|
||||||
RET_CHECK(alpha_is_float || CV_MAT_DEPTH(alpha_mat.type()) == CV_8U);
|
RET_CHECK(alpha_is_float || CV_MAT_DEPTH(alpha_mat.type()) == CV_8U);
|
||||||
|
|
||||||
if (alpha_is_float) {
|
if (alpha_is_float) {
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(CopyAlphaImage<float>(alpha_mat, output_mat));
|
||||||
MergeRGBA8Image<float>(input_mat, alpha_mat, output_mat));
|
|
||||||
} else {
|
} else {
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(CopyAlphaImage<uchar>(alpha_mat, output_mat));
|
||||||
MergeRGBA8Image<uchar>(input_mat, alpha_mat, output_mat));
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f);
|
const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f);
|
||||||
for (int i = 0; i < output_mat.rows; ++i) {
|
for (int i = 0; i < output_mat.rows; ++i) {
|
||||||
const uchar* in_ptr = input_mat.ptr<uchar>(i);
|
|
||||||
uchar* out_ptr = output_mat.ptr<uchar>(i);
|
uchar* out_ptr = output_mat.ptr<uchar>(i);
|
||||||
for (int j = 0; j < output_mat.cols; ++j) {
|
for (int j = 0; j < output_mat.cols; ++j) {
|
||||||
const int out_idx = j * kNumChannelsRGBA;
|
const int out_idx = j * kNumChannelsRGBA;
|
||||||
const int in_idx = j * input_mat.channels();
|
|
||||||
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
|
|
||||||
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
|
|
||||||
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
|
|
||||||
out_ptr[out_idx + 3] = alpha_value; // use value from options
|
out_ptr[out_idx + 3] = alpha_value; // use value from options
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
156
mediapipe/calculators/image/set_alpha_calculator_test.cc
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/image/set_alpha_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "testing/base/public/benchmark.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int input_width = 100;
|
||||||
|
constexpr int input_height = 100;
|
||||||
|
|
||||||
|
std::unique_ptr<ImageFrame> GetInputFrame(int width, int height, int channel) {
|
||||||
|
const int total_size = width * height * channel;
|
||||||
|
|
||||||
|
ImageFormat::Format image_format;
|
||||||
|
if (channel == 4) {
|
||||||
|
image_format = ImageFormat::SRGBA;
|
||||||
|
} else if (channel == 3) {
|
||||||
|
image_format = ImageFormat::SRGB;
|
||||||
|
} else {
|
||||||
|
image_format = ImageFormat::GRAY8;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto input_frame = std::make_unique<ImageFrame>(image_format, width, height,
|
||||||
|
/*alignment_boundary =*/1);
|
||||||
|
for (int i = 0; i < total_size; ++i) {
|
||||||
|
input_frame->MutablePixelData()[i] = i % 256;
|
||||||
|
}
|
||||||
|
return input_frame;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test SetAlphaCalculator with RGB IMAGE input.
|
||||||
|
TEST(SetAlphaCalculatorTest, CpuRgb) {
|
||||||
|
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
|
||||||
|
R"pb(
|
||||||
|
calculator: "SetAlphaCalculator"
|
||||||
|
input_stream: "IMAGE:input_frames"
|
||||||
|
input_stream: "ALPHA:masks"
|
||||||
|
output_stream: "IMAGE:output_frames"
|
||||||
|
)pb");
|
||||||
|
CalculatorRunner runner(calculator_node);
|
||||||
|
|
||||||
|
// Input frames.
|
||||||
|
const auto input_frame = GetInputFrame(input_width, input_height, 3);
|
||||||
|
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
|
||||||
|
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
|
||||||
|
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
|
||||||
|
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
|
||||||
|
input_frame_packet.At(Timestamp(1)));
|
||||||
|
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
|
||||||
|
mask_frame_packet.At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& outputs = runner.Outputs();
|
||||||
|
EXPECT_EQ(outputs.NumEntries(), 1);
|
||||||
|
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
|
||||||
|
|
||||||
|
// Generate ground truth (expected_mat).
|
||||||
|
const auto image = GetInputFrame(input_width, input_height, 3);
|
||||||
|
const auto input_mat = formats::MatView(image.get());
|
||||||
|
const auto mask = GetInputFrame(input_width, input_height, 1);
|
||||||
|
const auto mask_mat = formats::MatView(mask.get());
|
||||||
|
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
|
||||||
|
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
|
||||||
|
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 3, 3});
|
||||||
|
|
||||||
|
cv::Mat output_mat = formats::MatView(&output_image);
|
||||||
|
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
|
||||||
|
EXPECT_FLOAT_EQ(max_diff, 0);
|
||||||
|
} // TEST
|
||||||
|
|
||||||
|
// Test SetAlphaCalculator with RGBA IMAGE input.
|
||||||
|
TEST(SetAlphaCalculatorTest, CpuRgba) {
|
||||||
|
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
|
||||||
|
R"pb(
|
||||||
|
calculator: "SetAlphaCalculator"
|
||||||
|
input_stream: "IMAGE:input_frames"
|
||||||
|
input_stream: "ALPHA:masks"
|
||||||
|
output_stream: "IMAGE:output_frames"
|
||||||
|
)pb");
|
||||||
|
CalculatorRunner runner(calculator_node);
|
||||||
|
|
||||||
|
// Input frames.
|
||||||
|
const auto input_frame = GetInputFrame(input_width, input_height, 4);
|
||||||
|
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
|
||||||
|
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
|
||||||
|
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
|
||||||
|
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
|
||||||
|
input_frame_packet.At(Timestamp(1)));
|
||||||
|
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
|
||||||
|
mask_frame_packet.At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& outputs = runner.Outputs();
|
||||||
|
EXPECT_EQ(outputs.NumEntries(), 1);
|
||||||
|
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
|
||||||
|
|
||||||
|
// Generate ground truth (expected_mat).
|
||||||
|
const auto image = GetInputFrame(input_width, input_height, 4);
|
||||||
|
const auto input_mat = formats::MatView(image.get());
|
||||||
|
const auto mask = GetInputFrame(input_width, input_height, 1);
|
||||||
|
const auto mask_mat = formats::MatView(mask.get());
|
||||||
|
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
|
||||||
|
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
|
||||||
|
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 4, 3});
|
||||||
|
|
||||||
|
cv::Mat output_mat = formats::MatView(&output_image);
|
||||||
|
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
|
||||||
|
EXPECT_FLOAT_EQ(max_diff, 0);
|
||||||
|
} // TEST
|
||||||
|
|
||||||
|
static void BM_SetAlpha3ChannelImage(benchmark::State& state) {
|
||||||
|
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
|
||||||
|
R"pb(
|
||||||
|
calculator: "SetAlphaCalculator"
|
||||||
|
input_stream: "IMAGE:input_frames"
|
||||||
|
input_stream: "ALPHA:masks"
|
||||||
|
output_stream: "IMAGE:output_frames"
|
||||||
|
)pb");
|
||||||
|
CalculatorRunner runner(calculator_node);
|
||||||
|
|
||||||
|
// Input frames.
|
||||||
|
const auto input_frame = GetInputFrame(input_width, input_height, 3);
|
||||||
|
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
|
||||||
|
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
|
||||||
|
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
|
||||||
|
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
|
||||||
|
input_frame_packet.At(Timestamp(1)));
|
||||||
|
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
|
||||||
|
mask_frame_packet.At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
const auto& outputs = runner.Outputs();
|
||||||
|
ASSERT_EQ(1, outputs.NumEntries());
|
||||||
|
|
||||||
|
for (const auto _ : state) {
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BENCHMARK(BM_SetAlpha3ChannelImage);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -53,6 +53,17 @@ AffineTransformation::BorderMode GetBorderMode(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AffineTransformation::Interpolation GetInterpolation(
|
||||||
|
mediapipe::WarpAffineCalculatorOptions::Interpolation interpolation) {
|
||||||
|
switch (interpolation) {
|
||||||
|
case mediapipe::WarpAffineCalculatorOptions::INTER_UNSPECIFIED:
|
||||||
|
case mediapipe::WarpAffineCalculatorOptions::INTER_LINEAR:
|
||||||
|
return AffineTransformation::Interpolation::kLinear;
|
||||||
|
case mediapipe::WarpAffineCalculatorOptions::INTER_CUBIC:
|
||||||
|
return AffineTransformation::Interpolation::kCubic;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename ImageT>
|
template <typename ImageT>
|
||||||
class WarpAffineRunnerHolder {};
|
class WarpAffineRunnerHolder {};
|
||||||
|
|
||||||
|
@ -61,16 +72,22 @@ template <>
|
||||||
class WarpAffineRunnerHolder<ImageFrame> {
|
class WarpAffineRunnerHolder<ImageFrame> {
|
||||||
public:
|
public:
|
||||||
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
|
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
|
||||||
absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); }
|
absl::Status Open(CalculatorContext* cc) {
|
||||||
|
interpolation_ = GetInterpolation(
|
||||||
|
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
absl::StatusOr<RunnerType*> GetRunner() {
|
absl::StatusOr<RunnerType*> GetRunner() {
|
||||||
if (!runner_) {
|
if (!runner_) {
|
||||||
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner());
|
ASSIGN_OR_RETURN(runner_,
|
||||||
|
CreateAffineTransformationOpenCvRunner(interpolation_));
|
||||||
}
|
}
|
||||||
return runner_.get();
|
return runner_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<RunnerType> runner_;
|
std::unique_ptr<RunnerType> runner_;
|
||||||
|
AffineTransformation::Interpolation interpolation_;
|
||||||
};
|
};
|
||||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||||
|
|
||||||
|
@ -85,12 +102,14 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
|
||||||
gpu_origin_ =
|
gpu_origin_ =
|
||||||
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
|
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
|
||||||
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
|
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
|
||||||
|
interpolation_ = GetInterpolation(
|
||||||
|
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
|
||||||
return gl_helper_->Open(cc);
|
return gl_helper_->Open(cc);
|
||||||
}
|
}
|
||||||
absl::StatusOr<RunnerType*> GetRunner() {
|
absl::StatusOr<RunnerType*> GetRunner() {
|
||||||
if (!runner_) {
|
if (!runner_) {
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationGlRunner(
|
||||||
runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_));
|
gl_helper_, gpu_origin_, interpolation_));
|
||||||
}
|
}
|
||||||
return runner_.get();
|
return runner_.get();
|
||||||
}
|
}
|
||||||
|
@ -99,6 +118,7 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
|
||||||
mediapipe::GpuOrigin::Mode gpu_origin_;
|
mediapipe::GpuOrigin::Mode gpu_origin_;
|
||||||
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
|
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
|
||||||
std::unique_ptr<RunnerType> runner_;
|
std::unique_ptr<RunnerType> runner_;
|
||||||
|
AffineTransformation::Interpolation interpolation_;
|
||||||
};
|
};
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,13 @@ message WarpAffineCalculatorOptions {
|
||||||
BORDER_REPLICATE = 2;
|
BORDER_REPLICATE = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pixel sampling interpolation methods. See @interpolation.
|
||||||
|
enum Interpolation {
|
||||||
|
INTER_UNSPECIFIED = 0;
|
||||||
|
INTER_LINEAR = 1;
|
||||||
|
INTER_CUBIC = 2;
|
||||||
|
}
|
||||||
|
|
||||||
// Pixel extrapolation method.
|
// Pixel extrapolation method.
|
||||||
// When converting image to tensor it may happen that tensor needs to read
|
// When converting image to tensor it may happen that tensor needs to read
|
||||||
// pixels outside image boundaries. Border mode helps to specify how such
|
// pixels outside image boundaries. Border mode helps to specify how such
|
||||||
|
@ -43,4 +50,10 @@ message WarpAffineCalculatorOptions {
|
||||||
// to be flipped vertically as tensors are expected to start at top.
|
// to be flipped vertically as tensors are expected to start at top.
|
||||||
// (DEFAULT or unset interpreted as CONVENTIONAL.)
|
// (DEFAULT or unset interpreted as CONVENTIONAL.)
|
||||||
optional GpuOrigin.Mode gpu_origin = 2;
|
optional GpuOrigin.Mode gpu_origin = 2;
|
||||||
|
|
||||||
|
// Sampling method for neighboring pixels.
|
||||||
|
// INTER_LINEAR (bilinear) linearly interpolates from the nearest 4 neighbors.
|
||||||
|
// INTER_CUBIC (bicubic) interpolates a small neighborhood with cubic weights.
|
||||||
|
// INTER_UNSPECIFIED or unset interpreted as INTER_LINEAR.
|
||||||
|
optional Interpolation interpolation = 3;
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,8 @@ void RunTest(const std::string& graph_text, const std::string& tag,
|
||||||
const cv::Mat& input, cv::Mat expected_result,
|
const cv::Mat& input, cv::Mat expected_result,
|
||||||
float similarity_threshold, std::array<float, 16> matrix,
|
float similarity_threshold, std::array<float, 16> matrix,
|
||||||
int out_width, int out_height,
|
int out_width, int out_height,
|
||||||
absl::optional<AffineTransformation::BorderMode> border_mode) {
|
std::optional<AffineTransformation::BorderMode> border_mode,
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation) {
|
||||||
std::string border_mode_str;
|
std::string border_mode_str;
|
||||||
if (border_mode) {
|
if (border_mode) {
|
||||||
switch (*border_mode) {
|
switch (*border_mode) {
|
||||||
|
@ -75,8 +76,20 @@ void RunTest(const std::string& graph_text, const std::string& tag,
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
std::string interpolation_str;
|
||||||
|
if (interpolation) {
|
||||||
|
switch (*interpolation) {
|
||||||
|
case AffineTransformation::Interpolation::kLinear:
|
||||||
|
interpolation_str = "interpolation: INTER_LINEAR";
|
||||||
|
break;
|
||||||
|
case AffineTransformation::Interpolation::kCubic:
|
||||||
|
interpolation_str = "interpolation: INTER_CUBIC";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
absl::Substitute(graph_text, /*$0=*/border_mode_str));
|
absl::Substitute(graph_text, /*$0=*/border_mode_str,
|
||||||
|
/*$1=*/interpolation_str));
|
||||||
|
|
||||||
std::vector<Packet> output_packets;
|
std::vector<Packet> output_packets;
|
||||||
tool::AddVectorSink("output_image", &graph_config, &output_packets);
|
tool::AddVectorSink("output_image", &graph_config, &output_packets);
|
||||||
|
@ -132,7 +145,8 @@ struct SimilarityConfig {
|
||||||
void RunTest(cv::Mat input, cv::Mat expected_result,
|
void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
const SimilarityConfig& similarity, std::array<float, 16> matrix,
|
const SimilarityConfig& similarity, std::array<float, 16> matrix,
|
||||||
int out_width, int out_height,
|
int out_width, int out_height,
|
||||||
absl::optional<AffineTransformation::BorderMode> border_mode) {
|
std::optional<AffineTransformation::BorderMode> border_mode,
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation) {
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
input_stream: "output_size"
|
input_stream: "output_size"
|
||||||
|
@ -146,12 +160,13 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
|
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
|
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
|
@ -171,6 +186,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -181,7 +197,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"cpu_image", input, expected_result, similarity.threshold_on_cpu,
|
"cpu_image", input, expected_result, similarity.threshold_on_cpu,
|
||||||
matrix, out_width, out_height, border_mode);
|
matrix, out_width, out_height, border_mode, interpolation);
|
||||||
|
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
|
@ -201,6 +217,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
gpu_origin: TOP_LEFT
|
gpu_origin: TOP_LEFT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -212,7 +229,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
|
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
|
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
|
@ -237,6 +254,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
gpu_origin: TOP_LEFT
|
gpu_origin: TOP_LEFT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -253,7 +271,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"gpu_image", input, expected_result, similarity.threshold_on_gpu,
|
"gpu_image", input, expected_result, similarity.threshold_on_gpu,
|
||||||
matrix, out_width, out_height, border_mode);
|
matrix, out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
|
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
|
||||||
|
@ -287,10 +305,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) {
|
||||||
int out_height = 256;
|
int out_height = 256;
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
||||||
|
@ -312,10 +331,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
||||||
|
@ -337,10 +357,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
|
||||||
|
@ -362,10 +383,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
|
||||||
|
@ -386,10 +408,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
||||||
|
@ -411,10 +434,38 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZeroInterpCubic) {
|
||||||
|
mediapipe::NormalizedRect roi;
|
||||||
|
roi.set_x_center(0.65f);
|
||||||
|
roi.set_y_center(0.4f);
|
||||||
|
roi.set_width(0.5f);
|
||||||
|
roi.set_height(0.5f);
|
||||||
|
roi.set_rotation(M_PI * -45.0f / 180.0f);
|
||||||
|
auto input = GetRgb(
|
||||||
|
"/mediapipe/calculators/"
|
||||||
|
"tensor/testdata/image_to_tensor/input.jpg");
|
||||||
|
auto expected_output = GetRgb(
|
||||||
|
"/mediapipe/calculators/"
|
||||||
|
"tensor/testdata/image_to_tensor/"
|
||||||
|
"medium_sub_rect_with_rotation_border_zero_interp_cubic.png");
|
||||||
|
int out_width = 256;
|
||||||
|
int out_height = 256;
|
||||||
|
bool keep_aspect_ratio = false;
|
||||||
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation =
|
||||||
|
AffineTransformation::Interpolation::kCubic;
|
||||||
|
RunTest(input, expected_output,
|
||||||
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.78},
|
||||||
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRect) {
|
TEST(WarpAffineCalculatorTest, LargeSubRect) {
|
||||||
|
@ -435,10 +486,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRect) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
|
||||||
|
@ -459,10 +511,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
|
||||||
|
@ -483,10 +536,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
||||||
|
@ -508,10 +562,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
||||||
|
@ -532,10 +587,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
||||||
int out_height = 128;
|
int out_height = 128;
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
|
||||||
|
@ -557,10 +613,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, NoOp) {
|
TEST(WarpAffineCalculatorTest, NoOp) {
|
||||||
|
@ -581,10 +638,11 @@ TEST(WarpAffineCalculatorTest, NoOp) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
|
TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
|
||||||
|
@ -605,10 +663,11 @@ TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -997,17 +997,20 @@ cc_library(
|
||||||
":image_to_tensor_converter_gl_buffer",
|
":image_to_tensor_converter_gl_buffer",
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
"//mediapipe/gpu:gpu_service",
|
||||||
],
|
],
|
||||||
"//mediapipe:apple": [
|
"//mediapipe:apple": [
|
||||||
":image_to_tensor_converter_metal",
|
":image_to_tensor_converter_metal",
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
"//mediapipe/gpu:MPPMetalHelper",
|
"//mediapipe/gpu:MPPMetalHelper",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
"//mediapipe/gpu:gpu_service",
|
||||||
],
|
],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
":image_to_tensor_converter_gl_buffer",
|
":image_to_tensor_converter_gl_buffer",
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
"//mediapipe/gpu:gpu_service",
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -1045,6 +1048,10 @@ cc_test(
|
||||||
":image_to_tensor_calculator",
|
":image_to_tensor_calculator",
|
||||||
":image_to_tensor_converter",
|
":image_to_tensor_converter",
|
||||||
":image_to_tensor_utils",
|
":image_to_tensor_utils",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework/deps:file_path",
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
@ -1061,11 +1068,10 @@ cc_test(
|
||||||
"//mediapipe/framework/port:opencv_imgproc",
|
"//mediapipe/framework/port:opencv_imgproc",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
"//mediapipe/util:image_test_utils",
|
"//mediapipe/util:image_test_utils",
|
||||||
"@com_google_absl//absl/flags:flag",
|
] + select({
|
||||||
"@com_google_absl//absl/memory",
|
"//mediapipe:apple": [],
|
||||||
"@com_google_absl//absl/strings",
|
"//conditions:default": ["//mediapipe/gpu:gl_context"],
|
||||||
"@com_google_absl//absl/strings:str_format",
|
}),
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -45,9 +45,11 @@
|
||||||
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||||
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h"
|
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h"
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
|
#include "mediapipe/gpu/gpu_service.h"
|
||||||
#else
|
#else
|
||||||
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h"
|
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h"
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
|
#include "mediapipe/gpu/gpu_service.h"
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
|
@ -147,7 +149,7 @@ class ImageToTensorCalculator : public Node {
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
#if MEDIAPIPE_METAL_ENABLED
|
||||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||||
#else
|
#else
|
||||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
cc->UseService(kGpuService).Optional();
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
#endif // MEDIAPIPE_DISABLE_GPU
|
#endif // MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,10 @@
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/util/image_test_utils.h"
|
#include "mediapipe/util/image_test_utils.h"
|
||||||
|
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
#include "mediapipe/gpu/gl_context.h"
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -507,5 +511,79 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeAndUseInputImageDims) {
|
||||||
/*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt,
|
/*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt,
|
||||||
/*keep_aspect=*/false, BorderMode::kZero, roi);
|
/*keep_aspect=*/false, BorderMode::kZero, roi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ImageToTensorCalculatorTest, CanBeUsedWithoutGpuServiceSet) {
|
||||||
|
auto graph_config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input_image"
|
||||||
|
node {
|
||||||
|
calculator: "ImageToTensorCalculator"
|
||||||
|
input_stream: "IMAGE:input_image"
|
||||||
|
output_stream: "TENSORS:tensor"
|
||||||
|
options {
|
||||||
|
[mediapipe.ImageToTensorCalculatorOptions.ext] {
|
||||||
|
output_tensor_float_range { min: 0.0f max: 1.0f }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
auto image_frame =
|
||||||
|
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
|
||||||
|
Image image = Image(std::move(image_frame));
|
||||||
|
Packet packet = MakePacket<Image>(std::move(image));
|
||||||
|
MP_ASSERT_OK(
|
||||||
|
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
|
TEST(ImageToTensorCalculatorTest,
|
||||||
|
FailsGracefullyWhenGpuServiceNeededButNotAvailable) {
|
||||||
|
auto graph_config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input_image"
|
||||||
|
node {
|
||||||
|
calculator: "ImageToTensorCalculator"
|
||||||
|
input_stream: "IMAGE:input_image"
|
||||||
|
output_stream: "TENSORS:tensor"
|
||||||
|
options {
|
||||||
|
[mediapipe.ImageToTensorCalculatorOptions.ext] {
|
||||||
|
output_tensor_float_range { min: 0.0f max: 1.0f }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto context,
|
||||||
|
GlContext::Create(nullptr, /*create_thread=*/true));
|
||||||
|
Packet packet;
|
||||||
|
context->Run([&packet]() {
|
||||||
|
auto image_frame =
|
||||||
|
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
|
||||||
|
Image image = Image(std::move(image_frame));
|
||||||
|
// Ensure image is available on GPU to force ImageToTensorCalculator to
|
||||||
|
// run on GPU.
|
||||||
|
ASSERT_TRUE(image.ConvertToGpu());
|
||||||
|
packet = MakePacket<Image>(std::move(image));
|
||||||
|
});
|
||||||
|
MP_ASSERT_OK(
|
||||||
|
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
|
||||||
|
EXPECT_THAT(graph.WaitUntilIdle(),
|
||||||
|
StatusIs(absl::StatusCode::kInternal,
|
||||||
|
HasSubstr("GPU service not available")));
|
||||||
|
}
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -141,7 +141,7 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
||||||
}
|
}
|
||||||
// Run inference.
|
// Run inference.
|
||||||
{
|
{
|
||||||
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
|
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE_ADVANCED, cc);
|
||||||
return tflite_gpu_runner_->Invoke();
|
return tflite_gpu_runner_->Invoke();
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// The option proto for the TensorsReadbackCalculator.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message TensorsReadbackCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional TensorsReadbackCalculatorOptions ext = 514750372;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expected shapes of the input tensors.
|
||||||
|
// The calculator uses these shape to build the GPU programs during
|
||||||
|
// initialization, and check the actual tensor shapes against the expected
|
||||||
|
// shapes during runtime.
|
||||||
|
// Batch size of the tensor is set to be 1. `TensorShape` here can be C, WC,
|
||||||
|
// or HWC.
|
||||||
|
// For example {dims: 1 dims: 2} represents a tensor with batch_size = 1,
|
||||||
|
// width = 1, and num_channels = 2.
|
||||||
|
message TensorShape {
|
||||||
|
repeated int32 dims = 1 [packed = true];
|
||||||
|
}
|
||||||
|
// tensor_shape specifies the shape of each input tensors.
|
||||||
|
repeated TensorShape tensor_shape = 1;
|
||||||
|
}
|
After Width: | Height: | Size: 64 KiB |
|
@ -14,6 +14,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||||
|
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||||
load("@bazel_skylib//lib:selects.bzl", "selects")
|
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
@ -312,15 +313,19 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
# TODO: Re-evaluate which of these libraries we can avoid making
|
||||||
|
# cc_library_with_tflite and can be changed back to cc_library.
|
||||||
|
cc_library_with_tflite(
|
||||||
name = "tflite_model_calculator",
|
name = "tflite_model_calculator",
|
||||||
srcs = ["tflite_model_calculator.cc"],
|
srcs = ["tflite_model_calculator.cc"],
|
||||||
|
tflite_deps = [
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:packet",
|
"//mediapipe/framework:packet",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -66,7 +66,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
|
||||||
} else {
|
} else {
|
||||||
cc->OutputSidePackets()
|
cc->OutputSidePackets()
|
||||||
.Index(0)
|
.Index(0)
|
||||||
.Set<tflite_shims::ops::builtin::BuiltinOpResolver>();
|
.Set<tflite::ops::builtin::BuiltinOpResolver>();
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -77,7 +77,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
|
||||||
const TfLiteCustomOpResolverCalculatorOptions& options =
|
const TfLiteCustomOpResolverCalculatorOptions& options =
|
||||||
cc->Options<TfLiteCustomOpResolverCalculatorOptions>();
|
cc->Options<TfLiteCustomOpResolverCalculatorOptions>();
|
||||||
|
|
||||||
std::unique_ptr<tflite_shims::ops::builtin::BuiltinOpResolver> op_resolver;
|
std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> op_resolver;
|
||||||
if (options.use_gpu()) {
|
if (options.use_gpu()) {
|
||||||
op_resolver = absl::make_unique<mediapipe::OpResolver>();
|
op_resolver = absl::make_unique<mediapipe::OpResolver>();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include "mediapipe/framework/packet.h"
|
#include "mediapipe/framework/packet.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "tensorflow/lite/allocation.h"
|
#include "tensorflow/lite/allocation.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
|
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
|
||||||
#ifdef ABSL_HAVE_MMAP
|
#if defined(ABSL_HAVE_MMAP) && !TFLITE_WITH_STABLE_ABI
|
||||||
model_packet = cc->InputSidePackets().Tag("MODEL_FD");
|
model_packet = cc->InputSidePackets().Tag("MODEL_FD");
|
||||||
const auto& model_fd =
|
const auto& model_fd =
|
||||||
model_packet.Get<std::tuple<int, size_t, size_t>>();
|
model_packet.Get<std::tuple<int, size_t, size_t>>();
|
||||||
|
|
|
@ -1270,6 +1270,50 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "flat_color_image_calculator_proto",
|
||||||
|
srcs = ["flat_color_image_calculator.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
"//mediapipe/util:color_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "flat_color_image_calculator",
|
||||||
|
srcs = ["flat_color_image_calculator.cc"],
|
||||||
|
deps = [
|
||||||
|
":flat_color_image_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/formats:image",
|
||||||
|
"//mediapipe/framework/formats:image_frame",
|
||||||
|
"//mediapipe/framework/formats:image_frame_opencv",
|
||||||
|
"//mediapipe/framework/port:opencv_core",
|
||||||
|
"//mediapipe/util:color_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "flat_color_image_calculator_test",
|
||||||
|
srcs = ["flat_color_image_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":flat_color_image_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework/formats:image",
|
||||||
|
"//mediapipe/framework/formats:image_frame",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/util:color_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "from_image_calculator",
|
name = "from_image_calculator",
|
||||||
srcs = ["from_image_calculator.cc"],
|
srcs = ["from_image_calculator.cc"],
|
||||||
|
|
138
mediapipe/calculators/util/flat_color_image_calculator.cc
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||||
|
#include "mediapipe/util/color.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Node;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// A calculator for generating an image filled with a single color.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// IMAGE (Image, optional)
|
||||||
|
// If provided, the output will have the same size
|
||||||
|
// COLOR (Color proto, optional)
|
||||||
|
// Color to paint the output with. Takes precedence over the equivalent
|
||||||
|
// calculator options.
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// IMAGE (Image)
|
||||||
|
// Image filled with the requested color.
|
||||||
|
//
|
||||||
|
// Example useage:
|
||||||
|
// node {
|
||||||
|
// calculator: "FlatColorImageCalculator"
|
||||||
|
// input_stream: "IMAGE:image"
|
||||||
|
// input_stream: "COLOR:color"
|
||||||
|
// output_stream: "IMAGE:blank_image"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
// color: {
|
||||||
|
// r: 255
|
||||||
|
// g: 255
|
||||||
|
// b: 255
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
class FlatColorImageCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<Image>::Optional kInImage{"IMAGE"};
|
||||||
|
static constexpr Input<Color>::Optional kInColor{"COLOR"};
|
||||||
|
static constexpr Output<Image> kOutImage{"IMAGE"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage);
|
||||||
|
|
||||||
|
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||||
|
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
||||||
|
|
||||||
|
RET_CHECK(kInImage(cc).IsConnected() ^
|
||||||
|
(options.has_output_height() || options.has_output_width()))
|
||||||
|
<< "Either set IMAGE input stream, or set through options";
|
||||||
|
RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color())
|
||||||
|
<< "Either set COLOR input stream, or set through options";
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool use_dimension_from_option_ = false;
|
||||||
|
bool use_color_from_option_ = false;
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator);
|
||||||
|
|
||||||
|
absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) {
|
||||||
|
use_dimension_from_option_ = !kInImage(cc).IsConnected();
|
||||||
|
use_color_from_option_ = !kInColor(cc).IsConnected();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
||||||
|
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
||||||
|
|
||||||
|
int output_height = -1;
|
||||||
|
int output_width = -1;
|
||||||
|
if (use_dimension_from_option_) {
|
||||||
|
output_height = options.output_height();
|
||||||
|
output_width = options.output_width();
|
||||||
|
} else if (!kInImage(cc).IsEmpty()) {
|
||||||
|
const Image& input_image = kInImage(cc).Get();
|
||||||
|
output_height = input_image.height();
|
||||||
|
output_width = input_image.width();
|
||||||
|
} else {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Color color;
|
||||||
|
if (use_color_from_option_) {
|
||||||
|
color = options.color();
|
||||||
|
} else if (!kInColor(cc).IsEmpty()) {
|
||||||
|
color = kInColor(cc).Get();
|
||||||
|
} else {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||||
|
output_width, output_height);
|
||||||
|
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
|
||||||
|
|
||||||
|
output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b()));
|
||||||
|
|
||||||
|
kOutImage(cc).Send(Image(output_frame));
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
32
mediapipe/calculators/util/flat_color_image_calculator.proto
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/util/color.proto";
|
||||||
|
|
||||||
|
message FlatColorImageCalculatorOptions {
|
||||||
|
extend CalculatorOptions {
|
||||||
|
optional FlatColorImageCalculatorOptions ext = 515548435;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output dimensions.
|
||||||
|
optional int32 output_width = 1;
|
||||||
|
optional int32 output_height = 2;
|
||||||
|
// The color to fill with in the output image.
|
||||||
|
optional Color color = 3;
|
||||||
|
}
|
210
mediapipe/calculators/util/flat_color_image_calculator_test.cc
Normal file
|
@ -0,0 +1,210 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/util/color.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
constexpr char kColorTag[] = "COLOR";
|
||||||
|
constexpr int kImageWidth = 256;
|
||||||
|
constexpr int kImageHeight = 256;
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, SpecifyColorThroughOptions) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "IMAGE:image"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
options {
|
||||||
|
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
color: {
|
||||||
|
r: 100,
|
||||||
|
g: 200,
|
||||||
|
b: 255,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||||
|
kImageWidth, kImageHeight);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
|
MakePacket<Image>(image_frame).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
|
||||||
|
ASSERT_EQ(outputs.size(), 3);
|
||||||
|
|
||||||
|
for (const auto& packet : outputs) {
|
||||||
|
const auto& image = packet.Get<Image>();
|
||||||
|
EXPECT_EQ(image.width(), kImageWidth);
|
||||||
|
EXPECT_EQ(image.height(), kImageHeight);
|
||||||
|
auto image_frame = image.GetImageFrameSharedPtr();
|
||||||
|
auto* pixel_data = image_frame->PixelData();
|
||||||
|
EXPECT_EQ(pixel_data[0], 100);
|
||||||
|
EXPECT_EQ(pixel_data[1], 200);
|
||||||
|
EXPECT_EQ(pixel_data[2], 255);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "COLOR:color"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
options {
|
||||||
|
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
output_width: 7,
|
||||||
|
output_height: 13,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
Color color;
|
||||||
|
color.set_r(0);
|
||||||
|
color.set_g(5);
|
||||||
|
color.set_b(0);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
|
||||||
|
MakePacket<Color>(color).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
|
||||||
|
ASSERT_EQ(outputs.size(), 3);
|
||||||
|
|
||||||
|
for (const auto& packet : outputs) {
|
||||||
|
const auto& image = packet.Get<Image>();
|
||||||
|
EXPECT_EQ(image.width(), 7);
|
||||||
|
EXPECT_EQ(image.height(), 13);
|
||||||
|
auto image_frame = image.GetImageFrameSharedPtr();
|
||||||
|
const uint8_t* pixel_data = image_frame->PixelData();
|
||||||
|
EXPECT_EQ(pixel_data[0], 0);
|
||||||
|
EXPECT_EQ(pixel_data[1], 5);
|
||||||
|
EXPECT_EQ(pixel_data[2], 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, FailureMissingDimension) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "COLOR:color"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
Color color;
|
||||||
|
color.set_r(0);
|
||||||
|
color.set_g(5);
|
||||||
|
color.set_b(0);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
|
||||||
|
MakePacket<Color>(color).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
ASSERT_THAT(runner.Run().message(),
|
||||||
|
HasSubstr("Either set IMAGE input stream"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, FailureMissingColor) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "IMAGE:image"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||||
|
kImageWidth, kImageHeight);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
|
MakePacket<Image>(image_frame).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
ASSERT_THAT(runner.Run().message(),
|
||||||
|
HasSubstr("Either set COLOR input stream"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, FailureDuplicateDimension) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "IMAGE:image"
|
||||||
|
input_stream: "COLOR:color"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
options {
|
||||||
|
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
output_width: 7,
|
||||||
|
output_height: 13,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||||
|
kImageWidth, kImageHeight);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
|
MakePacket<Image>(image_frame).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
ASSERT_THAT(runner.Run().message(),
|
||||||
|
HasSubstr("Either set IMAGE input stream"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "IMAGE:image"
|
||||||
|
input_stream: "COLOR:color"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
options {
|
||||||
|
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
color: {
|
||||||
|
r: 100,
|
||||||
|
g: 200,
|
||||||
|
b: 255,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
Color color;
|
||||||
|
color.set_r(0);
|
||||||
|
color.set_g(5);
|
||||||
|
color.set_b(0);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
|
||||||
|
MakePacket<Color>(color).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
ASSERT_THAT(runner.Run().message(),
|
||||||
|
HasSubstr("Either set COLOR input stream"));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -138,7 +138,23 @@ void TestWithAspectRatio(const double aspect_ratio,
|
||||||
std::string result_image;
|
std::string result_image;
|
||||||
MP_ASSERT_OK(
|
MP_ASSERT_OK(
|
||||||
mediapipe::file::GetContents(result_string_path, &result_image));
|
mediapipe::file::GetContents(result_string_path, &result_image));
|
||||||
EXPECT_EQ(result_image, output_string);
|
if (result_image != output_string) {
|
||||||
|
// There may be slight differences due to the way the JPEG was encoded or
|
||||||
|
// the OpenCV version used to generate the reference files. Compare
|
||||||
|
// pixel-by-pixel using the Peak Signal-to-Noise Ratio instead.
|
||||||
|
cv::Mat result_mat =
|
||||||
|
cv::imdecode(cv::Mat(1, result_image.size(), CV_8UC1,
|
||||||
|
const_cast<char*>(result_image.data())),
|
||||||
|
cv::IMREAD_UNCHANGED);
|
||||||
|
cv::Mat output_mat =
|
||||||
|
cv::imdecode(cv::Mat(1, output_string.size(), CV_8UC1,
|
||||||
|
const_cast<char*>(output_string.data())),
|
||||||
|
cv::IMREAD_UNCHANGED);
|
||||||
|
ASSERT_EQ(result_mat.rows, output_mat.rows);
|
||||||
|
ASSERT_EQ(result_mat.cols, output_mat.cols);
|
||||||
|
ASSERT_EQ(result_mat.type(), output_mat.type());
|
||||||
|
EXPECT_GT(cv::PSNR(result_mat, output_mat), 45.0);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
std::string output_string_path = mediapipe::file::JoinPath(
|
std::string output_string_path = mediapipe::file::JoinPath(
|
||||||
absl::GetFlag(FLAGS_output_folder),
|
absl::GetFlag(FLAGS_output_folder),
|
||||||
|
|
Before Width: | Height: | Size: 3.2 KiB After Width: | Height: | Size: 3.2 KiB |
Before Width: | Height: | Size: 6.1 KiB After Width: | Height: | Size: 6.1 KiB |
Before Width: | Height: | Size: 8.2 KiB After Width: | Height: | Size: 8.2 KiB |
Before Width: | Height: | Size: 7.6 KiB After Width: | Height: | Size: 7.6 KiB |
|
@ -136,6 +136,8 @@ message GraphTrace {
|
||||||
GPU_TASK_INVOKE = 16;
|
GPU_TASK_INVOKE = 16;
|
||||||
TPU_TASK_INVOKE = 17;
|
TPU_TASK_INVOKE = 17;
|
||||||
CPU_TASK_INVOKE = 18;
|
CPU_TASK_INVOKE = 18;
|
||||||
|
GPU_TASK_INVOKE_ADVANCED = 19;
|
||||||
|
TPU_TASK_INVOKE_ASYNC = 20;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The timing for one packet set being processed at one caclulator node.
|
// The timing for one packet set being processed at one caclulator node.
|
||||||
|
|
|
@ -112,6 +112,10 @@ struct TraceEvent {
|
||||||
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
|
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
|
||||||
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
|
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
|
||||||
static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE;
|
static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE;
|
||||||
|
static constexpr EventType GPU_TASK_INVOKE_ADVANCED =
|
||||||
|
GraphTrace::GPU_TASK_INVOKE_ADVANCED;
|
||||||
|
static constexpr EventType TPU_TASK_INVOKE_ASYNC =
|
||||||
|
GraphTrace::TPU_TASK_INVOKE_ASYNC;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Packet trace log buffer.
|
// Packet trace log buffer.
|
||||||
|
|
|
@ -57,7 +57,6 @@ struct hash<mediapipe::TaskId> {
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void BasicTraceEventTypes(TraceEventRegistry* result) {
|
void BasicTraceEventTypes(TraceEventRegistry* result) {
|
||||||
// The initializer arguments below are: event_type, description,
|
// The initializer arguments below are: event_type, description,
|
||||||
// is_packet_event, is_stream_event, id_event_data.
|
// is_packet_event, is_stream_event, id_event_data.
|
||||||
|
@ -84,6 +83,15 @@ void BasicTraceEventTypes(TraceEventRegistry* result) {
|
||||||
"A time measured by GPU clock and by CPU clock.", true, false},
|
"A time measured by GPU clock and by CPU clock.", true, false},
|
||||||
{TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.",
|
{TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.",
|
||||||
true, true, false},
|
true, true, false},
|
||||||
|
|
||||||
|
{TraceEvent::GPU_TASK_INVOKE, "CPU timing for initiating a GPU task."},
|
||||||
|
{TraceEvent::TPU_TASK_INVOKE, "CPU timing for initiating a TPU task."},
|
||||||
|
{TraceEvent::CPU_TASK_INVOKE, "CPU timing for initiating a CPU task."},
|
||||||
|
{TraceEvent::GPU_TASK_INVOKE_ADVANCED,
|
||||||
|
"CPU timing for initiating a GPU task bypassing the TFLite "
|
||||||
|
"interpreter."},
|
||||||
|
{TraceEvent::TPU_TASK_INVOKE_ASYNC,
|
||||||
|
"CPU timing for async initiation of a TPU task."},
|
||||||
};
|
};
|
||||||
for (const TraceEventType& t : basic_types) {
|
for (const TraceEventType& t : basic_types) {
|
||||||
(*result)[t.event_type()] = t;
|
(*result)[t.event_type()] = t;
|
||||||
|
|
|
@ -204,7 +204,7 @@ def rewrite_mediapipe_proto(name, rewrite_proto, source_proto, **kwargs):
|
||||||
'import public "' + join_path + '";',
|
'import public "' + join_path + '";',
|
||||||
)
|
)
|
||||||
rewrite_ref = SubsituteCommand(
|
rewrite_ref = SubsituteCommand(
|
||||||
r"mediapipe\\.(" + rewrite_message_regex + ")",
|
r"mediapipe\.(" + rewrite_message_regex + ")",
|
||||||
r"mediapipe.\\1",
|
r"mediapipe.\\1",
|
||||||
)
|
)
|
||||||
rewrite_objc = SubsituteCommand(
|
rewrite_objc = SubsituteCommand(
|
||||||
|
@ -284,7 +284,7 @@ def mediapipe_proto_library(
|
||||||
def_jspb_proto: define the jspb_proto_library target
|
def_jspb_proto: define the jspb_proto_library target
|
||||||
def_go_proto: define the go_proto_library target
|
def_go_proto: define the go_proto_library target
|
||||||
def_options_lib: define the mediapipe_options_library target
|
def_options_lib: define the mediapipe_options_library target
|
||||||
def_rewrite: define a sibbling mediapipe_proto_library with package "mediapipe"
|
def_rewrite: define a sibling mediapipe_proto_library with package "mediapipe"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mediapipe_proto_library_impl(
|
mediapipe_proto_library_impl(
|
||||||
|
|
|
@ -183,12 +183,13 @@ absl::Status FindCorrespondingStreams(
|
||||||
// name, calculator, input_stream, output_stream, input_side_packet,
|
// name, calculator, input_stream, output_stream, input_side_packet,
|
||||||
// output_side_packet, options.
|
// output_side_packet, options.
|
||||||
// All other fields are only applicable to calculators.
|
// All other fields are only applicable to calculators.
|
||||||
|
// TODO: Check whether executor is not set in the subgraph node
|
||||||
|
// after this issues is properly solved.
|
||||||
absl::Status ValidateSubgraphFields(
|
absl::Status ValidateSubgraphFields(
|
||||||
const CalculatorGraphConfig::Node& subgraph_node) {
|
const CalculatorGraphConfig::Node& subgraph_node) {
|
||||||
if (subgraph_node.source_layer() || subgraph_node.buffer_size_hint() ||
|
if (subgraph_node.source_layer() || subgraph_node.buffer_size_hint() ||
|
||||||
subgraph_node.has_output_stream_handler() ||
|
subgraph_node.has_output_stream_handler() ||
|
||||||
subgraph_node.input_stream_info_size() != 0 ||
|
subgraph_node.input_stream_info_size() != 0) {
|
||||||
!subgraph_node.executor().empty()) {
|
|
||||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||||
<< "Subgraph \"" << subgraph_node.name()
|
<< "Subgraph \"" << subgraph_node.name()
|
||||||
<< "\" has a field that is only applicable to calculators.";
|
<< "\" has a field that is only applicable to calculators.";
|
||||||
|
|
|
@ -467,6 +467,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:frame_buffer",
|
"//mediapipe/framework/formats:frame_buffer",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/formats:yuv_image",
|
"//mediapipe/framework/formats:yuv_image",
|
||||||
|
"//mediapipe/util/frame_buffer:frame_buffer_util",
|
||||||
"//third_party/libyuv",
|
"//third_party/libyuv",
|
||||||
"@com_google_absl//absl/log",
|
"@com_google_absl//absl/log",
|
||||||
"@com_google_absl//absl/log:check",
|
"@com_google_absl//absl/log:check",
|
||||||
|
@ -932,7 +933,7 @@ mediapipe_proto_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
proto_library(
|
mediapipe_proto_library(
|
||||||
name = "gl_scaler_calculator_proto",
|
name = "gl_scaler_calculator_proto",
|
||||||
srcs = ["gl_scaler_calculator.proto"],
|
srcs = ["gl_scaler_calculator.proto"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
@ -942,17 +943,6 @@ proto_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_cc_proto_library(
|
|
||||||
name = "gl_scaler_calculator_cc_proto",
|
|
||||||
srcs = ["gl_scaler_calculator.proto"],
|
|
||||||
cc_deps = [
|
|
||||||
":scale_mode_cc_proto",
|
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
|
||||||
],
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = [":gl_scaler_calculator_proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gl_scaler_calculator",
|
name = "gl_scaler_calculator",
|
||||||
srcs = ["gl_scaler_calculator.cc"],
|
srcs = ["gl_scaler_calculator.cc"],
|
||||||
|
|
|
@ -66,7 +66,8 @@ public class GraphTextureFrame implements TextureFrame {
|
||||||
if (nativeBufferHandle == 0) {
|
if (nativeBufferHandle == 0) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) {
|
long contextHandle = nativeGetCurrentExternalContextHandle();
|
||||||
|
if (contextHandle != 0 && activeConsumerContextHandleSet.add(contextHandle)) {
|
||||||
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
|
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
|
||||||
// PacketGetter.getTextureFrameDeferredSync().
|
// PacketGetter.getTextureFrameDeferredSync().
|
||||||
if (deferredSync) {
|
if (deferredSync) {
|
||||||
|
@ -116,7 +117,14 @@ public class GraphTextureFrame implements TextureFrame {
|
||||||
GlSyncToken consumerToken = null;
|
GlSyncToken consumerToken = null;
|
||||||
// Note that this remove should be moved to the other overload of release when b/68808951 is
|
// Note that this remove should be moved to the other overload of release when b/68808951 is
|
||||||
// addressed.
|
// addressed.
|
||||||
if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) {
|
final long contextHandle = nativeGetCurrentExternalContextHandle();
|
||||||
|
if (contextHandle == 0 && !activeConsumerContextHandleSet.isEmpty()) {
|
||||||
|
logger.atWarning().log(
|
||||||
|
"GraphTextureFrame is being released on non GL thread while having active consumers,"
|
||||||
|
+ " which may lead to external / internal GL contexts synchronization issues.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (contextHandle != 0 && activeConsumerContextHandleSet.remove(contextHandle)) {
|
||||||
consumerToken =
|
consumerToken =
|
||||||
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
|
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
|
||||||
}
|
}
|
||||||
|
@ -169,7 +177,9 @@ public class GraphTextureFrame implements TextureFrame {
|
||||||
private native void nativeReleaseBuffer(long nativeHandle);
|
private native void nativeReleaseBuffer(long nativeHandle);
|
||||||
|
|
||||||
private native int nativeGetTextureName(long nativeHandle);
|
private native int nativeGetTextureName(long nativeHandle);
|
||||||
|
|
||||||
private native int nativeGetWidth(long nativeHandle);
|
private native int nativeGetWidth(long nativeHandle);
|
||||||
|
|
||||||
private native int nativeGetHeight(long nativeHandle);
|
private native int nativeGetHeight(long nativeHandle);
|
||||||
|
|
||||||
private native void nativeGpuWait(long nativeHandle);
|
private native void nativeGpuWait(long nativeHandle);
|
||||||
|
|
|
@ -357,6 +357,22 @@ def mediapipe_java_proto_srcs(name = ""):
|
||||||
target = "//mediapipe/framework/formats:rect_java_proto_lite",
|
target = "//mediapipe/framework/formats:rect_java_proto_lite",
|
||||||
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
|
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
|
||||||
))
|
))
|
||||||
|
|
||||||
|
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||||
|
target = "//mediapipe/util:color_java_proto_lite",
|
||||||
|
src_out = "com/google/mediapipe/util/proto/ColorProto.java",
|
||||||
|
))
|
||||||
|
|
||||||
|
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||||
|
target = "//mediapipe/util:label_map_java_proto_lite",
|
||||||
|
src_out = "com/google/mediapipe/util/proto/LabelMapProto.java",
|
||||||
|
))
|
||||||
|
|
||||||
|
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||||
|
target = "//mediapipe/util:render_data_java_proto_lite",
|
||||||
|
src_out = "com/google/mediapipe/util/proto/RenderDataProto.java",
|
||||||
|
))
|
||||||
|
|
||||||
return proto_src_list
|
return proto_src_list
|
||||||
|
|
||||||
def mediapipe_logging_java_proto_srcs(name = ""):
|
def mediapipe_logging_java_proto_srcs(name = ""):
|
||||||
|
|
24
mediapipe/model_maker/models/face_stylizer/BUILD
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/model_maker/python/vision/face_stylizer:__subpackages__"])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "models",
|
||||||
|
srcs = glob([
|
||||||
|
"**",
|
||||||
|
]),
|
||||||
|
)
|
|
@ -18,7 +18,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Callable, Optional, Tuple, TypeVar
|
from typing import Any, Callable, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
# Dependency imports
|
# Dependency imports
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -66,12 +66,14 @@ class Dataset(object):
|
||||||
"""
|
"""
|
||||||
return self._size
|
return self._size
|
||||||
|
|
||||||
def gen_tf_dataset(self,
|
def gen_tf_dataset(
|
||||||
|
self,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
is_training: bool = False,
|
is_training: bool = False,
|
||||||
shuffle: bool = False,
|
shuffle: bool = False,
|
||||||
preprocess: Optional[Callable[..., bool]] = None,
|
preprocess: Optional[Callable[..., Any]] = None,
|
||||||
drop_remainder: bool = False) -> tf.data.Dataset:
|
drop_remainder: bool = False,
|
||||||
|
) -> tf.data.Dataset:
|
||||||
"""Generates a batched tf.data.Dataset for training/evaluation.
|
"""Generates a batched tf.data.Dataset for training/evaluation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -48,11 +48,13 @@ class Classifier(custom_model.CustomModel):
|
||||||
self._hparams: hp.BaseHParams = None
|
self._hparams: hp.BaseHParams = None
|
||||||
self._history: tf.keras.callbacks.History = None
|
self._history: tf.keras.callbacks.History = None
|
||||||
|
|
||||||
def _train_model(self,
|
def _train_model(
|
||||||
|
self,
|
||||||
train_data: classification_ds.ClassificationDataset,
|
train_data: classification_ds.ClassificationDataset,
|
||||||
validation_data: classification_ds.ClassificationDataset,
|
validation_data: classification_ds.ClassificationDataset,
|
||||||
preprocessor: Optional[Callable[..., bool]] = None,
|
preprocessor: Optional[Callable[..., Any]] = None,
|
||||||
checkpoint_path: Optional[str] = None):
|
checkpoint_path: Optional[str] = None,
|
||||||
|
):
|
||||||
"""Trains the classifier model.
|
"""Trains the classifier model.
|
||||||
|
|
||||||
Compiles and fits the tf.keras `_model` and records the `_history`.
|
Compiles and fits the tf.keras `_model` and records the `_history`.
|
||||||
|
|
|
@ -115,9 +115,11 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||||
def convert_to_tflite(
|
def convert_to_tflite(
|
||||||
model: tf.keras.Model,
|
model: tf.keras.Model,
|
||||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
supported_ops: Tuple[tf.lite.OpsSet,
|
supported_ops: Tuple[tf.lite.OpsSet, ...] = (
|
||||||
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
|
tf.lite.OpsSet.TFLITE_BUILTINS,
|
||||||
preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
|
),
|
||||||
|
preprocess: Optional[Callable[..., Any]] = None,
|
||||||
|
) -> bytearray:
|
||||||
"""Converts the input Keras model to TFLite format.
|
"""Converts the input Keras model to TFLite format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
48
mediapipe/model_maker/python/vision/face_stylizer/BUILD
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = glob([
|
||||||
|
"testdata/**",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dataset",
|
||||||
|
srcs = ["dataset.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dataset_test",
|
||||||
|
srcs = ["dataset_test.py"],
|
||||||
|
data = [
|
||||||
|
":testdata",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,14 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe Model Maker Python Public API For Face Stylization."""
|
98
mediapipe/model_maker/python/vision/face_stylizer/dataset.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Face stylizer dataset library."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
from mediapipe.model_maker.python.vision.core import image_utils
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Change to a unlabeled dataset if it makes sense.
|
||||||
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
|
"""Dataset library for face stylizer fine tuning."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_folder(
|
||||||
|
cls, dirname: str
|
||||||
|
) -> classification_dataset.ClassificationDataset:
|
||||||
|
"""Loads images from the given directory.
|
||||||
|
|
||||||
|
The style image dataset directory is expected to contain one subdirectory
|
||||||
|
whose name represents the label of the style. There can be one or multiple
|
||||||
|
images of the same style in that subdirectory. Supported input image formats
|
||||||
|
include 'jpg', 'jpeg', 'png'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dirname: Name of the directory containing the image files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing images and labels and other related info.
|
||||||
|
Raises:
|
||||||
|
ValueError: if the input data directory is empty.
|
||||||
|
"""
|
||||||
|
data_root = os.path.abspath(dirname)
|
||||||
|
|
||||||
|
# Assumes the image data of the same label are in the same subdirectory,
|
||||||
|
# gets image path and label names.
|
||||||
|
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
||||||
|
all_image_size = len(all_image_paths)
|
||||||
|
if all_image_size == 0:
|
||||||
|
raise ValueError('Invalid input data directory')
|
||||||
|
if not any(
|
||||||
|
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
|
||||||
|
):
|
||||||
|
raise ValueError('No images found under given directory')
|
||||||
|
|
||||||
|
label_names = sorted(
|
||||||
|
name
|
||||||
|
for name in os.listdir(data_root)
|
||||||
|
if os.path.isdir(os.path.join(data_root, name))
|
||||||
|
)
|
||||||
|
all_label_size = len(label_names)
|
||||||
|
index_by_label = dict(
|
||||||
|
(name, index) for index, name in enumerate(label_names)
|
||||||
|
)
|
||||||
|
# Get the style label from the subdirectory name.
|
||||||
|
all_image_labels = [
|
||||||
|
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||||
|
for path in all_image_paths
|
||||||
|
]
|
||||||
|
|
||||||
|
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||||
|
|
||||||
|
image_ds = path_ds.map(
|
||||||
|
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load label
|
||||||
|
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
|
tf.cast(all_image_labels, tf.int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a dataset of (image, label) pairs
|
||||||
|
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
'Load images dataset with size: %d, num_label: %d, labels: %s.',
|
||||||
|
all_image_size,
|
||||||
|
all_label_size,
|
||||||
|
', '.join(label_names),
|
||||||
|
)
|
||||||
|
return Dataset(
|
||||||
|
dataset=image_label_ds, size=all_image_size, label_names=label_names
|
||||||
|
)
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import dataset
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# TODO: Replace the stylize image dataset with licensed images.
|
||||||
|
self._test_data_dirname = 'testdata'
|
||||||
|
|
||||||
|
def test_from_folder(self):
|
||||||
|
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
|
||||||
|
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||||
|
self.assertEqual(data.num_classes, 2)
|
||||||
|
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
||||||
|
self.assertLen(data, 2)
|
||||||
|
|
||||||
|
def test_from_folder_raise_value_error_for_invalid_path(self):
|
||||||
|
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
|
||||||
|
dataset.Dataset.from_folder(dirname='invalid')
|
||||||
|
|
||||||
|
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
|
||||||
|
input_data_dir = test_utils.get_test_data_path('face_stylizer')
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'No images found under given directory'
|
||||||
|
):
|
||||||
|
dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
vendored
Normal file
After Width: | Height: | Size: 347 KiB |
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
vendored
Normal file
After Width: | Height: | Size: 336 KiB |
|
@ -15,6 +15,7 @@
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
|
@ -31,6 +32,7 @@ _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdat
|
||||||
tf.keras.backend.experimental.enable_tf_random_generator()
|
tf.keras.backend.experimental.enable_tf_random_generator()
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip('b/273818271')
|
||||||
class GestureRecognizerTest(tf.test.TestCase):
|
class GestureRecognizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
def _load_data(self):
|
def _load_data(self):
|
||||||
|
@ -72,8 +74,10 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
|
||||||
|
@unittest.skip('b/273818271')
|
||||||
@unittest_mock.patch.object(
|
@unittest_mock.patch.object(
|
||||||
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense)
|
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense
|
||||||
|
)
|
||||||
def test_gesture_recognizer_model_layer_widths(self, mock_dense):
|
def test_gesture_recognizer_model_layer_widths(self, mock_dense):
|
||||||
layer_widths = [64, 32]
|
layer_widths = [64, 32]
|
||||||
mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
|
mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
|
||||||
|
@ -143,12 +147,14 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
hyperparameters,
|
hyperparameters,
|
||||||
'HParams',
|
'HParams',
|
||||||
autospec=True,
|
autospec=True,
|
||||||
return_value=gesture_recognizer.HParams(epochs=1))
|
return_value=gesture_recognizer.HParams(epochs=1),
|
||||||
|
)
|
||||||
@unittest_mock.patch.object(
|
@unittest_mock.patch.object(
|
||||||
model_options,
|
model_options,
|
||||||
'GestureRecognizerModelOptions',
|
'GestureRecognizerModelOptions',
|
||||||
autospec=True,
|
autospec=True,
|
||||||
return_value=gesture_recognizer.ModelOptions())
|
return_value=gesture_recognizer.ModelOptions(),
|
||||||
|
)
|
||||||
def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options(
|
def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options(
|
||||||
self, mock_hparams, mock_model_options):
|
self, mock_hparams, mock_model_options):
|
||||||
options = gesture_recognizer.GestureRecognizerOptions()
|
options = gesture_recognizer.GestureRecognizerOptions()
|
||||||
|
|
|
@ -37,6 +37,7 @@ constexpr char kDetectionTag[] = "DETECTION";
|
||||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||||
constexpr char kLabelsTag[] = "LABELS";
|
constexpr char kLabelsTag[] = "LABELS";
|
||||||
constexpr char kLabelsCsvTag[] = "LABELS_CSV";
|
constexpr char kLabelsCsvTag[] = "LABELS_CSV";
|
||||||
|
constexpr char kLabelMapTag[] = "LABEL_MAP";
|
||||||
|
|
||||||
using mediapipe::RE2;
|
using mediapipe::RE2;
|
||||||
using Detections = std::vector<Detection>;
|
using Detections = std::vector<Detection>;
|
||||||
|
@ -151,6 +152,11 @@ absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
||||||
cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
|
cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
|
if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
|
||||||
|
cc->InputSidePackets()
|
||||||
|
.Tag(kLabelMapTag)
|
||||||
|
.Set<std::unique_ptr<std::map<int, std::string>>>();
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,7 +164,8 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
options_ = cc->Options<FilterDetectionCalculatorOptions>();
|
options_ = cc->Options<FilterDetectionCalculatorOptions>();
|
||||||
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
|
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
|
||||||
cc->InputSidePackets().HasTag(kLabelsCsvTag);
|
cc->InputSidePackets().HasTag(kLabelsCsvTag) ||
|
||||||
|
cc->InputSidePackets().HasTag(kLabelMapTag);
|
||||||
if (limit_labels_) {
|
if (limit_labels_) {
|
||||||
Strings allowlist_labels;
|
Strings allowlist_labels;
|
||||||
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
||||||
|
@ -168,8 +175,16 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
|
||||||
for (auto& e : allowlist_labels) {
|
for (auto& e : allowlist_labels) {
|
||||||
absl::StripAsciiWhitespace(&e);
|
absl::StripAsciiWhitespace(&e);
|
||||||
}
|
}
|
||||||
} else {
|
} else if (cc->InputSidePackets().HasTag(kLabelsTag)) {
|
||||||
allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
|
allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
|
||||||
|
} else if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
|
||||||
|
auto label_map = cc->InputSidePackets()
|
||||||
|
.Tag(kLabelMapTag)
|
||||||
|
.Get<std::unique_ptr<std::map<int, std::string>>>()
|
||||||
|
.get();
|
||||||
|
for (const auto& [_, v] : *label_map) {
|
||||||
|
allowlist_labels.push_back(v);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
|
allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,5 +67,68 @@ TEST(FilterDetectionCalculatorTest, DetectionFilterTest) {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FilterDetectionCalculatorTest, DetectionFilterLabelMapTest) {
|
||||||
|
auto runner = std::make_unique<CalculatorRunner>(
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
calculator: "FilterDetectionCalculator"
|
||||||
|
input_stream: "DETECTION:input"
|
||||||
|
input_side_packet: "LABEL_MAP:input_map"
|
||||||
|
output_stream: "DETECTION:output"
|
||||||
|
options {
|
||||||
|
[mediapipe.FilterDetectionCalculatorOptions.ext]: { min_score: 0.6 }
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
runner->MutableInputs()->Tag("DETECTION").packets = {
|
||||||
|
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
|
||||||
|
label: "a"
|
||||||
|
label: "b"
|
||||||
|
label: "c"
|
||||||
|
label: "d"
|
||||||
|
score: 1
|
||||||
|
score: 0.8
|
||||||
|
score: 0.3
|
||||||
|
score: 0.9
|
||||||
|
)pb"))
|
||||||
|
.At(Timestamp(20)),
|
||||||
|
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
|
||||||
|
label: "a"
|
||||||
|
label: "b"
|
||||||
|
label: "c"
|
||||||
|
label: "e"
|
||||||
|
score: 0.6
|
||||||
|
score: 0.4
|
||||||
|
score: 0.2
|
||||||
|
score: 0.7
|
||||||
|
)pb"))
|
||||||
|
.At(Timestamp(40)),
|
||||||
|
};
|
||||||
|
|
||||||
|
auto label_map = std::make_unique<std::map<int, std::string>>();
|
||||||
|
(*label_map)[0] = "a";
|
||||||
|
(*label_map)[1] = "b";
|
||||||
|
(*label_map)[2] = "c";
|
||||||
|
runner->MutableSidePackets()->Tag("LABEL_MAP") =
|
||||||
|
AdoptAsUniquePtr(label_map.release());
|
||||||
|
|
||||||
|
// Run graph.
|
||||||
|
MP_ASSERT_OK(runner->Run());
|
||||||
|
|
||||||
|
// Check output.
|
||||||
|
EXPECT_THAT(
|
||||||
|
runner->Outputs().Tag("DETECTION").packets,
|
||||||
|
ElementsAre(PacketContainsTimestampAndPayload<Detection>(
|
||||||
|
Eq(Timestamp(20)),
|
||||||
|
EqualsProto(R"pb(
|
||||||
|
label: "a" label: "b" score: 1 score: 0.8
|
||||||
|
)pb")), // Packet 1 at timestamp 20.
|
||||||
|
PacketContainsTimestampAndPayload<Detection>(
|
||||||
|
Eq(Timestamp(40)),
|
||||||
|
EqualsProto(R"pb(
|
||||||
|
label: "a" score: 0.6
|
||||||
|
)pb")) // Packet 2 at timestamp 40.
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -57,6 +57,7 @@ pybind_extension(
|
||||||
"//mediapipe/framework/formats:landmark_registration",
|
"//mediapipe/framework/formats:landmark_registration",
|
||||||
"//mediapipe/framework/formats:rect_registration",
|
"//mediapipe/framework/formats:rect_registration",
|
||||||
"//mediapipe/modules/objectron/calculators:annotation_registration",
|
"//mediapipe/modules/objectron/calculators:annotation_registration",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_registration",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -94,6 +95,8 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
"//mediapipe/tasks/cc/vision/image_embedder:image_embedder_graph",
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||||
] + select({
|
] + select({
|
||||||
# TODO: Build text_classifier_graph and text_embedder_graph on Windows.
|
# TODO: Build text_classifier_graph and text_embedder_graph on Windows.
|
||||||
"//mediapipe:windows": [],
|
"//mediapipe:windows": [],
|
||||||
|
|
|
@ -30,7 +30,7 @@ constexpr absl::string_view kMediaPipeTasksPayload = "MediaPipeTasksStatus";
|
||||||
//
|
//
|
||||||
// At runtime, such codes are meant to be attached (where applicable) to a
|
// At runtime, such codes are meant to be attached (where applicable) to a
|
||||||
// `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and
|
// `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and
|
||||||
// stringifed error code as value (aka payload). This logic is encapsulated in
|
// stringified error code as value (aka payload). This logic is encapsulated in
|
||||||
// the `CreateStatusWithPayload` helper below for convenience.
|
// the `CreateStatusWithPayload` helper below for convenience.
|
||||||
//
|
//
|
||||||
// The returned status includes:
|
// The returned status includes:
|
||||||
|
|
|
@ -51,12 +51,11 @@ ModelAssetBundleResources::Create(
|
||||||
auto model_bundle_resources = absl::WrapUnique(
|
auto model_bundle_resources = absl::WrapUnique(
|
||||||
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
|
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
model_bundle_resources->ExtractModelFilesFromExternalFileProto());
|
model_bundle_resources->ExtractFilesFromExternalFileProto());
|
||||||
return model_bundle_resources;
|
return model_bundle_resources;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status
|
absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() {
|
||||||
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
|
|
||||||
if (model_asset_bundle_file_->has_file_name()) {
|
if (model_asset_bundle_file_->has_file_name()) {
|
||||||
// If the model asset bundle file name is a relative path, searches the file
|
// If the model asset bundle file name is a relative path, searches the file
|
||||||
// in a platform-specific location and returns the absolute path on success.
|
// in a platform-specific location and returns the absolute path on success.
|
||||||
|
@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
|
||||||
model_asset_bundle_file_handler_->GetFileContent().data();
|
model_asset_bundle_file_handler_->GetFileContent().data();
|
||||||
size_t buffer_size =
|
size_t buffer_size =
|
||||||
model_asset_bundle_file_handler_->GetFileContent().size();
|
model_asset_bundle_file_handler_->GetFileContent().size();
|
||||||
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size,
|
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_);
|
||||||
&model_files_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile(
|
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetFile(
|
||||||
const std::string& filename) const {
|
const std::string& filename) const {
|
||||||
auto it = model_files_.find(filename);
|
auto it = files_.find(filename);
|
||||||
if (it == model_files_.end()) {
|
if (it == files_.end()) {
|
||||||
auto model_files = ListModelFiles();
|
auto files = ListFiles();
|
||||||
std::string all_model_files =
|
std::string all_files = absl::StrJoin(files.begin(), files.end(), ", ");
|
||||||
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
|
|
||||||
|
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
StatusCode::kNotFound,
|
StatusCode::kNotFound,
|
||||||
absl::StrFormat("No model file with name: %s. All model files in the "
|
absl::StrFormat("No file with name: %s. All files in the model asset "
|
||||||
"model asset bundle are: %s.",
|
"bundle are: %s.",
|
||||||
filename, all_model_files),
|
filename, all_files),
|
||||||
MediaPipeTasksStatus::kFileNotFoundError);
|
MediaPipeTasksStatus::kFileNotFoundError);
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const {
|
std::vector<std::string> ModelAssetBundleResources::ListFiles() const {
|
||||||
std::vector<std::string> model_names;
|
std::vector<std::string> file_names;
|
||||||
for (const auto& [model_name, _] : model_files_) {
|
for (const auto& [file_name, _] : files_) {
|
||||||
model_names.push_back(model_name);
|
file_names.push_back(file_name);
|
||||||
}
|
}
|
||||||
return model_names;
|
return file_names;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace core
|
} // namespace core
|
||||||
|
|
|
@ -28,8 +28,8 @@ namespace core {
|
||||||
// The mediapipe task model asset bundle resources class.
|
// The mediapipe task model asset bundle resources class.
|
||||||
// A ModelAssetBundleResources object, created from an external file proto,
|
// A ModelAssetBundleResources object, created from an external file proto,
|
||||||
// contains model asset bundle related resources and the method to extract the
|
// contains model asset bundle related resources and the method to extract the
|
||||||
// tflite models or model asset bundles for the mediapipe sub-tasks. As the
|
// tflite models, resource files or model asset bundles for the mediapipe
|
||||||
// resources are owned by the ModelAssetBundleResources object
|
// sub-tasks. As the resources are owned by the ModelAssetBundleResources object
|
||||||
// callers must keep ModelAssetBundleResources alive while using any of the
|
// callers must keep ModelAssetBundleResources alive while using any of the
|
||||||
// resources.
|
// resources.
|
||||||
class ModelAssetBundleResources {
|
class ModelAssetBundleResources {
|
||||||
|
@ -50,14 +50,13 @@ class ModelAssetBundleResources {
|
||||||
// Returns the model asset bundle resources tag.
|
// Returns the model asset bundle resources tag.
|
||||||
std::string GetTag() const { return tag_; }
|
std::string GetTag() const { return tag_; }
|
||||||
|
|
||||||
// Gets the contents of the model file (either tflite model file or model
|
// Gets the contents of the model file (either tflite model file, resource
|
||||||
// bundle file) with the provided name. An error is returned if there is no
|
// file or model bundle file) with the provided name. An error is returned if
|
||||||
// such model file.
|
// there is no such model file.
|
||||||
absl::StatusOr<absl::string_view> GetModelFile(
|
absl::StatusOr<absl::string_view> GetFile(const std::string& filename) const;
|
||||||
const std::string& filename) const;
|
|
||||||
|
|
||||||
// Lists all the model file names in the model asset model.
|
// Lists all the file names in the model asset model.
|
||||||
std::vector<std::string> ListModelFiles() const;
|
std::vector<std::string> ListFiles() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Constructor.
|
// Constructor.
|
||||||
|
@ -65,9 +64,9 @@ class ModelAssetBundleResources {
|
||||||
const std::string& tag,
|
const std::string& tag,
|
||||||
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
|
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
|
||||||
|
|
||||||
// Extracts the model files (either tflite model file or model bundle file)
|
// Extracts the model files (either tflite model file, resource file or model
|
||||||
// from the external file proto.
|
// bundle file) from the external file proto.
|
||||||
absl::Status ExtractModelFilesFromExternalFileProto();
|
absl::Status ExtractFilesFromExternalFileProto();
|
||||||
|
|
||||||
// The model asset bundle resources tag.
|
// The model asset bundle resources tag.
|
||||||
const std::string tag_;
|
const std::string tag_;
|
||||||
|
@ -78,11 +77,11 @@ class ModelAssetBundleResources {
|
||||||
// The ExternalFileHandler for the model asset bundle.
|
// The ExternalFileHandler for the model asset bundle.
|
||||||
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
|
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
|
||||||
|
|
||||||
// The model files bundled in model asset bundle, as a map with the filename
|
// The files bundled in model asset bundle, as a map with the filename
|
||||||
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and
|
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and
|
||||||
// a pointer to the file contents as value. Each model file can be either
|
// a pointer to the file contents as value. Each file can be either a TFLite
|
||||||
// a TFLite model file or a model bundle file for sub-task.
|
// model file, resource file or a model bundle file for sub-task.
|
||||||
absl::flat_hash_map<std::string, absl::string_view> model_files_;
|
absl::flat_hash_map<std::string, absl::string_view> files_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace core
|
} // namespace core
|
||||||
|
|
|
@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||||
.status());
|
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||||
.status());
|
.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||||
.status());
|
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||||
.status());
|
.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||||
.status());
|
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||||
.status());
|
.status());
|
||||||
}
|
}
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||||
.status());
|
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||||
.status());
|
.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
auto status_or_model_bundle_file =
|
auto status_or_model_bundle_file =
|
||||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task");
|
model_bundle_resources->GetFile("dummy_hand_landmarker.task");
|
||||||
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
||||||
|
|
||||||
// Creates sub-task model asset bundle resources.
|
// Creates sub-task model asset bundle resources.
|
||||||
|
@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(hand_landmaker_model_file)));
|
std::move(hand_landmaker_model_file)));
|
||||||
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
||||||
->GetModelFile("dummy_hand_detector.tflite")
|
->GetFile("dummy_hand_detector.tflite")
|
||||||
.status());
|
.status());
|
||||||
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
||||||
->GetModelFile("dummy_hand_landmarker.tflite")
|
->GetFile("dummy_hand_landmarker.tflite")
|
||||||
.status());
|
.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
auto status_or_model_bundle_file =
|
auto status_or_model_bundle_file =
|
||||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite");
|
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite");
|
||||||
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
||||||
|
|
||||||
// Verify tflite model works.
|
// Verify tflite model works.
|
||||||
|
@ -200,11 +196,11 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
|
||||||
auto model_bundle_resources,
|
auto model_bundle_resources,
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
auto status = model_bundle_resources->GetModelFile("not_found.task").status();
|
auto status = model_bundle_resources->GetFile("not_found.task").status();
|
||||||
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
|
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
|
||||||
EXPECT_THAT(status.message(),
|
EXPECT_THAT(
|
||||||
testing::HasSubstr(
|
status.message(),
|
||||||
"No model file with name: not_found.task. All model files in "
|
testing::HasSubstr("No file with name: not_found.task. All files in "
|
||||||
"the model asset bundle are: "));
|
"the model asset bundle are: "));
|
||||||
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
|
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
|
||||||
testing::Optional(absl::Cord(
|
testing::Optional(absl::Cord(
|
||||||
|
@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
|
||||||
auto model_bundle_resources,
|
auto model_bundle_resources,
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
auto model_files = model_bundle_resources->ListModelFiles();
|
auto model_files = model_bundle_resources->ListFiles();
|
||||||
std::vector<std::string> expected_model_files = {
|
std::vector<std::string> expected_model_files = {
|
||||||
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
|
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
|
||||||
std::sort(model_files.begin(), model_files.end());
|
std::sort(model_files.begin(), model_files.end());
|
||||||
|
|
|
@ -77,9 +77,11 @@ class ModelResourcesCalculator : public api2::Node {
|
||||||
if (options.has_model_file()) {
|
if (options.has_model_file()) {
|
||||||
RET_CHECK(options.model_file().has_file_content() ||
|
RET_CHECK(options.model_file().has_file_content() ||
|
||||||
options.model_file().has_file_descriptor_meta() ||
|
options.model_file().has_file_descriptor_meta() ||
|
||||||
options.model_file().has_file_name())
|
options.model_file().has_file_name() ||
|
||||||
|
options.model_file().has_file_pointer_meta())
|
||||||
<< "'model_file' must specify at least one of "
|
<< "'model_file' must specify at least one of "
|
||||||
"'file_content', 'file_descriptor_meta', or 'file_name'";
|
"'file_content', 'file_descriptor_meta', 'file_name', or "
|
||||||
|
"'file_pointer_meta'";
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -179,9 +179,9 @@ TEST_F(ModelResourcesCalculatorTest, EmptyExternalFileProto) {
|
||||||
auto status = graph.Initialize(graph_config);
|
auto status = graph.Initialize(graph_config);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_THAT(status.message(),
|
EXPECT_THAT(status.message(),
|
||||||
testing::HasSubstr(
|
testing::HasSubstr("'model_file' must specify at least one of "
|
||||||
"'model_file' must specify at least one of "
|
"'file_content', 'file_descriptor_meta', "
|
||||||
"'file_content', 'file_descriptor_meta', or 'file_name'"));
|
"'file_name', or 'file_pointer_meta'"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModelResourcesCalculatorTest, GraphServiceNotAvailable) {
|
TEST_F(ModelResourcesCalculatorTest, GraphServiceNotAvailable) {
|
||||||
|
|
|
@ -138,7 +138,7 @@ class InferenceSubgraph : public Subgraph {
|
||||||
delegate.mutable_tflite()->CopyFrom(acceleration.tflite());
|
delegate.mutable_tflite()->CopyFrom(acceleration.tflite());
|
||||||
break;
|
break;
|
||||||
case Acceleration::DELEGATE_NOT_SET:
|
case Acceleration::DELEGATE_NOT_SET:
|
||||||
// Deafult inference calculator setting.
|
// Default inference calculator setting.
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return delegate;
|
return delegate;
|
||||||
|
|
|
@ -124,10 +124,10 @@ class ModelTaskGraph : public Subgraph {
|
||||||
// Inserts a mediapipe task inference subgraph into the provided
|
// Inserts a mediapipe task inference subgraph into the provided
|
||||||
// GraphBuilder. The returned node provides the following interfaces to the
|
// GraphBuilder. The returned node provides the following interfaces to the
|
||||||
// the rest of the graph:
|
// the rest of the graph:
|
||||||
// - a tensor vector (std::vector<MeidaPipe::Tensor>) input stream with tag
|
// - a tensor vector (std::vector<mediapipe::Tensor>) input stream with tag
|
||||||
// "TENSORS", representing the input tensors to be consumed by the
|
// "TENSORS", representing the input tensors to be consumed by the
|
||||||
// inference engine.
|
// inference engine.
|
||||||
// - a tensor vector (std::vector<MeidaPipe::Tensor>) output stream with tag
|
// - a tensor vector (std::vector<mediapipe::Tensor>) output stream with tag
|
||||||
// "TENSORS", representing the output tensors generated by the inference
|
// "TENSORS", representing the output tensors generated by the inference
|
||||||
// engine.
|
// engine.
|
||||||
// - a MetadataExtractor output side packet with tag "METADATA_EXTRACTOR".
|
// - a MetadataExtractor output side packet with tag "METADATA_EXTRACTOR".
|
||||||
|
|
|
@ -301,7 +301,7 @@ absl::Status TaskRunner::Close() {
|
||||||
}
|
}
|
||||||
is_running_ = false;
|
is_running_ = false;
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
AddPayload(graph_.CloseAllInputStreams(), "Fail to close intput streams",
|
AddPayload(graph_.CloseAllInputStreams(), "Fail to close input streams",
|
||||||
MediaPipeTasksStatus::kRunnerFailsToCloseError));
|
MediaPipeTasksStatus::kRunnerFailsToCloseError));
|
||||||
MP_RETURN_IF_ERROR(AddPayload(
|
MP_RETURN_IF_ERROR(AddPayload(
|
||||||
graph_.WaitUntilDone(), "Fail to shutdown the MediaPipe graph.",
|
graph_.WaitUntilDone(), "Fail to shutdown the MediaPipe graph.",
|
||||||
|
|
|
@ -65,7 +65,7 @@ class TaskRunner {
|
||||||
// Creates the task runner with a CalculatorGraphConfig proto.
|
// Creates the task runner with a CalculatorGraphConfig proto.
|
||||||
// If a tflite op resolver object is provided, the task runner will take
|
// If a tflite op resolver object is provided, the task runner will take
|
||||||
// it as the global op resolver for all models running within this task.
|
// it as the global op resolver for all models running within this task.
|
||||||
// The op resolver's owernship will be transferred into the pipeleine runner.
|
// The op resolver's ownership will be transferred into the pipeleine runner.
|
||||||
// When a user-defined PacketsCallback is provided, clients must use the
|
// When a user-defined PacketsCallback is provided, clients must use the
|
||||||
// asynchronous method, Send(), to provide the input packets. If the packets
|
// asynchronous method, Send(), to provide the input packets. If the packets
|
||||||
// callback is absent, clients must use the synchronous method, Process(), to
|
// callback is absent, clients must use the synchronous method, Process(), to
|
||||||
|
@ -84,7 +84,7 @@ class TaskRunner {
|
||||||
// frames from a video file and an audio file. The call blocks the current
|
// frames from a video file and an audio file. The call blocks the current
|
||||||
// thread until a failure status or a successful result is returned.
|
// thread until a failure status or a successful result is returned.
|
||||||
// If the input packets have no timestamp, an internal timestamp will be
|
// If the input packets have no timestamp, an internal timestamp will be
|
||||||
// assigend per invocation. Otherwise, when the timestamp is set in the
|
// assigned per invocation. Otherwise, when the timestamp is set in the
|
||||||
// input packets, the caller must ensure that the input packet timestamps are
|
// input packets, the caller must ensure that the input packet timestamps are
|
||||||
// greater than the timestamps of the previous invocation. This method is
|
// greater than the timestamps of the previous invocation. This method is
|
||||||
// thread-unsafe and it is the caller's responsibility to synchronize access
|
// thread-unsafe and it is the caller's responsibility to synchronize access
|
||||||
|
|
|
@ -64,7 +64,7 @@ class ModelMetadataPopulator {
|
||||||
// Loads associated files into the TFLite FlatBuffer model. The input is a map
|
// Loads associated files into the TFLite FlatBuffer model. The input is a map
|
||||||
// of {filename, file contents}.
|
// of {filename, file contents}.
|
||||||
//
|
//
|
||||||
// Warning: this method removes any previoulsy present associated files.
|
// Warning: this method removes any previously present associated files.
|
||||||
// Calling this method multiple time removes any associated files from
|
// Calling this method multiple time removes any associated files from
|
||||||
// previous calls, so this method should usually be called only once.
|
// previous calls, so this method should usually be called only once.
|
||||||
void LoadAssociatedFiles(
|
void LoadAssociatedFiles(
|
||||||
|
|
|
@ -213,7 +213,7 @@ void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
|
||||||
Version* min_version) {
|
Version* min_version) {
|
||||||
if (table == nullptr) return;
|
if (table == nullptr) return;
|
||||||
|
|
||||||
// Checks the ContenProperties field.
|
// Checks the ContentProperties field.
|
||||||
if (table->content_properties_type() == ContentProperties_AudioProperties) {
|
if (table->content_properties_type() == ContentProperties_AudioProperties) {
|
||||||
UpdateMinimumVersion(
|
UpdateMinimumVersion(
|
||||||
GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties),
|
GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties),
|
||||||
|
|
|
@ -31,8 +31,8 @@ PYBIND11_MODULE(_pywrap_metadata_version, m) {
|
||||||
|
|
||||||
// Using pybind11 type conversions to convert between Python and native
|
// Using pybind11 type conversions to convert between Python and native
|
||||||
// C++ types. There are other options to provide access to native Python types
|
// C++ types. There are other options to provide access to native Python types
|
||||||
// in C++ and vice versa. See the pybind 11 instrcution [1] for more details.
|
// in C++ and vice versa. See the pybind 11 instruction [1] for more details.
|
||||||
// Type converstions is recommended by pybind11, though the main downside
|
// Type conversions is recommended by pybind11, though the main downside
|
||||||
// is that a copy of the data must be made on every Python to C++ transition:
|
// is that a copy of the data must be made on every Python to C++ transition:
|
||||||
// this is needed since the C++ and Python versions of the same type generally
|
// this is needed since the C++ and Python versions of the same type generally
|
||||||
// won’t have the same memory layout.
|
// won’t have the same memory layout.
|
||||||
|
|
|
@ -79,7 +79,7 @@ TEST(MetadataVersionTest,
|
||||||
auto metadata = metadata_builder.Finish();
|
auto metadata = metadata_builder.Finish();
|
||||||
FinishModelMetadataBuffer(builder, metadata);
|
FinishModelMetadataBuffer(builder, metadata);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -100,7 +100,7 @@ TEST(MetadataVersionTest,
|
||||||
auto metadata = metadata_builder.Finish();
|
auto metadata = metadata_builder.Finish();
|
||||||
builder.Finish(metadata);
|
builder.Finish(metadata);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version and triggers error.
|
// Gets the minimum metadata parser version and triggers error.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -121,7 +121,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_associated_files(associated_files);
|
metadata_builder.add_associated_files(associated_files);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -147,7 +147,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -172,7 +172,7 @@ TEST(MetadataVersionTest,
|
||||||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||||
CreateModelWithMetadata(tensors, builder);
|
CreateModelWithMetadata(tensors, builder);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -203,7 +203,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -234,7 +234,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -265,7 +265,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -294,7 +294,7 @@ TEST(MetadataVersionTest,
|
||||||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||||
CreateModelWithMetadata(tensors, builder);
|
CreateModelWithMetadata(tensors, builder);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -323,7 +323,7 @@ TEST(MetadataVersionTest,
|
||||||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||||
CreateModelWithMetadata(tensors, builder);
|
CreateModelWithMetadata(tensors, builder);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -348,7 +348,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -373,7 +373,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -404,7 +404,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -431,7 +431,7 @@ TEST(MetadataVersionTest,
|
||||||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||||
CreateModelWithMetadata(tensors, builder);
|
CreateModelWithMetadata(tensors, builder);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -453,7 +453,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_associated_files(associated_files);
|
metadata_builder.add_associated_files(associated_files);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -476,7 +476,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_associated_files(associated_files);
|
metadata_builder.add_associated_files(associated_files);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -504,7 +504,7 @@ TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) {
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
|
|
@ -42,3 +42,36 @@ cc_test(
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "ngram_hash",
|
||||||
|
srcs = ["ngram_hash.cc"],
|
||||||
|
hdrs = ["ngram_hash.h"],
|
||||||
|
copts = tflite_copts(),
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils",
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
|
||||||
|
"@flatbuffers",
|
||||||
|
"@org_tensorflow//tensorflow/lite:string_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "ngram_hash_test",
|
||||||
|
srcs = ["ngram_hash_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":ngram_hash",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
"@flatbuffers",
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
|
"@org_tensorflow//tensorflow/lite:string_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:common",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,264 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "flatbuffers/flexbuffers.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
|
||||||
|
namespace ngram_op {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::flexbuffers::GetRoot;
|
||||||
|
using ::flexbuffers::Map;
|
||||||
|
using ::flexbuffers::TypedVector;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::
|
||||||
|
LowercaseUnicodeStr;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
|
||||||
|
MurmurHash64WithSeed;
|
||||||
|
using ::tflite::GetString;
|
||||||
|
using ::tflite::StringRef;
|
||||||
|
|
||||||
|
constexpr int kInputMessage = 0;
|
||||||
|
constexpr int kOutputLabel = 0;
|
||||||
|
constexpr int kDefaultMaxSplits = 128;
|
||||||
|
|
||||||
|
// This op takes in a string, finds the character ngrams for it and then
|
||||||
|
// maps each of these ngrams to an index using the specified vocabulary sizes.
|
||||||
|
|
||||||
|
// Input(s):
|
||||||
|
// - input: Input string.
|
||||||
|
// - seeds: Seed for the random number generator.
|
||||||
|
// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would
|
||||||
|
// be interpreted as generating unigrams, bigrams, and trigrams.
|
||||||
|
// - vocab_sizes: Size of the vocabulary for each of the ngram features
|
||||||
|
// respectively. The op would generate vocab ids to be less than or equal to
|
||||||
|
// the vocab size. The index 0 implies an invalid ngram.
|
||||||
|
// - max_splits: Maximum number of tokens in the output. If this is unset, the
|
||||||
|
// limit is `kDefaultMaxSplits`.
|
||||||
|
// - lower_case_input: If this is set to true, the input string would be
|
||||||
|
// lower-cased before any processing.
|
||||||
|
|
||||||
|
// Output(s):
|
||||||
|
// - output: A tensor of size [number of ngrams, number of tokens + 2],
|
||||||
|
// where 2 tokens are reserved for the padding. If `max_splits` is set, this
|
||||||
|
// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`.
|
||||||
|
|
||||||
|
// Helper class used for pre-processing the input.
|
||||||
|
class NGramHashParams {
|
||||||
|
public:
|
||||||
|
NGramHashParams(const uint64_t seed, const std::vector<int>& ngram_lengths,
|
||||||
|
const std::vector<int>& vocab_sizes, int max_splits,
|
||||||
|
bool lower_case_input)
|
||||||
|
: seed_(seed),
|
||||||
|
ngram_lengths_(ngram_lengths),
|
||||||
|
vocab_sizes_(vocab_sizes),
|
||||||
|
max_splits_(max_splits),
|
||||||
|
lower_case_input_(lower_case_input) {}
|
||||||
|
|
||||||
|
TfLiteStatus PreprocessInput(const TfLiteTensor* input_t,
|
||||||
|
TfLiteContext* context) {
|
||||||
|
if (input_t->bytes == 0) {
|
||||||
|
context->ReportError(context, "Empty input not supported.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do sanity checks on the input.
|
||||||
|
if (ngram_lengths_.empty()) {
|
||||||
|
context->ReportError(context, "`ngram_lengths` must be non-empty.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vocab_sizes_.empty()) {
|
||||||
|
context->ReportError(context, "`vocab_sizes` must be non-empty.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ngram_lengths_.size() != vocab_sizes_.size()) {
|
||||||
|
context->ReportError(
|
||||||
|
context,
|
||||||
|
"Sizes of `ngram_lengths` and `vocab_sizes` must be the same.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (max_splits_ <= 0) {
|
||||||
|
context->ReportError(context, "`max_splits` must be > 0.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtain and tokenize the input.
|
||||||
|
StringRef inputref = GetString(input_t, /*string_index=*/0);
|
||||||
|
if (lower_case_input_) {
|
||||||
|
std::string lower_cased_str;
|
||||||
|
LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str);
|
||||||
|
|
||||||
|
tokenized_output_ =
|
||||||
|
Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_,
|
||||||
|
/*exclude_nonalphaspace_tokens=*/true);
|
||||||
|
} else {
|
||||||
|
tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_,
|
||||||
|
/*exclude_nonalphaspace_tokens=*/true);
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
uint64_t GetSeed() const { return seed_; }
|
||||||
|
|
||||||
|
int GetNumTokens() const { return tokenized_output_.tokens.size(); }
|
||||||
|
|
||||||
|
int GetNumNGrams() const { return ngram_lengths_.size(); }
|
||||||
|
|
||||||
|
std::vector<int> GetNGramLengths() const { return ngram_lengths_; }
|
||||||
|
|
||||||
|
std::vector<int> GetVocabSizes() const { return vocab_sizes_; }
|
||||||
|
|
||||||
|
const TokenizedOutput& GetTokenizedOutput() const {
|
||||||
|
return tokenized_output_;
|
||||||
|
}
|
||||||
|
|
||||||
|
TokenizedOutput tokenized_output_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const uint64_t seed_;
|
||||||
|
std::vector<int> ngram_lengths_;
|
||||||
|
std::vector<int> vocab_sizes_;
|
||||||
|
const int max_splits_;
|
||||||
|
const bool lower_case_input_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert the TypedVector into a regular std::vector.
|
||||||
|
std::vector<int> GetIntVector(TypedVector typed_vec) {
|
||||||
|
std::vector<int> vec(typed_vec.size());
|
||||||
|
for (int j = 0; j < typed_vec.size(); j++) {
|
||||||
|
vec[j] = typed_vec[j].AsInt32();
|
||||||
|
}
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetNGramHashIndices(NGramHashParams* params, int32_t* data) {
|
||||||
|
const int max_unicode_length = params->GetNumTokens();
|
||||||
|
const auto ngram_lengths = params->GetNGramLengths();
|
||||||
|
const auto vocab_sizes = params->GetVocabSizes();
|
||||||
|
const auto& tokenized_output = params->GetTokenizedOutput();
|
||||||
|
const auto seed = params->GetSeed();
|
||||||
|
|
||||||
|
// Compute for each ngram.
|
||||||
|
for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) {
|
||||||
|
const int vocab_size = vocab_sizes[ngram];
|
||||||
|
const int ngram_length = ngram_lengths[ngram];
|
||||||
|
|
||||||
|
// Compute for each token within the input.
|
||||||
|
for (int start = 0; start < tokenized_output.tokens.size(); start++) {
|
||||||
|
// Compute the number of bytes for the ngram starting at the given
|
||||||
|
// token.
|
||||||
|
int num_bytes = 0;
|
||||||
|
for (int i = start;
|
||||||
|
i < tokenized_output.tokens.size() && i < (start + ngram_length);
|
||||||
|
i++) {
|
||||||
|
num_bytes += tokenized_output.tokens[i].second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the hash for the ngram starting at the token.
|
||||||
|
const auto str_hash = MurmurHash64WithSeed(
|
||||||
|
tokenized_output.str.c_str() + tokenized_output.tokens[start].first,
|
||||||
|
num_bytes, seed);
|
||||||
|
|
||||||
|
// Map the hash to an index in the vocab.
|
||||||
|
data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||||
|
const Map& m = GetRoot(buffer_t, length).AsMap();
|
||||||
|
|
||||||
|
const uint64_t seed = m["seed"].AsUInt64();
|
||||||
|
const std::vector<int> ngram_lengths =
|
||||||
|
GetIntVector(m["ngram_lengths"].AsTypedVector());
|
||||||
|
const std::vector<int> vocab_sizes =
|
||||||
|
GetIntVector(m["vocab_sizes"].AsTypedVector());
|
||||||
|
const int max_splits =
|
||||||
|
m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32();
|
||||||
|
const bool lowercase_input =
|
||||||
|
m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool();
|
||||||
|
|
||||||
|
return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits,
|
||||||
|
lowercase_input);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
|
delete reinterpret_cast<NGramHashParams*>(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
SetTensorToDynamic(output);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context,
|
||||||
|
params->PreprocessInput(GetInput(context, node, kInputMessage), context));
|
||||||
|
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
if (IsDynamicTensor(output)) {
|
||||||
|
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||||
|
output_size->data[0] = 1;
|
||||||
|
output_size->data[1] = params->GetNumNGrams();
|
||||||
|
output_size->data[2] = params->GetNumTokens();
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
context->ResizeTensor(context, output, output_size));
|
||||||
|
} else {
|
||||||
|
context->ReportError(context, "Output must by dynamic.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (output->type == kTfLiteInt32) {
|
||||||
|
GetNGramHashIndices(params, output->data.i32);
|
||||||
|
} else {
|
||||||
|
context->ReportError(context, "Output type must be Int32.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ngram_op
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_NGRAM_HASH() {
|
||||||
|
static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free,
|
||||||
|
ngram_op::Resize, ngram_op::Eval};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite::ops::custom
|
|
@ -0,0 +1,27 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_NGRAM_HASH();
|
||||||
|
|
||||||
|
} // namespace tflite::ops::custom
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
|
@ -0,0 +1,313 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "flatbuffers/flexbuffers.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::flexbuffers::Builder;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
|
||||||
|
MurmurHash64WithSeed;
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
using ::testing::Message;
|
||||||
|
|
||||||
|
// Helper class for testing the op.
|
||||||
|
class NGramHashModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
explicit NGramHashModel(const uint64_t seed,
|
||||||
|
const std::vector<int>& ngram_lengths,
|
||||||
|
const std::vector<int>& vocab_sizes,
|
||||||
|
const absl::optional<int> max_splits = std::nullopt) {
|
||||||
|
// Setup the model inputs.
|
||||||
|
Builder fbb;
|
||||||
|
size_t start = fbb.StartMap();
|
||||||
|
fbb.UInt("seed", seed);
|
||||||
|
{
|
||||||
|
size_t start = fbb.StartVector("ngram_lengths");
|
||||||
|
for (const int& ngram_len : ngram_lengths) {
|
||||||
|
fbb.Int(ngram_len);
|
||||||
|
}
|
||||||
|
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
size_t start = fbb.StartVector("vocab_sizes");
|
||||||
|
for (const int& vocab_size : vocab_sizes) {
|
||||||
|
fbb.Int(vocab_size);
|
||||||
|
}
|
||||||
|
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
|
||||||
|
}
|
||||||
|
if (max_splits) {
|
||||||
|
fbb.Int("max_splits", *max_splits);
|
||||||
|
}
|
||||||
|
fbb.EndMap(start);
|
||||||
|
fbb.Finish();
|
||||||
|
output_ = AddOutput({TensorType_INT32, {}});
|
||||||
|
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
|
||||||
|
BuildInterpreter({GetShape(input_)});
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetupInputTensor(const std::string& input) {
|
||||||
|
PopulateStringTensor(input_, {input});
|
||||||
|
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
|
||||||
|
<< "Cannot allocate tensors";
|
||||||
|
}
|
||||||
|
|
||||||
|
void Invoke(const std::string& input) {
|
||||||
|
SetupInputTensor(input);
|
||||||
|
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus InvokeUnchecked(const std::string& input) {
|
||||||
|
SetupInputTensor(input);
|
||||||
|
return SingleOpModel::Invoke();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<T> GetOutput() {
|
||||||
|
return ExtractVector<T>(output_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int input_ = AddInput(TensorType_STRING);
|
||||||
|
int output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) {
|
||||||
|
// Checks that the op returns the expected value when the input is sane.
|
||||||
|
// Also checks that when `max_splits` is not specified, the entire string is
|
||||||
|
// tokenized.
|
||||||
|
const uint64_t kSeed = 123;
|
||||||
|
const std::vector<int> vocab_sizes({100, 200});
|
||||||
|
std::vector<int> ngram_lengths({1, 2});
|
||||||
|
const std::vector<std::string> testcase_inputs({
|
||||||
|
"hi",
|
||||||
|
"wow",
|
||||||
|
"!",
|
||||||
|
"HI",
|
||||||
|
});
|
||||||
|
|
||||||
|
// A hash function that maps the given string to an index in the embedding
|
||||||
|
// table denoted by `vocab_idx`.
|
||||||
|
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
|
||||||
|
const auto hash_value =
|
||||||
|
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
|
||||||
|
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
|
||||||
|
};
|
||||||
|
const std::vector<std::vector<int>> expected_testcase_outputs(
|
||||||
|
{{
|
||||||
|
// Unigram & Bigram output for "hi".
|
||||||
|
hash("^", 0),
|
||||||
|
hash("h", 0),
|
||||||
|
hash("i", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^h", 1),
|
||||||
|
hash("hi", 1),
|
||||||
|
hash("i$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow".
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("ow", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "!" (which will get replaced by " ").
|
||||||
|
hash("^", 0),
|
||||||
|
hash(" ", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^ ", 1),
|
||||||
|
hash(" $", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "HI" (which will get lower-cased).
|
||||||
|
hash("^", 0),
|
||||||
|
hash("h", 0),
|
||||||
|
hash("i", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^h", 1),
|
||||||
|
hash("hi", 1),
|
||||||
|
hash("i$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
}});
|
||||||
|
|
||||||
|
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
|
||||||
|
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
|
||||||
|
const string& testcase_input = testcase_inputs[test_idx];
|
||||||
|
m.Invoke(testcase_input);
|
||||||
|
SCOPED_TRACE(Message() << "Where the testcases' input is: "
|
||||||
|
<< testcase_input);
|
||||||
|
EXPECT_THAT(m.GetOutput<int>(),
|
||||||
|
ElementsAreArray(expected_testcase_outputs[test_idx]));
|
||||||
|
EXPECT_THAT(m.GetOutputShape(),
|
||||||
|
ElementsAreArray(
|
||||||
|
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
|
||||||
|
static_cast<int>(testcase_input.size()) + /*padding*/ 2}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) {
|
||||||
|
// Checks that the op returns the expected value when the input is correct
|
||||||
|
// when `max_splits` is specified.
|
||||||
|
const uint64_t kSeed = 123;
|
||||||
|
const std::vector<int> vocab_sizes({100, 200});
|
||||||
|
std::vector<int> ngram_lengths({1, 2});
|
||||||
|
|
||||||
|
const std::string testcase_input = "wow";
|
||||||
|
const std::vector<int> max_splits({2, 3, 4, 5, 6});
|
||||||
|
|
||||||
|
// A hash function that maps the given string to an index in the embedding
|
||||||
|
// table denoted by `vocab_idx`.
|
||||||
|
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
|
||||||
|
const auto hash_value =
|
||||||
|
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
|
||||||
|
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<std::vector<int>> expected_testcase_outputs(
|
||||||
|
{{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 2.
|
||||||
|
// We cannot include any of the actual tokens, since `max_splits`
|
||||||
|
// only allows enough space for the delimiters.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 3.
|
||||||
|
// We can start to include some tokens from the input string.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 4.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("o$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 5.
|
||||||
|
// We can include the full input string.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("ow", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 6.
|
||||||
|
// `max_splits` is more than the full input string.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("ow", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
}});
|
||||||
|
|
||||||
|
for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) {
|
||||||
|
const int testcase_max_splits = max_splits[test_idx];
|
||||||
|
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits);
|
||||||
|
m.Invoke(testcase_input);
|
||||||
|
SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits);
|
||||||
|
EXPECT_THAT(m.GetOutput<int>(),
|
||||||
|
ElementsAreArray(expected_testcase_outputs[test_idx]));
|
||||||
|
EXPECT_THAT(
|
||||||
|
m.GetOutputShape(),
|
||||||
|
ElementsAreArray(
|
||||||
|
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
|
||||||
|
std::min(
|
||||||
|
// Longest possible tokenization when using the entire
|
||||||
|
// input.
|
||||||
|
static_cast<int>(testcase_input.size()) + /*padding*/ 2,
|
||||||
|
// Longest possible string when the `max_splits` value
|
||||||
|
// is < testcase_input.size() + 2 for padding.
|
||||||
|
testcase_max_splits)}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NGramHashTest, InvalidMaxSplitsValue) {
|
||||||
|
// Check that the op errors out when given an invalid max splits value.
|
||||||
|
const std::vector<int> invalid_max_splits({0, -1, -5, -100});
|
||||||
|
for (const int max_splits : invalid_max_splits) {
|
||||||
|
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
|
||||||
|
/*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits);
|
||||||
|
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) {
|
||||||
|
// Check that the op errors out when ngram lengths and vocab sizes mistmatch.
|
||||||
|
{
|
||||||
|
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300},
|
||||||
|
/*vocab_sizes=*/{1, 2});
|
||||||
|
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
|
||||||
|
/*vocab_sizes=*/{1, 2, 3});
|
||||||
|
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tflite::ops::custom
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "ngram_hash_ops_utils",
|
||||||
|
srcs = [
|
||||||
|
"ngram_hash_ops_utils.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"ngram_hash_ops_utils.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "ngram_hash_ops_utils_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"ngram_hash_ops_utils_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":ngram_hash_ops_utils",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "murmur",
|
||||||
|
srcs = ["murmur.cc"],
|
||||||
|
hdrs = ["murmur.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/port:integral_types",
|
||||||
|
"@com_google_absl//absl/base:core_headers",
|
||||||
|
"@com_google_absl//absl/base:endian",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "murmur_test",
|
||||||
|
srcs = ["murmur_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":murmur",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:integral_types",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,95 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Forked from a library written by Austin Appelby and Jyrki Alakuijala.
|
||||||
|
// Original copyright message below.
|
||||||
|
// Copyright 2009 Google Inc. All Rights Reserved.
|
||||||
|
// Author: aappleby@google.com (Austin Appleby)
|
||||||
|
// jyrki@google.com (Jyrki Alakuijala)
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "absl/base/internal/endian.h"
|
||||||
|
#include "absl/base/optimization.h"
|
||||||
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::absl::little_endian::Load64;
|
||||||
|
|
||||||
|
// Murmur 2.0 multiplication constant.
|
||||||
|
static const uint64_t kMul = 0xc6a4a7935bd1e995ULL;
|
||||||
|
|
||||||
|
// We need to mix some of the bits that get propagated and mixed into the
|
||||||
|
// high bits by multiplication back into the low bits. 17 last bits get
|
||||||
|
// a more efficiently mixed with this.
|
||||||
|
inline uint64_t ShiftMix(uint64_t val) { return val ^ (val >> 47); }
|
||||||
|
|
||||||
|
// Accumulate 8 bytes into 64-bit Murmur hash
|
||||||
|
inline uint64_t MurmurStep(uint64_t hash, uint64_t data) {
|
||||||
|
hash ^= ShiftMix(data * kMul) * kMul;
|
||||||
|
hash *= kMul;
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a uint64 from 1-8 bytes.
|
||||||
|
// 8 * len least significant bits are loaded from the memory with
|
||||||
|
// LittleEndian order. The 64 - 8 * len most significant bits are
|
||||||
|
// set all to 0.
|
||||||
|
// In latex-friendly words, this function returns:
|
||||||
|
// $\sum_{i=0}^{len-1} p[i] 256^{i}$, where p[i] is unsigned.
|
||||||
|
//
|
||||||
|
// This function is equivalent to:
|
||||||
|
// uint64 val = 0;
|
||||||
|
// memcpy(&val, p, len);
|
||||||
|
// return ToHost64(val);
|
||||||
|
//
|
||||||
|
// The caller needs to guarantee that 0 <= len <= 8.
|
||||||
|
uint64_t Load64VariableLength(const void* const p, int len) {
|
||||||
|
ABSL_ASSUME(len >= 0 && len <= 8);
|
||||||
|
uint64_t val = 0;
|
||||||
|
const uint8_t* const src = static_cast<const uint8_t*>(p);
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
val |= static_cast<uint64_t>(src[i]) << (8 * i);
|
||||||
|
}
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
unsigned long long MurmurHash64WithSeed(const char* buf, // NOLINT
|
||||||
|
const size_t len, const uint64_t seed) {
|
||||||
|
// Let's remove the bytes not divisible by the sizeof(uint64).
|
||||||
|
// This allows the inner loop to process the data as 64 bit integers.
|
||||||
|
const size_t len_aligned = len & ~0x7;
|
||||||
|
const char* const end = buf + len_aligned;
|
||||||
|
uint64_t hash = seed ^ (len * kMul);
|
||||||
|
for (const char* p = buf; p != end; p += 8) {
|
||||||
|
hash = MurmurStep(hash, Load64(p));
|
||||||
|
}
|
||||||
|
if ((len & 0x7) != 0) {
|
||||||
|
const uint64_t data = Load64VariableLength(end, len & 0x7);
|
||||||
|
hash ^= data;
|
||||||
|
hash *= kMul;
|
||||||
|
}
|
||||||
|
hash = ShiftMix(hash) * kMul;
|
||||||
|
hash = ShiftMix(hash);
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash
|
|
@ -0,0 +1,43 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Forked from a library written by Austin Appelby and Jyrki Alakuijala.
|
||||||
|
// Original copyright message below.
|
||||||
|
// Copyright 2009 Google Inc. All Rights Reserved.
|
||||||
|
// Author: aappleby@google.com (Austin Appelby)
|
||||||
|
// jyrki@google.com (Jyrki Alakuijala)
|
||||||
|
//
|
||||||
|
// MurmurHash is a fast multiplication and shifting based algorithm,
|
||||||
|
// based on Austin Appleby's MurmurHash 2.0 algorithm.
|
||||||
|
|
||||||
|
#ifndef UTIL_HASH_MURMUR_H_
|
||||||
|
#define UTIL_HASH_MURMUR_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdlib.h> // for size_t.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
|
||||||
|
|
||||||
|
// Hash function for a byte array. Has a seed which allows this hash function to
|
||||||
|
// be used in algorithms that need a family of parameterized hash functions.
|
||||||
|
// e.g. Minhash.
|
||||||
|
unsigned long long MurmurHash64WithSeed(const char* buf, size_t len, // NOLINT
|
||||||
|
uint64_t seed);
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash
|
||||||
|
|
||||||
|
#endif // UTIL_HASH_MURMUR_H_
|
|
@ -0,0 +1,66 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Forked from a test library written by Jyrki Alakuijala.
|
||||||
|
// Original copyright message below.
|
||||||
|
// Copyright 2009 Google Inc. All Rights Reserved.
|
||||||
|
// Author: jyrki@google.com (Jyrki Alakuijala)
|
||||||
|
//
|
||||||
|
// Tests for the fast hashing algorithm based on Austin Appleby's
|
||||||
|
// MurmurHash 2.0 algorithm. See http://murmurhash.googlepages.com/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||||
|
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
|
||||||
|
|
||||||
|
TEST(Murmur, EmptyData64) {
|
||||||
|
EXPECT_EQ(uint64_t{0}, MurmurHash64WithSeed(nullptr, uint64_t{0}, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Murmur, VaryWithDifferentSeeds) {
|
||||||
|
// While in theory different seeds could return the same
|
||||||
|
// hash for the same data this is unlikely.
|
||||||
|
char data1 = 'x';
|
||||||
|
EXPECT_NE(MurmurHash64WithSeed(&data1, 1, 100),
|
||||||
|
MurmurHash64WithSeed(&data1, 1, 101));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hashes don't change.
|
||||||
|
TEST(Murmur, Idempotence) {
|
||||||
|
const char data[] = "deadbeef";
|
||||||
|
const size_t dlen = strlen(data);
|
||||||
|
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
EXPECT_EQ(MurmurHash64WithSeed(data, dlen, i),
|
||||||
|
MurmurHash64WithSeed(data, dlen, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
const char next_data[] = "deadbeef000---";
|
||||||
|
const size_t next_dlen = strlen(next_data);
|
||||||
|
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
EXPECT_EQ(MurmurHash64WithSeed(next_data, next_dlen, i),
|
||||||
|
MurmurHash64WithSeed(next_data, next_dlen, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash
|
|
@ -0,0 +1,96 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops {
|
||||||
|
|
||||||
|
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
|
||||||
|
bool exclude_nonalphaspace_tokens) {
|
||||||
|
const std::string kPrefix = "^";
|
||||||
|
const std::string kSuffix = "$";
|
||||||
|
const std::string kReplacementToken = " ";
|
||||||
|
|
||||||
|
TokenizedOutput output;
|
||||||
|
|
||||||
|
size_t token_start = 0;
|
||||||
|
output.str.reserve(len + 2);
|
||||||
|
output.tokens.reserve(len + 2);
|
||||||
|
|
||||||
|
output.str.append(kPrefix);
|
||||||
|
output.tokens.push_back(std::make_pair(token_start, kPrefix.size()));
|
||||||
|
token_start += kPrefix.size();
|
||||||
|
|
||||||
|
Rune token;
|
||||||
|
for (int i = 0; i < len && output.tokens.size() + 1 < max_tokens;) {
|
||||||
|
// Use the standard UTF-8 library to find the next token.
|
||||||
|
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
|
||||||
|
|
||||||
|
// Stop processing, if we can't read any more tokens, or we have reached
|
||||||
|
// maximum allowed tokens, allocating one token for the suffix.
|
||||||
|
if (bytes_read == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If `exclude_nonalphaspace_tokens` is set to true, and the token is not
|
||||||
|
// alphanumeric, replace it with a replacement token.
|
||||||
|
if (exclude_nonalphaspace_tokens && !utf_isalpharune(token)) {
|
||||||
|
output.str.append(kReplacementToken);
|
||||||
|
output.tokens.push_back(
|
||||||
|
std::make_pair(token_start, kReplacementToken.size()));
|
||||||
|
token_start += kReplacementToken.size();
|
||||||
|
i += bytes_read;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append the token in the output string, and note its position and the
|
||||||
|
// number of bytes that token consumed.
|
||||||
|
output.str.append(input_str + i, bytes_read);
|
||||||
|
output.tokens.push_back(std::make_pair(token_start, bytes_read));
|
||||||
|
token_start += bytes_read;
|
||||||
|
i += bytes_read;
|
||||||
|
}
|
||||||
|
output.str.append(kSuffix);
|
||||||
|
output.tokens.push_back(std::make_pair(token_start, kSuffix.size()));
|
||||||
|
token_start += kSuffix.size();
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LowercaseUnicodeStr(const char* input_str, int len,
|
||||||
|
std::string* output_str) {
|
||||||
|
for (int i = 0; i < len;) {
|
||||||
|
Rune token;
|
||||||
|
|
||||||
|
// Tokenize the given string, and get the appropriate lowercase token.
|
||||||
|
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
|
||||||
|
token = utf_isalpharune(token) ? utf_tolowerrune(token) : token;
|
||||||
|
|
||||||
|
// Write back the token to the output string.
|
||||||
|
char token_buf[UTFmax];
|
||||||
|
size_t bytes_to_write = utf_runetochar(token_buf, &token);
|
||||||
|
output_str->append(token_buf, bytes_to_write);
|
||||||
|
|
||||||
|
i += bytes_read;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops
|
|
@ -0,0 +1,56 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops {
|
||||||
|
|
||||||
|
struct TokenizedOutput {
|
||||||
|
// The processed string (with necessary prefix, suffix, skipped tokens, etc.).
|
||||||
|
std::string str;
|
||||||
|
|
||||||
|
// This vector contains pairs, where each pair has two members. The first
|
||||||
|
// denoting the starting index of the token in the `str` string, and the
|
||||||
|
// second denoting the length of that token in bytes.
|
||||||
|
std::vector<std::pair<const size_t, const size_t>> tokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Tokenizes the given input string on Unicode token boundaries, with a maximum
|
||||||
|
// of `max_tokens` tokens.
|
||||||
|
//
|
||||||
|
// If `exclude_nonalphaspace_tokens` is enabled, the tokenization ignores
|
||||||
|
// non-alphanumeric tokens, and replaces them with a replacement token (" ").
|
||||||
|
//
|
||||||
|
// The method returns the output in the `TokenizedOutput` struct, which stores
|
||||||
|
// both, the processed input string, and the indices and sizes of each token
|
||||||
|
// within that string.
|
||||||
|
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
|
||||||
|
bool exclude_nonalphaspace_tokens);
|
||||||
|
|
||||||
|
// Converts the given unicode string (`input_str`) with the specified length
|
||||||
|
// (`len`) to a lowercase string.
|
||||||
|
//
|
||||||
|
// The method populates the lowercased string in `output_str`.
|
||||||
|
void LowercaseUnicodeStr(const char* input_str, int len,
|
||||||
|
std::string* output_str);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
|
|
@ -0,0 +1,135 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::Values;
|
||||||
|
|
||||||
|
std::string ReconstructStringFromTokens(TokenizedOutput output) {
|
||||||
|
std::string reconstructed_str;
|
||||||
|
for (int i = 0; i < output.tokens.size(); i++) {
|
||||||
|
reconstructed_str.append(
|
||||||
|
output.str.c_str() + output.tokens[i].first,
|
||||||
|
output.str.c_str() + output.tokens[i].first + output.tokens[i].second);
|
||||||
|
}
|
||||||
|
return reconstructed_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TokenizeTestParams {
|
||||||
|
std::string input_str;
|
||||||
|
size_t max_tokens;
|
||||||
|
bool exclude_nonalphaspace_tokens;
|
||||||
|
std::string expected_output_str;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TokenizeParameterizedTest
|
||||||
|
: public ::testing::Test,
|
||||||
|
public testing::WithParamInterface<TokenizeTestParams> {};
|
||||||
|
|
||||||
|
TEST_P(TokenizeParameterizedTest, Tokenize) {
|
||||||
|
// Checks that the Tokenize method returns the expected value.
|
||||||
|
const TokenizeTestParams params = TokenizeParameterizedTest::GetParam();
|
||||||
|
const TokenizedOutput output = Tokenize(
|
||||||
|
/*input_str=*/params.input_str.c_str(),
|
||||||
|
/*len=*/params.input_str.size(),
|
||||||
|
/*max_tokens=*/params.max_tokens,
|
||||||
|
/*exclude_nonalphaspace_tokens=*/params.exclude_nonalphaspace_tokens);
|
||||||
|
|
||||||
|
// The output string should have the necessary prefixes, and the "!" token
|
||||||
|
// should have been replaced with a " ".
|
||||||
|
EXPECT_EQ(output.str, params.expected_output_str);
|
||||||
|
EXPECT_EQ(ReconstructStringFromTokens(output), params.expected_output_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
TokenizeParameterizedTests, TokenizeParameterizedTest,
|
||||||
|
Values(
|
||||||
|
// Test including non-alphanumeric characters.
|
||||||
|
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
|
||||||
|
/*exclude_alphanonspace=*/false,
|
||||||
|
/*expected_output_str=*/"^hi!$"}),
|
||||||
|
// Test not including non-alphanumeric characters.
|
||||||
|
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
|
||||||
|
/*exclude_alphanonspace=*/true,
|
||||||
|
/*expected_output_str=*/"^hi $"}),
|
||||||
|
// Test with a maximum of 3 tokens.
|
||||||
|
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/3,
|
||||||
|
/*exclude_alphanonspace=*/true,
|
||||||
|
/*expected_output_str=*/"^h$"}),
|
||||||
|
// Test with non-latin characters.
|
||||||
|
TokenizeTestParams({/*input_str=*/"ありがと", /*max_tokens=*/100,
|
||||||
|
/*exclude_alphanonspace=*/true,
|
||||||
|
/*expected_output_str=*/"^ありがと$"})));
|
||||||
|
|
||||||
|
TEST(LowercaseUnicodeTest, TestLowercaseUnicode) {
|
||||||
|
{
|
||||||
|
// Check that the method is a no-op when the string is lowercase.
|
||||||
|
std::string input_str = "hello";
|
||||||
|
std::string output_str;
|
||||||
|
LowercaseUnicodeStr(
|
||||||
|
/*input_str=*/input_str.c_str(),
|
||||||
|
/*len=*/input_str.size(),
|
||||||
|
/*output_str=*/&output_str);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_str, "hello");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Check that the method has uppercase characters.
|
||||||
|
std::string input_str = "hElLo";
|
||||||
|
std::string output_str;
|
||||||
|
LowercaseUnicodeStr(
|
||||||
|
/*input_str=*/input_str.c_str(),
|
||||||
|
/*len=*/input_str.size(),
|
||||||
|
/*output_str=*/&output_str);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_str, "hello");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Check that the method works with non-latin scripts.
|
||||||
|
// Cyrillic has the concept of cases, so it should change the input.
|
||||||
|
std::string input_str = "БЙп";
|
||||||
|
std::string output_str;
|
||||||
|
LowercaseUnicodeStr(
|
||||||
|
/*input_str=*/input_str.c_str(),
|
||||||
|
/*len=*/input_str.size(),
|
||||||
|
/*output_str=*/&output_str);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_str, "бйп");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Check that the method works with non-latin scripts.
|
||||||
|
// Japanese doesn't have the concept of cases, so it should not change.
|
||||||
|
std::string input_str = "ありがと";
|
||||||
|
std::string output_str;
|
||||||
|
LowercaseUnicodeStr(
|
||||||
|
/*input_str=*/input_str.c_str(),
|
||||||
|
/*len=*/input_str.size(),
|
||||||
|
/*output_str=*/&output_str);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_str, "ありがと");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "utf",
|
||||||
|
srcs = [
|
||||||
|
"rune.c",
|
||||||
|
"runetype.c",
|
||||||
|
"runetypebody.h",
|
||||||
|
],
|
||||||
|
hdrs = ["utf.h"],
|
||||||
|
)
|
|
@ -0,0 +1,233 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Forked from a library written by Rob Pike and Ken Thompson. Original
|
||||||
|
// copyright message below.
|
||||||
|
/*
|
||||||
|
* The authors of this software are Rob Pike and Ken Thompson.
|
||||||
|
* Copyright (c) 2002 by Lucent Technologies.
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose without fee is hereby granted, provided that this entire notice
|
||||||
|
* is included in all copies of any software which is or includes a copy
|
||||||
|
* or modification of this software and in all copies of the supporting
|
||||||
|
* documentation for such software.
|
||||||
|
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
|
||||||
|
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
|
||||||
|
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
|
||||||
|
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
|
||||||
|
*/
|
||||||
|
#include <stdarg.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
|
||||||
|
|
||||||
|
enum
|
||||||
|
{
|
||||||
|
Bit1 = 7,
|
||||||
|
Bitx = 6,
|
||||||
|
Bit2 = 5,
|
||||||
|
Bit3 = 4,
|
||||||
|
Bit4 = 3,
|
||||||
|
Bit5 = 2,
|
||||||
|
|
||||||
|
T1 = ((1<<(Bit1+1))-1) ^ 0xFF, /* 0000 0000 */
|
||||||
|
Tx = ((1<<(Bitx+1))-1) ^ 0xFF, /* 1000 0000 */
|
||||||
|
T2 = ((1<<(Bit2+1))-1) ^ 0xFF, /* 1100 0000 */
|
||||||
|
T3 = ((1<<(Bit3+1))-1) ^ 0xFF, /* 1110 0000 */
|
||||||
|
T4 = ((1<<(Bit4+1))-1) ^ 0xFF, /* 1111 0000 */
|
||||||
|
T5 = ((1<<(Bit5+1))-1) ^ 0xFF, /* 1111 1000 */
|
||||||
|
|
||||||
|
Rune1 = (1<<(Bit1+0*Bitx))-1, /* 0000 0000 0111 1111 */
|
||||||
|
Rune2 = (1<<(Bit2+1*Bitx))-1, /* 0000 0111 1111 1111 */
|
||||||
|
Rune3 = (1<<(Bit3+2*Bitx))-1, /* 1111 1111 1111 1111 */
|
||||||
|
Rune4 = (1<<(Bit4+3*Bitx))-1,
|
||||||
|
/* 0001 1111 1111 1111 1111 1111 */
|
||||||
|
|
||||||
|
Maskx = (1<<Bitx)-1, /* 0011 1111 */
|
||||||
|
Testx = Maskx ^ 0xFF, /* 1100 0000 */
|
||||||
|
|
||||||
|
Bad = Runeerror,
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Modified by Wei-Hwa Huang, Google Inc., on 2004-09-24
|
||||||
|
* This is a slower but "safe" version of the old chartorune
|
||||||
|
* that works on strings that are not necessarily null-terminated.
|
||||||
|
*
|
||||||
|
* If you know for sure that your string is null-terminated,
|
||||||
|
* chartorune will be a bit faster.
|
||||||
|
*
|
||||||
|
* It is guaranteed not to attempt to access "length"
|
||||||
|
* past the incoming pointer. This is to avoid
|
||||||
|
* possible access violations. If the string appears to be
|
||||||
|
* well-formed but incomplete (i.e., to get the whole Rune
|
||||||
|
* we'd need to read past str+length) then we'll set the Rune
|
||||||
|
* to Bad and return 0.
|
||||||
|
*
|
||||||
|
* Note that if we have decoding problems for other
|
||||||
|
* reasons, we return 1 instead of 0.
|
||||||
|
*/
|
||||||
|
int
|
||||||
|
utf_charntorune(Rune *rune, const char *str, int length)
|
||||||
|
{
|
||||||
|
int c, c1, c2, c3;
|
||||||
|
long l;
|
||||||
|
|
||||||
|
/* When we're not allowed to read anything */
|
||||||
|
if(length <= 0) {
|
||||||
|
goto badlen;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* one character sequence (7-bit value)
|
||||||
|
* 00000-0007F => T1
|
||||||
|
*/
|
||||||
|
c = *(uchar*)str;
|
||||||
|
if(c < Tx) {
|
||||||
|
*rune = c;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we can't read more than one character we must stop
|
||||||
|
if(length <= 1) {
|
||||||
|
goto badlen;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* two character sequence (11-bit value)
|
||||||
|
* 0080-07FF => T2 Tx
|
||||||
|
*/
|
||||||
|
c1 = *(uchar*)(str+1) ^ Tx;
|
||||||
|
if(c1 & Testx)
|
||||||
|
goto bad;
|
||||||
|
if(c < T3) {
|
||||||
|
if(c < T2)
|
||||||
|
goto bad;
|
||||||
|
l = ((c << Bitx) | c1) & Rune2;
|
||||||
|
if(l <= Rune1)
|
||||||
|
goto bad;
|
||||||
|
*rune = l;
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we can't read more than two characters we must stop
|
||||||
|
if(length <= 2) {
|
||||||
|
goto badlen;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* three character sequence (16-bit value)
|
||||||
|
* 0800-FFFF => T3 Tx Tx
|
||||||
|
*/
|
||||||
|
c2 = *(uchar*)(str+2) ^ Tx;
|
||||||
|
if(c2 & Testx)
|
||||||
|
goto bad;
|
||||||
|
if(c < T4) {
|
||||||
|
l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3;
|
||||||
|
if(l <= Rune2)
|
||||||
|
goto bad;
|
||||||
|
*rune = l;
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (length <= 3)
|
||||||
|
goto badlen;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* four character sequence (21-bit value)
|
||||||
|
* 10000-1FFFFF => T4 Tx Tx Tx
|
||||||
|
*/
|
||||||
|
c3 = *(uchar*)(str+3) ^ Tx;
|
||||||
|
if (c3 & Testx)
|
||||||
|
goto bad;
|
||||||
|
if (c < T5) {
|
||||||
|
l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4;
|
||||||
|
if (l <= Rune3)
|
||||||
|
goto bad;
|
||||||
|
if (l > Runemax)
|
||||||
|
goto bad;
|
||||||
|
*rune = l;
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Support for 5-byte or longer UTF-8 would go here, but
|
||||||
|
// since we don't have that, we'll just fall through to bad.
|
||||||
|
|
||||||
|
/*
|
||||||
|
* bad decoding
|
||||||
|
*/
|
||||||
|
bad:
|
||||||
|
*rune = Bad;
|
||||||
|
return 1;
|
||||||
|
badlen:
|
||||||
|
*rune = Bad;
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
utf_runetochar(char *str, const Rune *rune)
|
||||||
|
{
|
||||||
|
/* Runes are signed, so convert to unsigned for range check. */
|
||||||
|
unsigned long c;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* one character sequence
|
||||||
|
* 00000-0007F => 00-7F
|
||||||
|
*/
|
||||||
|
c = *rune;
|
||||||
|
if(c <= Rune1) {
|
||||||
|
str[0] = c;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* two character sequence
|
||||||
|
* 0080-07FF => T2 Tx
|
||||||
|
*/
|
||||||
|
if(c <= Rune2) {
|
||||||
|
str[0] = T2 | (c >> 1*Bitx);
|
||||||
|
str[1] = Tx | (c & Maskx);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* If the Rune is out of range, convert it to the error rune.
|
||||||
|
* Do this test here because the error rune encodes to three bytes.
|
||||||
|
* Doing it earlier would duplicate work, since an out of range
|
||||||
|
* Rune wouldn't have fit in one or two bytes.
|
||||||
|
*/
|
||||||
|
if (c > Runemax)
|
||||||
|
c = Runeerror;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* three character sequence
|
||||||
|
* 0800-FFFF => T3 Tx Tx
|
||||||
|
*/
|
||||||
|
if (c <= Rune3) {
|
||||||
|
str[0] = T3 | (c >> 2*Bitx);
|
||||||
|
str[1] = Tx | ((c >> 1*Bitx) & Maskx);
|
||||||
|
str[2] = Tx | (c & Maskx);
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* four character sequence (21-bit value)
|
||||||
|
* 10000-1FFFFF => T4 Tx Tx Tx
|
||||||
|
*/
|
||||||
|
str[0] = T4 | (c >> 3*Bitx);
|
||||||
|
str[1] = Tx | ((c >> 2*Bitx) & Maskx);
|
||||||
|
str[2] = Tx | ((c >> 1*Bitx) & Maskx);
|
||||||
|
str[3] = Tx | (c & Maskx);
|
||||||
|
return 4;
|
||||||
|
}
|
|
@ -0,0 +1,54 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Forked from a library written by Rob Pike and Ken Thompson. Original
|
||||||
|
// copyright message below.
|
||||||
|
/*
|
||||||
|
* The authors of this software are Rob Pike and Ken Thompson.
|
||||||
|
* Copyright (c) 2002 by Lucent Technologies.
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose without fee is hereby granted, provided that this entire notice
|
||||||
|
* is included in all copies of any software which is or includes a copy
|
||||||
|
* or modification of this software and in all copies of the supporting
|
||||||
|
* documentation for such software.
|
||||||
|
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
|
||||||
|
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
|
||||||
|
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
|
||||||
|
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
|
||||||
|
*/
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
|
||||||
|
|
||||||
|
static
|
||||||
|
Rune*
|
||||||
|
rbsearch(Rune c, Rune *t, int n, int ne)
|
||||||
|
{
|
||||||
|
Rune *p;
|
||||||
|
int m;
|
||||||
|
|
||||||
|
while(n > 1) {
|
||||||
|
m = n >> 1;
|
||||||
|
p = t + m*ne;
|
||||||
|
if(c >= p[0]) {
|
||||||
|
t = p;
|
||||||
|
n = n-m;
|
||||||
|
} else
|
||||||
|
n = m;
|
||||||
|
}
|
||||||
|
if(n && c >= t[0])
|
||||||
|
return t;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define RUNETYPEBODY
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/runetypebody.h"
|
|
@ -0,0 +1,212 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifdef RUNETYPEBODY
|
||||||
|
|
||||||
|
static Rune __isalphar[] = {
|
||||||
|
0x0041, 0x005a, 0x0061, 0x007a, 0x00c0, 0x00d6, 0x00d8, 0x00f6,
|
||||||
|
0x00f8, 0x02c1, 0x02c6, 0x02d1, 0x02e0, 0x02e4, 0x0370, 0x0374,
|
||||||
|
0x0376, 0x0377, 0x037a, 0x037d, 0x0388, 0x038a, 0x038e, 0x03a1,
|
||||||
|
0x03a3, 0x03f5, 0x03f7, 0x0481, 0x048a, 0x0527, 0x0531, 0x0556,
|
||||||
|
0x0561, 0x0587, 0x05d0, 0x05ea, 0x05f0, 0x05f2, 0x0620, 0x064a,
|
||||||
|
0x066e, 0x066f, 0x0671, 0x06d3, 0x06e5, 0x06e6, 0x06ee, 0x06ef,
|
||||||
|
0x06fa, 0x06fc, 0x0712, 0x072f, 0x074d, 0x07a5, 0x07ca, 0x07ea,
|
||||||
|
0x07f4, 0x07f5, 0x0800, 0x0815, 0x0840, 0x0858, 0x08a2, 0x08ac,
|
||||||
|
0x0904, 0x0939, 0x0958, 0x0961, 0x0971, 0x0977, 0x0979, 0x097f,
|
||||||
|
0x0985, 0x098c, 0x098f, 0x0990, 0x0993, 0x09a8, 0x09aa, 0x09b0,
|
||||||
|
0x09b6, 0x09b9, 0x09dc, 0x09dd, 0x09df, 0x09e1, 0x09f0, 0x09f1,
|
||||||
|
0x0a05, 0x0a0a, 0x0a0f, 0x0a10, 0x0a13, 0x0a28, 0x0a2a, 0x0a30,
|
||||||
|
0x0a32, 0x0a33, 0x0a35, 0x0a36, 0x0a38, 0x0a39, 0x0a59, 0x0a5c,
|
||||||
|
0x0a72, 0x0a74, 0x0a85, 0x0a8d, 0x0a8f, 0x0a91, 0x0a93, 0x0aa8,
|
||||||
|
0x0aaa, 0x0ab0, 0x0ab2, 0x0ab3, 0x0ab5, 0x0ab9, 0x0ae0, 0x0ae1,
|
||||||
|
0x0b05, 0x0b0c, 0x0b0f, 0x0b10, 0x0b13, 0x0b28, 0x0b2a, 0x0b30,
|
||||||
|
0x0b32, 0x0b33, 0x0b35, 0x0b39, 0x0b5c, 0x0b5d, 0x0b5f, 0x0b61,
|
||||||
|
0x0b85, 0x0b8a, 0x0b8e, 0x0b90, 0x0b92, 0x0b95, 0x0b99, 0x0b9a,
|
||||||
|
0x0b9e, 0x0b9f, 0x0ba3, 0x0ba4, 0x0ba8, 0x0baa, 0x0bae, 0x0bb9,
|
||||||
|
0x0c05, 0x0c0c, 0x0c0e, 0x0c10, 0x0c12, 0x0c28, 0x0c2a, 0x0c33,
|
||||||
|
0x0c35, 0x0c39, 0x0c58, 0x0c59, 0x0c60, 0x0c61, 0x0c85, 0x0c8c,
|
||||||
|
0x0c8e, 0x0c90, 0x0c92, 0x0ca8, 0x0caa, 0x0cb3, 0x0cb5, 0x0cb9,
|
||||||
|
0x0ce0, 0x0ce1, 0x0cf1, 0x0cf2, 0x0d05, 0x0d0c, 0x0d0e, 0x0d10,
|
||||||
|
0x0d12, 0x0d3a, 0x0d60, 0x0d61, 0x0d7a, 0x0d7f, 0x0d85, 0x0d96,
|
||||||
|
0x0d9a, 0x0db1, 0x0db3, 0x0dbb, 0x0dc0, 0x0dc6, 0x0e01, 0x0e30,
|
||||||
|
0x0e32, 0x0e33, 0x0e40, 0x0e46, 0x0e81, 0x0e82, 0x0e87, 0x0e88,
|
||||||
|
0x0e94, 0x0e97, 0x0e99, 0x0e9f, 0x0ea1, 0x0ea3, 0x0eaa, 0x0eab,
|
||||||
|
0x0ead, 0x0eb0, 0x0eb2, 0x0eb3, 0x0ec0, 0x0ec4, 0x0edc, 0x0edf,
|
||||||
|
0x0f40, 0x0f47, 0x0f49, 0x0f6c, 0x0f88, 0x0f8c, 0x1000, 0x102a,
|
||||||
|
0x1050, 0x1055, 0x105a, 0x105d, 0x1065, 0x1066, 0x106e, 0x1070,
|
||||||
|
0x1075, 0x1081, 0x10a0, 0x10c5, 0x10d0, 0x10fa, 0x10fc, 0x1248,
|
||||||
|
0x124a, 0x124d, 0x1250, 0x1256, 0x125a, 0x125d, 0x1260, 0x1288,
|
||||||
|
0x128a, 0x128d, 0x1290, 0x12b0, 0x12b2, 0x12b5, 0x12b8, 0x12be,
|
||||||
|
0x12c2, 0x12c5, 0x12c8, 0x12d6, 0x12d8, 0x1310, 0x1312, 0x1315,
|
||||||
|
0x1318, 0x135a, 0x1380, 0x138f, 0x13a0, 0x13f4, 0x1401, 0x166c,
|
||||||
|
0x166f, 0x167f, 0x1681, 0x169a, 0x16a0, 0x16ea, 0x1700, 0x170c,
|
||||||
|
0x170e, 0x1711, 0x1720, 0x1731, 0x1740, 0x1751, 0x1760, 0x176c,
|
||||||
|
0x176e, 0x1770, 0x1780, 0x17b3, 0x1820, 0x1877, 0x1880, 0x18a8,
|
||||||
|
0x18b0, 0x18f5, 0x1900, 0x191c, 0x1950, 0x196d, 0x1970, 0x1974,
|
||||||
|
0x1980, 0x19ab, 0x19c1, 0x19c7, 0x1a00, 0x1a16, 0x1a20, 0x1a54,
|
||||||
|
0x1b05, 0x1b33, 0x1b45, 0x1b4b, 0x1b83, 0x1ba0, 0x1bae, 0x1baf,
|
||||||
|
0x1bba, 0x1be5, 0x1c00, 0x1c23, 0x1c4d, 0x1c4f, 0x1c5a, 0x1c7d,
|
||||||
|
0x1ce9, 0x1cec, 0x1cee, 0x1cf1, 0x1cf5, 0x1cf6, 0x1d00, 0x1dbf,
|
||||||
|
0x1e00, 0x1f15, 0x1f18, 0x1f1d, 0x1f20, 0x1f45, 0x1f48, 0x1f4d,
|
||||||
|
0x1f50, 0x1f57, 0x1f5f, 0x1f7d, 0x1f80, 0x1fb4, 0x1fb6, 0x1fbc,
|
||||||
|
0x1fc2, 0x1fc4, 0x1fc6, 0x1fcc, 0x1fd0, 0x1fd3, 0x1fd6, 0x1fdb,
|
||||||
|
0x1fe0, 0x1fec, 0x1ff2, 0x1ff4, 0x1ff6, 0x1ffc, 0x2090, 0x209c,
|
||||||
|
0x210a, 0x2113, 0x2119, 0x211d, 0x212a, 0x212d, 0x212f, 0x2139,
|
||||||
|
0x213c, 0x213f, 0x2145, 0x2149, 0x2183, 0x2184, 0x2c00, 0x2c2e,
|
||||||
|
0x2c30, 0x2c5e, 0x2c60, 0x2ce4, 0x2ceb, 0x2cee, 0x2cf2, 0x2cf3,
|
||||||
|
0x2d00, 0x2d25, 0x2d30, 0x2d67, 0x2d80, 0x2d96, 0x2da0, 0x2da6,
|
||||||
|
0x2da8, 0x2dae, 0x2db0, 0x2db6, 0x2db8, 0x2dbe, 0x2dc0, 0x2dc6,
|
||||||
|
0x2dc8, 0x2dce, 0x2dd0, 0x2dd6, 0x2dd8, 0x2dde, 0x3005, 0x3006,
|
||||||
|
0x3031, 0x3035, 0x303b, 0x303c, 0x3041, 0x3096, 0x309d, 0x309f,
|
||||||
|
0x30a1, 0x30fa, 0x30fc, 0x30ff, 0x3105, 0x312d, 0x3131, 0x318e,
|
||||||
|
0x31a0, 0x31ba, 0x31f0, 0x31ff, 0x3400, 0x4db5, 0x4e00, 0x9fcc,
|
||||||
|
0xa000, 0xa48c, 0xa4d0, 0xa4fd, 0xa500, 0xa60c, 0xa610, 0xa61f,
|
||||||
|
0xa62a, 0xa62b, 0xa640, 0xa66e, 0xa67f, 0xa697, 0xa6a0, 0xa6e5,
|
||||||
|
0xa717, 0xa71f, 0xa722, 0xa788, 0xa78b, 0xa78e, 0xa790, 0xa793,
|
||||||
|
0xa7a0, 0xa7aa, 0xa7f8, 0xa801, 0xa803, 0xa805, 0xa807, 0xa80a,
|
||||||
|
0xa80c, 0xa822, 0xa840, 0xa873, 0xa882, 0xa8b3, 0xa8f2, 0xa8f7,
|
||||||
|
0xa90a, 0xa925, 0xa930, 0xa946, 0xa960, 0xa97c, 0xa984, 0xa9b2,
|
||||||
|
0xaa00, 0xaa28, 0xaa40, 0xaa42, 0xaa44, 0xaa4b, 0xaa60, 0xaa76,
|
||||||
|
0xaa80, 0xaaaf, 0xaab5, 0xaab6, 0xaab9, 0xaabd, 0xaadb, 0xaadd,
|
||||||
|
0xaae0, 0xaaea, 0xaaf2, 0xaaf4, 0xab01, 0xab06, 0xab09, 0xab0e,
|
||||||
|
0xab11, 0xab16, 0xab20, 0xab26, 0xab28, 0xab2e, 0xabc0, 0xabe2,
|
||||||
|
0xac00, 0xd7a3, 0xd7b0, 0xd7c6, 0xd7cb, 0xd7fb, 0xf900, 0xfa6d,
|
||||||
|
0xfa70, 0xfad9, 0xfb00, 0xfb06, 0xfb13, 0xfb17, 0xfb1f, 0xfb28,
|
||||||
|
0xfb2a, 0xfb36, 0xfb38, 0xfb3c, 0xfb40, 0xfb41, 0xfb43, 0xfb44,
|
||||||
|
0xfb46, 0xfbb1, 0xfbd3, 0xfd3d, 0xfd50, 0xfd8f, 0xfd92, 0xfdc7,
|
||||||
|
0xfdf0, 0xfdfb, 0xfe70, 0xfe74, 0xfe76, 0xfefc, 0xff21, 0xff3a,
|
||||||
|
0xff41, 0xff5a, 0xff66, 0xffbe, 0xffc2, 0xffc7, 0xffca, 0xffcf,
|
||||||
|
0xffd2, 0xffd7, 0xffda, 0xffdc, 0x10000, 0x1000b, 0x1000d, 0x10026,
|
||||||
|
0x10028, 0x1003a, 0x1003c, 0x1003d, 0x1003f, 0x1004d, 0x10050, 0x1005d,
|
||||||
|
0x10080, 0x100fa, 0x10280, 0x1029c, 0x102a0, 0x102d0, 0x10300, 0x1031e,
|
||||||
|
0x10330, 0x10340, 0x10342, 0x10349, 0x10380, 0x1039d, 0x103a0, 0x103c3,
|
||||||
|
0x103c8, 0x103cf, 0x10400, 0x1049d, 0x10800, 0x10805, 0x1080a, 0x10835,
|
||||||
|
0x10837, 0x10838, 0x1083f, 0x10855, 0x10900, 0x10915, 0x10920, 0x10939,
|
||||||
|
0x10980, 0x109b7, 0x109be, 0x109bf, 0x10a10, 0x10a13, 0x10a15, 0x10a17,
|
||||||
|
0x10a19, 0x10a33, 0x10a60, 0x10a7c, 0x10b00, 0x10b35, 0x10b40, 0x10b55,
|
||||||
|
0x10b60, 0x10b72, 0x10c00, 0x10c48, 0x11003, 0x11037, 0x11083, 0x110af,
|
||||||
|
0x110d0, 0x110e8, 0x11103, 0x11126, 0x11183, 0x111b2, 0x111c1, 0x111c4,
|
||||||
|
0x11680, 0x116aa, 0x12000, 0x1236e, 0x13000, 0x1342e, 0x16800, 0x16a38,
|
||||||
|
0x16f00, 0x16f44, 0x16f93, 0x16f9f, 0x1b000, 0x1b001, 0x1d400, 0x1d454,
|
||||||
|
0x1d456, 0x1d49c, 0x1d49e, 0x1d49f, 0x1d4a5, 0x1d4a6, 0x1d4a9, 0x1d4ac,
|
||||||
|
0x1d4ae, 0x1d4b9, 0x1d4bd, 0x1d4c3, 0x1d4c5, 0x1d505, 0x1d507, 0x1d50a,
|
||||||
|
0x1d50d, 0x1d514, 0x1d516, 0x1d51c, 0x1d51e, 0x1d539, 0x1d53b, 0x1d53e,
|
||||||
|
0x1d540, 0x1d544, 0x1d54a, 0x1d550, 0x1d552, 0x1d6a5, 0x1d6a8, 0x1d6c0,
|
||||||
|
0x1d6c2, 0x1d6da, 0x1d6dc, 0x1d6fa, 0x1d6fc, 0x1d714, 0x1d716, 0x1d734,
|
||||||
|
0x1d736, 0x1d74e, 0x1d750, 0x1d76e, 0x1d770, 0x1d788, 0x1d78a, 0x1d7a8,
|
||||||
|
0x1d7aa, 0x1d7c2, 0x1d7c4, 0x1d7cb, 0x1ee00, 0x1ee03, 0x1ee05, 0x1ee1f,
|
||||||
|
0x1ee21, 0x1ee22, 0x1ee29, 0x1ee32, 0x1ee34, 0x1ee37, 0x1ee4d, 0x1ee4f,
|
||||||
|
0x1ee51, 0x1ee52, 0x1ee61, 0x1ee62, 0x1ee67, 0x1ee6a, 0x1ee6c, 0x1ee72,
|
||||||
|
0x1ee74, 0x1ee77, 0x1ee79, 0x1ee7c, 0x1ee80, 0x1ee89, 0x1ee8b, 0x1ee9b,
|
||||||
|
0x1eea1, 0x1eea3, 0x1eea5, 0x1eea9, 0x1eeab, 0x1eebb, 0x20000, 0x2a6d6,
|
||||||
|
0x2a700, 0x2b734, 0x2b740, 0x2b81d, 0x2f800, 0x2fa1d,
|
||||||
|
};
|
||||||
|
|
||||||
|
static Rune __isalphas[] = {
|
||||||
|
0x00aa, 0x00b5, 0x00ba, 0x02ec, 0x02ee, 0x0386, 0x038c, 0x0559,
|
||||||
|
0x06d5, 0x06ff, 0x0710, 0x07b1, 0x07fa, 0x081a, 0x0824, 0x0828,
|
||||||
|
0x08a0, 0x093d, 0x0950, 0x09b2, 0x09bd, 0x09ce, 0x0a5e, 0x0abd,
|
||||||
|
0x0ad0, 0x0b3d, 0x0b71, 0x0b83, 0x0b9c, 0x0bd0, 0x0c3d, 0x0cbd,
|
||||||
|
0x0cde, 0x0d3d, 0x0d4e, 0x0dbd, 0x0e84, 0x0e8a, 0x0e8d, 0x0ea5,
|
||||||
|
0x0ea7, 0x0ebd, 0x0ec6, 0x0f00, 0x103f, 0x1061, 0x108e, 0x10c7,
|
||||||
|
0x10cd, 0x1258, 0x12c0, 0x17d7, 0x17dc, 0x18aa, 0x1aa7, 0x1f59,
|
||||||
|
0x1f5b, 0x1f5d, 0x1fbe, 0x2071, 0x207f, 0x2102, 0x2107, 0x2115,
|
||||||
|
0x2124, 0x2126, 0x2128, 0x214e, 0x2d27, 0x2d2d, 0x2d6f, 0x2e2f,
|
||||||
|
0xa8fb, 0xa9cf, 0xaa7a, 0xaab1, 0xaac0, 0xaac2, 0xfb1d, 0xfb3e,
|
||||||
|
0x10808, 0x1083c, 0x10a00, 0x16f50, 0x1d4a2, 0x1d4bb, 0x1d546, 0x1ee24,
|
||||||
|
0x1ee27, 0x1ee39, 0x1ee3b, 0x1ee42, 0x1ee47, 0x1ee49, 0x1ee4b, 0x1ee54,
|
||||||
|
0x1ee57, 0x1ee59, 0x1ee5b, 0x1ee5d, 0x1ee5f, 0x1ee64, 0x1ee7e,
|
||||||
|
};
|
||||||
|
|
||||||
|
int utf_isalpharune(Rune c) {
|
||||||
|
Rune *p;
|
||||||
|
|
||||||
|
p = rbsearch(c, __isalphar, nelem(__isalphar) / 2, 2);
|
||||||
|
if (p && c >= p[0] && c <= p[1]) return 1;
|
||||||
|
p = rbsearch(c, __isalphas, nelem(__isalphas), 1);
|
||||||
|
if (p && c == p[0]) return 1;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Rune __tolowerr[] = {
|
||||||
|
0x0041, 0x005a, 1048608, 0x00c0, 0x00d6, 1048608, 0x00d8, 0x00de, 1048608,
|
||||||
|
0x0189, 0x018a, 1048781, 0x01b1, 0x01b2, 1048793, 0x0388, 0x038a, 1048613,
|
||||||
|
0x038e, 0x038f, 1048639, 0x0391, 0x03a1, 1048608, 0x03a3, 0x03ab, 1048608,
|
||||||
|
0x03fd, 0x03ff, 1048446, 0x0400, 0x040f, 1048656, 0x0410, 0x042f, 1048608,
|
||||||
|
0x0531, 0x0556, 1048624, 0x10a0, 0x10c5, 1055840, 0x1f08, 0x1f0f, 1048568,
|
||||||
|
0x1f18, 0x1f1d, 1048568, 0x1f28, 0x1f2f, 1048568, 0x1f38, 0x1f3f, 1048568,
|
||||||
|
0x1f48, 0x1f4d, 1048568, 0x1f68, 0x1f6f, 1048568, 0x1f88, 0x1f8f, 1048568,
|
||||||
|
0x1f98, 0x1f9f, 1048568, 0x1fa8, 0x1faf, 1048568, 0x1fb8, 0x1fb9, 1048568,
|
||||||
|
0x1fba, 0x1fbb, 1048502, 0x1fc8, 0x1fcb, 1048490, 0x1fd8, 0x1fd9, 1048568,
|
||||||
|
0x1fda, 0x1fdb, 1048476, 0x1fe8, 0x1fe9, 1048568, 0x1fea, 0x1feb, 1048464,
|
||||||
|
0x1ff8, 0x1ff9, 1048448, 0x1ffa, 0x1ffb, 1048450, 0x2160, 0x216f, 1048592,
|
||||||
|
0x24b6, 0x24cf, 1048602, 0x2c00, 0x2c2e, 1048624, 0x2c7e, 0x2c7f, 1037761,
|
||||||
|
0xff21, 0xff3a, 1048608, 0x10400, 0x10427, 1048616,
|
||||||
|
};
|
||||||
|
|
||||||
|
static Rune __tolowerp[] = {
|
||||||
|
0x0100, 0x012e, 1048577, 0x0132, 0x0136, 1048577, 0x0139, 0x0147, 1048577,
|
||||||
|
0x014a, 0x0176, 1048577, 0x017b, 0x017d, 1048577, 0x01a2, 0x01a4, 1048577,
|
||||||
|
0x01b3, 0x01b5, 1048577, 0x01cd, 0x01db, 1048577, 0x01de, 0x01ee, 1048577,
|
||||||
|
0x01f8, 0x021e, 1048577, 0x0222, 0x0232, 1048577, 0x0248, 0x024e, 1048577,
|
||||||
|
0x0370, 0x0372, 1048577, 0x03d8, 0x03ee, 1048577, 0x0460, 0x0480, 1048577,
|
||||||
|
0x048a, 0x04be, 1048577, 0x04c3, 0x04cd, 1048577, 0x04d0, 0x0526, 1048577,
|
||||||
|
0x1e00, 0x1e94, 1048577, 0x1ea0, 0x1efe, 1048577, 0x1f59, 0x1f5f, 1048568,
|
||||||
|
0x2c67, 0x2c6b, 1048577, 0x2c80, 0x2ce2, 1048577, 0x2ceb, 0x2ced, 1048577,
|
||||||
|
0xa640, 0xa66c, 1048577, 0xa680, 0xa696, 1048577, 0xa722, 0xa72e, 1048577,
|
||||||
|
0xa732, 0xa76e, 1048577, 0xa779, 0xa77b, 1048577, 0xa780, 0xa786, 1048577,
|
||||||
|
0xa790, 0xa792, 1048577, 0xa7a0, 0xa7a8, 1048577,
|
||||||
|
};
|
||||||
|
|
||||||
|
static Rune __tolowers[] = {
|
||||||
|
0x0130, 1048377, 0x0178, 1048455, 0x0179, 1048577, 0x0181, 1048786,
|
||||||
|
0x0182, 1048577, 0x0184, 1048577, 0x0186, 1048782, 0x0187, 1048577,
|
||||||
|
0x018b, 1048577, 0x018e, 1048655, 0x018f, 1048778, 0x0190, 1048779,
|
||||||
|
0x0191, 1048577, 0x0193, 1048781, 0x0194, 1048783, 0x0196, 1048787,
|
||||||
|
0x0197, 1048785, 0x0198, 1048577, 0x019c, 1048787, 0x019d, 1048789,
|
||||||
|
0x019f, 1048790, 0x01a0, 1048577, 0x01a6, 1048794, 0x01a7, 1048577,
|
||||||
|
0x01a9, 1048794, 0x01ac, 1048577, 0x01ae, 1048794, 0x01af, 1048577,
|
||||||
|
0x01b7, 1048795, 0x01b8, 1048577, 0x01bc, 1048577, 0x01c4, 1048578,
|
||||||
|
0x01c5, 1048577, 0x01c7, 1048578, 0x01c8, 1048577, 0x01ca, 1048578,
|
||||||
|
0x01cb, 1048577, 0x01f1, 1048578, 0x01f2, 1048577, 0x01f4, 1048577,
|
||||||
|
0x01f6, 1048479, 0x01f7, 1048520, 0x0220, 1048446, 0x023a, 1059371,
|
||||||
|
0x023b, 1048577, 0x023d, 1048413, 0x023e, 1059368, 0x0241, 1048577,
|
||||||
|
0x0243, 1048381, 0x0244, 1048645, 0x0245, 1048647, 0x0246, 1048577,
|
||||||
|
0x0376, 1048577, 0x0386, 1048614, 0x038c, 1048640, 0x03cf, 1048584,
|
||||||
|
0x03f4, 1048516, 0x03f7, 1048577, 0x03f9, 1048569, 0x03fa, 1048577,
|
||||||
|
0x04c0, 1048591, 0x04c1, 1048577, 0x10c7, 1055840, 0x10cd, 1055840,
|
||||||
|
0x1e9e, 1040961, 0x1fbc, 1048567, 0x1fcc, 1048567, 0x1fec, 1048569,
|
||||||
|
0x1ffc, 1048567, 0x2126, 1041059, 0x212a, 1040193, 0x212b, 1040314,
|
||||||
|
0x2132, 1048604, 0x2183, 1048577, 0x2c60, 1048577, 0x2c62, 1037833,
|
||||||
|
0x2c63, 1044762, 0x2c64, 1037849, 0x2c6d, 1037796, 0x2c6e, 1037827,
|
||||||
|
0x2c6f, 1037793, 0x2c70, 1037794, 0x2c72, 1048577, 0x2c75, 1048577,
|
||||||
|
0x2cf2, 1048577, 0xa77d, 1013244, 0xa77e, 1048577, 0xa78b, 1048577,
|
||||||
|
0xa78d, 1006296, 0xa7aa, 1006268,
|
||||||
|
};
|
||||||
|
|
||||||
|
Rune utf_tolowerrune(Rune c) {
|
||||||
|
Rune *p;
|
||||||
|
|
||||||
|
p = rbsearch(c, __tolowerr, nelem(__tolowerr) / 3, 3);
|
||||||
|
if (p && c >= p[0] && c <= p[1]) return c + p[2] - 1048576;
|
||||||
|
p = rbsearch(c, __tolowerp, nelem(__tolowerp) / 3, 3);
|
||||||
|
if (p && c >= p[0] && c <= p[1] && !((c - p[0]) & 1))
|
||||||
|
return c + p[2] - 1048576;
|
||||||
|
p = rbsearch(c, __tolowers, nelem(__tolowers) / 2, 2);
|
||||||
|
if (p && c == p[0]) return c + p[1] - 1048576;
|
||||||
|
return c;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,98 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Fork of several UTF utils originally written by Rob Pike and Ken Thompson.
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_ 1
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
// Code-point values in Unicode 4.0 are 21 bits wide.
|
||||||
|
typedef signed int Rune;
|
||||||
|
|
||||||
|
#define uchar _utfuchar
|
||||||
|
|
||||||
|
typedef unsigned char uchar;
|
||||||
|
|
||||||
|
#define nelem(x) (sizeof(x) / sizeof((x)[0]))
|
||||||
|
|
||||||
|
enum {
|
||||||
|
UTFmax = 4, // maximum bytes per rune
|
||||||
|
Runeerror = 0xFFFD, // decoding error in UTF
|
||||||
|
Runemax = 0x10FFFF, // maximum rune value
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/*
|
||||||
|
* rune routines
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
* These routines were written by Rob Pike and Ken Thompson
|
||||||
|
* and first appeared in Plan 9.
|
||||||
|
* SEE ALSO
|
||||||
|
* utf (7)
|
||||||
|
* tcs (1)
|
||||||
|
*/
|
||||||
|
|
||||||
|
// utf_runetochar copies (encodes) one rune, pointed to by r, to at most
|
||||||
|
// UTFmax bytes starting at s and returns the number of bytes generated.
|
||||||
|
|
||||||
|
int utf_runetochar(char* s, const Rune* r);
|
||||||
|
|
||||||
|
// utf_charntorune copies (decodes) at most UTFmax bytes starting at `str` to
|
||||||
|
// one rune, pointed to by `rune`, access at most `length` bytes of `str`, and
|
||||||
|
// returns the number of bytes consumed.
|
||||||
|
// If the UTF sequence is incomplete within n bytes,
|
||||||
|
// utf_charntorune will set *r to Runeerror and return 0. If it is complete
|
||||||
|
// but not in UTF format, it will set *r to Runeerror and return 1.
|
||||||
|
//
|
||||||
|
// Added 2004-09-24 by Wei-Hwa Huang
|
||||||
|
|
||||||
|
int utf_charntorune(Rune* rune, const char* str, int length);
|
||||||
|
|
||||||
|
// Unicode defines some characters as letters and
|
||||||
|
// specifies three cases: upper, lower, and title. Mappings among the
|
||||||
|
// cases are also defined, although they are not exhaustive: some
|
||||||
|
// upper case letters have no lower case mapping, and so on. Unicode
|
||||||
|
// also defines several character properties, a subset of which are
|
||||||
|
// checked by these routines. These routines are based on Unicode
|
||||||
|
// version 3.0.0.
|
||||||
|
//
|
||||||
|
// NOTE: The routines are implemented in C, so isalpharrune returns 0 for false
|
||||||
|
// and 1 for true.
|
||||||
|
//
|
||||||
|
// utf_tolowerrune is the Unicode case mapping. It returns the character
|
||||||
|
// unchanged if it has no defined mapping.
|
||||||
|
|
||||||
|
Rune utf_tolowerrune(Rune r);
|
||||||
|
|
||||||
|
// utf_isalpharune tests for Unicode letters; this includes ideographs in
|
||||||
|
// addition to alphabetic characters.
|
||||||
|
|
||||||
|
int utf_isalpharune(Rune r);
|
||||||
|
|
||||||
|
// (The comments in this file were copied from the manpage files rune.3,
|
||||||
|
// isalpharune.3, and runestrcat.3. Some formatting changes were also made
|
||||||
|
// to conform to Google style. /JRM 11/11/05)
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_UTF_UTF_H_
|
|
@ -34,7 +34,7 @@ constexpr char kTestSPModelPath[] =
|
||||||
|
|
||||||
std::unique_ptr<SentencePieceTokenizer> CreateSentencePieceTokenizer(
|
std::unique_ptr<SentencePieceTokenizer> CreateSentencePieceTokenizer(
|
||||||
absl::string_view model_path) {
|
absl::string_view model_path) {
|
||||||
// We are using `LoadBinaryContent()` instead of loading the model direclty
|
// We are using `LoadBinaryContent()` instead of loading the model directly
|
||||||
// via `SentencePieceTokenizer` so that the file can be located on Windows
|
// via `SentencePieceTokenizer` so that the file can be located on Windows
|
||||||
std::string buffer = LoadBinaryContent(kTestSPModelPath);
|
std::string buffer = LoadBinaryContent(kTestSPModelPath);
|
||||||
return absl::make_unique<SentencePieceTokenizer>(buffer.data(),
|
return absl::make_unique<SentencePieceTokenizer>(buffer.data(),
|
||||||
|
|
|
@ -74,7 +74,7 @@ class FaceDetector : core::BaseVisionTaskApi {
|
||||||
// three running modes:
|
// three running modes:
|
||||||
// 1) Image mode for detecting faces on single image inputs. Users
|
// 1) Image mode for detecting faces on single image inputs. Users
|
||||||
// provide mediapipe::Image to the `Detect` method, and will receive the
|
// provide mediapipe::Image to the `Detect` method, and will receive the
|
||||||
// deteced face detection results as the return value.
|
// detected face detection results as the return value.
|
||||||
// 2) Video mode for detecting faces on the decoded frames of a
|
// 2) Video mode for detecting faces on the decoded frames of a
|
||||||
// video. Users call `DetectForVideo` method, and will receive the detected
|
// video. Users call `DetectForVideo` method, and will receive the detected
|
||||||
// face detection results as the return value.
|
// face detection results as the return value.
|
||||||
|
|
|
@ -19,9 +19,6 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "face_geometry_from_landmarks_graph",
|
name = "face_geometry_from_landmarks_graph",
|
||||||
srcs = ["face_geometry_from_landmarks_graph.cc"],
|
srcs = ["face_geometry_from_landmarks_graph.cc"],
|
||||||
data = [
|
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/data:geometry_pipeline_metadata_landmarks",
|
|
||||||
],
|
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/calculators/core:begin_loop_calculator",
|
"//mediapipe/calculators/core:begin_loop_calculator",
|
||||||
"//mediapipe/calculators/core:end_loop_calculator",
|
"//mediapipe/calculators/core:end_loop_calculator",
|
||||||
|
@ -39,6 +36,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_cc_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_graph_options_cc_proto",
|
||||||
"//mediapipe/util:graph_builder_utils",
|
"//mediapipe/util:graph_builder_utils",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
|
|
|
@ -45,6 +45,7 @@ mediapipe_proto_library(
|
||||||
srcs = ["geometry_pipeline_calculator.proto"],
|
srcs = ["geometry_pipeline_calculator.proto"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:external_file_proto",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -59,6 +60,9 @@ cc_library(
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/port:statusor",
|
"//mediapipe/framework/port:statusor",
|
||||||
|
"//mediapipe/tasks/cc:common",
|
||||||
|
"//mediapipe/tasks/cc/core:external_file_handler",
|
||||||
|
"//mediapipe/tasks/cc/core/proto:external_file_cc_proto",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline",
|
"//mediapipe/tasks/cc/vision/face_geometry/libs:geometry_pipeline",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/libs:validation_utils",
|
"//mediapipe/tasks/cc/vision/face_geometry/libs:validation_utils",
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:environment_cc_proto",
|
||||||
|
@ -66,6 +70,7 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto",
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:geometry_pipeline_metadata_cc_proto",
|
||||||
"//mediapipe/util:resource_util",
|
"//mediapipe/util:resource_util",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,12 +18,16 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "absl/strings/str_format.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/status_macros.h"
|
#include "mediapipe/framework/port/status_macros.h"
|
||||||
#include "mediapipe/framework/port/statusor.h"
|
#include "mediapipe/framework/port/statusor.h"
|
||||||
|
#include "mediapipe/tasks/cc/common.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/external_file_handler.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/libs/geometry_pipeline.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/libs/validation_utils.h"
|
||||||
|
@ -39,13 +43,50 @@ static constexpr char kEnvironmentTag[] = "ENVIRONMENT";
|
||||||
static constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
static constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||||
static constexpr char kMultiFaceGeometryTag[] = "MULTI_FACE_GEOMETRY";
|
static constexpr char kMultiFaceGeometryTag[] = "MULTI_FACE_GEOMETRY";
|
||||||
static constexpr char kMultiFaceLandmarksTag[] = "MULTI_FACE_LANDMARKS";
|
static constexpr char kMultiFaceLandmarksTag[] = "MULTI_FACE_LANDMARKS";
|
||||||
|
static constexpr char kFaceGeometryTag[] = "FACE_GEOMETRY";
|
||||||
|
static constexpr char kFaceLandmarksTag[] = "FACE_LANDMARKS";
|
||||||
|
|
||||||
using ::mediapipe::tasks::vision::face_geometry::proto::Environment;
|
using ::mediapipe::tasks::vision::face_geometry::proto::Environment;
|
||||||
using ::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry;
|
using ::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry;
|
||||||
using ::mediapipe::tasks::vision::face_geometry::proto::
|
using ::mediapipe::tasks::vision::face_geometry::proto::
|
||||||
GeometryPipelineMetadata;
|
GeometryPipelineMetadata;
|
||||||
|
|
||||||
// A calculator that renders a visual effect for multiple faces.
|
absl::Status SanityCheck(CalculatorContract* cc) {
|
||||||
|
if (!(cc->Inputs().HasTag(kFaceLandmarksTag) ^
|
||||||
|
cc->Inputs().HasTag(kMultiFaceLandmarksTag))) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat("Only one of %s and %s can be set at a time.",
|
||||||
|
kFaceLandmarksTag, kMultiFaceLandmarksTag));
|
||||||
|
}
|
||||||
|
if (!(cc->Outputs().HasTag(kFaceGeometryTag) ^
|
||||||
|
cc->Outputs().HasTag(kMultiFaceGeometryTag))) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat("Only one of %s and %s can be set at a time.",
|
||||||
|
kFaceGeometryTag, kMultiFaceGeometryTag));
|
||||||
|
}
|
||||||
|
if (cc->Inputs().HasTag(kFaceLandmarksTag) !=
|
||||||
|
cc->Outputs().HasTag(kFaceGeometryTag)) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat(
|
||||||
|
"%s and %s must both be set or neither be set and a time.",
|
||||||
|
kFaceLandmarksTag, kFaceGeometryTag));
|
||||||
|
}
|
||||||
|
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag) !=
|
||||||
|
cc->Outputs().HasTag(kMultiFaceGeometryTag)) {
|
||||||
|
return CreateStatusWithPayload(
|
||||||
|
absl::StatusCode::kInvalidArgument,
|
||||||
|
absl::StrFormat(
|
||||||
|
"%s and %s must both be set or neither be set and a time.",
|
||||||
|
kMultiFaceLandmarksTag, kMultiFaceGeometryTag));
|
||||||
|
}
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
// A calculator that renders a visual effect for multiple faces. Support single
|
||||||
|
// face landmarks or multiple face landmarks.
|
||||||
//
|
//
|
||||||
// Inputs:
|
// Inputs:
|
||||||
// IMAGE_SIZE (`std::pair<int, int>`, required):
|
// IMAGE_SIZE (`std::pair<int, int>`, required):
|
||||||
|
@ -56,8 +97,12 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
|
||||||
// ratio. If used as-is, the resulting face geometry visualization should be
|
// ratio. If used as-is, the resulting face geometry visualization should be
|
||||||
// happening on a frame with the same ratio as well.
|
// happening on a frame with the same ratio as well.
|
||||||
//
|
//
|
||||||
// MULTI_FACE_LANDMARKS (`std::vector<NormalizedLandmarkList>`, required):
|
// MULTI_FACE_LANDMARKS (`std::vector<NormalizedLandmarkList>`, optional):
|
||||||
// A vector of face landmark lists.
|
// A vector of face landmark lists. If connected, the output stream
|
||||||
|
// MULTI_FACE_GEOMETRY must be connected.
|
||||||
|
// FACE_LANDMARKS (NormalizedLandmarkList, optional):
|
||||||
|
// A NormalizedLandmarkList of single face landmark lists. If connected, the
|
||||||
|
// output stream FACE_GEOMETRY must be connected.
|
||||||
//
|
//
|
||||||
// Input side packets:
|
// Input side packets:
|
||||||
// ENVIRONMENT (`proto::Environment`, required)
|
// ENVIRONMENT (`proto::Environment`, required)
|
||||||
|
@ -65,12 +110,14 @@ using ::mediapipe::tasks::vision::face_geometry::proto::
|
||||||
// as well as virtual camera parameters.
|
// as well as virtual camera parameters.
|
||||||
//
|
//
|
||||||
// Output:
|
// Output:
|
||||||
// MULTI_FACE_GEOMETRY (`std::vector<FaceGeometry>`, required):
|
// MULTI_FACE_GEOMETRY (`std::vector<FaceGeometry>`, optional):
|
||||||
// A vector of face geometry data.
|
// A vector of face geometry data if MULTI_FACE_LANDMARKS is connected .
|
||||||
|
// FACE_GEOMETRY (FaceGeometry, optional):
|
||||||
|
// A FaceGeometry of the face landmarks if FACE_LANDMARKS is connected.
|
||||||
//
|
//
|
||||||
// Options:
|
// Options:
|
||||||
// metadata_path (`string`, optional):
|
// metadata_file (`ExternalFile`, optional):
|
||||||
// Defines a path for the geometry pipeline metadata file.
|
// Defines an ExternalFile for the geometry pipeline metadata file.
|
||||||
//
|
//
|
||||||
// The geometry pipeline metadata file format must be the binary
|
// The geometry pipeline metadata file format must be the binary
|
||||||
// `GeometryPipelineMetadata` proto.
|
// `GeometryPipelineMetadata` proto.
|
||||||
|
@ -79,13 +126,21 @@ class GeometryPipelineCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static absl::Status GetContract(CalculatorContract* cc) {
|
static absl::Status GetContract(CalculatorContract* cc) {
|
||||||
cc->InputSidePackets().Tag(kEnvironmentTag).Set<Environment>();
|
cc->InputSidePackets().Tag(kEnvironmentTag).Set<Environment>();
|
||||||
|
MP_RETURN_IF_ERROR(SanityCheck(cc));
|
||||||
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
|
cc->Inputs().Tag(kImageSizeTag).Set<std::pair<int, int>>();
|
||||||
|
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) {
|
||||||
cc->Inputs()
|
cc->Inputs()
|
||||||
.Tag(kMultiFaceLandmarksTag)
|
.Tag(kMultiFaceLandmarksTag)
|
||||||
.Set<std::vector<mediapipe::NormalizedLandmarkList>>();
|
.Set<std::vector<mediapipe::NormalizedLandmarkList>>();
|
||||||
cc->Outputs().Tag(kMultiFaceGeometryTag).Set<std::vector<FaceGeometry>>();
|
cc->Outputs().Tag(kMultiFaceGeometryTag).Set<std::vector<FaceGeometry>>();
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
|
} else {
|
||||||
|
cc->Inputs()
|
||||||
|
.Tag(kFaceLandmarksTag)
|
||||||
|
.Set<mediapipe::NormalizedLandmarkList>();
|
||||||
|
cc->Outputs().Tag(kFaceGeometryTag).Set<FaceGeometry>();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) override {
|
absl::Status Open(CalculatorContext* cc) override {
|
||||||
|
@ -95,7 +150,7 @@ class GeometryPipelineCalculator : public CalculatorBase {
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(
|
||||||
GeometryPipelineMetadata metadata,
|
GeometryPipelineMetadata metadata,
|
||||||
ReadMetadataFromFile(options.metadata_path()),
|
ReadMetadataFromFile(options.metadata_file()),
|
||||||
_ << "Failed to read the geometry pipeline metadata from file!");
|
_ << "Failed to read the geometry pipeline metadata from file!");
|
||||||
|
|
||||||
MP_RETURN_IF_ERROR(ValidateGeometryPipelineMetadata(metadata))
|
MP_RETURN_IF_ERROR(ValidateGeometryPipelineMetadata(metadata))
|
||||||
|
@ -110,21 +165,26 @@ class GeometryPipelineCalculator : public CalculatorBase {
|
||||||
ASSIGN_OR_RETURN(geometry_pipeline_,
|
ASSIGN_OR_RETURN(geometry_pipeline_,
|
||||||
CreateGeometryPipeline(environment, metadata),
|
CreateGeometryPipeline(environment, metadata),
|
||||||
_ << "Failed to create a geometry pipeline!");
|
_ << "Failed to create a geometry pipeline!");
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status Process(CalculatorContext* cc) override {
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
// Both the `IMAGE_SIZE` and the `MULTI_FACE_LANDMARKS` streams are required
|
// Both the `IMAGE_SIZE` and either the `FACE_LANDMARKS` or
|
||||||
// to have a non-empty packet. In case this requirement is not met, there's
|
// `MULTI_FACE_LANDMARKS` streams are required to have a non-empty packet.
|
||||||
// nothing to be processed at the current timestamp.
|
// In case this requirement is not met, there's nothing to be processed at
|
||||||
if (cc->Inputs().Tag(kImageSizeTag).IsEmpty() ||
|
// the current timestamp and we return early (checked here and below).
|
||||||
cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty()) {
|
if (cc->Inputs().Tag(kImageSizeTag).IsEmpty()) {
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& image_size =
|
const auto& image_size =
|
||||||
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
|
cc->Inputs().Tag(kImageSizeTag).Get<std::pair<int, int>>();
|
||||||
|
|
||||||
|
if (cc->Inputs().HasTag(kMultiFaceLandmarksTag)) {
|
||||||
|
if (cc->Inputs().Tag(kMultiFaceLandmarksTag).IsEmpty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
const auto& multi_face_landmarks =
|
const auto& multi_face_landmarks =
|
||||||
cc->Inputs()
|
cc->Inputs()
|
||||||
.Tag(kMultiFaceLandmarksTag)
|
.Tag(kMultiFaceLandmarksTag)
|
||||||
|
@ -145,6 +205,29 @@ class GeometryPipelineCalculator : public CalculatorBase {
|
||||||
.AddPacket(mediapipe::Adopt<std::vector<FaceGeometry>>(
|
.AddPacket(mediapipe::Adopt<std::vector<FaceGeometry>>(
|
||||||
multi_face_geometry.release())
|
multi_face_geometry.release())
|
||||||
.At(cc->InputTimestamp()));
|
.At(cc->InputTimestamp()));
|
||||||
|
} else if (cc->Inputs().HasTag(kFaceLandmarksTag)) {
|
||||||
|
if (cc->Inputs().Tag(kFaceLandmarksTag).IsEmpty()) {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& face_landmarks =
|
||||||
|
cc->Inputs()
|
||||||
|
.Tag(kFaceLandmarksTag)
|
||||||
|
.Get<mediapipe::NormalizedLandmarkList>();
|
||||||
|
|
||||||
|
ASSIGN_OR_RETURN(
|
||||||
|
std::vector<FaceGeometry> multi_face_geometry,
|
||||||
|
geometry_pipeline_->EstimateFaceGeometry(
|
||||||
|
{face_landmarks}, //
|
||||||
|
/*frame_width*/ image_size.first,
|
||||||
|
/*frame_height*/ image_size.second),
|
||||||
|
_ << "Failed to estimate face geometry for multiple faces!");
|
||||||
|
|
||||||
|
cc->Outputs()
|
||||||
|
.Tag(kFaceGeometryTag)
|
||||||
|
.AddPacket(mediapipe::MakePacket<FaceGeometry>(multi_face_geometry[0])
|
||||||
|
.At(cc->InputTimestamp()));
|
||||||
|
}
|
||||||
|
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -155,32 +238,19 @@ class GeometryPipelineCalculator : public CalculatorBase {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static absl::StatusOr<GeometryPipelineMetadata> ReadMetadataFromFile(
|
static absl::StatusOr<GeometryPipelineMetadata> ReadMetadataFromFile(
|
||||||
const std::string& metadata_path) {
|
const core::proto::ExternalFile& metadata_file) {
|
||||||
ASSIGN_OR_RETURN(std::string metadata_blob,
|
ASSIGN_OR_RETURN(
|
||||||
ReadContentBlobFromFile(metadata_path),
|
const auto file_handler,
|
||||||
_ << "Failed to read a metadata blob from file!");
|
core::ExternalFileHandler::CreateFromExternalFile(&metadata_file));
|
||||||
|
|
||||||
GeometryPipelineMetadata metadata;
|
GeometryPipelineMetadata metadata;
|
||||||
RET_CHECK(metadata.ParseFromString(metadata_blob))
|
RET_CHECK(
|
||||||
|
metadata.ParseFromString(std::string(file_handler->GetFileContent())))
|
||||||
<< "Failed to parse a metadata proto from a binary blob!";
|
<< "Failed to parse a metadata proto from a binary blob!";
|
||||||
|
|
||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
static absl::StatusOr<std::string> ReadContentBlobFromFile(
|
|
||||||
const std::string& unresolved_path) {
|
|
||||||
ASSIGN_OR_RETURN(std::string resolved_path,
|
|
||||||
mediapipe::PathToResourceAsFile(unresolved_path),
|
|
||||||
_ << "Failed to resolve path! Path = " << unresolved_path);
|
|
||||||
|
|
||||||
std::string content_blob;
|
|
||||||
MP_RETURN_IF_ERROR(
|
|
||||||
mediapipe::GetResourceContents(resolved_path, &content_blob))
|
|
||||||
<< "Failed to read content blob! Resolved path = " << resolved_path;
|
|
||||||
|
|
||||||
return content_blob;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<GeometryPipeline> geometry_pipeline_;
|
std::unique_ptr<GeometryPipeline> geometry_pipeline_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -17,11 +17,12 @@ syntax = "proto2";
|
||||||
package mediapipe.tasks.vision.face_geometry;
|
package mediapipe.tasks.vision.face_geometry;
|
||||||
|
|
||||||
import "mediapipe/framework/calculator_options.proto";
|
import "mediapipe/framework/calculator_options.proto";
|
||||||
|
import "mediapipe/tasks/cc/core/proto/external_file.proto";
|
||||||
|
|
||||||
message FaceGeometryPipelineCalculatorOptions {
|
message FaceGeometryPipelineCalculatorOptions {
|
||||||
extend mediapipe.CalculatorOptions {
|
extend mediapipe.CalculatorOptions {
|
||||||
optional FaceGeometryPipelineCalculatorOptions ext = 512499200;
|
optional FaceGeometryPipelineCalculatorOptions ext = 512499200;
|
||||||
}
|
}
|
||||||
|
|
||||||
optional string metadata_path = 1;
|
optional core.proto.ExternalFile metadata_file = 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/calculators/geometry_pipeline_calculator.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
|
||||||
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry_graph_options.pb.h"
|
||||||
#include "mediapipe/util/graph_builder_utils.h"
|
#include "mediapipe/util/graph_builder_utils.h"
|
||||||
|
|
||||||
namespace mediapipe::tasks::vision::face_geometry {
|
namespace mediapipe::tasks::vision::face_geometry {
|
||||||
|
@ -49,10 +50,6 @@ constexpr char kIterableTag[] = "ITERABLE";
|
||||||
constexpr char kBatchEndTag[] = "BATCH_END";
|
constexpr char kBatchEndTag[] = "BATCH_END";
|
||||||
constexpr char kItemTag[] = "ITEM";
|
constexpr char kItemTag[] = "ITEM";
|
||||||
|
|
||||||
constexpr char kGeometryPipelineMetadataPath[] =
|
|
||||||
"mediapipe/tasks/cc/vision/face_geometry/data/"
|
|
||||||
"geometry_pipeline_metadata_landmarks.binarypb";
|
|
||||||
|
|
||||||
struct FaceGeometryOuts {
|
struct FaceGeometryOuts {
|
||||||
Stream<std::vector<FaceGeometry>> multi_face_geometry;
|
Stream<std::vector<FaceGeometry>> multi_face_geometry;
|
||||||
};
|
};
|
||||||
|
@ -127,6 +124,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
|
||||||
}
|
}
|
||||||
ASSIGN_OR_RETURN(auto outs,
|
ASSIGN_OR_RETURN(auto outs,
|
||||||
BuildFaceGeometryFromLandmarksGraph(
|
BuildFaceGeometryFromLandmarksGraph(
|
||||||
|
*sc->MutableOptions<proto::FaceGeometryGraphOptions>(),
|
||||||
graph.In(kFaceLandmarksTag)
|
graph.In(kFaceLandmarksTag)
|
||||||
.Cast<std::vector<NormalizedLandmarkList>>(),
|
.Cast<std::vector<NormalizedLandmarkList>>(),
|
||||||
graph.In(kImageSizeTag).Cast<std::pair<int, int>>(),
|
graph.In(kImageSizeTag).Cast<std::pair<int, int>>(),
|
||||||
|
@ -138,6 +136,7 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::StatusOr<FaceGeometryOuts> BuildFaceGeometryFromLandmarksGraph(
|
absl::StatusOr<FaceGeometryOuts> BuildFaceGeometryFromLandmarksGraph(
|
||||||
|
proto::FaceGeometryGraphOptions& graph_options,
|
||||||
Stream<std::vector<NormalizedLandmarkList>> multi_face_landmarks,
|
Stream<std::vector<NormalizedLandmarkList>> multi_face_landmarks,
|
||||||
Stream<std::pair<int, int>> image_size,
|
Stream<std::pair<int, int>> image_size,
|
||||||
std::optional<SidePacket<Environment>> environment, Graph& graph) {
|
std::optional<SidePacket<Environment>> environment, Graph& graph) {
|
||||||
|
@ -185,7 +184,8 @@ class FaceGeometryFromLandmarksGraph : public Subgraph {
|
||||||
"mediapipe.tasks.vision.face_geometry.FaceGeometryPipelineCalculator");
|
"mediapipe.tasks.vision.face_geometry.FaceGeometryPipelineCalculator");
|
||||||
auto& geometry_pipeline_options =
|
auto& geometry_pipeline_options =
|
||||||
geometry_pipeline.GetOptions<FaceGeometryPipelineCalculatorOptions>();
|
geometry_pipeline.GetOptions<FaceGeometryPipelineCalculatorOptions>();
|
||||||
geometry_pipeline_options.set_metadata_path(kGeometryPipelineMetadataPath);
|
geometry_pipeline_options.Swap(
|
||||||
|
graph_options.mutable_geometry_pipeline_options());
|
||||||
image_size >> geometry_pipeline.In(kImageSizeTag);
|
image_size >> geometry_pipeline.In(kImageSizeTag);
|
||||||
multi_face_landmarks_no_iris >>
|
multi_face_landmarks_no_iris >>
|
||||||
geometry_pipeline.In(kMultiFaceLandmarksTag);
|
geometry_pipeline.In(kMultiFaceLandmarksTag);
|
||||||
|
|
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
|
#include "absl/strings/substitute.h"
|
||||||
#include "mediapipe/framework/api2/port.h"
|
#include "mediapipe/framework/api2/port.h"
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/calculator_runner.h"
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
@ -31,6 +32,7 @@ limitations under the License.
|
||||||
#include "mediapipe/framework/port/gtest.h"
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
#include "mediapipe/framework/tool/sink.h"
|
#include "mediapipe/framework/tool/sink.h"
|
||||||
|
#include "mediapipe/tasks/cc/core/proto/external_file.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/environment.pb.h"
|
||||||
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
|
#include "mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"
|
||||||
|
|
||||||
|
@ -49,6 +51,9 @@ constexpr char kTestDataDirectory[] = "/mediapipe/tasks/testdata/vision/";
|
||||||
constexpr char kFaceLandmarksFileName[] =
|
constexpr char kFaceLandmarksFileName[] =
|
||||||
"face_blendshapes_in_landmarks.prototxt";
|
"face_blendshapes_in_landmarks.prototxt";
|
||||||
constexpr char kFaceGeometryFileName[] = "face_geometry_expected_out.pbtxt";
|
constexpr char kFaceGeometryFileName[] = "face_geometry_expected_out.pbtxt";
|
||||||
|
constexpr char kGeometryPipelineMetadataPath[] =
|
||||||
|
"mediapipe/tasks/cc/vision/face_geometry/data/"
|
||||||
|
"geometry_pipeline_metadata_landmarks.binarypb";
|
||||||
|
|
||||||
std::vector<NormalizedLandmarkList> GetLandmarks(absl::string_view filename) {
|
std::vector<NormalizedLandmarkList> GetLandmarks(absl::string_view filename) {
|
||||||
NormalizedLandmarkList landmarks;
|
NormalizedLandmarkList landmarks;
|
||||||
|
@ -89,7 +94,8 @@ void MakeInputPacketsAndRunGraph(CalculatorGraph& graph) {
|
||||||
|
|
||||||
TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
|
TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
|
||||||
CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
|
CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
|
||||||
CalculatorGraphConfig>(R"pb(
|
CalculatorGraphConfig>(absl::Substitute(
|
||||||
|
R"pb(
|
||||||
input_stream: "FACE_LANDMARKS:face_landmarks"
|
input_stream: "FACE_LANDMARKS:face_landmarks"
|
||||||
input_stream: "IMAGE_SIZE:image_size"
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
output_stream: "FACE_GEOMETRY:face_geometry"
|
output_stream: "FACE_GEOMETRY:face_geometry"
|
||||||
|
@ -98,8 +104,15 @@ TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
|
||||||
input_stream: "FACE_LANDMARKS:face_landmarks"
|
input_stream: "FACE_LANDMARKS:face_landmarks"
|
||||||
input_stream: "IMAGE_SIZE:image_size"
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
output_stream: "FACE_GEOMETRY:face_geometry"
|
output_stream: "FACE_GEOMETRY:face_geometry"
|
||||||
|
options: {
|
||||||
|
[mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions
|
||||||
|
.ext] {
|
||||||
|
geometry_pipeline_options { metadata_file { file_name: "$0" } }
|
||||||
}
|
}
|
||||||
)pb");
|
}
|
||||||
|
}
|
||||||
|
)pb",
|
||||||
|
kGeometryPipelineMetadataPath));
|
||||||
std::vector<Packet> output_packets;
|
std::vector<Packet> output_packets;
|
||||||
tool::AddVectorSink("face_geometry", &graph_config, &output_packets);
|
tool::AddVectorSink("face_geometry", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
@ -116,7 +129,8 @@ TEST(FaceGeometryFromLandmarksGraphTest, DefaultEnvironment) {
|
||||||
|
|
||||||
TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) {
|
TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) {
|
||||||
CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
|
CalculatorGraphConfig graph_config = ParseTextProtoOrDie<
|
||||||
CalculatorGraphConfig>(R"pb(
|
CalculatorGraphConfig>(absl::Substitute(
|
||||||
|
R"pb(
|
||||||
input_stream: "FACE_LANDMARKS:face_landmarks"
|
input_stream: "FACE_LANDMARKS:face_landmarks"
|
||||||
input_stream: "IMAGE_SIZE:image_size"
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
input_side_packet: "ENVIRONMENT:environment"
|
input_side_packet: "ENVIRONMENT:environment"
|
||||||
|
@ -127,8 +141,15 @@ TEST(FaceGeometryFromLandmarksGraphTest, SideInEnvironment) {
|
||||||
input_stream: "IMAGE_SIZE:image_size"
|
input_stream: "IMAGE_SIZE:image_size"
|
||||||
input_side_packet: "ENVIRONMENT:environment"
|
input_side_packet: "ENVIRONMENT:environment"
|
||||||
output_stream: "FACE_GEOMETRY:face_geometry"
|
output_stream: "FACE_GEOMETRY:face_geometry"
|
||||||
|
options: {
|
||||||
|
[mediapipe.tasks.vision.face_geometry.proto.FaceGeometryGraphOptions
|
||||||
|
.ext] {
|
||||||
|
geometry_pipeline_options { metadata_file { file_name: "$0" } }
|
||||||
}
|
}
|
||||||
)pb");
|
}
|
||||||
|
}
|
||||||
|
)pb",
|
||||||
|
kGeometryPipelineMetadataPath));
|
||||||
std::vector<Packet> output_packets;
|
std::vector<Packet> output_packets;
|
||||||
tool::AddVectorSink("face_geometry", &graph_config, &output_packets);
|
tool::AddVectorSink("face_geometry", &graph_config, &output_packets);
|
||||||
|
|
||||||
|
|
|
@ -99,7 +99,7 @@ class ScreenToMetricSpaceConverter {
|
||||||
//
|
//
|
||||||
// (3) Use the canonical-to-runtime scale from (2) to unproject the screen
|
// (3) Use the canonical-to-runtime scale from (2) to unproject the screen
|
||||||
// landmarks. The result is referenced as "intermediate landmarks" because
|
// landmarks. The result is referenced as "intermediate landmarks" because
|
||||||
// they are the first estimation of the resuling metric landmarks, but are
|
// they are the first estimation of the resulting metric landmarks,but are
|
||||||
// not quite there yet.
|
// not quite there yet.
|
||||||
//
|
//
|
||||||
// (4) Estimate a canonical-to-runtime landmark set scale by running the
|
// (4) Estimate a canonical-to-runtime landmark set scale by running the
|
||||||
|
@ -347,7 +347,7 @@ class GeometryPipelineImpl : public GeometryPipeline {
|
||||||
proto::Mesh3d* mutable_mesh = face_geometry.mutable_mesh();
|
proto::Mesh3d* mutable_mesh = face_geometry.mutable_mesh();
|
||||||
// Copy the canonical face mesh as the face geometry mesh.
|
// Copy the canonical face mesh as the face geometry mesh.
|
||||||
mutable_mesh->CopyFrom(canonical_mesh_);
|
mutable_mesh->CopyFrom(canonical_mesh_);
|
||||||
// Replace XYZ vertex mesh coodinates with the metric landmark positions.
|
// Replace XYZ vertex mesh coordinates with the metric landmark positions.
|
||||||
for (int i = 0; i < canonical_mesh_num_vertices_; ++i) {
|
for (int i = 0; i < canonical_mesh_num_vertices_; ++i) {
|
||||||
uint32_t vertex_buffer_offset = canonical_mesh_vertex_size_ * i +
|
uint32_t vertex_buffer_offset = canonical_mesh_vertex_size_ * i +
|
||||||
canonical_mesh_vertex_position_offset_;
|
canonical_mesh_vertex_position_offset_;
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||||
|
load("//mediapipe/framework:mediapipe_register_type.bzl", "mediapipe_register_type")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
|
@ -23,6 +24,16 @@ mediapipe_proto_library(
|
||||||
srcs = ["environment.proto"],
|
srcs = ["environment.proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_register_type(
|
||||||
|
base_name = "face_geometry",
|
||||||
|
include_headers = ["mediapipe/tasks/cc/vision/face_geometry/proto/face_geometry.pb.h"],
|
||||||
|
types = [
|
||||||
|
"::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry",
|
||||||
|
"::std::vector<::mediapipe::tasks::vision::face_geometry::proto::FaceGeometry>",
|
||||||
|
],
|
||||||
|
deps = [":face_geometry_cc_proto"],
|
||||||
|
)
|
||||||
|
|
||||||
mediapipe_proto_library(
|
mediapipe_proto_library(
|
||||||
name = "face_geometry_proto",
|
name = "face_geometry_proto",
|
||||||
srcs = ["face_geometry.proto"],
|
srcs = ["face_geometry.proto"],
|
||||||
|
@ -44,3 +55,12 @@ mediapipe_proto_library(
|
||||||
name = "mesh_3d_proto",
|
name = "mesh_3d_proto",
|
||||||
srcs = ["mesh_3d.proto"],
|
srcs = ["mesh_3d.proto"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "face_geometry_graph_options_proto",
|
||||||
|
srcs = ["face_geometry_graph_options.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_geometry/calculators:geometry_pipeline_calculator_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -16,7 +16,7 @@ syntax = "proto2";
|
||||||
|
|
||||||
package mediapipe.tasks.vision.face_geometry.proto;
|
package mediapipe.tasks.vision.face_geometry.proto;
|
||||||
|
|
||||||
option java_package = "mediapipe.tasks.vision.facegeometry.proto";
|
option java_package = "com.google.mediapipe.tasks.vision.facegeometry.proto";
|
||||||
option java_outer_classname = "EnvironmentProto";
|
option java_outer_classname = "EnvironmentProto";
|
||||||
|
|
||||||
// Defines the (0, 0) origin point location of the environment.
|
// Defines the (0, 0) origin point location of the environment.
|
||||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe.tasks.vision.face_geometry.proto;
|
||||||
import "mediapipe/framework/formats/matrix_data.proto";
|
import "mediapipe/framework/formats/matrix_data.proto";
|
||||||
import "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto";
|
import "mediapipe/tasks/cc/vision/face_geometry/proto/mesh_3d.proto";
|
||||||
|
|
||||||
option java_package = "mediapipe.tasks.vision.facegeometry.proto";
|
option java_package = "com.google.mediapipe.tasks.vision.facegeometry.proto";
|
||||||
option java_outer_classname = "FaceGeometryProto";
|
option java_outer_classname = "FaceGeometryProto";
|
||||||
|
|
||||||
// Defines the face geometry pipeline estimation result format.
|
// Defines the face geometry pipeline estimation result format.
|
||||||
|
@ -28,7 +28,7 @@ message FaceGeometry {
|
||||||
// the face landmark IDs.
|
// the face landmark IDs.
|
||||||
//
|
//
|
||||||
// XYZ coordinates exist in the right-handed Metric 3D space configured by an
|
// XYZ coordinates exist in the right-handed Metric 3D space configured by an
|
||||||
// environment. UV coodinates are taken from the canonical face mesh model.
|
// environment. UV coordinates are taken from the canonical face mesh model.
|
||||||
//
|
//
|
||||||
// XY coordinates are guaranteed to match the screen positions of
|
// XY coordinates are guaranteed to match the screen positions of
|
||||||
// the input face landmarks after (1) being multiplied by the face pose
|
// the input face landmarks after (1) being multiplied by the face pose
|
||||||
|
|