Merge branch 'master' into face-stylizer-python
17
LICENSE
|
@ -199,3 +199,20 @@
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
|
|
||||||
|
===========================================================================
|
||||||
|
For files under tasks/cc/text/language_detector/custom_ops/utils/utf/
|
||||||
|
===========================================================================
|
||||||
|
/*
|
||||||
|
* The authors of this software are Rob Pike and Ken Thompson.
|
||||||
|
* Copyright (c) 2002 by Lucent Technologies.
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose without fee is hereby granted, provided that this entire notice
|
||||||
|
* is included in all copies of any software which is or includes a copy
|
||||||
|
* or modification of this software and in all copies of the supporting
|
||||||
|
* documentation for such software.
|
||||||
|
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
|
||||||
|
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
|
||||||
|
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
|
||||||
|
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
|
||||||
|
*/
|
||||||
|
|
46
WORKSPACE
|
@ -270,7 +270,7 @@ new_local_repository(
|
||||||
# For local MacOS builds, the path should point to an opencv@3 installation.
|
# For local MacOS builds, the path should point to an opencv@3 installation.
|
||||||
# If you edit the path here, you will also need to update the corresponding
|
# If you edit the path here, you will also need to update the corresponding
|
||||||
# prefix in "opencv_macos.BUILD".
|
# prefix in "opencv_macos.BUILD".
|
||||||
path = "/usr/local",
|
path = "/usr/local", # e.g. /usr/local/Cellar for HomeBrew
|
||||||
)
|
)
|
||||||
|
|
||||||
new_local_repository(
|
new_local_repository(
|
||||||
|
@ -499,8 +499,8 @@ cc_crosstool(name = "crosstool")
|
||||||
# Node dependencies
|
# Node dependencies
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "build_bazel_rules_nodejs",
|
name = "build_bazel_rules_nodejs",
|
||||||
sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda",
|
sha256 = "94070eff79305be05b7699207fbac5d2608054dd53e6109f7d00d923919ff45a",
|
||||||
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"],
|
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.8.2/rules_nodejs-5.8.2.tar.gz"],
|
||||||
)
|
)
|
||||||
|
|
||||||
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")
|
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")
|
||||||
|
@ -543,3 +543,43 @@ external_files()
|
||||||
|
|
||||||
load("@//third_party:wasm_files.bzl", "wasm_files")
|
load("@//third_party:wasm_files.bzl", "wasm_files")
|
||||||
wasm_files()
|
wasm_files()
|
||||||
|
|
||||||
|
# Halide
|
||||||
|
|
||||||
|
new_local_repository(
|
||||||
|
name = "halide",
|
||||||
|
build_file = "@//third_party/halide:BUILD.bazel",
|
||||||
|
path = "third_party/halide"
|
||||||
|
)
|
||||||
|
|
||||||
|
http_archive(
|
||||||
|
name = "linux_halide",
|
||||||
|
sha256 = "f62b2914823d6e33d18693f5b74484f274523bf5402ce51988e24393d123b375",
|
||||||
|
strip_prefix = "Halide-15.0.0-x86-64-linux",
|
||||||
|
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-linux-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"],
|
||||||
|
build_file = "@//third_party:halide.BUILD",
|
||||||
|
)
|
||||||
|
|
||||||
|
http_archive(
|
||||||
|
name = "macos_x86_64_halide",
|
||||||
|
sha256 = "3d832aed942080ea89aa832462c68fbb906f3055c440b7b6d35093d7c52f6aab",
|
||||||
|
strip_prefix = "Halide-15.0.0-x86-64-osx",
|
||||||
|
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-osx-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"],
|
||||||
|
build_file = "@//third_party:halide.BUILD",
|
||||||
|
)
|
||||||
|
|
||||||
|
http_archive(
|
||||||
|
name = "macos_arm_64_halide",
|
||||||
|
sha256 = "b1fad3c9810122b187303d7031d9e35fb43761f345d18cc4492c00ed5877f641",
|
||||||
|
strip_prefix = "Halide-15.0.0-arm-64-osx",
|
||||||
|
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-arm-64-osx-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"],
|
||||||
|
build_file = "@//third_party:halide.BUILD",
|
||||||
|
)
|
||||||
|
|
||||||
|
http_archive(
|
||||||
|
name = "windows_halide",
|
||||||
|
sha256 = "5acf6fe161dd375856a2b43f4bb0a32815ba958b0585ee312c44e008aa7b0b64",
|
||||||
|
strip_prefix = "Halide-15.0.0-x86-64-windows",
|
||||||
|
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-windows-d7651f4b32f9dbd764f243134001f7554378d62d.zip"],
|
||||||
|
build_file = "@//third_party:halide.BUILD",
|
||||||
|
)
|
||||||
|
|
|
@ -113,14 +113,14 @@ Warning: On the other hand, it is not guaranteed that an input packet will
|
||||||
always be available for all streams.
|
always be available for all streams.
|
||||||
|
|
||||||
To explain how it works, we need to introduce the definition of a settled
|
To explain how it works, we need to introduce the definition of a settled
|
||||||
timestamp. We say that a timestamp in a stream is *settled* if it lower than the
|
timestamp. We say that a timestamp in a stream is *settled* if it is lower than
|
||||||
timestamp bound. In other words, a timestamp is settled for a stream once the
|
the timestamp bound. In other words, a timestamp is settled for a stream once
|
||||||
state of the input at that timestamp is irrevocably known: either there is a
|
the state of the input at that timestamp is irrevocably known: either there is a
|
||||||
packet, or there is the certainty that a packet with that timestamp will not
|
packet, or there is the certainty that a packet with that timestamp will not
|
||||||
arrive.
|
arrive.
|
||||||
|
|
||||||
Note: For this reason, MediaPipe also allows a stream producer to explicitly
|
Note: For this reason, MediaPipe also allows a stream producer to explicitly
|
||||||
advance the timestamp bound farther that what the last packet implies, i.e. to
|
advance the timestamp bound farther than what the last packet implies, i.e. to
|
||||||
provide a tighter bound. This can allow the downstream nodes to settle their
|
provide a tighter bound. This can allow the downstream nodes to settle their
|
||||||
inputs sooner.
|
inputs sooner.
|
||||||
|
|
||||||
|
|
|
@ -108,6 +108,8 @@ one over the other.
|
||||||
|
|
||||||
* [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite)
|
* [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite)
|
||||||
* [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite)
|
* [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite)
|
||||||
|
* [TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip)
|
||||||
|
* [Model information](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md)
|
||||||
|
|
||||||
### [Objectron](https://google.github.io/mediapipe/solutions/objectron)
|
### [Objectron](https://google.github.io/mediapipe/solutions/objectron)
|
||||||
|
|
||||||
|
|
|
@ -118,9 +118,9 @@ on how to build MediaPipe examples.
|
||||||
* With a TensorFlow Model
|
* With a TensorFlow Model
|
||||||
|
|
||||||
This uses the
|
This uses the
|
||||||
[TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model)
|
[TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip)
|
||||||
( see also
|
( see also
|
||||||
[model info](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md)),
|
[model info](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md)),
|
||||||
and the pipeline is implemented in this
|
and the pipeline is implemented in this
|
||||||
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt).
|
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt).
|
||||||
|
|
||||||
|
|
62
docs/solutions/object_detection_saved_model.md
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
## TensorFlow/TFLite Object Detection Model
|
||||||
|
|
||||||
|
### TensorFlow model
|
||||||
|
|
||||||
|
The model is trained on [MSCOCO 2014](http://cocodataset.org) dataset using [TensorFlow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection). It is a MobileNetV2-based SSD model with 0.5 depth multiplier. Detailed training configuration is in the provided `pipeline.config`. The model is a relatively compact model which has `0.171 mAP` to achieve real-time performance on mobile devices. You can compare it with other models from the [TensorFlow detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md).
|
||||||
|
|
||||||
|
|
||||||
|
### TFLite model
|
||||||
|
|
||||||
|
The TFLite model is converted from the TensorFlow above. The steps needed to convert the model are similar to [this tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) with minor modifications. Assuming now we have a trained TensorFlow model which includes the checkpoint files and the training configuration file, for example the files provided in this repo:
|
||||||
|
|
||||||
|
* `model.ckpt.index`
|
||||||
|
* `model.ckpt.meta`
|
||||||
|
* `model.ckpt.data-00000-of-00001`
|
||||||
|
* `pipeline.config`
|
||||||
|
|
||||||
|
Make sure you have installed these [python libraries](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1.md). Then to get the frozen graph, run the `export_tflite_ssd_graph.py` script from the `models/research` directory with this command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ PATH_TO_MODEL=path/to/the/model
|
||||||
|
$ bazel run object_detection:export_tflite_ssd_graph -- \
|
||||||
|
--pipeline_config_path ${PATH_TO_MODEL}/pipeline.config \
|
||||||
|
--trained_checkpoint_prefix ${PATH_TO_MODEL}/model.ckpt \
|
||||||
|
--output_directory ${PATH_TO_MODEL} \
|
||||||
|
--add_postprocessing_op=False
|
||||||
|
```
|
||||||
|
|
||||||
|
The exported model contains two files:
|
||||||
|
|
||||||
|
* `tflite_graph.pb`
|
||||||
|
* `tflite_graph.pbtxt`
|
||||||
|
|
||||||
|
The difference between this step and the one in [the tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) is that we set `add_postprocessing_op` to False. In MediaPipe, we have provided all the calculators needed for post-processing such that we can exclude the custom TFLite ops for post-processing in the original graph, e.g., non-maximum suppression. This enables the flexibility to integrate with different post-processing algorithms and implementations.
|
||||||
|
|
||||||
|
Optional: You can install and use the [graph tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms) to inspect the input/output of the exported model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ bazel run graph_transforms:summarize_graph -- \
|
||||||
|
--in_graph=${PATH_TO_MODEL}/tflite_graph.pb
|
||||||
|
```
|
||||||
|
|
||||||
|
You should be able to see the input image size of the model is 320x320 and the outputs of the model are:
|
||||||
|
|
||||||
|
* `raw_outputs/box_encodings`
|
||||||
|
* `raw_outputs/class_predictions`
|
||||||
|
|
||||||
|
The last step is to convert the model to TFLite. You can look at [this guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md) for more detail. For this example, you just need to run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ tflite_convert -- \
|
||||||
|
--graph_def_file=${PATH_TO_MODEL}/tflite_graph.pb \
|
||||||
|
--output_file=${PATH_TO_MODEL}/model.tflite \
|
||||||
|
--input_format=TENSORFLOW_GRAPHDEF \
|
||||||
|
--output_format=TFLITE \
|
||||||
|
--inference_type=FLOAT \
|
||||||
|
--input_shapes=1,320,320,3 \
|
||||||
|
--input_arrays=normalized_input_image_tensor \
|
||||||
|
--output_arrays=raw_outputs/box_encodings,raw_outputs/class_predictions
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you have the TFLite model `model.tflite` ready to use with MediaPipe Object Detection graphs. Please see the examples for more detail.
|
|
@ -269,6 +269,7 @@ Supported configuration options:
|
||||||
```python
|
```python
|
||||||
import cv2
|
import cv2
|
||||||
import mediapipe as mp
|
import mediapipe as mp
|
||||||
|
import numpy as np
|
||||||
mp_drawing = mp.solutions.drawing_utils
|
mp_drawing = mp.solutions.drawing_utils
|
||||||
mp_drawing_styles = mp.solutions.drawing_styles
|
mp_drawing_styles = mp.solutions.drawing_styles
|
||||||
mp_pose = mp.solutions.pose
|
mp_pose = mp.solutions.pose
|
||||||
|
|
|
@ -156,6 +156,7 @@ cc_library(
|
||||||
"//mediapipe/framework/port:opencv_core",
|
"//mediapipe/framework/port:opencv_core",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/port:vector",
|
"//mediapipe/framework/port:vector",
|
||||||
|
"//mediapipe/framework/port:opencv_imgproc",
|
||||||
] + select({
|
] + select({
|
||||||
"//mediapipe/gpu:disable_gpu": [],
|
"//mediapipe/gpu:disable_gpu": [],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
|
@ -168,6 +169,25 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "set_alpha_calculator_test",
|
||||||
|
srcs = ["set_alpha_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":set_alpha_calculator",
|
||||||
|
":set_alpha_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework/formats:image_frame_opencv",
|
||||||
|
"//mediapipe/framework/formats:rect_cc_proto",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:opencv_core",
|
||||||
|
"//mediapipe/framework/port:opencv_imgproc",
|
||||||
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
|
"//mediapipe/framework/port:status",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "bilateral_filter_calculator",
|
name = "bilateral_filter_calculator",
|
||||||
srcs = ["bilateral_filter_calculator.cc"],
|
srcs = ["bilateral_filter_calculator.cc"],
|
||||||
|
@ -748,6 +768,7 @@ cc_test(
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
|
||||||
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png",
|
||||||
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
|
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
|
||||||
],
|
],
|
||||||
tags = ["desktop_only_test"],
|
tags = ["desktop_only_test"],
|
||||||
|
|
|
@ -29,6 +29,9 @@ class AffineTransformation {
|
||||||
// pixels will be calculated.
|
// pixels will be calculated.
|
||||||
enum class BorderMode { kZero, kReplicate };
|
enum class BorderMode { kZero, kReplicate };
|
||||||
|
|
||||||
|
// Pixel sampling interpolation method.
|
||||||
|
enum class Interpolation { kLinear, kCubic };
|
||||||
|
|
||||||
struct Size {
|
struct Size {
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
|
|
|
@ -77,8 +77,11 @@ class GlTextureWarpAffineRunner
|
||||||
std::unique_ptr<GpuBuffer>> {
|
std::unique_ptr<GpuBuffer>> {
|
||||||
public:
|
public:
|
||||||
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
|
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
|
||||||
GpuOrigin::Mode gpu_origin)
|
GpuOrigin::Mode gpu_origin,
|
||||||
: gl_helper_(gl_helper), gpu_origin_(gpu_origin) {}
|
AffineTransformation::Interpolation interpolation)
|
||||||
|
: gl_helper_(gl_helper),
|
||||||
|
gpu_origin_(gpu_origin),
|
||||||
|
interpolation_(interpolation) {}
|
||||||
absl::Status Init() {
|
absl::Status Init() {
|
||||||
return gl_helper_->RunInGlContext([this]() -> absl::Status {
|
return gl_helper_->RunInGlContext([this]() -> absl::Status {
|
||||||
const GLint attr_location[kNumAttributes] = {
|
const GLint attr_location[kNumAttributes] = {
|
||||||
|
@ -103,10 +106,13 @@ class GlTextureWarpAffineRunner
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
// TODO Move bicubic code to common shared place.
|
||||||
constexpr GLchar kFragShader[] = R"(
|
constexpr GLchar kFragShader[] = R"(
|
||||||
DEFAULT_PRECISION(highp, float)
|
DEFAULT_PRECISION(highp, float)
|
||||||
|
|
||||||
in vec2 sample_coordinate;
|
in vec2 sample_coordinate;
|
||||||
uniform sampler2D input_texture;
|
uniform sampler2D input_texture;
|
||||||
|
uniform vec2 input_size;
|
||||||
|
|
||||||
#ifdef GL_ES
|
#ifdef GL_ES
|
||||||
#define fragColor gl_FragColor
|
#define fragColor gl_FragColor
|
||||||
|
@ -114,8 +120,60 @@ class GlTextureWarpAffineRunner
|
||||||
out vec4 fragColor;
|
out vec4 fragColor;
|
||||||
#endif // defined(GL_ES);
|
#endif // defined(GL_ES);
|
||||||
|
|
||||||
|
#ifdef CUBIC_INTERPOLATION
|
||||||
|
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
|
||||||
|
const vec2 halve = vec2(0.5,0.5);
|
||||||
|
const vec2 one = vec2(1.0,1.0);
|
||||||
|
const vec2 two = vec2(2.0,2.0);
|
||||||
|
const vec2 three = vec2(3.0,3.0);
|
||||||
|
const vec2 six = vec2(6.0,6.0);
|
||||||
|
|
||||||
|
// Calculate the fraction and integer.
|
||||||
|
tex_coord = tex_coord * tex_size - halve;
|
||||||
|
vec2 frac = fract(tex_coord);
|
||||||
|
vec2 index = tex_coord - frac + halve;
|
||||||
|
|
||||||
|
// Calculate weights for Catmull-Rom filter.
|
||||||
|
vec2 w0 = frac * (-halve + frac * (one - halve * frac));
|
||||||
|
vec2 w1 = one + frac * frac * (-(two+halve) + three/two * frac);
|
||||||
|
vec2 w2 = frac * (halve + frac * (two - three/two * frac));
|
||||||
|
vec2 w3 = frac * frac * (-halve + halve * frac);
|
||||||
|
|
||||||
|
// Calculate weights to take advantage of bilinear texture lookup.
|
||||||
|
vec2 w12 = w1 + w2;
|
||||||
|
vec2 offset12 = w2 / (w1 + w2);
|
||||||
|
|
||||||
|
vec2 index_tl = index - one;
|
||||||
|
vec2 index_br = index + two;
|
||||||
|
vec2 index_eq = index + offset12;
|
||||||
|
|
||||||
|
index_tl /= tex_size;
|
||||||
|
index_br /= tex_size;
|
||||||
|
index_eq /= tex_size;
|
||||||
|
|
||||||
|
// 9 texture lookup and linear blending.
|
||||||
|
vec4 color = vec4(0.0);
|
||||||
|
color += texture2D(tex, vec2(index_tl.x, index_tl.y)) * w0.x * w0.y;
|
||||||
|
color += texture2D(tex, vec2(index_eq.x, index_tl.y)) * w12.x *w0.y;
|
||||||
|
color += texture2D(tex, vec2(index_br.x, index_tl.y)) * w3.x * w0.y;
|
||||||
|
|
||||||
|
color += texture2D(tex, vec2(index_tl.x, index_eq.y)) * w0.x * w12.y;
|
||||||
|
color += texture2D(tex, vec2(index_eq.x, index_eq.y)) * w12.x *w12.y;
|
||||||
|
color += texture2D(tex, vec2(index_br.x, index_eq.y)) * w3.x * w12.y;
|
||||||
|
|
||||||
|
color += texture2D(tex, vec2(index_tl.x, index_br.y)) * w0.x * w3.y;
|
||||||
|
color += texture2D(tex, vec2(index_eq.x, index_br.y)) * w12.x *w3.y;
|
||||||
|
color += texture2D(tex, vec2(index_br.x, index_br.y)) * w3.x * w3.y;
|
||||||
|
return color;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
|
||||||
|
return texture2D(tex, tex_coord);
|
||||||
|
}
|
||||||
|
#endif // defined(CUBIC_INTERPOLATION)
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
vec4 color = texture2D(input_texture, sample_coordinate);
|
vec4 color = sample(input_texture, sample_coordinate, input_size);
|
||||||
#ifdef CUSTOM_ZERO_BORDER_MODE
|
#ifdef CUSTOM_ZERO_BORDER_MODE
|
||||||
float out_of_bounds =
|
float out_of_bounds =
|
||||||
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
|
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
|
||||||
|
@ -137,14 +195,28 @@ class GlTextureWarpAffineRunner
|
||||||
glUseProgram(program);
|
glUseProgram(program);
|
||||||
glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
|
glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
|
||||||
GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
|
GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
|
||||||
return Program{.id = program, .matrix_id = matrix_id};
|
GLint size_id = glGetUniformLocation(program, "input_size");
|
||||||
|
return Program{
|
||||||
|
.id = program, .matrix_id = matrix_id, .size_id = size_id};
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::string vert_src =
|
const std::string vert_src =
|
||||||
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
|
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
|
||||||
|
|
||||||
const std::string frag_src = absl::StrCat(
|
std::string interpolation_def;
|
||||||
mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader);
|
switch (interpolation_) {
|
||||||
|
case AffineTransformation::Interpolation::kCubic:
|
||||||
|
interpolation_def = R"(
|
||||||
|
#define CUBIC_INTERPOLATION
|
||||||
|
)";
|
||||||
|
break;
|
||||||
|
case AffineTransformation::Interpolation::kLinear:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string frag_src =
|
||||||
|
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
|
||||||
|
interpolation_def, kFragShader);
|
||||||
|
|
||||||
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
|
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
|
||||||
|
|
||||||
|
@ -152,9 +224,9 @@ class GlTextureWarpAffineRunner
|
||||||
std::string custom_zero_border_mode_def = R"(
|
std::string custom_zero_border_mode_def = R"(
|
||||||
#define CUSTOM_ZERO_BORDER_MODE
|
#define CUSTOM_ZERO_BORDER_MODE
|
||||||
)";
|
)";
|
||||||
const std::string frag_custom_zero_src =
|
const std::string frag_custom_zero_src = absl::StrCat(
|
||||||
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
|
mediapipe::kMediaPipeFragmentShaderPreamble,
|
||||||
custom_zero_border_mode_def, kFragShader);
|
custom_zero_border_mode_def, interpolation_def, kFragShader);
|
||||||
return create_fn(vert_src, frag_custom_zero_src);
|
return create_fn(vert_src, frag_custom_zero_src);
|
||||||
};
|
};
|
||||||
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
||||||
|
@ -256,6 +328,7 @@ class GlTextureWarpAffineRunner
|
||||||
}
|
}
|
||||||
glUseProgram(program->id);
|
glUseProgram(program->id);
|
||||||
|
|
||||||
|
// uniforms
|
||||||
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
|
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
|
||||||
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
|
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
|
||||||
// @matrix describes affine transformation in terms of TOP LEFT origin, so
|
// @matrix describes affine transformation in terms of TOP LEFT origin, so
|
||||||
|
@ -275,6 +348,10 @@ class GlTextureWarpAffineRunner
|
||||||
eigen_mat.transposeInPlace();
|
eigen_mat.transposeInPlace();
|
||||||
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
|
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
|
||||||
|
|
||||||
|
if (interpolation_ == AffineTransformation::Interpolation::kCubic) {
|
||||||
|
glUniform2f(program->size_id, texture.width(), texture.height());
|
||||||
|
}
|
||||||
|
|
||||||
// vao
|
// vao
|
||||||
glBindVertexArray(vao_);
|
glBindVertexArray(vao_);
|
||||||
|
|
||||||
|
@ -327,6 +404,7 @@ class GlTextureWarpAffineRunner
|
||||||
struct Program {
|
struct Program {
|
||||||
GLuint id;
|
GLuint id;
|
||||||
GLint matrix_id;
|
GLint matrix_id;
|
||||||
|
GLint size_id;
|
||||||
};
|
};
|
||||||
std::shared_ptr<GlCalculatorHelper> gl_helper_;
|
std::shared_ptr<GlCalculatorHelper> gl_helper_;
|
||||||
GpuOrigin::Mode gpu_origin_;
|
GpuOrigin::Mode gpu_origin_;
|
||||||
|
@ -335,6 +413,8 @@ class GlTextureWarpAffineRunner
|
||||||
Program program_;
|
Program program_;
|
||||||
std::optional<Program> program_custom_zero_;
|
std::optional<Program> program_custom_zero_;
|
||||||
GLuint framebuffer_ = 0;
|
GLuint framebuffer_ = 0;
|
||||||
|
AffineTransformation::Interpolation interpolation_ =
|
||||||
|
AffineTransformation::Interpolation::kLinear;
|
||||||
};
|
};
|
||||||
|
|
||||||
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
|
||||||
|
@ -344,9 +424,10 @@ class GlTextureWarpAffineRunner
|
||||||
absl::StatusOr<std::unique_ptr<
|
absl::StatusOr<std::unique_ptr<
|
||||||
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
|
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
|
||||||
CreateAffineTransformationGlRunner(
|
CreateAffineTransformationGlRunner(
|
||||||
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin) {
|
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin,
|
||||||
auto runner =
|
AffineTransformation::Interpolation interpolation) {
|
||||||
absl::make_unique<GlTextureWarpAffineRunner>(gl_helper, gpu_origin);
|
auto runner = absl::make_unique<GlTextureWarpAffineRunner>(
|
||||||
|
gl_helper, gpu_origin, interpolation);
|
||||||
MP_RETURN_IF_ERROR(runner->Init());
|
MP_RETURN_IF_ERROR(runner->Init());
|
||||||
return runner;
|
return runner;
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,8 @@ absl::StatusOr<std::unique_ptr<AffineTransformation::Runner<
|
||||||
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
|
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
|
||||||
CreateAffineTransformationGlRunner(
|
CreateAffineTransformationGlRunner(
|
||||||
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
|
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
|
||||||
mediapipe::GpuOrigin::Mode gpu_origin);
|
mediapipe::GpuOrigin::Mode gpu_origin,
|
||||||
|
AffineTransformation::Interpolation interpolation);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -39,9 +39,22 @@ cv::BorderTypes GetBorderModeForOpenCv(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int GetInterpolationForOpenCv(
|
||||||
|
AffineTransformation::Interpolation interpolation) {
|
||||||
|
switch (interpolation) {
|
||||||
|
case AffineTransformation::Interpolation::kLinear:
|
||||||
|
return cv::INTER_LINEAR;
|
||||||
|
case AffineTransformation::Interpolation::kCubic:
|
||||||
|
return cv::INTER_CUBIC;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class OpenCvRunner
|
class OpenCvRunner
|
||||||
: public AffineTransformation::Runner<ImageFrame, ImageFrame> {
|
: public AffineTransformation::Runner<ImageFrame, ImageFrame> {
|
||||||
public:
|
public:
|
||||||
|
OpenCvRunner(AffineTransformation::Interpolation interpolation)
|
||||||
|
: interpolation_(GetInterpolationForOpenCv(interpolation)) {}
|
||||||
|
|
||||||
absl::StatusOr<ImageFrame> Run(
|
absl::StatusOr<ImageFrame> Run(
|
||||||
const ImageFrame& input, const std::array<float, 16>& matrix,
|
const ImageFrame& input, const std::array<float, 16>& matrix,
|
||||||
const AffineTransformation::Size& size,
|
const AffineTransformation::Size& size,
|
||||||
|
@ -142,19 +155,23 @@ class OpenCvRunner
|
||||||
|
|
||||||
cv::warpAffine(in_mat, out_mat, cv_affine_transform,
|
cv::warpAffine(in_mat, out_mat, cv_affine_transform,
|
||||||
cv::Size(out_mat.cols, out_mat.rows),
|
cv::Size(out_mat.cols, out_mat.rows),
|
||||||
/*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP,
|
/*flags=*/interpolation_ | cv::WARP_INVERSE_MAP,
|
||||||
GetBorderModeForOpenCv(border_mode));
|
GetBorderModeForOpenCv(border_mode));
|
||||||
|
|
||||||
return out_image;
|
return out_image;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int interpolation_ = cv::INTER_LINEAR;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
absl::StatusOr<
|
absl::StatusOr<
|
||||||
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
||||||
CreateAffineTransformationOpenCvRunner() {
|
CreateAffineTransformationOpenCvRunner(
|
||||||
return absl::make_unique<OpenCvRunner>();
|
AffineTransformation::Interpolation interpolation) {
|
||||||
|
return absl::make_unique<OpenCvRunner>(interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -25,7 +25,8 @@ namespace mediapipe {
|
||||||
|
|
||||||
absl::StatusOr<
|
absl::StatusOr<
|
||||||
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
|
||||||
CreateAffineTransformationOpenCvRunner();
|
CreateAffineTransformationOpenCvRunner(
|
||||||
|
AffineTransformation::Interpolation interpolation);
|
||||||
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
||||||
|
|
|
@ -81,7 +81,8 @@ class ImageCloneCalculator : public Node {
|
||||||
absl::Status Process(CalculatorContext* cc) override {
|
absl::Status Process(CalculatorContext* cc) override {
|
||||||
std::unique_ptr<Image> output;
|
std::unique_ptr<Image> output;
|
||||||
const auto& input = *kIn(cc);
|
const auto& input = *kIn(cc);
|
||||||
if (input.UsesGpu()) {
|
bool input_on_gpu = input.UsesGpu();
|
||||||
|
if (input_on_gpu) {
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
// Create an output Image that co-owns the underlying texture buffer as
|
// Create an output Image that co-owns the underlying texture buffer as
|
||||||
// the input Image.
|
// the input Image.
|
||||||
|
@ -97,15 +98,15 @@ class ImageCloneCalculator : public Node {
|
||||||
// Image. This ensures a correct life span of the shared pixel data.
|
// Image. This ensures a correct life span of the shared pixel data.
|
||||||
output = std::make_unique<Image>(std::make_unique<mediapipe::ImageFrame>(
|
output = std::make_unique<Image>(std::make_unique<mediapipe::ImageFrame>(
|
||||||
input.image_format(), input.width(), input.height(), input.step(),
|
input.image_format(), input.width(), input.height(), input.step(),
|
||||||
const_cast<uint8*>(input.GetImageFrameSharedPtr()->PixelData()),
|
const_cast<uint8_t*>(input.GetImageFrameSharedPtr()->PixelData()),
|
||||||
[packet_copy_ptr](uint8*) { delete packet_copy_ptr; }));
|
[packet_copy_ptr](uint8_t*) { delete packet_copy_ptr; }));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (output_on_gpu_) {
|
if (output_on_gpu_ && !input_on_gpu) {
|
||||||
#if !MEDIAPIPE_DISABLE_GPU
|
#if !MEDIAPIPE_DISABLE_GPU
|
||||||
gpu_helper_.RunInGlContext([&output]() { output->ConvertToGpu(); });
|
gpu_helper_.RunInGlContext([&output]() { output->ConvertToGpu(); });
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
} else {
|
} else if (!output_on_gpu_ && input_on_gpu) {
|
||||||
output->ConvertToCpu();
|
output->ConvertToCpu();
|
||||||
}
|
}
|
||||||
kOut(cc).Send(std::move(output));
|
kOut(cc).Send(std::move(output));
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||||
#include "mediapipe/framework/port/logging.h"
|
#include "mediapipe/framework/port/logging.h"
|
||||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/framework/port/vector.h"
|
#include "mediapipe/framework/port/vector.h"
|
||||||
|
|
||||||
|
@ -53,24 +54,16 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
|
||||||
// range of [0, 1). Only the first channel of Alpha is used. Input & output Mat
|
// range of [0, 1). Only the first channel of Alpha is used. Input & output Mat
|
||||||
// must be uchar.
|
// must be uchar.
|
||||||
template <typename AlphaType>
|
template <typename AlphaType>
|
||||||
absl::Status MergeRGBA8Image(const cv::Mat input_mat, const cv::Mat& alpha_mat,
|
absl::Status CopyAlphaImage(const cv::Mat& alpha_mat, cv::Mat& output_mat) {
|
||||||
cv::Mat& output_mat) {
|
RET_CHECK_EQ(output_mat.rows, alpha_mat.rows);
|
||||||
RET_CHECK_EQ(input_mat.rows, alpha_mat.rows);
|
RET_CHECK_EQ(output_mat.cols, alpha_mat.cols);
|
||||||
RET_CHECK_EQ(input_mat.cols, alpha_mat.cols);
|
|
||||||
RET_CHECK_EQ(input_mat.rows, output_mat.rows);
|
|
||||||
RET_CHECK_EQ(input_mat.cols, output_mat.cols);
|
|
||||||
|
|
||||||
for (int i = 0; i < output_mat.rows; ++i) {
|
for (int i = 0; i < output_mat.rows; ++i) {
|
||||||
const uchar* in_ptr = input_mat.ptr<uchar>(i);
|
|
||||||
const AlphaType* alpha_ptr = alpha_mat.ptr<AlphaType>(i);
|
const AlphaType* alpha_ptr = alpha_mat.ptr<AlphaType>(i);
|
||||||
uchar* out_ptr = output_mat.ptr<uchar>(i);
|
uchar* out_ptr = output_mat.ptr<uchar>(i);
|
||||||
for (int j = 0; j < output_mat.cols; ++j) {
|
for (int j = 0; j < output_mat.cols; ++j) {
|
||||||
const int out_idx = j * kNumChannelsRGBA;
|
const int out_idx = j * kNumChannelsRGBA;
|
||||||
const int in_idx = j * input_mat.channels();
|
|
||||||
const int alpha_idx = j * alpha_mat.channels();
|
const int alpha_idx = j * alpha_mat.channels();
|
||||||
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
|
|
||||||
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
|
|
||||||
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
|
|
||||||
if constexpr (std::is_same<AlphaType, uchar>::value) {
|
if constexpr (std::is_same<AlphaType, uchar>::value) {
|
||||||
out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask
|
out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask
|
||||||
} else {
|
} else {
|
||||||
|
@ -273,7 +266,7 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
|
||||||
|
|
||||||
// Setup source image
|
// Setup source image
|
||||||
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
|
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
|
||||||
const cv::Mat input_mat = mediapipe::formats::MatView(&input_frame);
|
const cv::Mat input_mat = formats::MatView(&input_frame);
|
||||||
if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) {
|
if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) {
|
||||||
LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported";
|
LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported";
|
||||||
}
|
}
|
||||||
|
@ -281,38 +274,38 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
|
||||||
// Setup destination image
|
// Setup destination image
|
||||||
auto output_frame = absl::make_unique<ImageFrame>(
|
auto output_frame = absl::make_unique<ImageFrame>(
|
||||||
ImageFormat::SRGBA, input_mat.cols, input_mat.rows);
|
ImageFormat::SRGBA, input_mat.cols, input_mat.rows);
|
||||||
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
|
cv::Mat output_mat = formats::MatView(output_frame.get());
|
||||||
|
|
||||||
const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) &&
|
const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) &&
|
||||||
!cc->Inputs().Tag(kInputAlphaTag).IsEmpty();
|
!cc->Inputs().Tag(kInputAlphaTag).IsEmpty();
|
||||||
const bool use_alpha_mask = alpha_value_ < 0 && has_alpha_mask;
|
const bool use_alpha_mask = alpha_value_ < 0 && has_alpha_mask;
|
||||||
|
|
||||||
// Setup alpha image and Update image in CPU.
|
// Copy rgb part of the image in CPU
|
||||||
|
if (input_mat.channels() == 3) {
|
||||||
|
cv::cvtColor(input_mat, output_mat, cv::COLOR_RGB2RGBA);
|
||||||
|
} else {
|
||||||
|
input_mat.copyTo(output_mat);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup alpha image in CPU.
|
||||||
if (use_alpha_mask) {
|
if (use_alpha_mask) {
|
||||||
const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get<ImageFrame>();
|
const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get<ImageFrame>();
|
||||||
cv::Mat alpha_mat = mediapipe::formats::MatView(&alpha_mask);
|
cv::Mat alpha_mat = formats::MatView(&alpha_mask);
|
||||||
|
|
||||||
const bool alpha_is_float = CV_MAT_DEPTH(alpha_mat.type()) == CV_32F;
|
const bool alpha_is_float = CV_MAT_DEPTH(alpha_mat.type()) == CV_32F;
|
||||||
RET_CHECK(alpha_is_float || CV_MAT_DEPTH(alpha_mat.type()) == CV_8U);
|
RET_CHECK(alpha_is_float || CV_MAT_DEPTH(alpha_mat.type()) == CV_8U);
|
||||||
|
|
||||||
if (alpha_is_float) {
|
if (alpha_is_float) {
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(CopyAlphaImage<float>(alpha_mat, output_mat));
|
||||||
MergeRGBA8Image<float>(input_mat, alpha_mat, output_mat));
|
|
||||||
} else {
|
} else {
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(CopyAlphaImage<uchar>(alpha_mat, output_mat));
|
||||||
MergeRGBA8Image<uchar>(input_mat, alpha_mat, output_mat));
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f);
|
const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f);
|
||||||
for (int i = 0; i < output_mat.rows; ++i) {
|
for (int i = 0; i < output_mat.rows; ++i) {
|
||||||
const uchar* in_ptr = input_mat.ptr<uchar>(i);
|
|
||||||
uchar* out_ptr = output_mat.ptr<uchar>(i);
|
uchar* out_ptr = output_mat.ptr<uchar>(i);
|
||||||
for (int j = 0; j < output_mat.cols; ++j) {
|
for (int j = 0; j < output_mat.cols; ++j) {
|
||||||
const int out_idx = j * kNumChannelsRGBA;
|
const int out_idx = j * kNumChannelsRGBA;
|
||||||
const int in_idx = j * input_mat.channels();
|
|
||||||
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
|
|
||||||
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
|
|
||||||
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
|
|
||||||
out_ptr[out_idx + 3] = alpha_value; // use value from options
|
out_ptr[out_idx + 3] = alpha_value; // use value from options
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
156
mediapipe/calculators/image/set_alpha_calculator_test.cc
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "mediapipe/calculators/image/set_alpha_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||||
|
#include "mediapipe/framework/formats/rect.pb.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||||
|
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "testing/base/public/benchmark.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr int input_width = 100;
|
||||||
|
constexpr int input_height = 100;
|
||||||
|
|
||||||
|
std::unique_ptr<ImageFrame> GetInputFrame(int width, int height, int channel) {
|
||||||
|
const int total_size = width * height * channel;
|
||||||
|
|
||||||
|
ImageFormat::Format image_format;
|
||||||
|
if (channel == 4) {
|
||||||
|
image_format = ImageFormat::SRGBA;
|
||||||
|
} else if (channel == 3) {
|
||||||
|
image_format = ImageFormat::SRGB;
|
||||||
|
} else {
|
||||||
|
image_format = ImageFormat::GRAY8;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto input_frame = std::make_unique<ImageFrame>(image_format, width, height,
|
||||||
|
/*alignment_boundary =*/1);
|
||||||
|
for (int i = 0; i < total_size; ++i) {
|
||||||
|
input_frame->MutablePixelData()[i] = i % 256;
|
||||||
|
}
|
||||||
|
return input_frame;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test SetAlphaCalculator with RGB IMAGE input.
|
||||||
|
TEST(SetAlphaCalculatorTest, CpuRgb) {
|
||||||
|
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
|
||||||
|
R"pb(
|
||||||
|
calculator: "SetAlphaCalculator"
|
||||||
|
input_stream: "IMAGE:input_frames"
|
||||||
|
input_stream: "ALPHA:masks"
|
||||||
|
output_stream: "IMAGE:output_frames"
|
||||||
|
)pb");
|
||||||
|
CalculatorRunner runner(calculator_node);
|
||||||
|
|
||||||
|
// Input frames.
|
||||||
|
const auto input_frame = GetInputFrame(input_width, input_height, 3);
|
||||||
|
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
|
||||||
|
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
|
||||||
|
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
|
||||||
|
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
|
||||||
|
input_frame_packet.At(Timestamp(1)));
|
||||||
|
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
|
||||||
|
mask_frame_packet.At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& outputs = runner.Outputs();
|
||||||
|
EXPECT_EQ(outputs.NumEntries(), 1);
|
||||||
|
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
|
||||||
|
|
||||||
|
// Generate ground truth (expected_mat).
|
||||||
|
const auto image = GetInputFrame(input_width, input_height, 3);
|
||||||
|
const auto input_mat = formats::MatView(image.get());
|
||||||
|
const auto mask = GetInputFrame(input_width, input_height, 1);
|
||||||
|
const auto mask_mat = formats::MatView(mask.get());
|
||||||
|
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
|
||||||
|
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
|
||||||
|
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 3, 3});
|
||||||
|
|
||||||
|
cv::Mat output_mat = formats::MatView(&output_image);
|
||||||
|
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
|
||||||
|
EXPECT_FLOAT_EQ(max_diff, 0);
|
||||||
|
} // TEST
|
||||||
|
|
||||||
|
// Test SetAlphaCalculator with RGBA IMAGE input.
|
||||||
|
TEST(SetAlphaCalculatorTest, CpuRgba) {
|
||||||
|
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
|
||||||
|
R"pb(
|
||||||
|
calculator: "SetAlphaCalculator"
|
||||||
|
input_stream: "IMAGE:input_frames"
|
||||||
|
input_stream: "ALPHA:masks"
|
||||||
|
output_stream: "IMAGE:output_frames"
|
||||||
|
)pb");
|
||||||
|
CalculatorRunner runner(calculator_node);
|
||||||
|
|
||||||
|
// Input frames.
|
||||||
|
const auto input_frame = GetInputFrame(input_width, input_height, 4);
|
||||||
|
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
|
||||||
|
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
|
||||||
|
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
|
||||||
|
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
|
||||||
|
input_frame_packet.At(Timestamp(1)));
|
||||||
|
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
|
||||||
|
mask_frame_packet.At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& outputs = runner.Outputs();
|
||||||
|
EXPECT_EQ(outputs.NumEntries(), 1);
|
||||||
|
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
|
||||||
|
|
||||||
|
// Generate ground truth (expected_mat).
|
||||||
|
const auto image = GetInputFrame(input_width, input_height, 4);
|
||||||
|
const auto input_mat = formats::MatView(image.get());
|
||||||
|
const auto mask = GetInputFrame(input_width, input_height, 1);
|
||||||
|
const auto mask_mat = formats::MatView(mask.get());
|
||||||
|
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
|
||||||
|
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
|
||||||
|
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 4, 3});
|
||||||
|
|
||||||
|
cv::Mat output_mat = formats::MatView(&output_image);
|
||||||
|
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
|
||||||
|
EXPECT_FLOAT_EQ(max_diff, 0);
|
||||||
|
} // TEST
|
||||||
|
|
||||||
|
static void BM_SetAlpha3ChannelImage(benchmark::State& state) {
|
||||||
|
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
|
||||||
|
R"pb(
|
||||||
|
calculator: "SetAlphaCalculator"
|
||||||
|
input_stream: "IMAGE:input_frames"
|
||||||
|
input_stream: "ALPHA:masks"
|
||||||
|
output_stream: "IMAGE:output_frames"
|
||||||
|
)pb");
|
||||||
|
CalculatorRunner runner(calculator_node);
|
||||||
|
|
||||||
|
// Input frames.
|
||||||
|
const auto input_frame = GetInputFrame(input_width, input_height, 3);
|
||||||
|
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
|
||||||
|
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
|
||||||
|
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
|
||||||
|
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
|
||||||
|
input_frame_packet.At(Timestamp(1)));
|
||||||
|
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
|
||||||
|
mask_frame_packet.At(Timestamp(1)));
|
||||||
|
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
const auto& outputs = runner.Outputs();
|
||||||
|
ASSERT_EQ(1, outputs.NumEntries());
|
||||||
|
|
||||||
|
for (const auto _ : state) {
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BENCHMARK(BM_SetAlpha3ChannelImage);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -53,6 +53,17 @@ AffineTransformation::BorderMode GetBorderMode(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AffineTransformation::Interpolation GetInterpolation(
|
||||||
|
mediapipe::WarpAffineCalculatorOptions::Interpolation interpolation) {
|
||||||
|
switch (interpolation) {
|
||||||
|
case mediapipe::WarpAffineCalculatorOptions::INTER_UNSPECIFIED:
|
||||||
|
case mediapipe::WarpAffineCalculatorOptions::INTER_LINEAR:
|
||||||
|
return AffineTransformation::Interpolation::kLinear;
|
||||||
|
case mediapipe::WarpAffineCalculatorOptions::INTER_CUBIC:
|
||||||
|
return AffineTransformation::Interpolation::kCubic;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename ImageT>
|
template <typename ImageT>
|
||||||
class WarpAffineRunnerHolder {};
|
class WarpAffineRunnerHolder {};
|
||||||
|
|
||||||
|
@ -61,16 +72,22 @@ template <>
|
||||||
class WarpAffineRunnerHolder<ImageFrame> {
|
class WarpAffineRunnerHolder<ImageFrame> {
|
||||||
public:
|
public:
|
||||||
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
|
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
|
||||||
absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); }
|
absl::Status Open(CalculatorContext* cc) {
|
||||||
|
interpolation_ = GetInterpolation(
|
||||||
|
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
absl::StatusOr<RunnerType*> GetRunner() {
|
absl::StatusOr<RunnerType*> GetRunner() {
|
||||||
if (!runner_) {
|
if (!runner_) {
|
||||||
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner());
|
ASSIGN_OR_RETURN(runner_,
|
||||||
|
CreateAffineTransformationOpenCvRunner(interpolation_));
|
||||||
}
|
}
|
||||||
return runner_.get();
|
return runner_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<RunnerType> runner_;
|
std::unique_ptr<RunnerType> runner_;
|
||||||
|
AffineTransformation::Interpolation interpolation_;
|
||||||
};
|
};
|
||||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||||
|
|
||||||
|
@ -85,12 +102,14 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
|
||||||
gpu_origin_ =
|
gpu_origin_ =
|
||||||
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
|
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
|
||||||
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
|
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
|
||||||
|
interpolation_ = GetInterpolation(
|
||||||
|
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
|
||||||
return gl_helper_->Open(cc);
|
return gl_helper_->Open(cc);
|
||||||
}
|
}
|
||||||
absl::StatusOr<RunnerType*> GetRunner() {
|
absl::StatusOr<RunnerType*> GetRunner() {
|
||||||
if (!runner_) {
|
if (!runner_) {
|
||||||
ASSIGN_OR_RETURN(
|
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationGlRunner(
|
||||||
runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_));
|
gl_helper_, gpu_origin_, interpolation_));
|
||||||
}
|
}
|
||||||
return runner_.get();
|
return runner_.get();
|
||||||
}
|
}
|
||||||
|
@ -99,6 +118,7 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
|
||||||
mediapipe::GpuOrigin::Mode gpu_origin_;
|
mediapipe::GpuOrigin::Mode gpu_origin_;
|
||||||
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
|
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
|
||||||
std::unique_ptr<RunnerType> runner_;
|
std::unique_ptr<RunnerType> runner_;
|
||||||
|
AffineTransformation::Interpolation interpolation_;
|
||||||
};
|
};
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,13 @@ message WarpAffineCalculatorOptions {
|
||||||
BORDER_REPLICATE = 2;
|
BORDER_REPLICATE = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pixel sampling interpolation methods. See @interpolation.
|
||||||
|
enum Interpolation {
|
||||||
|
INTER_UNSPECIFIED = 0;
|
||||||
|
INTER_LINEAR = 1;
|
||||||
|
INTER_CUBIC = 2;
|
||||||
|
}
|
||||||
|
|
||||||
// Pixel extrapolation method.
|
// Pixel extrapolation method.
|
||||||
// When converting image to tensor it may happen that tensor needs to read
|
// When converting image to tensor it may happen that tensor needs to read
|
||||||
// pixels outside image boundaries. Border mode helps to specify how such
|
// pixels outside image boundaries. Border mode helps to specify how such
|
||||||
|
@ -43,4 +50,10 @@ message WarpAffineCalculatorOptions {
|
||||||
// to be flipped vertically as tensors are expected to start at top.
|
// to be flipped vertically as tensors are expected to start at top.
|
||||||
// (DEFAULT or unset interpreted as CONVENTIONAL.)
|
// (DEFAULT or unset interpreted as CONVENTIONAL.)
|
||||||
optional GpuOrigin.Mode gpu_origin = 2;
|
optional GpuOrigin.Mode gpu_origin = 2;
|
||||||
|
|
||||||
|
// Sampling method for neighboring pixels.
|
||||||
|
// INTER_LINEAR (bilinear) linearly interpolates from the nearest 4 neighbors.
|
||||||
|
// INTER_CUBIC (bicubic) interpolates a small neighborhood with cubic weights.
|
||||||
|
// INTER_UNSPECIFIED or unset interpreted as INTER_LINEAR.
|
||||||
|
optional Interpolation interpolation = 3;
|
||||||
}
|
}
|
||||||
|
|
|
@ -63,7 +63,8 @@ void RunTest(const std::string& graph_text, const std::string& tag,
|
||||||
const cv::Mat& input, cv::Mat expected_result,
|
const cv::Mat& input, cv::Mat expected_result,
|
||||||
float similarity_threshold, std::array<float, 16> matrix,
|
float similarity_threshold, std::array<float, 16> matrix,
|
||||||
int out_width, int out_height,
|
int out_width, int out_height,
|
||||||
absl::optional<AffineTransformation::BorderMode> border_mode) {
|
std::optional<AffineTransformation::BorderMode> border_mode,
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation) {
|
||||||
std::string border_mode_str;
|
std::string border_mode_str;
|
||||||
if (border_mode) {
|
if (border_mode) {
|
||||||
switch (*border_mode) {
|
switch (*border_mode) {
|
||||||
|
@ -75,8 +76,20 @@ void RunTest(const std::string& graph_text, const std::string& tag,
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
std::string interpolation_str;
|
||||||
|
if (interpolation) {
|
||||||
|
switch (*interpolation) {
|
||||||
|
case AffineTransformation::Interpolation::kLinear:
|
||||||
|
interpolation_str = "interpolation: INTER_LINEAR";
|
||||||
|
break;
|
||||||
|
case AffineTransformation::Interpolation::kCubic:
|
||||||
|
interpolation_str = "interpolation: INTER_CUBIC";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||||
absl::Substitute(graph_text, /*$0=*/border_mode_str));
|
absl::Substitute(graph_text, /*$0=*/border_mode_str,
|
||||||
|
/*$1=*/interpolation_str));
|
||||||
|
|
||||||
std::vector<Packet> output_packets;
|
std::vector<Packet> output_packets;
|
||||||
tool::AddVectorSink("output_image", &graph_config, &output_packets);
|
tool::AddVectorSink("output_image", &graph_config, &output_packets);
|
||||||
|
@ -132,7 +145,8 @@ struct SimilarityConfig {
|
||||||
void RunTest(cv::Mat input, cv::Mat expected_result,
|
void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
const SimilarityConfig& similarity, std::array<float, 16> matrix,
|
const SimilarityConfig& similarity, std::array<float, 16> matrix,
|
||||||
int out_width, int out_height,
|
int out_width, int out_height,
|
||||||
absl::optional<AffineTransformation::BorderMode> border_mode) {
|
std::optional<AffineTransformation::BorderMode> border_mode,
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation) {
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
input_stream: "output_size"
|
input_stream: "output_size"
|
||||||
|
@ -146,12 +160,13 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
|
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
|
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
|
@ -171,6 +186,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -181,7 +197,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"cpu_image", input, expected_result, similarity.threshold_on_cpu,
|
"cpu_image", input, expected_result, similarity.threshold_on_cpu,
|
||||||
matrix, out_width, out_height, border_mode);
|
matrix, out_width, out_height, border_mode, interpolation);
|
||||||
|
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
|
@ -201,6 +217,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
gpu_origin: TOP_LEFT
|
gpu_origin: TOP_LEFT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -212,7 +229,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
|
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
|
|
||||||
RunTest(R"(
|
RunTest(R"(
|
||||||
input_stream: "input_image"
|
input_stream: "input_image"
|
||||||
|
@ -237,6 +254,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
options {
|
options {
|
||||||
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
[mediapipe.WarpAffineCalculatorOptions.ext] {
|
||||||
$0 # border mode
|
$0 # border mode
|
||||||
|
$1 # interpolation
|
||||||
gpu_origin: TOP_LEFT
|
gpu_origin: TOP_LEFT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -253,7 +271,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||||
}
|
}
|
||||||
)",
|
)",
|
||||||
"gpu_image", input, expected_result, similarity.threshold_on_gpu,
|
"gpu_image", input, expected_result, similarity.threshold_on_gpu,
|
||||||
matrix, out_width, out_height, border_mode);
|
matrix, out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
|
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
|
||||||
|
@ -287,10 +305,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) {
|
||||||
int out_height = 256;
|
int out_height = 256;
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
||||||
|
@ -312,10 +331,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
||||||
|
@ -337,10 +357,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
|
||||||
|
@ -362,10 +383,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
|
||||||
|
@ -386,10 +408,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
||||||
|
@ -411,10 +434,38 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZeroInterpCubic) {
|
||||||
|
mediapipe::NormalizedRect roi;
|
||||||
|
roi.set_x_center(0.65f);
|
||||||
|
roi.set_y_center(0.4f);
|
||||||
|
roi.set_width(0.5f);
|
||||||
|
roi.set_height(0.5f);
|
||||||
|
roi.set_rotation(M_PI * -45.0f / 180.0f);
|
||||||
|
auto input = GetRgb(
|
||||||
|
"/mediapipe/calculators/"
|
||||||
|
"tensor/testdata/image_to_tensor/input.jpg");
|
||||||
|
auto expected_output = GetRgb(
|
||||||
|
"/mediapipe/calculators/"
|
||||||
|
"tensor/testdata/image_to_tensor/"
|
||||||
|
"medium_sub_rect_with_rotation_border_zero_interp_cubic.png");
|
||||||
|
int out_width = 256;
|
||||||
|
int out_height = 256;
|
||||||
|
bool keep_aspect_ratio = false;
|
||||||
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation =
|
||||||
|
AffineTransformation::Interpolation::kCubic;
|
||||||
|
RunTest(input, expected_output,
|
||||||
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.78},
|
||||||
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRect) {
|
TEST(WarpAffineCalculatorTest, LargeSubRect) {
|
||||||
|
@ -435,10 +486,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRect) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
|
||||||
|
@ -459,10 +511,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
|
||||||
bool keep_aspect_ratio = false;
|
bool keep_aspect_ratio = false;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
|
||||||
|
@ -483,10 +536,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
||||||
|
@ -508,10 +562,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
||||||
|
@ -532,10 +587,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
||||||
int out_height = 128;
|
int out_height = 128;
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
std::optional<AffineTransformation::BorderMode> border_mode = {};
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
|
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
|
||||||
|
@ -557,10 +613,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, NoOp) {
|
TEST(WarpAffineCalculatorTest, NoOp) {
|
||||||
|
@ -581,10 +638,11 @@ TEST(WarpAffineCalculatorTest, NoOp) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kReplicate;
|
AffineTransformation::BorderMode::kReplicate;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
|
TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
|
||||||
|
@ -605,10 +663,11 @@ TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
|
||||||
bool keep_aspect_ratio = true;
|
bool keep_aspect_ratio = true;
|
||||||
std::optional<AffineTransformation::BorderMode> border_mode =
|
std::optional<AffineTransformation::BorderMode> border_mode =
|
||||||
AffineTransformation::BorderMode::kZero;
|
AffineTransformation::BorderMode::kZero;
|
||||||
|
std::optional<AffineTransformation::Interpolation> interpolation = {};
|
||||||
RunTest(input, expected_output,
|
RunTest(input, expected_output,
|
||||||
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
|
||||||
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
|
||||||
out_width, out_height, border_mode);
|
out_width, out_height, border_mode, interpolation);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -997,17 +997,20 @@ cc_library(
|
||||||
":image_to_tensor_converter_gl_buffer",
|
":image_to_tensor_converter_gl_buffer",
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
"//mediapipe/gpu:gpu_service",
|
||||||
],
|
],
|
||||||
"//mediapipe:apple": [
|
"//mediapipe:apple": [
|
||||||
":image_to_tensor_converter_metal",
|
":image_to_tensor_converter_metal",
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
"//mediapipe/gpu:MPPMetalHelper",
|
"//mediapipe/gpu:MPPMetalHelper",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
"//mediapipe/gpu:gpu_service",
|
||||||
],
|
],
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
":image_to_tensor_converter_gl_buffer",
|
":image_to_tensor_converter_gl_buffer",
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
|
"//mediapipe/gpu:gpu_service",
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -1045,6 +1048,10 @@ cc_test(
|
||||||
":image_to_tensor_calculator",
|
":image_to_tensor_calculator",
|
||||||
":image_to_tensor_converter",
|
":image_to_tensor_converter",
|
||||||
":image_to_tensor_utils",
|
":image_to_tensor_utils",
|
||||||
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:calculator_runner",
|
"//mediapipe/framework:calculator_runner",
|
||||||
"//mediapipe/framework/deps:file_path",
|
"//mediapipe/framework/deps:file_path",
|
||||||
|
@ -1061,11 +1068,10 @@ cc_test(
|
||||||
"//mediapipe/framework/port:opencv_imgproc",
|
"//mediapipe/framework/port:opencv_imgproc",
|
||||||
"//mediapipe/framework/port:parse_text_proto",
|
"//mediapipe/framework/port:parse_text_proto",
|
||||||
"//mediapipe/util:image_test_utils",
|
"//mediapipe/util:image_test_utils",
|
||||||
"@com_google_absl//absl/flags:flag",
|
] + select({
|
||||||
"@com_google_absl//absl/memory",
|
"//mediapipe:apple": [],
|
||||||
"@com_google_absl//absl/strings",
|
"//conditions:default": ["//mediapipe/gpu:gl_context"],
|
||||||
"@com_google_absl//absl/strings:str_format",
|
}),
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
|
|
@ -45,9 +45,11 @@
|
||||||
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||||
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h"
|
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h"
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
|
#include "mediapipe/gpu/gpu_service.h"
|
||||||
#else
|
#else
|
||||||
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h"
|
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h"
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
|
#include "mediapipe/gpu/gpu_service.h"
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
|
@ -147,7 +149,7 @@ class ImageToTensorCalculator : public Node {
|
||||||
#if MEDIAPIPE_METAL_ENABLED
|
#if MEDIAPIPE_METAL_ENABLED
|
||||||
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
|
||||||
#else
|
#else
|
||||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
cc->UseService(kGpuService).Optional();
|
||||||
#endif // MEDIAPIPE_METAL_ENABLED
|
#endif // MEDIAPIPE_METAL_ENABLED
|
||||||
#endif // MEDIAPIPE_DISABLE_GPU
|
#endif // MEDIAPIPE_DISABLE_GPU
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,10 @@
|
||||||
#include "mediapipe/framework/port/status_matchers.h"
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
#include "mediapipe/util/image_test_utils.h"
|
#include "mediapipe/util/image_test_utils.h"
|
||||||
|
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
#include "mediapipe/gpu/gl_context.h"
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
@ -507,5 +511,79 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeAndUseInputImageDims) {
|
||||||
/*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt,
|
/*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt,
|
||||||
/*keep_aspect=*/false, BorderMode::kZero, roi);
|
/*keep_aspect=*/false, BorderMode::kZero, roi);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ImageToTensorCalculatorTest, CanBeUsedWithoutGpuServiceSet) {
|
||||||
|
auto graph_config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input_image"
|
||||||
|
node {
|
||||||
|
calculator: "ImageToTensorCalculator"
|
||||||
|
input_stream: "IMAGE:input_image"
|
||||||
|
output_stream: "TENSORS:tensor"
|
||||||
|
options {
|
||||||
|
[mediapipe.ImageToTensorCalculatorOptions.ext] {
|
||||||
|
output_tensor_float_range { min: 0.0f max: 1.0f }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
auto image_frame =
|
||||||
|
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
|
||||||
|
Image image = Image(std::move(image_frame));
|
||||||
|
Packet packet = MakePacket<Image>(std::move(image));
|
||||||
|
MP_ASSERT_OK(
|
||||||
|
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||||
|
MP_ASSERT_OK(graph.CloseAllPacketSources());
|
||||||
|
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
|
TEST(ImageToTensorCalculatorTest,
|
||||||
|
FailsGracefullyWhenGpuServiceNeededButNotAvailable) {
|
||||||
|
auto graph_config =
|
||||||
|
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||||
|
input_stream: "input_image"
|
||||||
|
node {
|
||||||
|
calculator: "ImageToTensorCalculator"
|
||||||
|
input_stream: "IMAGE:input_image"
|
||||||
|
output_stream: "TENSORS:tensor"
|
||||||
|
options {
|
||||||
|
[mediapipe.ImageToTensorCalculatorOptions.ext] {
|
||||||
|
output_tensor_float_range { min: 0.0f max: 1.0f }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
CalculatorGraph graph;
|
||||||
|
MP_ASSERT_OK(graph.Initialize(graph_config));
|
||||||
|
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
|
||||||
|
MP_ASSERT_OK(graph.StartRun({}));
|
||||||
|
|
||||||
|
MP_ASSERT_OK_AND_ASSIGN(auto context,
|
||||||
|
GlContext::Create(nullptr, /*create_thread=*/true));
|
||||||
|
Packet packet;
|
||||||
|
context->Run([&packet]() {
|
||||||
|
auto image_frame =
|
||||||
|
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
|
||||||
|
Image image = Image(std::move(image_frame));
|
||||||
|
// Ensure image is available on GPU to force ImageToTensorCalculator to
|
||||||
|
// run on GPU.
|
||||||
|
ASSERT_TRUE(image.ConvertToGpu());
|
||||||
|
packet = MakePacket<Image>(std::move(image));
|
||||||
|
});
|
||||||
|
MP_ASSERT_OK(
|
||||||
|
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
|
||||||
|
EXPECT_THAT(graph.WaitUntilIdle(),
|
||||||
|
StatusIs(absl::StatusCode::kInternal,
|
||||||
|
HasSubstr("GPU service not available")));
|
||||||
|
}
|
||||||
|
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -141,7 +141,7 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
|
||||||
}
|
}
|
||||||
// Run inference.
|
// Run inference.
|
||||||
{
|
{
|
||||||
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
|
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE_ADVANCED, cc);
|
||||||
return tflite_gpu_runner_->Invoke();
|
return tflite_gpu_runner_->Invoke();
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// The option proto for the TensorsReadbackCalculator.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
|
||||||
|
message TensorsReadbackCalculatorOptions {
|
||||||
|
extend mediapipe.CalculatorOptions {
|
||||||
|
optional TensorsReadbackCalculatorOptions ext = 514750372;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expected shapes of the input tensors.
|
||||||
|
// The calculator uses these shape to build the GPU programs during
|
||||||
|
// initialization, and check the actual tensor shapes against the expected
|
||||||
|
// shapes during runtime.
|
||||||
|
// Batch size of the tensor is set to be 1. `TensorShape` here can be C, WC,
|
||||||
|
// or HWC.
|
||||||
|
// For example {dims: 1 dims: 2} represents a tensor with batch_size = 1,
|
||||||
|
// width = 1, and num_channels = 2.
|
||||||
|
message TensorShape {
|
||||||
|
repeated int32 dims = 1 [packed = true];
|
||||||
|
}
|
||||||
|
// tensor_shape specifies the shape of each input tensors.
|
||||||
|
repeated TensorShape tensor_shape = 1;
|
||||||
|
}
|
After Width: | Height: | Size: 64 KiB |
|
@ -14,6 +14,7 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
|
||||||
|
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
|
||||||
load("@bazel_skylib//lib:selects.bzl", "selects")
|
load("@bazel_skylib//lib:selects.bzl", "selects")
|
||||||
|
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
@ -312,15 +313,19 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
# TODO: Re-evaluate which of these libraries we can avoid making
|
||||||
|
# cc_library_with_tflite and can be changed back to cc_library.
|
||||||
|
cc_library_with_tflite(
|
||||||
name = "tflite_model_calculator",
|
name = "tflite_model_calculator",
|
||||||
srcs = ["tflite_model_calculator.cc"],
|
srcs = ["tflite_model_calculator.cc"],
|
||||||
|
tflite_deps = [
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||||
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework:packet",
|
"//mediapipe/framework:packet",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"@com_google_absl//absl/status",
|
"@com_google_absl//absl/status",
|
||||||
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
|
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
|
@ -66,7 +66,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
|
||||||
} else {
|
} else {
|
||||||
cc->OutputSidePackets()
|
cc->OutputSidePackets()
|
||||||
.Index(0)
|
.Index(0)
|
||||||
.Set<tflite_shims::ops::builtin::BuiltinOpResolver>();
|
.Set<tflite::ops::builtin::BuiltinOpResolver>();
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
@ -77,7 +77,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
|
||||||
const TfLiteCustomOpResolverCalculatorOptions& options =
|
const TfLiteCustomOpResolverCalculatorOptions& options =
|
||||||
cc->Options<TfLiteCustomOpResolverCalculatorOptions>();
|
cc->Options<TfLiteCustomOpResolverCalculatorOptions>();
|
||||||
|
|
||||||
std::unique_ptr<tflite_shims::ops::builtin::BuiltinOpResolver> op_resolver;
|
std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> op_resolver;
|
||||||
if (options.use_gpu()) {
|
if (options.use_gpu()) {
|
||||||
op_resolver = absl::make_unique<mediapipe::OpResolver>();
|
op_resolver = absl::make_unique<mediapipe::OpResolver>();
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -21,7 +21,7 @@
|
||||||
#include "mediapipe/framework/packet.h"
|
#include "mediapipe/framework/packet.h"
|
||||||
#include "mediapipe/framework/port/ret_check.h"
|
#include "mediapipe/framework/port/ret_check.h"
|
||||||
#include "tensorflow/lite/allocation.h"
|
#include "tensorflow/lite/allocation.h"
|
||||||
#include "tensorflow/lite/core/shims/cc/model.h"
|
#include "tensorflow/lite/model.h"
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
|
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
|
||||||
#ifdef ABSL_HAVE_MMAP
|
#if defined(ABSL_HAVE_MMAP) && !TFLITE_WITH_STABLE_ABI
|
||||||
model_packet = cc->InputSidePackets().Tag("MODEL_FD");
|
model_packet = cc->InputSidePackets().Tag("MODEL_FD");
|
||||||
const auto& model_fd =
|
const auto& model_fd =
|
||||||
model_packet.Get<std::tuple<int, size_t, size_t>>();
|
model_packet.Get<std::tuple<int, size_t, size_t>>();
|
||||||
|
|
|
@ -1270,6 +1270,50 @@ cc_library(
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mediapipe_proto_library(
|
||||||
|
name = "flat_color_image_calculator_proto",
|
||||||
|
srcs = ["flat_color_image_calculator.proto"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
"//mediapipe/framework:calculator_proto",
|
||||||
|
"//mediapipe/util:color_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "flat_color_image_calculator",
|
||||||
|
srcs = ["flat_color_image_calculator.cc"],
|
||||||
|
deps = [
|
||||||
|
":flat_color_image_calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework/api2:node",
|
||||||
|
"//mediapipe/framework/formats:image",
|
||||||
|
"//mediapipe/framework/formats:image_frame",
|
||||||
|
"//mediapipe/framework/formats:image_frame_opencv",
|
||||||
|
"//mediapipe/framework/port:opencv_core",
|
||||||
|
"//mediapipe/util:color_cc_proto",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "flat_color_image_calculator_test",
|
||||||
|
srcs = ["flat_color_image_calculator_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":flat_color_image_calculator",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_runner",
|
||||||
|
"//mediapipe/framework:packet",
|
||||||
|
"//mediapipe/framework/formats:image",
|
||||||
|
"//mediapipe/framework/formats:image_frame",
|
||||||
|
"//mediapipe/framework/port:gtest",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/util:color_cc_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "from_image_calculator",
|
name = "from_image_calculator",
|
||||||
srcs = ["from_image_calculator.cc"],
|
srcs = ["from_image_calculator.cc"],
|
||||||
|
|
138
mediapipe/calculators/util/flat_color_image_calculator.cc
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
|
||||||
|
#include "mediapipe/framework/api2/node.h"
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||||
|
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||||
|
#include "mediapipe/util/color.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::mediapipe::api2::Input;
|
||||||
|
using ::mediapipe::api2::Node;
|
||||||
|
using ::mediapipe::api2::Output;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// A calculator for generating an image filled with a single color.
|
||||||
|
//
|
||||||
|
// Inputs:
|
||||||
|
// IMAGE (Image, optional)
|
||||||
|
// If provided, the output will have the same size
|
||||||
|
// COLOR (Color proto, optional)
|
||||||
|
// Color to paint the output with. Takes precedence over the equivalent
|
||||||
|
// calculator options.
|
||||||
|
//
|
||||||
|
// Outputs:
|
||||||
|
// IMAGE (Image)
|
||||||
|
// Image filled with the requested color.
|
||||||
|
//
|
||||||
|
// Example useage:
|
||||||
|
// node {
|
||||||
|
// calculator: "FlatColorImageCalculator"
|
||||||
|
// input_stream: "IMAGE:image"
|
||||||
|
// input_stream: "COLOR:color"
|
||||||
|
// output_stream: "IMAGE:blank_image"
|
||||||
|
// options {
|
||||||
|
// [mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
// color: {
|
||||||
|
// r: 255
|
||||||
|
// g: 255
|
||||||
|
// b: 255
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
class FlatColorImageCalculator : public Node {
|
||||||
|
public:
|
||||||
|
static constexpr Input<Image>::Optional kInImage{"IMAGE"};
|
||||||
|
static constexpr Input<Color>::Optional kInColor{"COLOR"};
|
||||||
|
static constexpr Output<Image> kOutImage{"IMAGE"};
|
||||||
|
|
||||||
|
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage);
|
||||||
|
|
||||||
|
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||||
|
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
||||||
|
|
||||||
|
RET_CHECK(kInImage(cc).IsConnected() ^
|
||||||
|
(options.has_output_height() || options.has_output_width()))
|
||||||
|
<< "Either set IMAGE input stream, or set through options";
|
||||||
|
RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color())
|
||||||
|
<< "Either set COLOR input stream, or set through options";
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool use_dimension_from_option_ = false;
|
||||||
|
bool use_color_from_option_ = false;
|
||||||
|
};
|
||||||
|
MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator);
|
||||||
|
|
||||||
|
absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) {
|
||||||
|
use_dimension_from_option_ = !kInImage(cc).IsConnected();
|
||||||
|
use_color_from_option_ = !kInColor(cc).IsConnected();
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
|
||||||
|
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
|
||||||
|
|
||||||
|
int output_height = -1;
|
||||||
|
int output_width = -1;
|
||||||
|
if (use_dimension_from_option_) {
|
||||||
|
output_height = options.output_height();
|
||||||
|
output_width = options.output_width();
|
||||||
|
} else if (!kInImage(cc).IsEmpty()) {
|
||||||
|
const Image& input_image = kInImage(cc).Get();
|
||||||
|
output_height = input_image.height();
|
||||||
|
output_width = input_image.width();
|
||||||
|
} else {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
Color color;
|
||||||
|
if (use_color_from_option_) {
|
||||||
|
color = options.color();
|
||||||
|
} else if (!kInColor(cc).IsEmpty()) {
|
||||||
|
color = kInColor(cc).Get();
|
||||||
|
} else {
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||||
|
output_width, output_height);
|
||||||
|
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
|
||||||
|
|
||||||
|
output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b()));
|
||||||
|
|
||||||
|
kOutImage(cc).Send(Image(output_frame));
|
||||||
|
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe
|
32
mediapipe/calculators/util/flat_color_image_calculator.proto
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
syntax = "proto2";
|
||||||
|
|
||||||
|
package mediapipe;
|
||||||
|
|
||||||
|
import "mediapipe/framework/calculator.proto";
|
||||||
|
import "mediapipe/util/color.proto";
|
||||||
|
|
||||||
|
message FlatColorImageCalculatorOptions {
|
||||||
|
extend CalculatorOptions {
|
||||||
|
optional FlatColorImageCalculatorOptions ext = 515548435;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output dimensions.
|
||||||
|
optional int32 output_width = 1;
|
||||||
|
optional int32 output_height = 2;
|
||||||
|
// The color to fill with in the output image.
|
||||||
|
optional Color color = 3;
|
||||||
|
}
|
210
mediapipe/calculators/util/flat_color_image_calculator_test.cc
Normal file
|
@ -0,0 +1,210 @@
|
||||||
|
// Copyright 2023 The MediaPipe Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
|
#include "mediapipe/framework/calculator_runner.h"
|
||||||
|
#include "mediapipe/framework/formats/image.h"
|
||||||
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
|
#include "mediapipe/framework/packet.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/status_matchers.h"
|
||||||
|
#include "mediapipe/util/color.pb.h"
|
||||||
|
|
||||||
|
namespace mediapipe {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::HasSubstr;
|
||||||
|
|
||||||
|
constexpr char kImageTag[] = "IMAGE";
|
||||||
|
constexpr char kColorTag[] = "COLOR";
|
||||||
|
constexpr int kImageWidth = 256;
|
||||||
|
constexpr int kImageHeight = 256;
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, SpecifyColorThroughOptions) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "IMAGE:image"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
options {
|
||||||
|
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
color: {
|
||||||
|
r: 100,
|
||||||
|
g: 200,
|
||||||
|
b: 255,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||||
|
kImageWidth, kImageHeight);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
|
MakePacket<Image>(image_frame).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
|
||||||
|
ASSERT_EQ(outputs.size(), 3);
|
||||||
|
|
||||||
|
for (const auto& packet : outputs) {
|
||||||
|
const auto& image = packet.Get<Image>();
|
||||||
|
EXPECT_EQ(image.width(), kImageWidth);
|
||||||
|
EXPECT_EQ(image.height(), kImageHeight);
|
||||||
|
auto image_frame = image.GetImageFrameSharedPtr();
|
||||||
|
auto* pixel_data = image_frame->PixelData();
|
||||||
|
EXPECT_EQ(pixel_data[0], 100);
|
||||||
|
EXPECT_EQ(pixel_data[1], 200);
|
||||||
|
EXPECT_EQ(pixel_data[2], 255);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "COLOR:color"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
options {
|
||||||
|
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
output_width: 7,
|
||||||
|
output_height: 13,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
Color color;
|
||||||
|
color.set_r(0);
|
||||||
|
color.set_g(5);
|
||||||
|
color.set_b(0);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
|
||||||
|
MakePacket<Color>(color).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
MP_ASSERT_OK(runner.Run());
|
||||||
|
|
||||||
|
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
|
||||||
|
ASSERT_EQ(outputs.size(), 3);
|
||||||
|
|
||||||
|
for (const auto& packet : outputs) {
|
||||||
|
const auto& image = packet.Get<Image>();
|
||||||
|
EXPECT_EQ(image.width(), 7);
|
||||||
|
EXPECT_EQ(image.height(), 13);
|
||||||
|
auto image_frame = image.GetImageFrameSharedPtr();
|
||||||
|
const uint8_t* pixel_data = image_frame->PixelData();
|
||||||
|
EXPECT_EQ(pixel_data[0], 0);
|
||||||
|
EXPECT_EQ(pixel_data[1], 5);
|
||||||
|
EXPECT_EQ(pixel_data[2], 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, FailureMissingDimension) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "COLOR:color"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
Color color;
|
||||||
|
color.set_r(0);
|
||||||
|
color.set_g(5);
|
||||||
|
color.set_b(0);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
|
||||||
|
MakePacket<Color>(color).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
ASSERT_THAT(runner.Run().message(),
|
||||||
|
HasSubstr("Either set IMAGE input stream"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, FailureMissingColor) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "IMAGE:image"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||||
|
kImageWidth, kImageHeight);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
|
MakePacket<Image>(image_frame).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
ASSERT_THAT(runner.Run().message(),
|
||||||
|
HasSubstr("Either set COLOR input stream"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, FailureDuplicateDimension) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "IMAGE:image"
|
||||||
|
input_stream: "COLOR:color"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
options {
|
||||||
|
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
output_width: 7,
|
||||||
|
output_height: 13,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
|
||||||
|
kImageWidth, kImageHeight);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
|
||||||
|
MakePacket<Image>(image_frame).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
ASSERT_THAT(runner.Run().message(),
|
||||||
|
HasSubstr("Either set IMAGE input stream"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) {
|
||||||
|
CalculatorRunner runner(R"pb(
|
||||||
|
calculator: "FlatColorImageCalculator"
|
||||||
|
input_stream: "IMAGE:image"
|
||||||
|
input_stream: "COLOR:color"
|
||||||
|
output_stream: "IMAGE:out_image"
|
||||||
|
options {
|
||||||
|
[mediapipe.FlatColorImageCalculatorOptions.ext] {
|
||||||
|
color: {
|
||||||
|
r: 100,
|
||||||
|
g: 200,
|
||||||
|
b: 255,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)pb");
|
||||||
|
|
||||||
|
Color color;
|
||||||
|
color.set_r(0);
|
||||||
|
color.set_g(5);
|
||||||
|
color.set_b(0);
|
||||||
|
|
||||||
|
for (int ts = 0; ts < 3; ++ts) {
|
||||||
|
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
|
||||||
|
MakePacket<Color>(color).At(Timestamp(ts)));
|
||||||
|
}
|
||||||
|
ASSERT_THAT(runner.Run().message(),
|
||||||
|
HasSubstr("Either set COLOR input stream"));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe
|
|
@ -1,5 +1,6 @@
|
||||||
distributionBase=GRADLE_USER_HOME
|
distributionBase=GRADLE_USER_HOME
|
||||||
distributionPath=wrapper/dists
|
distributionPath=wrapper/dists
|
||||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip
|
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip
|
||||||
|
networkTimeout=10000
|
||||||
zipStoreBase=GRADLE_USER_HOME
|
zipStoreBase=GRADLE_USER_HOME
|
||||||
zipStorePath=wrapper/dists
|
zipStorePath=wrapper/dists
|
||||||
|
|
285
mediapipe/examples/android/solutions/gradlew
vendored
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/env sh
|
#!/bin/sh
|
||||||
|
|
||||||
#
|
#
|
||||||
# Copyright 2015 the original author or authors.
|
# Copyright © 2015-2021 the original authors.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -17,67 +17,101 @@
|
||||||
#
|
#
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
##
|
#
|
||||||
## Gradle start up script for UN*X
|
# Gradle start up script for POSIX generated by Gradle.
|
||||||
##
|
#
|
||||||
|
# Important for running:
|
||||||
|
#
|
||||||
|
# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is
|
||||||
|
# noncompliant, but you have some other compliant shell such as ksh or
|
||||||
|
# bash, then to run this script, type that shell name before the whole
|
||||||
|
# command line, like:
|
||||||
|
#
|
||||||
|
# ksh Gradle
|
||||||
|
#
|
||||||
|
# Busybox and similar reduced shells will NOT work, because this script
|
||||||
|
# requires all of these POSIX shell features:
|
||||||
|
# * functions;
|
||||||
|
# * expansions «$var», «${var}», «${var:-default}», «${var+SET}»,
|
||||||
|
# «${var#prefix}», «${var%suffix}», and «$( cmd )»;
|
||||||
|
# * compound commands having a testable exit status, especially «case»;
|
||||||
|
# * various built-in commands including «command», «set», and «ulimit».
|
||||||
|
#
|
||||||
|
# Important for patching:
|
||||||
|
#
|
||||||
|
# (2) This script targets any POSIX shell, so it avoids extensions provided
|
||||||
|
# by Bash, Ksh, etc; in particular arrays are avoided.
|
||||||
|
#
|
||||||
|
# The "traditional" practice of packing multiple parameters into a
|
||||||
|
# space-separated string is a well documented source of bugs and security
|
||||||
|
# problems, so this is (mostly) avoided, by progressively accumulating
|
||||||
|
# options in "$@", and eventually passing that to Java.
|
||||||
|
#
|
||||||
|
# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS,
|
||||||
|
# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly;
|
||||||
|
# see the in-line comments for details.
|
||||||
|
#
|
||||||
|
# There are tweaks for specific operating systems such as AIX, CygWin,
|
||||||
|
# Darwin, MinGW, and NonStop.
|
||||||
|
#
|
||||||
|
# (3) This script is generated from the Groovy template
|
||||||
|
# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
|
||||||
|
# within the Gradle project.
|
||||||
|
#
|
||||||
|
# You can find Gradle at https://github.com/gradle/gradle/.
|
||||||
|
#
|
||||||
##############################################################################
|
##############################################################################
|
||||||
|
|
||||||
# Attempt to set APP_HOME
|
# Attempt to set APP_HOME
|
||||||
# Resolve links: $0 may be a link
|
|
||||||
PRG="$0"
|
|
||||||
# Need this for relative symlinks.
|
|
||||||
while [ -h "$PRG" ] ; do
|
|
||||||
ls=`ls -ld "$PRG"`
|
|
||||||
link=`expr "$ls" : '.*-> \(.*\)$'`
|
|
||||||
if expr "$link" : '/.*' > /dev/null; then
|
|
||||||
PRG="$link"
|
|
||||||
else
|
|
||||||
PRG=`dirname "$PRG"`"/$link"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
SAVED="`pwd`"
|
|
||||||
cd "`dirname \"$PRG\"`/" >/dev/null
|
|
||||||
APP_HOME="`pwd -P`"
|
|
||||||
cd "$SAVED" >/dev/null
|
|
||||||
|
|
||||||
APP_NAME="Gradle"
|
# Resolve links: $0 may be a link
|
||||||
APP_BASE_NAME=`basename "$0"`
|
app_path=$0
|
||||||
|
|
||||||
|
# Need this for daisy-chained symlinks.
|
||||||
|
while
|
||||||
|
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
|
||||||
|
[ -h "$app_path" ]
|
||||||
|
do
|
||||||
|
ls=$( ls -ld "$app_path" )
|
||||||
|
link=${ls#*' -> '}
|
||||||
|
case $link in #(
|
||||||
|
/*) app_path=$link ;; #(
|
||||||
|
*) app_path=$APP_HOME$link ;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# This is normally unused
|
||||||
|
# shellcheck disable=SC2034
|
||||||
|
APP_BASE_NAME=${0##*/}
|
||||||
|
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
|
||||||
|
|
||||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||||
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||||
|
|
||||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||||
MAX_FD="maximum"
|
MAX_FD=maximum
|
||||||
|
|
||||||
warn () {
|
warn () {
|
||||||
echo "$*"
|
echo "$*"
|
||||||
}
|
} >&2
|
||||||
|
|
||||||
die () {
|
die () {
|
||||||
echo
|
echo
|
||||||
echo "$*"
|
echo "$*"
|
||||||
echo
|
echo
|
||||||
exit 1
|
exit 1
|
||||||
}
|
} >&2
|
||||||
|
|
||||||
# OS specific support (must be 'true' or 'false').
|
# OS specific support (must be 'true' or 'false').
|
||||||
cygwin=false
|
cygwin=false
|
||||||
msys=false
|
msys=false
|
||||||
darwin=false
|
darwin=false
|
||||||
nonstop=false
|
nonstop=false
|
||||||
case "`uname`" in
|
case "$( uname )" in #(
|
||||||
CYGWIN* )
|
CYGWIN* ) cygwin=true ;; #(
|
||||||
cygwin=true
|
Darwin* ) darwin=true ;; #(
|
||||||
;;
|
MSYS* | MINGW* ) msys=true ;; #(
|
||||||
Darwin* )
|
NONSTOP* ) nonstop=true ;;
|
||||||
darwin=true
|
|
||||||
;;
|
|
||||||
MINGW* )
|
|
||||||
msys=true
|
|
||||||
;;
|
|
||||||
NONSTOP* )
|
|
||||||
nonstop=true
|
|
||||||
;;
|
|
||||||
esac
|
esac
|
||||||
|
|
||||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||||
|
@ -87,9 +121,9 @@ CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||||
if [ -n "$JAVA_HOME" ] ; then
|
if [ -n "$JAVA_HOME" ] ; then
|
||||||
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||||
# IBM's JDK on AIX uses strange locations for the executables
|
# IBM's JDK on AIX uses strange locations for the executables
|
||||||
JAVACMD="$JAVA_HOME/jre/sh/java"
|
JAVACMD=$JAVA_HOME/jre/sh/java
|
||||||
else
|
else
|
||||||
JAVACMD="$JAVA_HOME/bin/java"
|
JAVACMD=$JAVA_HOME/bin/java
|
||||||
fi
|
fi
|
||||||
if [ ! -x "$JAVACMD" ] ; then
|
if [ ! -x "$JAVACMD" ] ; then
|
||||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||||
|
@ -98,7 +132,7 @@ Please set the JAVA_HOME variable in your environment to match the
|
||||||
location of your Java installation."
|
location of your Java installation."
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
JAVACMD="java"
|
JAVACMD=java
|
||||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||||
|
|
||||||
Please set the JAVA_HOME variable in your environment to match the
|
Please set the JAVA_HOME variable in your environment to match the
|
||||||
|
@ -106,80 +140,105 @@ location of your Java installation."
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Increase the maximum file descriptors if we can.
|
# Increase the maximum file descriptors if we can.
|
||||||
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
|
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
|
||||||
MAX_FD_LIMIT=`ulimit -H -n`
|
case $MAX_FD in #(
|
||||||
if [ $? -eq 0 ] ; then
|
max*)
|
||||||
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
|
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
|
||||||
MAX_FD="$MAX_FD_LIMIT"
|
# shellcheck disable=SC3045
|
||||||
fi
|
MAX_FD=$( ulimit -H -n ) ||
|
||||||
ulimit -n $MAX_FD
|
warn "Could not query maximum file descriptor limit"
|
||||||
if [ $? -ne 0 ] ; then
|
esac
|
||||||
warn "Could not set maximum file descriptor limit: $MAX_FD"
|
case $MAX_FD in #(
|
||||||
fi
|
'' | soft) :;; #(
|
||||||
else
|
*)
|
||||||
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
|
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
|
||||||
fi
|
# shellcheck disable=SC3045
|
||||||
fi
|
ulimit -n "$MAX_FD" ||
|
||||||
|
warn "Could not set maximum file descriptor limit to $MAX_FD"
|
||||||
# For Darwin, add options to specify how the application appears in the dock
|
|
||||||
if $darwin; then
|
|
||||||
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
|
|
||||||
fi
|
|
||||||
|
|
||||||
# For Cygwin or MSYS, switch paths to Windows format before running java
|
|
||||||
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
|
|
||||||
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
|
|
||||||
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
|
|
||||||
|
|
||||||
JAVACMD=`cygpath --unix "$JAVACMD"`
|
|
||||||
|
|
||||||
# We build the pattern for arguments to be converted via cygpath
|
|
||||||
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
|
|
||||||
SEP=""
|
|
||||||
for dir in $ROOTDIRSRAW ; do
|
|
||||||
ROOTDIRS="$ROOTDIRS$SEP$dir"
|
|
||||||
SEP="|"
|
|
||||||
done
|
|
||||||
OURCYGPATTERN="(^($ROOTDIRS))"
|
|
||||||
# Add a user-defined pattern to the cygpath arguments
|
|
||||||
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
|
|
||||||
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
|
|
||||||
fi
|
|
||||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
|
||||||
i=0
|
|
||||||
for arg in "$@" ; do
|
|
||||||
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
|
|
||||||
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
|
|
||||||
|
|
||||||
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
|
|
||||||
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
|
|
||||||
else
|
|
||||||
eval `echo args$i`="\"$arg\""
|
|
||||||
fi
|
|
||||||
i=`expr $i + 1`
|
|
||||||
done
|
|
||||||
case $i in
|
|
||||||
0) set -- ;;
|
|
||||||
1) set -- "$args0" ;;
|
|
||||||
2) set -- "$args0" "$args1" ;;
|
|
||||||
3) set -- "$args0" "$args1" "$args2" ;;
|
|
||||||
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
|
|
||||||
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
|
|
||||||
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
|
|
||||||
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
|
|
||||||
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
|
|
||||||
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
|
|
||||||
esac
|
esac
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Escape application args
|
# Collect all arguments for the java command, stacking in reverse order:
|
||||||
save () {
|
# * args from the command line
|
||||||
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
|
# * the main class name
|
||||||
echo " "
|
# * -classpath
|
||||||
}
|
# * -D...appname settings
|
||||||
APP_ARGS=`save "$@"`
|
# * --module-path (only if needed)
|
||||||
|
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
|
||||||
|
|
||||||
# Collect all arguments for the java command, following the shell quoting and substitution rules
|
# For Cygwin or MSYS, switch paths to Windows format before running java
|
||||||
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
|
if "$cygwin" || "$msys" ; then
|
||||||
|
APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
|
||||||
|
CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
|
||||||
|
|
||||||
|
JAVACMD=$( cygpath --unix "$JAVACMD" )
|
||||||
|
|
||||||
|
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||||
|
for arg do
|
||||||
|
if
|
||||||
|
case $arg in #(
|
||||||
|
-*) false ;; # don't mess with options #(
|
||||||
|
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
|
||||||
|
[ -e "$t" ] ;; #(
|
||||||
|
*) false ;;
|
||||||
|
esac
|
||||||
|
then
|
||||||
|
arg=$( cygpath --path --ignore --mixed "$arg" )
|
||||||
|
fi
|
||||||
|
# Roll the args list around exactly as many times as the number of
|
||||||
|
# args, so each arg winds up back in the position where it started, but
|
||||||
|
# possibly modified.
|
||||||
|
#
|
||||||
|
# NB: a `for` loop captures its iteration list before it begins, so
|
||||||
|
# changing the positional parameters here affects neither the number of
|
||||||
|
# iterations, nor the values presented in `arg`.
|
||||||
|
shift # remove old arg
|
||||||
|
set -- "$@" "$arg" # push replacement arg
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Collect all arguments for the java command;
|
||||||
|
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
|
||||||
|
# shell script including quotes and variable substitutions, so put them in
|
||||||
|
# double quotes to make sure that they get re-expanded; and
|
||||||
|
# * put everything else in single quotes, so that it's not re-expanded.
|
||||||
|
|
||||||
|
set -- \
|
||||||
|
"-Dorg.gradle.appname=$APP_BASE_NAME" \
|
||||||
|
-classpath "$CLASSPATH" \
|
||||||
|
org.gradle.wrapper.GradleWrapperMain \
|
||||||
|
"$@"
|
||||||
|
|
||||||
|
# Stop when "xargs" is not available.
|
||||||
|
if ! command -v xargs >/dev/null 2>&1
|
||||||
|
then
|
||||||
|
die "xargs is not available"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Use "xargs" to parse quoted args.
|
||||||
|
#
|
||||||
|
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
|
||||||
|
#
|
||||||
|
# In Bash we could simply go:
|
||||||
|
#
|
||||||
|
# readarray ARGS < <( xargs -n1 <<<"$var" ) &&
|
||||||
|
# set -- "${ARGS[@]}" "$@"
|
||||||
|
#
|
||||||
|
# but POSIX shell has neither arrays nor command substitution, so instead we
|
||||||
|
# post-process each arg (as a line of input to sed) to backslash-escape any
|
||||||
|
# character that might be a shell metacharacter, then use eval to reverse
|
||||||
|
# that process (while maintaining the separation between arguments), and wrap
|
||||||
|
# the whole thing up as a single "set" statement.
|
||||||
|
#
|
||||||
|
# This will of course break if any of these variables contains a newline or
|
||||||
|
# an unmatched quote.
|
||||||
|
#
|
||||||
|
|
||||||
|
eval "set -- $(
|
||||||
|
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
|
||||||
|
xargs -n1 |
|
||||||
|
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
|
||||||
|
tr '\n' ' '
|
||||||
|
)" '"$@"'
|
||||||
|
|
||||||
exec "$JAVACMD" "$@"
|
exec "$JAVACMD" "$@"
|
||||||
|
|
15
mediapipe/examples/android/solutions/gradlew.bat
vendored
|
@ -14,7 +14,7 @@
|
||||||
@rem limitations under the License.
|
@rem limitations under the License.
|
||||||
@rem
|
@rem
|
||||||
|
|
||||||
@if "%DEBUG%" == "" @echo off
|
@if "%DEBUG%"=="" @echo off
|
||||||
@rem ##########################################################################
|
@rem ##########################################################################
|
||||||
@rem
|
@rem
|
||||||
@rem Gradle startup script for Windows
|
@rem Gradle startup script for Windows
|
||||||
|
@ -25,7 +25,8 @@
|
||||||
if "%OS%"=="Windows_NT" setlocal
|
if "%OS%"=="Windows_NT" setlocal
|
||||||
|
|
||||||
set DIRNAME=%~dp0
|
set DIRNAME=%~dp0
|
||||||
if "%DIRNAME%" == "" set DIRNAME=.
|
if "%DIRNAME%"=="" set DIRNAME=.
|
||||||
|
@rem This is normally unused
|
||||||
set APP_BASE_NAME=%~n0
|
set APP_BASE_NAME=%~n0
|
||||||
set APP_HOME=%DIRNAME%
|
set APP_HOME=%DIRNAME%
|
||||||
|
|
||||||
|
@ -40,7 +41,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
|
||||||
|
|
||||||
set JAVA_EXE=java.exe
|
set JAVA_EXE=java.exe
|
||||||
%JAVA_EXE% -version >NUL 2>&1
|
%JAVA_EXE% -version >NUL 2>&1
|
||||||
if "%ERRORLEVEL%" == "0" goto execute
|
if %ERRORLEVEL% equ 0 goto execute
|
||||||
|
|
||||||
echo.
|
echo.
|
||||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||||
|
@ -75,13 +76,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
||||||
|
|
||||||
:end
|
:end
|
||||||
@rem End local scope for the variables with windows NT shell
|
@rem End local scope for the variables with windows NT shell
|
||||||
if "%ERRORLEVEL%"=="0" goto mainEnd
|
if %ERRORLEVEL% equ 0 goto mainEnd
|
||||||
|
|
||||||
:fail
|
:fail
|
||||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||||
rem the _cmd.exe /c_ return code!
|
rem the _cmd.exe /c_ return code!
|
||||||
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
|
set EXIT_CODE=%ERRORLEVEL%
|
||||||
exit /b 1
|
if %EXIT_CODE% equ 0 set EXIT_CODE=1
|
||||||
|
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
|
||||||
|
exit /b %EXIT_CODE%
|
||||||
|
|
||||||
:mainEnd
|
:mainEnd
|
||||||
if "%OS%"=="Windows_NT" endlocal
|
if "%OS%"=="Windows_NT" endlocal
|
||||||
|
|
|
@ -138,7 +138,23 @@ void TestWithAspectRatio(const double aspect_ratio,
|
||||||
std::string result_image;
|
std::string result_image;
|
||||||
MP_ASSERT_OK(
|
MP_ASSERT_OK(
|
||||||
mediapipe::file::GetContents(result_string_path, &result_image));
|
mediapipe::file::GetContents(result_string_path, &result_image));
|
||||||
EXPECT_EQ(result_image, output_string);
|
if (result_image != output_string) {
|
||||||
|
// There may be slight differences due to the way the JPEG was encoded or
|
||||||
|
// the OpenCV version used to generate the reference files. Compare
|
||||||
|
// pixel-by-pixel using the Peak Signal-to-Noise Ratio instead.
|
||||||
|
cv::Mat result_mat =
|
||||||
|
cv::imdecode(cv::Mat(1, result_image.size(), CV_8UC1,
|
||||||
|
const_cast<char*>(result_image.data())),
|
||||||
|
cv::IMREAD_UNCHANGED);
|
||||||
|
cv::Mat output_mat =
|
||||||
|
cv::imdecode(cv::Mat(1, output_string.size(), CV_8UC1,
|
||||||
|
const_cast<char*>(output_string.data())),
|
||||||
|
cv::IMREAD_UNCHANGED);
|
||||||
|
ASSERT_EQ(result_mat.rows, output_mat.rows);
|
||||||
|
ASSERT_EQ(result_mat.cols, output_mat.cols);
|
||||||
|
ASSERT_EQ(result_mat.type(), output_mat.type());
|
||||||
|
EXPECT_GT(cv::PSNR(result_mat, output_mat), 45.0);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
std::string output_string_path = mediapipe::file::JoinPath(
|
std::string output_string_path = mediapipe::file::JoinPath(
|
||||||
absl::GetFlag(FLAGS_output_folder),
|
absl::GetFlag(FLAGS_output_folder),
|
||||||
|
|
Before Width: | Height: | Size: 3.2 KiB After Width: | Height: | Size: 3.2 KiB |
Before Width: | Height: | Size: 6.1 KiB After Width: | Height: | Size: 6.1 KiB |
Before Width: | Height: | Size: 8.2 KiB After Width: | Height: | Size: 8.2 KiB |
Before Width: | Height: | Size: 7.6 KiB After Width: | Height: | Size: 7.6 KiB |
|
@ -56,5 +56,6 @@ objc_library(
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
|
||||||
"//mediapipe/graphs/edge_detection:mobile_calculators",
|
"//mediapipe/graphs/edge_detection:mobile_calculators",
|
||||||
|
"//third_party/apple_frameworks:Metal",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -631,7 +631,13 @@ absl::Status CalculatorGraph::PrepareServices() {
|
||||||
for (const auto& [key, request] : node->Contract().ServiceRequests()) {
|
for (const auto& [key, request] : node->Contract().ServiceRequests()) {
|
||||||
auto packet = service_manager_.GetServicePacket(request.Service());
|
auto packet = service_manager_.GetServicePacket(request.Service());
|
||||||
if (!packet.IsEmpty()) continue;
|
if (!packet.IsEmpty()) continue;
|
||||||
auto packet_or = request.Service().CreateDefaultObject();
|
absl::StatusOr<Packet> packet_or;
|
||||||
|
if (allow_service_default_initialization_) {
|
||||||
|
packet_or = request.Service().CreateDefaultObject();
|
||||||
|
} else {
|
||||||
|
packet_or = absl::FailedPreconditionError(
|
||||||
|
"Service default initialization is disallowed.");
|
||||||
|
}
|
||||||
if (packet_or.ok()) {
|
if (packet_or.ok()) {
|
||||||
MP_RETURN_IF_ERROR(service_manager_.SetServicePacket(
|
MP_RETURN_IF_ERROR(service_manager_.SetServicePacket(
|
||||||
request.Service(), std::move(packet_or).value()));
|
request.Service(), std::move(packet_or).value()));
|
||||||
|
|
|
@ -405,6 +405,34 @@ class CalculatorGraph {
|
||||||
return service_manager_.GetServiceObject(service);
|
return service_manager_.GetServiceObject(service);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Disallows/disables default initialization of MediaPipe graph services.
|
||||||
|
//
|
||||||
|
// IMPORTANT: MediaPipe graph serices, essentially a graph-level singletons,
|
||||||
|
// are designed in the way, so they may provide default initialization. For
|
||||||
|
// example, this allows to run OpenGL processing wihtin the graph without
|
||||||
|
// provinging a praticular OpenGL context as it can be provided by
|
||||||
|
// default-initializable `kGpuService`. (One caveat here, you may still need
|
||||||
|
// to initialize it manually to share graph context with external context.)
|
||||||
|
//
|
||||||
|
// Even if calculators require some service optionally
|
||||||
|
// (`calculator_contract->UseService(kSomeService).Optional()`), it will be
|
||||||
|
// still initialized if it allows default initialization.
|
||||||
|
//
|
||||||
|
// So far, in rare cases, this may be unwanted and strict control of what
|
||||||
|
// services are allowed in the graph can be achieved by calling this method,
|
||||||
|
// following `SetServiceObject` call for services which are allowed in the
|
||||||
|
// graph.
|
||||||
|
//
|
||||||
|
// Recommendation: do not use unless you have to (for example, default
|
||||||
|
// initialization has side effects)
|
||||||
|
//
|
||||||
|
// NOTE: must be called before `StartRun`/`Run`, where services are checked
|
||||||
|
// and can be default-initialized.
|
||||||
|
absl::Status DisallowServiceDefaultInitialization() {
|
||||||
|
allow_service_default_initialization_ = false;
|
||||||
|
return absl::OkStatus();
|
||||||
|
}
|
||||||
|
|
||||||
// Sets a service object, essentially a graph-level singleton, which can be
|
// Sets a service object, essentially a graph-level singleton, which can be
|
||||||
// accessed by calculators and subgraphs without requiring an explicit
|
// accessed by calculators and subgraphs without requiring an explicit
|
||||||
// connection.
|
// connection.
|
||||||
|
@ -644,6 +672,9 @@ class CalculatorGraph {
|
||||||
// Object to manage graph services.
|
// Object to manage graph services.
|
||||||
GraphServiceManager service_manager_;
|
GraphServiceManager service_manager_;
|
||||||
|
|
||||||
|
// Indicates whether service default initialization is allowed.
|
||||||
|
bool allow_service_default_initialization_ = true;
|
||||||
|
|
||||||
// Vector of errors encountered while running graph. Always use RecordError()
|
// Vector of errors encountered while running graph. Always use RecordError()
|
||||||
// to add an error to this vector.
|
// to add an error to this vector.
|
||||||
std::vector<absl::Status> errors_ ABSL_GUARDED_BY(error_mutex_);
|
std::vector<absl::Status> errors_ ABSL_GUARDED_BY(error_mutex_);
|
||||||
|
|
|
@ -136,6 +136,8 @@ message GraphTrace {
|
||||||
GPU_TASK_INVOKE = 16;
|
GPU_TASK_INVOKE = 16;
|
||||||
TPU_TASK_INVOKE = 17;
|
TPU_TASK_INVOKE = 17;
|
||||||
CPU_TASK_INVOKE = 18;
|
CPU_TASK_INVOKE = 18;
|
||||||
|
GPU_TASK_INVOKE_ADVANCED = 19;
|
||||||
|
TPU_TASK_INVOKE_ASYNC = 20;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The timing for one packet set being processed at one caclulator node.
|
// The timing for one packet set being processed at one caclulator node.
|
||||||
|
|
|
@ -315,11 +315,11 @@ cc_library(
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
deps = [
|
deps = [
|
||||||
"@com_google_absl//absl/flags:flag",
|
"@com_google_absl//absl/flags:flag",
|
||||||
|
"//mediapipe/framework/deps:file_path",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:statusor",
|
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"//mediapipe/framework/deps:file_path",
|
"//mediapipe/framework/port:statusor",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"//mediapipe/framework/port:file_helpers",
|
"//mediapipe/framework/port:file_helpers",
|
||||||
|
@ -328,8 +328,8 @@ cc_library(
|
||||||
"//mediapipe/framework/port:file_helpers",
|
"//mediapipe/framework/port:file_helpers",
|
||||||
],
|
],
|
||||||
"//mediapipe:android": [
|
"//mediapipe:android": [
|
||||||
"//mediapipe/java/com/google/mediapipe/framework/jni:jni_util",
|
|
||||||
"//mediapipe/framework/port:file_helpers",
|
"//mediapipe/framework/port:file_helpers",
|
||||||
|
"//mediapipe/java/com/google/mediapipe/framework/jni:jni_util",
|
||||||
],
|
],
|
||||||
"//mediapipe:apple": [
|
"//mediapipe:apple": [
|
||||||
"//mediapipe/framework/port:file_helpers",
|
"//mediapipe/framework/port:file_helpers",
|
||||||
|
|
|
@ -112,6 +112,10 @@ struct TraceEvent {
|
||||||
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
|
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
|
||||||
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
|
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
|
||||||
static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE;
|
static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE;
|
||||||
|
static constexpr EventType GPU_TASK_INVOKE_ADVANCED =
|
||||||
|
GraphTrace::GPU_TASK_INVOKE_ADVANCED;
|
||||||
|
static constexpr EventType TPU_TASK_INVOKE_ASYNC =
|
||||||
|
GraphTrace::TPU_TASK_INVOKE_ASYNC;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Packet trace log buffer.
|
// Packet trace log buffer.
|
||||||
|
|
|
@ -57,7 +57,6 @@ struct hash<mediapipe::TaskId> {
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void BasicTraceEventTypes(TraceEventRegistry* result) {
|
void BasicTraceEventTypes(TraceEventRegistry* result) {
|
||||||
// The initializer arguments below are: event_type, description,
|
// The initializer arguments below are: event_type, description,
|
||||||
// is_packet_event, is_stream_event, id_event_data.
|
// is_packet_event, is_stream_event, id_event_data.
|
||||||
|
@ -84,6 +83,15 @@ void BasicTraceEventTypes(TraceEventRegistry* result) {
|
||||||
"A time measured by GPU clock and by CPU clock.", true, false},
|
"A time measured by GPU clock and by CPU clock.", true, false},
|
||||||
{TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.",
|
{TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.",
|
||||||
true, true, false},
|
true, true, false},
|
||||||
|
|
||||||
|
{TraceEvent::GPU_TASK_INVOKE, "CPU timing for initiating a GPU task."},
|
||||||
|
{TraceEvent::TPU_TASK_INVOKE, "CPU timing for initiating a TPU task."},
|
||||||
|
{TraceEvent::CPU_TASK_INVOKE, "CPU timing for initiating a CPU task."},
|
||||||
|
{TraceEvent::GPU_TASK_INVOKE_ADVANCED,
|
||||||
|
"CPU timing for initiating a GPU task bypassing the TFLite "
|
||||||
|
"interpreter."},
|
||||||
|
{TraceEvent::TPU_TASK_INVOKE_ASYNC,
|
||||||
|
"CPU timing for async initiation of a TPU task."},
|
||||||
};
|
};
|
||||||
for (const TraceEventType& t : basic_types) {
|
for (const TraceEventType& t : basic_types) {
|
||||||
(*result)[t.event_type()] = t;
|
(*result)[t.event_type()] = t;
|
||||||
|
|
|
@ -77,7 +77,6 @@ mediapipe_proto_library(
|
||||||
name = "calculator_graph_template_proto",
|
name = "calculator_graph_template_proto",
|
||||||
srcs = ["calculator_graph_template.proto"],
|
srcs = ["calculator_graph_template.proto"],
|
||||||
def_options_lib = False,
|
def_options_lib = False,
|
||||||
def_py_proto = False,
|
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework:calculator_options_proto",
|
"//mediapipe/framework:calculator_options_proto",
|
||||||
|
|
|
@ -204,7 +204,7 @@ def rewrite_mediapipe_proto(name, rewrite_proto, source_proto, **kwargs):
|
||||||
'import public "' + join_path + '";',
|
'import public "' + join_path + '";',
|
||||||
)
|
)
|
||||||
rewrite_ref = SubsituteCommand(
|
rewrite_ref = SubsituteCommand(
|
||||||
r"mediapipe\\.(" + rewrite_message_regex + ")",
|
r"mediapipe\.(" + rewrite_message_regex + ")",
|
||||||
r"mediapipe.\\1",
|
r"mediapipe.\\1",
|
||||||
)
|
)
|
||||||
rewrite_objc = SubsituteCommand(
|
rewrite_objc = SubsituteCommand(
|
||||||
|
@ -284,7 +284,7 @@ def mediapipe_proto_library(
|
||||||
def_jspb_proto: define the jspb_proto_library target
|
def_jspb_proto: define the jspb_proto_library target
|
||||||
def_go_proto: define the go_proto_library target
|
def_go_proto: define the go_proto_library target
|
||||||
def_options_lib: define the mediapipe_options_library target
|
def_options_lib: define the mediapipe_options_library target
|
||||||
def_rewrite: define a sibbling mediapipe_proto_library with package "mediapipe"
|
def_rewrite: define a sibling mediapipe_proto_library with package "mediapipe"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mediapipe_proto_library_impl(
|
mediapipe_proto_library_impl(
|
||||||
|
|
|
@ -183,12 +183,13 @@ absl::Status FindCorrespondingStreams(
|
||||||
// name, calculator, input_stream, output_stream, input_side_packet,
|
// name, calculator, input_stream, output_stream, input_side_packet,
|
||||||
// output_side_packet, options.
|
// output_side_packet, options.
|
||||||
// All other fields are only applicable to calculators.
|
// All other fields are only applicable to calculators.
|
||||||
|
// TODO: Check whether executor is not set in the subgraph node
|
||||||
|
// after this issues is properly solved.
|
||||||
absl::Status ValidateSubgraphFields(
|
absl::Status ValidateSubgraphFields(
|
||||||
const CalculatorGraphConfig::Node& subgraph_node) {
|
const CalculatorGraphConfig::Node& subgraph_node) {
|
||||||
if (subgraph_node.source_layer() || subgraph_node.buffer_size_hint() ||
|
if (subgraph_node.source_layer() || subgraph_node.buffer_size_hint() ||
|
||||||
subgraph_node.has_output_stream_handler() ||
|
subgraph_node.has_output_stream_handler() ||
|
||||||
subgraph_node.input_stream_info_size() != 0 ||
|
subgraph_node.input_stream_info_size() != 0) {
|
||||||
!subgraph_node.executor().empty()) {
|
|
||||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||||
<< "Subgraph \"" << subgraph_node.name()
|
<< "Subgraph \"" << subgraph_node.name()
|
||||||
<< "\" has a field that is only applicable to calculators.";
|
<< "\" has a field that is only applicable to calculators.";
|
||||||
|
|
|
@ -272,14 +272,6 @@ selects.config_setting_group(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
selects.config_setting_group(
|
|
||||||
name = "platform_ios_without_gpu",
|
|
||||||
match_all = [
|
|
||||||
":disable_gpu",
|
|
||||||
"//mediapipe:ios",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
selects.config_setting_group(
|
selects.config_setting_group(
|
||||||
name = "platform_macos_with_gpu",
|
name = "platform_macos_with_gpu",
|
||||||
match_all = [
|
match_all = [
|
||||||
|
@ -296,32 +288,33 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":gpu_buffer_format",
|
":gpu_buffer_format",
|
||||||
":gpu_buffer_storage",
|
":gpu_buffer_storage",
|
||||||
|
":gpu_buffer_storage_image_frame",
|
||||||
"@com_google_absl//absl/functional:bind_front",
|
"@com_google_absl//absl/functional:bind_front",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
":gpu_buffer_storage_image_frame",
|
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
":gl_texture_view",
|
|
||||||
":gl_texture_buffer",
|
":gl_texture_buffer",
|
||||||
|
":gl_texture_view",
|
||||||
],
|
],
|
||||||
":platform_ios_with_gpu": [
|
":platform_ios_with_gpu": [
|
||||||
":gl_texture_view",
|
":gl_texture_view",
|
||||||
":gpu_buffer_storage_cv_pixel_buffer",
|
":gpu_buffer_storage_cv_pixel_buffer",
|
||||||
"//mediapipe/objc:util",
|
|
||||||
"//mediapipe/objc:CFHolder",
|
"//mediapipe/objc:CFHolder",
|
||||||
],
|
],
|
||||||
":platform_macos_with_gpu": [
|
":platform_macos_with_gpu": [
|
||||||
"//mediapipe/objc:CFHolder",
|
|
||||||
":gl_texture_view",
|
|
||||||
":gl_texture_buffer",
|
":gl_texture_buffer",
|
||||||
],
|
":gl_texture_view",
|
||||||
":platform_ios_without_gpu": [
|
"//mediapipe/objc:CFHolder",
|
||||||
"//mediapipe/objc:util",
|
|
||||||
],
|
],
|
||||||
":disable_gpu": [],
|
":disable_gpu": [],
|
||||||
|
}) + select({
|
||||||
|
"//conditions:default": [],
|
||||||
|
"//mediapipe:ios": [
|
||||||
|
"//mediapipe/objc:util",
|
||||||
|
],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -331,9 +324,9 @@ cc_library(
|
||||||
hdrs = ["gpu_buffer_format.h"],
|
hdrs = ["gpu_buffer_format.h"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"//mediapipe/framework/deps:no_destructor",
|
"//mediapipe/framework/deps:no_destructor",
|
||||||
|
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
|
@ -474,6 +467,7 @@ cc_library(
|
||||||
"//mediapipe/framework/formats:frame_buffer",
|
"//mediapipe/framework/formats:frame_buffer",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/formats:yuv_image",
|
"//mediapipe/framework/formats:yuv_image",
|
||||||
|
"//mediapipe/util/frame_buffer:frame_buffer_util",
|
||||||
"//third_party/libyuv",
|
"//third_party/libyuv",
|
||||||
"@com_google_absl//absl/log",
|
"@com_google_absl//absl/log",
|
||||||
"@com_google_absl//absl/log:check",
|
"@com_google_absl//absl/log:check",
|
||||||
|
@ -619,22 +613,22 @@ cc_library(
|
||||||
}),
|
}),
|
||||||
visibility = ["//visibility:private"],
|
visibility = ["//visibility:private"],
|
||||||
deps = [
|
deps = [
|
||||||
":gl_context_options_cc_proto",
|
|
||||||
":graph_support",
|
|
||||||
"//mediapipe/framework:calculator_context",
|
|
||||||
"//mediapipe/framework:executor",
|
|
||||||
"//mediapipe/framework:calculator_node",
|
|
||||||
"//mediapipe/framework/port:ret_check",
|
|
||||||
"//mediapipe/framework/deps:no_destructor",
|
|
||||||
":gl_base",
|
":gl_base",
|
||||||
":gl_context",
|
":gl_context",
|
||||||
|
":gl_context_options_cc_proto",
|
||||||
":gpu_buffer_multi_pool",
|
":gpu_buffer_multi_pool",
|
||||||
":gpu_shared_data_header",
|
":gpu_shared_data_header",
|
||||||
|
":graph_support",
|
||||||
|
"//mediapipe/framework:calculator_context",
|
||||||
|
"//mediapipe/framework:calculator_node",
|
||||||
|
"//mediapipe/framework:executor",
|
||||||
|
"//mediapipe/framework/deps:no_destructor",
|
||||||
|
"//mediapipe/framework/port:ret_check",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
"//mediapipe:apple": [
|
"//mediapipe:apple": [
|
||||||
":metal_shared_resources",
|
|
||||||
":cv_texture_cache_manager",
|
":cv_texture_cache_manager",
|
||||||
|
":metal_shared_resources",
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -703,13 +697,13 @@ cc_library(
|
||||||
":gpu_buffer",
|
":gpu_buffer",
|
||||||
":gpu_shared_data_header",
|
":gpu_shared_data_header",
|
||||||
":multi_pool",
|
":multi_pool",
|
||||||
|
"@com_google_absl//absl/hash",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"//mediapipe/framework:calculator_context",
|
"//mediapipe/framework:calculator_context",
|
||||||
"//mediapipe/framework:calculator_node",
|
"//mediapipe/framework:calculator_node",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
"//mediapipe/util:resource_cache",
|
"//mediapipe/util:resource_cache",
|
||||||
"@com_google_absl//absl/hash",
|
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
"@com_google_absl//absl/synchronization",
|
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
":gl_texture_buffer",
|
":gl_texture_buffer",
|
||||||
|
@ -725,9 +719,9 @@ cc_library(
|
||||||
"//mediapipe:macos": [
|
"//mediapipe:macos": [
|
||||||
":cv_pixel_buffer_pool_wrapper",
|
":cv_pixel_buffer_pool_wrapper",
|
||||||
":cv_texture_cache_manager",
|
":cv_texture_cache_manager",
|
||||||
":pixel_buffer_pool_util",
|
|
||||||
":gl_texture_buffer",
|
":gl_texture_buffer",
|
||||||
":gl_texture_buffer_pool",
|
":gl_texture_buffer_pool",
|
||||||
|
":pixel_buffer_pool_util",
|
||||||
],
|
],
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -795,31 +789,31 @@ cc_library(
|
||||||
":gpu_buffer",
|
":gpu_buffer",
|
||||||
":gpu_buffer_format",
|
":gpu_buffer_format",
|
||||||
":gpu_buffer_multi_pool",
|
":gpu_buffer_multi_pool",
|
||||||
":gpu_shared_data_internal",
|
|
||||||
":gpu_service",
|
":gpu_service",
|
||||||
|
":gpu_shared_data_internal",
|
||||||
":graph_support",
|
":graph_support",
|
||||||
":image_frame_view",
|
":image_frame_view",
|
||||||
":shader_util",
|
":shader_util",
|
||||||
"//mediapipe/framework:calculator_framework",
|
|
||||||
"@com_google_absl//absl/base:core_headers",
|
"@com_google_absl//absl/base:core_headers",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/synchronization",
|
||||||
"//mediapipe/framework:calculator_context",
|
"//mediapipe/framework:calculator_context",
|
||||||
"//mediapipe/framework:calculator_node",
|
|
||||||
"//mediapipe/framework:calculator_contract",
|
"//mediapipe/framework:calculator_contract",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_node",
|
||||||
"//mediapipe/framework:demangle",
|
"//mediapipe/framework:demangle",
|
||||||
"//mediapipe/framework:legacy_calculator_support",
|
"//mediapipe/framework:legacy_calculator_support",
|
||||||
"//mediapipe/framework:packet",
|
"//mediapipe/framework:packet",
|
||||||
"//mediapipe/framework:packet_set",
|
"//mediapipe/framework:packet_set",
|
||||||
"//mediapipe/framework:packet_type",
|
"//mediapipe/framework:packet_type",
|
||||||
"//mediapipe/framework:timestamp",
|
"//mediapipe/framework:timestamp",
|
||||||
|
"//mediapipe/framework/deps:registration",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
|
"//mediapipe/framework/port:map_util",
|
||||||
"//mediapipe/framework/port:ret_check",
|
"//mediapipe/framework/port:ret_check",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
"@com_google_absl//absl/memory",
|
|
||||||
"@com_google_absl//absl/synchronization",
|
|
||||||
"//mediapipe/framework/deps:registration",
|
|
||||||
"//mediapipe/framework/port:map_util",
|
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
],
|
],
|
||||||
|
@ -918,8 +912,6 @@ cc_library(
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":gl_calculator_helper",
|
":gl_calculator_helper",
|
||||||
":gpu_buffer_storage_image_frame",
|
|
||||||
"//mediapipe/framework/api2:node",
|
|
||||||
"//mediapipe/framework:calculator_framework",
|
"//mediapipe/framework:calculator_framework",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
|
@ -941,7 +933,7 @@ mediapipe_proto_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
proto_library(
|
mediapipe_proto_library(
|
||||||
name = "gl_scaler_calculator_proto",
|
name = "gl_scaler_calculator_proto",
|
||||||
srcs = ["gl_scaler_calculator.proto"],
|
srcs = ["gl_scaler_calculator.proto"],
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
|
@ -951,17 +943,6 @@ proto_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
mediapipe_cc_proto_library(
|
|
||||||
name = "gl_scaler_calculator_cc_proto",
|
|
||||||
srcs = ["gl_scaler_calculator.proto"],
|
|
||||||
cc_deps = [
|
|
||||||
":scale_mode_cc_proto",
|
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
|
||||||
],
|
|
||||||
visibility = ["//visibility:public"],
|
|
||||||
deps = [":gl_scaler_calculator_proto"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gl_scaler_calculator",
|
name = "gl_scaler_calculator",
|
||||||
srcs = ["gl_scaler_calculator.cc"],
|
srcs = ["gl_scaler_calculator.cc"],
|
||||||
|
|
|
@ -12,63 +12,73 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "mediapipe/framework/api2/node.h"
|
|
||||||
#include "mediapipe/framework/calculator_framework.h"
|
#include "mediapipe/framework/calculator_framework.h"
|
||||||
#include "mediapipe/framework/formats/image_frame.h"
|
#include "mediapipe/framework/formats/image_frame.h"
|
||||||
#include "mediapipe/framework/port/status.h"
|
#include "mediapipe/framework/port/status.h"
|
||||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||||
|
|
||||||
|
#ifdef __APPLE__
|
||||||
|
#include "mediapipe/objc/util.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace mediapipe {
|
namespace mediapipe {
|
||||||
namespace api2 {
|
|
||||||
|
|
||||||
class ImageFrameToGpuBufferCalculator
|
// Convert ImageFrame to GpuBuffer.
|
||||||
: public RegisteredNode<ImageFrameToGpuBufferCalculator> {
|
class ImageFrameToGpuBufferCalculator : public CalculatorBase {
|
||||||
public:
|
public:
|
||||||
static constexpr Input<ImageFrame> kIn{""};
|
ImageFrameToGpuBufferCalculator() {}
|
||||||
static constexpr Output<GpuBuffer> kOut{""};
|
|
||||||
|
|
||||||
MEDIAPIPE_NODE_INTERFACE(ImageFrameToGpuBufferCalculator, kIn, kOut);
|
static absl::Status GetContract(CalculatorContract* cc);
|
||||||
|
|
||||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
|
||||||
|
|
||||||
absl::Status Open(CalculatorContext* cc) override;
|
absl::Status Open(CalculatorContext* cc) override;
|
||||||
absl::Status Process(CalculatorContext* cc) override;
|
absl::Status Process(CalculatorContext* cc) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
GlCalculatorHelper helper_;
|
GlCalculatorHelper helper_;
|
||||||
|
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
};
|
};
|
||||||
|
REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator);
|
||||||
|
|
||||||
// static
|
// static
|
||||||
absl::Status ImageFrameToGpuBufferCalculator::UpdateContract(
|
absl::Status ImageFrameToGpuBufferCalculator::GetContract(
|
||||||
CalculatorContract* cc) {
|
CalculatorContract* cc) {
|
||||||
|
cc->Inputs().Index(0).Set<ImageFrame>();
|
||||||
|
cc->Outputs().Index(0).Set<GpuBuffer>();
|
||||||
// Note: we call this method even on platforms where we don't use the helper,
|
// Note: we call this method even on platforms where we don't use the helper,
|
||||||
// to ensure the calculator's contract is the same. In particular, the helper
|
// to ensure the calculator's contract is the same. In particular, the helper
|
||||||
// enables support for the legacy side packet, which several graphs still use.
|
// enables support for the legacy side packet, which several graphs still use.
|
||||||
return GlCalculatorHelper::UpdateContract(cc);
|
MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
|
||||||
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) {
|
absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) {
|
||||||
|
// Inform the framework that we always output at the same timestamp
|
||||||
|
// as we receive a packet at.
|
||||||
|
cc->SetOffset(TimestampDiff(0));
|
||||||
|
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
MP_RETURN_IF_ERROR(helper_.Open(cc));
|
MP_RETURN_IF_ERROR(helper_.Open(cc));
|
||||||
|
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) {
|
absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) {
|
||||||
auto image_frame = std::const_pointer_cast<ImageFrame>(
|
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
mediapipe::SharedPtrWithPacket<ImageFrame>(kIn(cc).packet()));
|
CFHolder<CVPixelBufferRef> buffer;
|
||||||
auto gpu_buffer = api2::MakePacket<GpuBuffer>(
|
MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket(
|
||||||
std::make_shared<mediapipe::GpuBufferStorageImageFrame>(
|
cc->Inputs().Index(0).Value(), &buffer));
|
||||||
std::move(image_frame)))
|
cc->Outputs().Index(0).Add(new GpuBuffer(buffer), cc->InputTimestamp());
|
||||||
.At(cc->InputTimestamp());
|
#else
|
||||||
// This calculator's behavior has been to do the texture upload eagerly, and
|
const auto& input = cc->Inputs().Index(0).Get<ImageFrame>();
|
||||||
// some graphs may rely on running this on a separate GL context to avoid
|
helper_.RunInGlContext([this, &input, &cc]() {
|
||||||
// blocking another context with the read operation. So let's request GPU
|
auto src = helper_.CreateSourceTexture(input);
|
||||||
// access here to ensure that the behavior stays the same.
|
auto output = src.GetFrame<GpuBuffer>();
|
||||||
// TODO: have a better way to do this, or defer until later.
|
glFlush();
|
||||||
helper_.RunInGlContext(
|
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
|
||||||
[&gpu_buffer] { auto view = gpu_buffer->GetReadView<GlTextureView>(0); });
|
src.Release();
|
||||||
kOut(cc).Send(std::move(gpu_buffer));
|
});
|
||||||
|
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace api2
|
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -66,7 +66,8 @@ public class GraphTextureFrame implements TextureFrame {
|
||||||
if (nativeBufferHandle == 0) {
|
if (nativeBufferHandle == 0) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) {
|
long contextHandle = nativeGetCurrentExternalContextHandle();
|
||||||
|
if (contextHandle != 0 && activeConsumerContextHandleSet.add(contextHandle)) {
|
||||||
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
|
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
|
||||||
// PacketGetter.getTextureFrameDeferredSync().
|
// PacketGetter.getTextureFrameDeferredSync().
|
||||||
if (deferredSync) {
|
if (deferredSync) {
|
||||||
|
@ -116,7 +117,14 @@ public class GraphTextureFrame implements TextureFrame {
|
||||||
GlSyncToken consumerToken = null;
|
GlSyncToken consumerToken = null;
|
||||||
// Note that this remove should be moved to the other overload of release when b/68808951 is
|
// Note that this remove should be moved to the other overload of release when b/68808951 is
|
||||||
// addressed.
|
// addressed.
|
||||||
if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) {
|
final long contextHandle = nativeGetCurrentExternalContextHandle();
|
||||||
|
if (contextHandle == 0 && !activeConsumerContextHandleSet.isEmpty()) {
|
||||||
|
logger.atWarning().log(
|
||||||
|
"GraphTextureFrame is being released on non GL thread while having active consumers,"
|
||||||
|
+ " which may lead to external / internal GL contexts synchronization issues.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (contextHandle != 0 && activeConsumerContextHandleSet.remove(contextHandle)) {
|
||||||
consumerToken =
|
consumerToken =
|
||||||
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
|
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
|
||||||
}
|
}
|
||||||
|
@ -169,7 +177,9 @@ public class GraphTextureFrame implements TextureFrame {
|
||||||
private native void nativeReleaseBuffer(long nativeHandle);
|
private native void nativeReleaseBuffer(long nativeHandle);
|
||||||
|
|
||||||
private native int nativeGetTextureName(long nativeHandle);
|
private native int nativeGetTextureName(long nativeHandle);
|
||||||
|
|
||||||
private native int nativeGetWidth(long nativeHandle);
|
private native int nativeGetWidth(long nativeHandle);
|
||||||
|
|
||||||
private native int nativeGetHeight(long nativeHandle);
|
private native int nativeGetHeight(long nativeHandle);
|
||||||
|
|
||||||
private native void nativeGpuWait(long nativeHandle);
|
private native void nativeGpuWait(long nativeHandle);
|
||||||
|
|
|
@ -30,11 +30,11 @@ cc_library(
|
||||||
"compat_jni.cc",
|
"compat_jni.cc",
|
||||||
"graph.cc",
|
"graph.cc",
|
||||||
"graph_jni.cc",
|
"graph_jni.cc",
|
||||||
|
"graph_profiler_jni.cc",
|
||||||
"graph_service_jni.cc",
|
"graph_service_jni.cc",
|
||||||
"packet_context_jni.cc",
|
"packet_context_jni.cc",
|
||||||
"packet_creator_jni.cc",
|
"packet_creator_jni.cc",
|
||||||
"packet_getter_jni.cc",
|
"packet_getter_jni.cc",
|
||||||
"graph_profiler_jni.cc",
|
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
"//mediapipe:android": [
|
"//mediapipe:android": [
|
||||||
|
@ -54,11 +54,11 @@ cc_library(
|
||||||
"compat_jni.h",
|
"compat_jni.h",
|
||||||
"graph.h",
|
"graph.h",
|
||||||
"graph_jni.h",
|
"graph_jni.h",
|
||||||
|
"graph_profiler_jni.h",
|
||||||
"graph_service_jni.h",
|
"graph_service_jni.h",
|
||||||
"packet_context_jni.h",
|
"packet_context_jni.h",
|
||||||
"packet_creator_jni.h",
|
"packet_creator_jni.h",
|
||||||
"packet_getter_jni.h",
|
"packet_getter_jni.h",
|
||||||
"graph_profiler_jni.h",
|
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
"//mediapipe:android": [
|
"//mediapipe:android": [
|
||||||
|
@ -84,40 +84,40 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":class_registry",
|
":class_registry",
|
||||||
":jni_util",
|
":jni_util",
|
||||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
|
||||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
|
||||||
"//mediapipe/framework:calculator_cc_proto",
|
|
||||||
"//mediapipe/framework:calculator_profile_cc_proto",
|
|
||||||
"//mediapipe/framework:calculator_framework",
|
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"@eigen_archive//:eigen3",
|
"@eigen_archive//:eigen3",
|
||||||
|
"//mediapipe/framework:calculator_cc_proto",
|
||||||
|
"//mediapipe/framework:calculator_framework",
|
||||||
|
"//mediapipe/framework:calculator_profile_cc_proto",
|
||||||
"//mediapipe/framework:camera_intrinsics",
|
"//mediapipe/framework:camera_intrinsics",
|
||||||
"//mediapipe/framework/formats:image",
|
"//mediapipe/framework/formats:image",
|
||||||
|
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||||
"//mediapipe/framework/formats:image_frame",
|
"//mediapipe/framework/formats:image_frame",
|
||||||
"//mediapipe/framework/formats:matrix",
|
"//mediapipe/framework/formats:matrix",
|
||||||
|
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||||
"//mediapipe/framework/formats:video_stream_header",
|
"//mediapipe/framework/formats:video_stream_header",
|
||||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
|
||||||
"//mediapipe/framework/tool:name_util",
|
|
||||||
"//mediapipe/framework/tool:executor_util",
|
|
||||||
"//mediapipe/framework/port:core_proto",
|
"//mediapipe/framework/port:core_proto",
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
"//mediapipe/framework/port:threadpool",
|
|
||||||
"//mediapipe/framework/port:singleton",
|
"//mediapipe/framework/port:singleton",
|
||||||
"//mediapipe/framework/port:status",
|
"//mediapipe/framework/port:status",
|
||||||
|
"//mediapipe/framework/port:threadpool",
|
||||||
|
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||||
|
"//mediapipe/framework/tool:executor_util",
|
||||||
|
"//mediapipe/framework/tool:name_util",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"//mediapipe/framework/port:file_helpers",
|
"//mediapipe/framework/port:file_helpers",
|
||||||
],
|
],
|
||||||
"//mediapipe:android": [
|
"//mediapipe:android": [
|
||||||
"//mediapipe/util/android/file/base",
|
|
||||||
"//mediapipe/util/android:asset_manager_util",
|
"//mediapipe/util/android:asset_manager_util",
|
||||||
|
"//mediapipe/util/android/file/base",
|
||||||
],
|
],
|
||||||
}) + select({
|
}) + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
"//mediapipe/gpu:gl_quad_renderer",
|
|
||||||
"//mediapipe/gpu:gl_calculator_helper",
|
"//mediapipe/gpu:gl_calculator_helper",
|
||||||
|
"//mediapipe/gpu:gl_quad_renderer",
|
||||||
"//mediapipe/gpu:gl_surface_sink_calculator",
|
"//mediapipe/gpu:gl_surface_sink_calculator",
|
||||||
"//mediapipe/gpu:gpu_buffer",
|
"//mediapipe/gpu:gpu_buffer",
|
||||||
"//mediapipe/gpu:gpu_shared_data_internal",
|
"//mediapipe/gpu:gpu_shared_data_internal",
|
||||||
|
@ -153,9 +153,9 @@ cc_library(
|
||||||
srcs = ["class_registry.cc"],
|
srcs = ["class_registry.cc"],
|
||||||
hdrs = ["class_registry.h"],
|
hdrs = ["class_registry.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
"@com_google_absl//absl/container:node_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@com_google_absl//absl/container:node_hash_map",
|
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
],
|
],
|
||||||
|
@ -172,9 +172,9 @@ cc_library(
|
||||||
":class_registry",
|
":class_registry",
|
||||||
":loose_headers",
|
":loose_headers",
|
||||||
":mediapipe_framework_jni",
|
":mediapipe_framework_jni",
|
||||||
|
"@com_google_absl//absl/container:node_hash_map",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@com_google_absl//absl/container:node_hash_map",
|
|
||||||
"//mediapipe/framework/port:logging",
|
"//mediapipe/framework/port:logging",
|
||||||
] + select({
|
] + select({
|
||||||
"//conditions:default": [
|
"//conditions:default": [
|
||||||
|
|
|
@ -357,6 +357,22 @@ def mediapipe_java_proto_srcs(name = ""):
|
||||||
target = "//mediapipe/framework/formats:rect_java_proto_lite",
|
target = "//mediapipe/framework/formats:rect_java_proto_lite",
|
||||||
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
|
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
|
||||||
))
|
))
|
||||||
|
|
||||||
|
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||||
|
target = "//mediapipe/util:color_java_proto_lite",
|
||||||
|
src_out = "com/google/mediapipe/util/proto/ColorProto.java",
|
||||||
|
))
|
||||||
|
|
||||||
|
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||||
|
target = "//mediapipe/util:label_map_java_proto_lite",
|
||||||
|
src_out = "com/google/mediapipe/util/proto/LabelMapProto.java",
|
||||||
|
))
|
||||||
|
|
||||||
|
proto_src_list.append(mediapipe_java_proto_src_extractor(
|
||||||
|
target = "//mediapipe/util:render_data_java_proto_lite",
|
||||||
|
src_out = "com/google/mediapipe/util/proto/RenderDataProto.java",
|
||||||
|
))
|
||||||
|
|
||||||
return proto_src_list
|
return proto_src_list
|
||||||
|
|
||||||
def mediapipe_logging_java_proto_srcs(name = ""):
|
def mediapipe_logging_java_proto_srcs(name = ""):
|
||||||
|
|
24
mediapipe/model_maker/models/face_stylizer/BUILD
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/model_maker/python/vision/face_stylizer:__subpackages__"])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "models",
|
||||||
|
srcs = glob([
|
||||||
|
"**",
|
||||||
|
]),
|
||||||
|
)
|
|
@ -18,7 +18,7 @@ from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Callable, Optional, Tuple, TypeVar
|
from typing import Any, Callable, Optional, Tuple, TypeVar
|
||||||
|
|
||||||
# Dependency imports
|
# Dependency imports
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
@ -66,12 +66,14 @@ class Dataset(object):
|
||||||
"""
|
"""
|
||||||
return self._size
|
return self._size
|
||||||
|
|
||||||
def gen_tf_dataset(self,
|
def gen_tf_dataset(
|
||||||
|
self,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
is_training: bool = False,
|
is_training: bool = False,
|
||||||
shuffle: bool = False,
|
shuffle: bool = False,
|
||||||
preprocess: Optional[Callable[..., bool]] = None,
|
preprocess: Optional[Callable[..., Any]] = None,
|
||||||
drop_remainder: bool = False) -> tf.data.Dataset:
|
drop_remainder: bool = False,
|
||||||
|
) -> tf.data.Dataset:
|
||||||
"""Generates a batched tf.data.Dataset for training/evaluation.
|
"""Generates a batched tf.data.Dataset for training/evaluation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -48,11 +48,13 @@ class Classifier(custom_model.CustomModel):
|
||||||
self._hparams: hp.BaseHParams = None
|
self._hparams: hp.BaseHParams = None
|
||||||
self._history: tf.keras.callbacks.History = None
|
self._history: tf.keras.callbacks.History = None
|
||||||
|
|
||||||
def _train_model(self,
|
def _train_model(
|
||||||
|
self,
|
||||||
train_data: classification_ds.ClassificationDataset,
|
train_data: classification_ds.ClassificationDataset,
|
||||||
validation_data: classification_ds.ClassificationDataset,
|
validation_data: classification_ds.ClassificationDataset,
|
||||||
preprocessor: Optional[Callable[..., bool]] = None,
|
preprocessor: Optional[Callable[..., Any]] = None,
|
||||||
checkpoint_path: Optional[str] = None):
|
checkpoint_path: Optional[str] = None,
|
||||||
|
):
|
||||||
"""Trains the classifier model.
|
"""Trains the classifier model.
|
||||||
|
|
||||||
Compiles and fits the tf.keras `_model` and records the `_history`.
|
Compiles and fits the tf.keras `_model` and records the `_history`.
|
||||||
|
|
|
@ -115,9 +115,11 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
|
||||||
def convert_to_tflite(
|
def convert_to_tflite(
|
||||||
model: tf.keras.Model,
|
model: tf.keras.Model,
|
||||||
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
quantization_config: Optional[quantization.QuantizationConfig] = None,
|
||||||
supported_ops: Tuple[tf.lite.OpsSet,
|
supported_ops: Tuple[tf.lite.OpsSet, ...] = (
|
||||||
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
|
tf.lite.OpsSet.TFLITE_BUILTINS,
|
||||||
preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
|
),
|
||||||
|
preprocess: Optional[Callable[..., Any]] = None,
|
||||||
|
) -> bytearray:
|
||||||
"""Converts the input Keras model to TFLite format.
|
"""Converts the input Keras model to TFLite format.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
48
mediapipe/model_maker/python/vision/face_stylizer/BUILD
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Placeholder for internal Python strict test compatibility macro.
|
||||||
|
# Placeholder for internal Python strict library and test compatibility macro.
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe:__subpackages__"])
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "testdata",
|
||||||
|
srcs = glob([
|
||||||
|
"testdata/**",
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "dataset",
|
||||||
|
srcs = ["dataset.py"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/model_maker/python/core/data:classification_dataset",
|
||||||
|
"//mediapipe/model_maker/python/vision/core:image_utils",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "dataset_test",
|
||||||
|
srcs = ["dataset_test.py"],
|
||||||
|
data = [
|
||||||
|
":testdata",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":dataset",
|
||||||
|
"//mediapipe/tasks/python/test:test_utils",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,14 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MediaPipe Model Maker Python Public API For Face Stylization."""
|
98
mediapipe/model_maker/python/vision/face_stylizer/dataset.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Face stylizer dataset library."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.core.data import classification_dataset
|
||||||
|
from mediapipe.model_maker.python.vision.core import image_utils
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Change to a unlabeled dataset if it makes sense.
|
||||||
|
class Dataset(classification_dataset.ClassificationDataset):
|
||||||
|
"""Dataset library for face stylizer fine tuning."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_folder(
|
||||||
|
cls, dirname: str
|
||||||
|
) -> classification_dataset.ClassificationDataset:
|
||||||
|
"""Loads images from the given directory.
|
||||||
|
|
||||||
|
The style image dataset directory is expected to contain one subdirectory
|
||||||
|
whose name represents the label of the style. There can be one or multiple
|
||||||
|
images of the same style in that subdirectory. Supported input image formats
|
||||||
|
include 'jpg', 'jpeg', 'png'.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dirname: Name of the directory containing the image files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset containing images and labels and other related info.
|
||||||
|
Raises:
|
||||||
|
ValueError: if the input data directory is empty.
|
||||||
|
"""
|
||||||
|
data_root = os.path.abspath(dirname)
|
||||||
|
|
||||||
|
# Assumes the image data of the same label are in the same subdirectory,
|
||||||
|
# gets image path and label names.
|
||||||
|
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
|
||||||
|
all_image_size = len(all_image_paths)
|
||||||
|
if all_image_size == 0:
|
||||||
|
raise ValueError('Invalid input data directory')
|
||||||
|
if not any(
|
||||||
|
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
|
||||||
|
):
|
||||||
|
raise ValueError('No images found under given directory')
|
||||||
|
|
||||||
|
label_names = sorted(
|
||||||
|
name
|
||||||
|
for name in os.listdir(data_root)
|
||||||
|
if os.path.isdir(os.path.join(data_root, name))
|
||||||
|
)
|
||||||
|
all_label_size = len(label_names)
|
||||||
|
index_by_label = dict(
|
||||||
|
(name, index) for index, name in enumerate(label_names)
|
||||||
|
)
|
||||||
|
# Get the style label from the subdirectory name.
|
||||||
|
all_image_labels = [
|
||||||
|
index_by_label[os.path.basename(os.path.dirname(path))]
|
||||||
|
for path in all_image_paths
|
||||||
|
]
|
||||||
|
|
||||||
|
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
|
||||||
|
|
||||||
|
image_ds = path_ds.map(
|
||||||
|
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load label
|
||||||
|
label_ds = tf.data.Dataset.from_tensor_slices(
|
||||||
|
tf.cast(all_image_labels, tf.int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a dataset of (image, label) pairs
|
||||||
|
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
'Load images dataset with size: %d, num_label: %d, labels: %s.',
|
||||||
|
all_image_size,
|
||||||
|
all_label_size,
|
||||||
|
', '.join(label_names),
|
||||||
|
)
|
||||||
|
return Dataset(
|
||||||
|
dataset=image_label_ds, size=all_image_size, label_names=label_names
|
||||||
|
)
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from mediapipe.model_maker.python.vision.face_stylizer import dataset
|
||||||
|
from mediapipe.tasks.python.test import test_utils
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super().setUp()
|
||||||
|
# TODO: Replace the stylize image dataset with licensed images.
|
||||||
|
self._test_data_dirname = 'testdata'
|
||||||
|
|
||||||
|
def test_from_folder(self):
|
||||||
|
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
|
||||||
|
data = dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||||
|
self.assertEqual(data.num_classes, 2)
|
||||||
|
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
|
||||||
|
self.assertLen(data, 2)
|
||||||
|
|
||||||
|
def test_from_folder_raise_value_error_for_invalid_path(self):
|
||||||
|
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
|
||||||
|
dataset.Dataset.from_folder(dirname='invalid')
|
||||||
|
|
||||||
|
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
|
||||||
|
input_data_dir = test_utils.get_test_data_path('face_stylizer')
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, 'No images found under given directory'
|
||||||
|
):
|
||||||
|
dataset.Dataset.from_folder(dirname=input_data_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
tf.test.main()
|
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/cartoon/disney.png
vendored
Normal file
After Width: | Height: | Size: 347 KiB |
BIN
mediapipe/model_maker/python/vision/face_stylizer/testdata/sketch/sketch.png
vendored
Normal file
After Width: | Height: | Size: 336 KiB |
|
@ -15,6 +15,7 @@
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import unittest
|
||||||
from unittest import mock as unittest_mock
|
from unittest import mock as unittest_mock
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
|
@ -31,6 +32,7 @@ _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdat
|
||||||
tf.keras.backend.experimental.enable_tf_random_generator()
|
tf.keras.backend.experimental.enable_tf_random_generator()
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skip('b/273818271')
|
||||||
class GestureRecognizerTest(tf.test.TestCase):
|
class GestureRecognizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
def _load_data(self):
|
def _load_data(self):
|
||||||
|
@ -72,8 +74,10 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
|
|
||||||
self._test_accuracy(model)
|
self._test_accuracy(model)
|
||||||
|
|
||||||
|
@unittest.skip('b/273818271')
|
||||||
@unittest_mock.patch.object(
|
@unittest_mock.patch.object(
|
||||||
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense)
|
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense
|
||||||
|
)
|
||||||
def test_gesture_recognizer_model_layer_widths(self, mock_dense):
|
def test_gesture_recognizer_model_layer_widths(self, mock_dense):
|
||||||
layer_widths = [64, 32]
|
layer_widths = [64, 32]
|
||||||
mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
|
mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
|
||||||
|
@ -143,12 +147,14 @@ class GestureRecognizerTest(tf.test.TestCase):
|
||||||
hyperparameters,
|
hyperparameters,
|
||||||
'HParams',
|
'HParams',
|
||||||
autospec=True,
|
autospec=True,
|
||||||
return_value=gesture_recognizer.HParams(epochs=1))
|
return_value=gesture_recognizer.HParams(epochs=1),
|
||||||
|
)
|
||||||
@unittest_mock.patch.object(
|
@unittest_mock.patch.object(
|
||||||
model_options,
|
model_options,
|
||||||
'GestureRecognizerModelOptions',
|
'GestureRecognizerModelOptions',
|
||||||
autospec=True,
|
autospec=True,
|
||||||
return_value=gesture_recognizer.ModelOptions())
|
return_value=gesture_recognizer.ModelOptions(),
|
||||||
|
)
|
||||||
def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options(
|
def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options(
|
||||||
self, mock_hparams, mock_model_options):
|
self, mock_hparams, mock_model_options):
|
||||||
options = gesture_recognizer.GestureRecognizerOptions()
|
options = gesture_recognizer.GestureRecognizerOptions()
|
||||||
|
|
|
@ -28,7 +28,7 @@ class ModelSpec(object):
|
||||||
uri: str,
|
uri: str,
|
||||||
input_image_shape: Optional[List[int]] = None,
|
input_image_shape: Optional[List[int]] = None,
|
||||||
name: str = ''):
|
name: str = ''):
|
||||||
"""Initializes a new instance of the `ImageModelSpec` class.
|
"""Initializes a new instance of the image classifier `ModelSpec` class.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
uri: str, URI to the pretrained model.
|
uri: str, URI to the pretrained model.
|
||||||
|
|
|
@ -5,4 +5,4 @@ opencv-python
|
||||||
tensorflow>=2.10
|
tensorflow>=2.10
|
||||||
tensorflow-datasets
|
tensorflow-datasets
|
||||||
tensorflow-hub
|
tensorflow-hub
|
||||||
tf-models-official>=2.10.1
|
tf-models-official>=2.11.4
|
||||||
|
|
|
@ -37,6 +37,7 @@ constexpr char kDetectionTag[] = "DETECTION";
|
||||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||||
constexpr char kLabelsTag[] = "LABELS";
|
constexpr char kLabelsTag[] = "LABELS";
|
||||||
constexpr char kLabelsCsvTag[] = "LABELS_CSV";
|
constexpr char kLabelsCsvTag[] = "LABELS_CSV";
|
||||||
|
constexpr char kLabelMapTag[] = "LABEL_MAP";
|
||||||
|
|
||||||
using mediapipe::RE2;
|
using mediapipe::RE2;
|
||||||
using Detections = std::vector<Detection>;
|
using Detections = std::vector<Detection>;
|
||||||
|
@ -151,6 +152,11 @@ absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) {
|
||||||
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
||||||
cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
|
cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
|
||||||
}
|
}
|
||||||
|
if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
|
||||||
|
cc->InputSidePackets()
|
||||||
|
.Tag(kLabelMapTag)
|
||||||
|
.Set<std::unique_ptr<std::map<int, std::string>>>();
|
||||||
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,7 +164,8 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
|
||||||
cc->SetOffset(TimestampDiff(0));
|
cc->SetOffset(TimestampDiff(0));
|
||||||
options_ = cc->Options<FilterDetectionCalculatorOptions>();
|
options_ = cc->Options<FilterDetectionCalculatorOptions>();
|
||||||
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
|
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
|
||||||
cc->InputSidePackets().HasTag(kLabelsCsvTag);
|
cc->InputSidePackets().HasTag(kLabelsCsvTag) ||
|
||||||
|
cc->InputSidePackets().HasTag(kLabelMapTag);
|
||||||
if (limit_labels_) {
|
if (limit_labels_) {
|
||||||
Strings allowlist_labels;
|
Strings allowlist_labels;
|
||||||
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
|
||||||
|
@ -168,8 +175,16 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
|
||||||
for (auto& e : allowlist_labels) {
|
for (auto& e : allowlist_labels) {
|
||||||
absl::StripAsciiWhitespace(&e);
|
absl::StripAsciiWhitespace(&e);
|
||||||
}
|
}
|
||||||
} else {
|
} else if (cc->InputSidePackets().HasTag(kLabelsTag)) {
|
||||||
allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
|
allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
|
||||||
|
} else if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
|
||||||
|
auto label_map = cc->InputSidePackets()
|
||||||
|
.Tag(kLabelMapTag)
|
||||||
|
.Get<std::unique_ptr<std::map<int, std::string>>>()
|
||||||
|
.get();
|
||||||
|
for (const auto& [_, v] : *label_map) {
|
||||||
|
allowlist_labels.push_back(v);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
|
allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,5 +67,68 @@ TEST(FilterDetectionCalculatorTest, DetectionFilterTest) {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(FilterDetectionCalculatorTest, DetectionFilterLabelMapTest) {
|
||||||
|
auto runner = std::make_unique<CalculatorRunner>(
|
||||||
|
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||||
|
calculator: "FilterDetectionCalculator"
|
||||||
|
input_stream: "DETECTION:input"
|
||||||
|
input_side_packet: "LABEL_MAP:input_map"
|
||||||
|
output_stream: "DETECTION:output"
|
||||||
|
options {
|
||||||
|
[mediapipe.FilterDetectionCalculatorOptions.ext]: { min_score: 0.6 }
|
||||||
|
}
|
||||||
|
)pb"));
|
||||||
|
|
||||||
|
runner->MutableInputs()->Tag("DETECTION").packets = {
|
||||||
|
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
|
||||||
|
label: "a"
|
||||||
|
label: "b"
|
||||||
|
label: "c"
|
||||||
|
label: "d"
|
||||||
|
score: 1
|
||||||
|
score: 0.8
|
||||||
|
score: 0.3
|
||||||
|
score: 0.9
|
||||||
|
)pb"))
|
||||||
|
.At(Timestamp(20)),
|
||||||
|
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
|
||||||
|
label: "a"
|
||||||
|
label: "b"
|
||||||
|
label: "c"
|
||||||
|
label: "e"
|
||||||
|
score: 0.6
|
||||||
|
score: 0.4
|
||||||
|
score: 0.2
|
||||||
|
score: 0.7
|
||||||
|
)pb"))
|
||||||
|
.At(Timestamp(40)),
|
||||||
|
};
|
||||||
|
|
||||||
|
auto label_map = std::make_unique<std::map<int, std::string>>();
|
||||||
|
(*label_map)[0] = "a";
|
||||||
|
(*label_map)[1] = "b";
|
||||||
|
(*label_map)[2] = "c";
|
||||||
|
runner->MutableSidePackets()->Tag("LABEL_MAP") =
|
||||||
|
AdoptAsUniquePtr(label_map.release());
|
||||||
|
|
||||||
|
// Run graph.
|
||||||
|
MP_ASSERT_OK(runner->Run());
|
||||||
|
|
||||||
|
// Check output.
|
||||||
|
EXPECT_THAT(
|
||||||
|
runner->Outputs().Tag("DETECTION").packets,
|
||||||
|
ElementsAre(PacketContainsTimestampAndPayload<Detection>(
|
||||||
|
Eq(Timestamp(20)),
|
||||||
|
EqualsProto(R"pb(
|
||||||
|
label: "a" label: "b" score: 1 score: 0.8
|
||||||
|
)pb")), // Packet 1 at timestamp 20.
|
||||||
|
PacketContainsTimestampAndPayload<Detection>(
|
||||||
|
Eq(Timestamp(40)),
|
||||||
|
EqualsProto(R"pb(
|
||||||
|
label: "a" score: 0.6
|
||||||
|
)pb")) // Packet 2 at timestamp 40.
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace mediapipe
|
} // namespace mediapipe
|
||||||
|
|
|
@ -57,6 +57,7 @@ pybind_extension(
|
||||||
"//mediapipe/framework/formats:landmark_registration",
|
"//mediapipe/framework/formats:landmark_registration",
|
||||||
"//mediapipe/framework/formats:rect_registration",
|
"//mediapipe/framework/formats:rect_registration",
|
||||||
"//mediapipe/modules/objectron/calculators:annotation_registration",
|
"//mediapipe/modules/objectron/calculators:annotation_registration",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_registration",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -95,6 +96,8 @@ cc_library(
|
||||||
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
|
||||||
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
|
||||||
"//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph",
|
"//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
|
||||||
|
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
|
||||||
] + select({
|
] + select({
|
||||||
# TODO: Build text_classifier_graph and text_embedder_graph on Windows.
|
# TODO: Build text_classifier_graph and text_embedder_graph on Windows.
|
||||||
"//mediapipe:windows": [],
|
"//mediapipe:windows": [],
|
||||||
|
|
|
@ -30,7 +30,7 @@ constexpr absl::string_view kMediaPipeTasksPayload = "MediaPipeTasksStatus";
|
||||||
//
|
//
|
||||||
// At runtime, such codes are meant to be attached (where applicable) to a
|
// At runtime, such codes are meant to be attached (where applicable) to a
|
||||||
// `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and
|
// `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and
|
||||||
// stringifed error code as value (aka payload). This logic is encapsulated in
|
// stringified error code as value (aka payload). This logic is encapsulated in
|
||||||
// the `CreateStatusWithPayload` helper below for convenience.
|
// the `CreateStatusWithPayload` helper below for convenience.
|
||||||
//
|
//
|
||||||
// The returned status includes:
|
// The returned status includes:
|
||||||
|
|
|
@ -51,12 +51,11 @@ ModelAssetBundleResources::Create(
|
||||||
auto model_bundle_resources = absl::WrapUnique(
|
auto model_bundle_resources = absl::WrapUnique(
|
||||||
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
|
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
model_bundle_resources->ExtractModelFilesFromExternalFileProto());
|
model_bundle_resources->ExtractFilesFromExternalFileProto());
|
||||||
return model_bundle_resources;
|
return model_bundle_resources;
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::Status
|
absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() {
|
||||||
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
|
|
||||||
if (model_asset_bundle_file_->has_file_name()) {
|
if (model_asset_bundle_file_->has_file_name()) {
|
||||||
// If the model asset bundle file name is a relative path, searches the file
|
// If the model asset bundle file name is a relative path, searches the file
|
||||||
// in a platform-specific location and returns the absolute path on success.
|
// in a platform-specific location and returns the absolute path on success.
|
||||||
|
@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
|
||||||
model_asset_bundle_file_handler_->GetFileContent().data();
|
model_asset_bundle_file_handler_->GetFileContent().data();
|
||||||
size_t buffer_size =
|
size_t buffer_size =
|
||||||
model_asset_bundle_file_handler_->GetFileContent().size();
|
model_asset_bundle_file_handler_->GetFileContent().size();
|
||||||
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size,
|
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_);
|
||||||
&model_files_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile(
|
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetFile(
|
||||||
const std::string& filename) const {
|
const std::string& filename) const {
|
||||||
auto it = model_files_.find(filename);
|
auto it = files_.find(filename);
|
||||||
if (it == model_files_.end()) {
|
if (it == files_.end()) {
|
||||||
auto model_files = ListModelFiles();
|
auto files = ListFiles();
|
||||||
std::string all_model_files =
|
std::string all_files = absl::StrJoin(files.begin(), files.end(), ", ");
|
||||||
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
|
|
||||||
|
|
||||||
return CreateStatusWithPayload(
|
return CreateStatusWithPayload(
|
||||||
StatusCode::kNotFound,
|
StatusCode::kNotFound,
|
||||||
absl::StrFormat("No model file with name: %s. All model files in the "
|
absl::StrFormat("No file with name: %s. All files in the model asset "
|
||||||
"model asset bundle are: %s.",
|
"bundle are: %s.",
|
||||||
filename, all_model_files),
|
filename, all_files),
|
||||||
MediaPipeTasksStatus::kFileNotFoundError);
|
MediaPipeTasksStatus::kFileNotFoundError);
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const {
|
std::vector<std::string> ModelAssetBundleResources::ListFiles() const {
|
||||||
std::vector<std::string> model_names;
|
std::vector<std::string> file_names;
|
||||||
for (const auto& [model_name, _] : model_files_) {
|
for (const auto& [file_name, _] : files_) {
|
||||||
model_names.push_back(model_name);
|
file_names.push_back(file_name);
|
||||||
}
|
}
|
||||||
return model_names;
|
return file_names;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace core
|
} // namespace core
|
||||||
|
|
|
@ -28,8 +28,8 @@ namespace core {
|
||||||
// The mediapipe task model asset bundle resources class.
|
// The mediapipe task model asset bundle resources class.
|
||||||
// A ModelAssetBundleResources object, created from an external file proto,
|
// A ModelAssetBundleResources object, created from an external file proto,
|
||||||
// contains model asset bundle related resources and the method to extract the
|
// contains model asset bundle related resources and the method to extract the
|
||||||
// tflite models or model asset bundles for the mediapipe sub-tasks. As the
|
// tflite models, resource files or model asset bundles for the mediapipe
|
||||||
// resources are owned by the ModelAssetBundleResources object
|
// sub-tasks. As the resources are owned by the ModelAssetBundleResources object
|
||||||
// callers must keep ModelAssetBundleResources alive while using any of the
|
// callers must keep ModelAssetBundleResources alive while using any of the
|
||||||
// resources.
|
// resources.
|
||||||
class ModelAssetBundleResources {
|
class ModelAssetBundleResources {
|
||||||
|
@ -50,14 +50,13 @@ class ModelAssetBundleResources {
|
||||||
// Returns the model asset bundle resources tag.
|
// Returns the model asset bundle resources tag.
|
||||||
std::string GetTag() const { return tag_; }
|
std::string GetTag() const { return tag_; }
|
||||||
|
|
||||||
// Gets the contents of the model file (either tflite model file or model
|
// Gets the contents of the model file (either tflite model file, resource
|
||||||
// bundle file) with the provided name. An error is returned if there is no
|
// file or model bundle file) with the provided name. An error is returned if
|
||||||
// such model file.
|
// there is no such model file.
|
||||||
absl::StatusOr<absl::string_view> GetModelFile(
|
absl::StatusOr<absl::string_view> GetFile(const std::string& filename) const;
|
||||||
const std::string& filename) const;
|
|
||||||
|
|
||||||
// Lists all the model file names in the model asset model.
|
// Lists all the file names in the model asset model.
|
||||||
std::vector<std::string> ListModelFiles() const;
|
std::vector<std::string> ListFiles() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Constructor.
|
// Constructor.
|
||||||
|
@ -65,9 +64,9 @@ class ModelAssetBundleResources {
|
||||||
const std::string& tag,
|
const std::string& tag,
|
||||||
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
|
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
|
||||||
|
|
||||||
// Extracts the model files (either tflite model file or model bundle file)
|
// Extracts the model files (either tflite model file, resource file or model
|
||||||
// from the external file proto.
|
// bundle file) from the external file proto.
|
||||||
absl::Status ExtractModelFilesFromExternalFileProto();
|
absl::Status ExtractFilesFromExternalFileProto();
|
||||||
|
|
||||||
// The model asset bundle resources tag.
|
// The model asset bundle resources tag.
|
||||||
const std::string tag_;
|
const std::string tag_;
|
||||||
|
@ -78,11 +77,11 @@ class ModelAssetBundleResources {
|
||||||
// The ExternalFileHandler for the model asset bundle.
|
// The ExternalFileHandler for the model asset bundle.
|
||||||
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
|
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
|
||||||
|
|
||||||
// The model files bundled in model asset bundle, as a map with the filename
|
// The files bundled in model asset bundle, as a map with the filename
|
||||||
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and
|
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and
|
||||||
// a pointer to the file contents as value. Each model file can be either
|
// a pointer to the file contents as value. Each file can be either a TFLite
|
||||||
// a TFLite model file or a model bundle file for sub-task.
|
// model file, resource file or a model bundle file for sub-task.
|
||||||
absl::flat_hash_map<std::string, absl::string_view> model_files_;
|
absl::flat_hash_map<std::string, absl::string_view> files_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace core
|
} // namespace core
|
||||||
|
|
|
@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||||
.status());
|
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||||
.status());
|
.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||||
.status());
|
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||||
.status());
|
.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||||
.status());
|
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||||
.status());
|
.status());
|
||||||
}
|
}
|
||||||
#endif // _WIN32
|
#endif // _WIN32
|
||||||
|
@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
|
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
|
||||||
.status());
|
|
||||||
MP_EXPECT_OK(
|
MP_EXPECT_OK(
|
||||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
|
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
|
||||||
.status());
|
.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
auto status_or_model_bundle_file =
|
auto status_or_model_bundle_file =
|
||||||
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task");
|
model_bundle_resources->GetFile("dummy_hand_landmarker.task");
|
||||||
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
||||||
|
|
||||||
// Creates sub-task model asset bundle resources.
|
// Creates sub-task model asset bundle resources.
|
||||||
|
@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(hand_landmaker_model_file)));
|
std::move(hand_landmaker_model_file)));
|
||||||
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
||||||
->GetModelFile("dummy_hand_detector.tflite")
|
->GetFile("dummy_hand_detector.tflite")
|
||||||
.status());
|
.status());
|
||||||
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
|
||||||
->GetModelFile("dummy_hand_landmarker.tflite")
|
->GetFile("dummy_hand_landmarker.tflite")
|
||||||
.status());
|
.status());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
auto status_or_model_bundle_file =
|
auto status_or_model_bundle_file =
|
||||||
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite");
|
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite");
|
||||||
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
MP_EXPECT_OK(status_or_model_bundle_file.status());
|
||||||
|
|
||||||
// Verify tflite model works.
|
// Verify tflite model works.
|
||||||
|
@ -200,11 +196,11 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
|
||||||
auto model_bundle_resources,
|
auto model_bundle_resources,
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
auto status = model_bundle_resources->GetModelFile("not_found.task").status();
|
auto status = model_bundle_resources->GetFile("not_found.task").status();
|
||||||
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
|
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
|
||||||
EXPECT_THAT(status.message(),
|
EXPECT_THAT(
|
||||||
testing::HasSubstr(
|
status.message(),
|
||||||
"No model file with name: not_found.task. All model files in "
|
testing::HasSubstr("No file with name: not_found.task. All files in "
|
||||||
"the model asset bundle are: "));
|
"the model asset bundle are: "));
|
||||||
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
|
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
|
||||||
testing::Optional(absl::Cord(
|
testing::Optional(absl::Cord(
|
||||||
|
@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
|
||||||
auto model_bundle_resources,
|
auto model_bundle_resources,
|
||||||
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
|
||||||
std::move(model_file)));
|
std::move(model_file)));
|
||||||
auto model_files = model_bundle_resources->ListModelFiles();
|
auto model_files = model_bundle_resources->ListFiles();
|
||||||
std::vector<std::string> expected_model_files = {
|
std::vector<std::string> expected_model_files = {
|
||||||
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
|
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
|
||||||
std::sort(model_files.begin(), model_files.end());
|
std::sort(model_files.begin(), model_files.end());
|
||||||
|
|
|
@ -77,9 +77,11 @@ class ModelResourcesCalculator : public api2::Node {
|
||||||
if (options.has_model_file()) {
|
if (options.has_model_file()) {
|
||||||
RET_CHECK(options.model_file().has_file_content() ||
|
RET_CHECK(options.model_file().has_file_content() ||
|
||||||
options.model_file().has_file_descriptor_meta() ||
|
options.model_file().has_file_descriptor_meta() ||
|
||||||
options.model_file().has_file_name())
|
options.model_file().has_file_name() ||
|
||||||
|
options.model_file().has_file_pointer_meta())
|
||||||
<< "'model_file' must specify at least one of "
|
<< "'model_file' must specify at least one of "
|
||||||
"'file_content', 'file_descriptor_meta', or 'file_name'";
|
"'file_content', 'file_descriptor_meta', 'file_name', or "
|
||||||
|
"'file_pointer_meta'";
|
||||||
}
|
}
|
||||||
return absl::OkStatus();
|
return absl::OkStatus();
|
||||||
}
|
}
|
||||||
|
|
|
@ -179,9 +179,9 @@ TEST_F(ModelResourcesCalculatorTest, EmptyExternalFileProto) {
|
||||||
auto status = graph.Initialize(graph_config);
|
auto status = graph.Initialize(graph_config);
|
||||||
ASSERT_FALSE(status.ok());
|
ASSERT_FALSE(status.ok());
|
||||||
EXPECT_THAT(status.message(),
|
EXPECT_THAT(status.message(),
|
||||||
testing::HasSubstr(
|
testing::HasSubstr("'model_file' must specify at least one of "
|
||||||
"'model_file' must specify at least one of "
|
"'file_content', 'file_descriptor_meta', "
|
||||||
"'file_content', 'file_descriptor_meta', or 'file_name'"));
|
"'file_name', or 'file_pointer_meta'"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ModelResourcesCalculatorTest, GraphServiceNotAvailable) {
|
TEST_F(ModelResourcesCalculatorTest, GraphServiceNotAvailable) {
|
||||||
|
|
|
@ -138,7 +138,7 @@ class InferenceSubgraph : public Subgraph {
|
||||||
delegate.mutable_tflite()->CopyFrom(acceleration.tflite());
|
delegate.mutable_tflite()->CopyFrom(acceleration.tflite());
|
||||||
break;
|
break;
|
||||||
case Acceleration::DELEGATE_NOT_SET:
|
case Acceleration::DELEGATE_NOT_SET:
|
||||||
// Deafult inference calculator setting.
|
// Default inference calculator setting.
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return delegate;
|
return delegate;
|
||||||
|
|
|
@ -124,10 +124,10 @@ class ModelTaskGraph : public Subgraph {
|
||||||
// Inserts a mediapipe task inference subgraph into the provided
|
// Inserts a mediapipe task inference subgraph into the provided
|
||||||
// GraphBuilder. The returned node provides the following interfaces to the
|
// GraphBuilder. The returned node provides the following interfaces to the
|
||||||
// the rest of the graph:
|
// the rest of the graph:
|
||||||
// - a tensor vector (std::vector<MeidaPipe::Tensor>) input stream with tag
|
// - a tensor vector (std::vector<mediapipe::Tensor>) input stream with tag
|
||||||
// "TENSORS", representing the input tensors to be consumed by the
|
// "TENSORS", representing the input tensors to be consumed by the
|
||||||
// inference engine.
|
// inference engine.
|
||||||
// - a tensor vector (std::vector<MeidaPipe::Tensor>) output stream with tag
|
// - a tensor vector (std::vector<mediapipe::Tensor>) output stream with tag
|
||||||
// "TENSORS", representing the output tensors generated by the inference
|
// "TENSORS", representing the output tensors generated by the inference
|
||||||
// engine.
|
// engine.
|
||||||
// - a MetadataExtractor output side packet with tag "METADATA_EXTRACTOR".
|
// - a MetadataExtractor output side packet with tag "METADATA_EXTRACTOR".
|
||||||
|
|
|
@ -301,7 +301,7 @@ absl::Status TaskRunner::Close() {
|
||||||
}
|
}
|
||||||
is_running_ = false;
|
is_running_ = false;
|
||||||
MP_RETURN_IF_ERROR(
|
MP_RETURN_IF_ERROR(
|
||||||
AddPayload(graph_.CloseAllInputStreams(), "Fail to close intput streams",
|
AddPayload(graph_.CloseAllInputStreams(), "Fail to close input streams",
|
||||||
MediaPipeTasksStatus::kRunnerFailsToCloseError));
|
MediaPipeTasksStatus::kRunnerFailsToCloseError));
|
||||||
MP_RETURN_IF_ERROR(AddPayload(
|
MP_RETURN_IF_ERROR(AddPayload(
|
||||||
graph_.WaitUntilDone(), "Fail to shutdown the MediaPipe graph.",
|
graph_.WaitUntilDone(), "Fail to shutdown the MediaPipe graph.",
|
||||||
|
|
|
@ -65,7 +65,7 @@ class TaskRunner {
|
||||||
// Creates the task runner with a CalculatorGraphConfig proto.
|
// Creates the task runner with a CalculatorGraphConfig proto.
|
||||||
// If a tflite op resolver object is provided, the task runner will take
|
// If a tflite op resolver object is provided, the task runner will take
|
||||||
// it as the global op resolver for all models running within this task.
|
// it as the global op resolver for all models running within this task.
|
||||||
// The op resolver's owernship will be transferred into the pipeleine runner.
|
// The op resolver's ownership will be transferred into the pipeleine runner.
|
||||||
// When a user-defined PacketsCallback is provided, clients must use the
|
// When a user-defined PacketsCallback is provided, clients must use the
|
||||||
// asynchronous method, Send(), to provide the input packets. If the packets
|
// asynchronous method, Send(), to provide the input packets. If the packets
|
||||||
// callback is absent, clients must use the synchronous method, Process(), to
|
// callback is absent, clients must use the synchronous method, Process(), to
|
||||||
|
@ -84,7 +84,7 @@ class TaskRunner {
|
||||||
// frames from a video file and an audio file. The call blocks the current
|
// frames from a video file and an audio file. The call blocks the current
|
||||||
// thread until a failure status or a successful result is returned.
|
// thread until a failure status or a successful result is returned.
|
||||||
// If the input packets have no timestamp, an internal timestamp will be
|
// If the input packets have no timestamp, an internal timestamp will be
|
||||||
// assigend per invocation. Otherwise, when the timestamp is set in the
|
// assigned per invocation. Otherwise, when the timestamp is set in the
|
||||||
// input packets, the caller must ensure that the input packet timestamps are
|
// input packets, the caller must ensure that the input packet timestamps are
|
||||||
// greater than the timestamps of the previous invocation. This method is
|
// greater than the timestamps of the previous invocation. This method is
|
||||||
// thread-unsafe and it is the caller's responsibility to synchronize access
|
// thread-unsafe and it is the caller's responsibility to synchronize access
|
||||||
|
|
|
@ -64,7 +64,7 @@ class ModelMetadataPopulator {
|
||||||
// Loads associated files into the TFLite FlatBuffer model. The input is a map
|
// Loads associated files into the TFLite FlatBuffer model. The input is a map
|
||||||
// of {filename, file contents}.
|
// of {filename, file contents}.
|
||||||
//
|
//
|
||||||
// Warning: this method removes any previoulsy present associated files.
|
// Warning: this method removes any previously present associated files.
|
||||||
// Calling this method multiple time removes any associated files from
|
// Calling this method multiple time removes any associated files from
|
||||||
// previous calls, so this method should usually be called only once.
|
// previous calls, so this method should usually be called only once.
|
||||||
void LoadAssociatedFiles(
|
void LoadAssociatedFiles(
|
||||||
|
|
|
@ -213,7 +213,7 @@ void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
|
||||||
Version* min_version) {
|
Version* min_version) {
|
||||||
if (table == nullptr) return;
|
if (table == nullptr) return;
|
||||||
|
|
||||||
// Checks the ContenProperties field.
|
// Checks the ContentProperties field.
|
||||||
if (table->content_properties_type() == ContentProperties_AudioProperties) {
|
if (table->content_properties_type() == ContentProperties_AudioProperties) {
|
||||||
UpdateMinimumVersion(
|
UpdateMinimumVersion(
|
||||||
GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties),
|
GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties),
|
||||||
|
|
|
@ -31,8 +31,8 @@ PYBIND11_MODULE(_pywrap_metadata_version, m) {
|
||||||
|
|
||||||
// Using pybind11 type conversions to convert between Python and native
|
// Using pybind11 type conversions to convert between Python and native
|
||||||
// C++ types. There are other options to provide access to native Python types
|
// C++ types. There are other options to provide access to native Python types
|
||||||
// in C++ and vice versa. See the pybind 11 instrcution [1] for more details.
|
// in C++ and vice versa. See the pybind 11 instruction [1] for more details.
|
||||||
// Type converstions is recommended by pybind11, though the main downside
|
// Type conversions is recommended by pybind11, though the main downside
|
||||||
// is that a copy of the data must be made on every Python to C++ transition:
|
// is that a copy of the data must be made on every Python to C++ transition:
|
||||||
// this is needed since the C++ and Python versions of the same type generally
|
// this is needed since the C++ and Python versions of the same type generally
|
||||||
// won’t have the same memory layout.
|
// won’t have the same memory layout.
|
||||||
|
|
|
@ -79,7 +79,7 @@ TEST(MetadataVersionTest,
|
||||||
auto metadata = metadata_builder.Finish();
|
auto metadata = metadata_builder.Finish();
|
||||||
FinishModelMetadataBuffer(builder, metadata);
|
FinishModelMetadataBuffer(builder, metadata);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -100,7 +100,7 @@ TEST(MetadataVersionTest,
|
||||||
auto metadata = metadata_builder.Finish();
|
auto metadata = metadata_builder.Finish();
|
||||||
builder.Finish(metadata);
|
builder.Finish(metadata);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version and triggers error.
|
// Gets the minimum metadata parser version and triggers error.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -121,7 +121,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_associated_files(associated_files);
|
metadata_builder.add_associated_files(associated_files);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -147,7 +147,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -172,7 +172,7 @@ TEST(MetadataVersionTest,
|
||||||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||||
CreateModelWithMetadata(tensors, builder);
|
CreateModelWithMetadata(tensors, builder);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -203,7 +203,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -234,7 +234,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -265,7 +265,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -294,7 +294,7 @@ TEST(MetadataVersionTest,
|
||||||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||||
CreateModelWithMetadata(tensors, builder);
|
CreateModelWithMetadata(tensors, builder);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -323,7 +323,7 @@ TEST(MetadataVersionTest,
|
||||||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||||
CreateModelWithMetadata(tensors, builder);
|
CreateModelWithMetadata(tensors, builder);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -348,7 +348,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -373,7 +373,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -404,7 +404,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -431,7 +431,7 @@ TEST(MetadataVersionTest,
|
||||||
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
|
||||||
CreateModelWithMetadata(tensors, builder);
|
CreateModelWithMetadata(tensors, builder);
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -453,7 +453,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_associated_files(associated_files);
|
metadata_builder.add_associated_files(associated_files);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -476,7 +476,7 @@ TEST(MetadataVersionTest,
|
||||||
metadata_builder.add_associated_files(associated_files);
|
metadata_builder.add_associated_files(associated_files);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
@ -504,7 +504,7 @@ TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) {
|
||||||
metadata_builder.add_subgraph_metadata(subgraphs);
|
metadata_builder.add_subgraph_metadata(subgraphs);
|
||||||
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
|
||||||
|
|
||||||
// Gets the mimimum metadata parser version.
|
// Gets the minimum metadata parser version.
|
||||||
std::string min_version;
|
std::string min_version;
|
||||||
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
|
||||||
builder.GetSize(), &min_version),
|
builder.GetSize(), &min_version),
|
||||||
|
|
|
@ -42,3 +42,36 @@ cc_test(
|
||||||
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "ngram_hash",
|
||||||
|
srcs = ["ngram_hash.cc"],
|
||||||
|
hdrs = ["ngram_hash.h"],
|
||||||
|
copts = tflite_copts(),
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils",
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
|
||||||
|
"@flatbuffers",
|
||||||
|
"@org_tensorflow//tensorflow/lite:string_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "ngram_hash_test",
|
||||||
|
srcs = ["ngram_hash_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":ngram_hash",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
"@flatbuffers",
|
||||||
|
"@org_tensorflow//tensorflow/lite:framework",
|
||||||
|
"@org_tensorflow//tensorflow/lite:string_util",
|
||||||
|
"@org_tensorflow//tensorflow/lite/c:common",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||||
|
"@org_tensorflow//tensorflow/lite/kernels:test_util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,264 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "flatbuffers/flexbuffers.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
|
||||||
|
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
|
||||||
|
namespace ngram_op {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::flexbuffers::GetRoot;
|
||||||
|
using ::flexbuffers::Map;
|
||||||
|
using ::flexbuffers::TypedVector;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::
|
||||||
|
LowercaseUnicodeStr;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
|
||||||
|
MurmurHash64WithSeed;
|
||||||
|
using ::tflite::GetString;
|
||||||
|
using ::tflite::StringRef;
|
||||||
|
|
||||||
|
constexpr int kInputMessage = 0;
|
||||||
|
constexpr int kOutputLabel = 0;
|
||||||
|
constexpr int kDefaultMaxSplits = 128;
|
||||||
|
|
||||||
|
// This op takes in a string, finds the character ngrams for it and then
|
||||||
|
// maps each of these ngrams to an index using the specified vocabulary sizes.
|
||||||
|
|
||||||
|
// Input(s):
|
||||||
|
// - input: Input string.
|
||||||
|
// - seeds: Seed for the random number generator.
|
||||||
|
// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would
|
||||||
|
// be interpreted as generating unigrams, bigrams, and trigrams.
|
||||||
|
// - vocab_sizes: Size of the vocabulary for each of the ngram features
|
||||||
|
// respectively. The op would generate vocab ids to be less than or equal to
|
||||||
|
// the vocab size. The index 0 implies an invalid ngram.
|
||||||
|
// - max_splits: Maximum number of tokens in the output. If this is unset, the
|
||||||
|
// limit is `kDefaultMaxSplits`.
|
||||||
|
// - lower_case_input: If this is set to true, the input string would be
|
||||||
|
// lower-cased before any processing.
|
||||||
|
|
||||||
|
// Output(s):
|
||||||
|
// - output: A tensor of size [number of ngrams, number of tokens + 2],
|
||||||
|
// where 2 tokens are reserved for the padding. If `max_splits` is set, this
|
||||||
|
// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`.
|
||||||
|
|
||||||
|
// Helper class used for pre-processing the input.
|
||||||
|
class NGramHashParams {
|
||||||
|
public:
|
||||||
|
NGramHashParams(const uint64_t seed, const std::vector<int>& ngram_lengths,
|
||||||
|
const std::vector<int>& vocab_sizes, int max_splits,
|
||||||
|
bool lower_case_input)
|
||||||
|
: seed_(seed),
|
||||||
|
ngram_lengths_(ngram_lengths),
|
||||||
|
vocab_sizes_(vocab_sizes),
|
||||||
|
max_splits_(max_splits),
|
||||||
|
lower_case_input_(lower_case_input) {}
|
||||||
|
|
||||||
|
TfLiteStatus PreprocessInput(const TfLiteTensor* input_t,
|
||||||
|
TfLiteContext* context) {
|
||||||
|
if (input_t->bytes == 0) {
|
||||||
|
context->ReportError(context, "Empty input not supported.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do sanity checks on the input.
|
||||||
|
if (ngram_lengths_.empty()) {
|
||||||
|
context->ReportError(context, "`ngram_lengths` must be non-empty.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (vocab_sizes_.empty()) {
|
||||||
|
context->ReportError(context, "`vocab_sizes` must be non-empty.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ngram_lengths_.size() != vocab_sizes_.size()) {
|
||||||
|
context->ReportError(
|
||||||
|
context,
|
||||||
|
"Sizes of `ngram_lengths` and `vocab_sizes` must be the same.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (max_splits_ <= 0) {
|
||||||
|
context->ReportError(context, "`max_splits` must be > 0.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Obtain and tokenize the input.
|
||||||
|
StringRef inputref = GetString(input_t, /*string_index=*/0);
|
||||||
|
if (lower_case_input_) {
|
||||||
|
std::string lower_cased_str;
|
||||||
|
LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str);
|
||||||
|
|
||||||
|
tokenized_output_ =
|
||||||
|
Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_,
|
||||||
|
/*exclude_nonalphaspace_tokens=*/true);
|
||||||
|
} else {
|
||||||
|
tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_,
|
||||||
|
/*exclude_nonalphaspace_tokens=*/true);
|
||||||
|
}
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
uint64_t GetSeed() const { return seed_; }
|
||||||
|
|
||||||
|
int GetNumTokens() const { return tokenized_output_.tokens.size(); }
|
||||||
|
|
||||||
|
int GetNumNGrams() const { return ngram_lengths_.size(); }
|
||||||
|
|
||||||
|
std::vector<int> GetNGramLengths() const { return ngram_lengths_; }
|
||||||
|
|
||||||
|
std::vector<int> GetVocabSizes() const { return vocab_sizes_; }
|
||||||
|
|
||||||
|
const TokenizedOutput& GetTokenizedOutput() const {
|
||||||
|
return tokenized_output_;
|
||||||
|
}
|
||||||
|
|
||||||
|
TokenizedOutput tokenized_output_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const uint64_t seed_;
|
||||||
|
std::vector<int> ngram_lengths_;
|
||||||
|
std::vector<int> vocab_sizes_;
|
||||||
|
const int max_splits_;
|
||||||
|
const bool lower_case_input_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Convert the TypedVector into a regular std::vector.
|
||||||
|
std::vector<int> GetIntVector(TypedVector typed_vec) {
|
||||||
|
std::vector<int> vec(typed_vec.size());
|
||||||
|
for (int j = 0; j < typed_vec.size(); j++) {
|
||||||
|
vec[j] = typed_vec[j].AsInt32();
|
||||||
|
}
|
||||||
|
return vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetNGramHashIndices(NGramHashParams* params, int32_t* data) {
|
||||||
|
const int max_unicode_length = params->GetNumTokens();
|
||||||
|
const auto ngram_lengths = params->GetNGramLengths();
|
||||||
|
const auto vocab_sizes = params->GetVocabSizes();
|
||||||
|
const auto& tokenized_output = params->GetTokenizedOutput();
|
||||||
|
const auto seed = params->GetSeed();
|
||||||
|
|
||||||
|
// Compute for each ngram.
|
||||||
|
for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) {
|
||||||
|
const int vocab_size = vocab_sizes[ngram];
|
||||||
|
const int ngram_length = ngram_lengths[ngram];
|
||||||
|
|
||||||
|
// Compute for each token within the input.
|
||||||
|
for (int start = 0; start < tokenized_output.tokens.size(); start++) {
|
||||||
|
// Compute the number of bytes for the ngram starting at the given
|
||||||
|
// token.
|
||||||
|
int num_bytes = 0;
|
||||||
|
for (int i = start;
|
||||||
|
i < tokenized_output.tokens.size() && i < (start + ngram_length);
|
||||||
|
i++) {
|
||||||
|
num_bytes += tokenized_output.tokens[i].second;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the hash for the ngram starting at the token.
|
||||||
|
const auto str_hash = MurmurHash64WithSeed(
|
||||||
|
tokenized_output.str.c_str() + tokenized_output.tokens[start].first,
|
||||||
|
num_bytes, seed);
|
||||||
|
|
||||||
|
// Map the hash to an index in the vocab.
|
||||||
|
data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||||
|
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
|
||||||
|
const Map& m = GetRoot(buffer_t, length).AsMap();
|
||||||
|
|
||||||
|
const uint64_t seed = m["seed"].AsUInt64();
|
||||||
|
const std::vector<int> ngram_lengths =
|
||||||
|
GetIntVector(m["ngram_lengths"].AsTypedVector());
|
||||||
|
const std::vector<int> vocab_sizes =
|
||||||
|
GetIntVector(m["vocab_sizes"].AsTypedVector());
|
||||||
|
const int max_splits =
|
||||||
|
m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32();
|
||||||
|
const bool lowercase_input =
|
||||||
|
m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool();
|
||||||
|
|
||||||
|
return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits,
|
||||||
|
lowercase_input);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Free(TfLiteContext* context, void* buffer) {
|
||||||
|
delete reinterpret_cast<NGramHashParams*>(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
SetTensorToDynamic(output);
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||||
|
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
|
||||||
|
TF_LITE_ENSURE_OK(
|
||||||
|
context,
|
||||||
|
params->PreprocessInput(GetInput(context, node, kInputMessage), context));
|
||||||
|
|
||||||
|
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
|
||||||
|
TF_LITE_ENSURE(context, output != nullptr);
|
||||||
|
if (IsDynamicTensor(output)) {
|
||||||
|
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
|
||||||
|
output_size->data[0] = 1;
|
||||||
|
output_size->data[1] = params->GetNumNGrams();
|
||||||
|
output_size->data[2] = params->GetNumTokens();
|
||||||
|
TF_LITE_ENSURE_OK(context,
|
||||||
|
context->ResizeTensor(context, output, output_size));
|
||||||
|
} else {
|
||||||
|
context->ReportError(context, "Output must by dynamic.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (output->type == kTfLiteInt32) {
|
||||||
|
GetNGramHashIndices(params, output->data.i32);
|
||||||
|
} else {
|
||||||
|
context->ReportError(context, "Output type must be Int32.");
|
||||||
|
return kTfLiteError;
|
||||||
|
}
|
||||||
|
|
||||||
|
return kTfLiteOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ngram_op
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_NGRAM_HASH() {
|
||||||
|
static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free,
|
||||||
|
ngram_op::Resize, ngram_op::Eval};
|
||||||
|
return &r;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tflite::ops::custom
|
|
@ -0,0 +1,27 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
||||||
|
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
|
||||||
|
TfLiteRegistration* Register_NGRAM_HASH();
|
||||||
|
|
||||||
|
} // namespace tflite::ops::custom
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
|
|
@ -0,0 +1,313 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <optional>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "flatbuffers/flexbuffers.h"
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||||
|
#include "tensorflow/lite/c/common.h"
|
||||||
|
#include "tensorflow/lite/interpreter.h"
|
||||||
|
#include "tensorflow/lite/kernels/register.h"
|
||||||
|
#include "tensorflow/lite/kernels/test_util.h"
|
||||||
|
#include "tensorflow/lite/model.h"
|
||||||
|
#include "tensorflow/lite/string_util.h"
|
||||||
|
|
||||||
|
namespace tflite::ops::custom {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::flexbuffers::Builder;
|
||||||
|
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
|
||||||
|
MurmurHash64WithSeed;
|
||||||
|
using ::testing::ElementsAreArray;
|
||||||
|
using ::testing::Message;
|
||||||
|
|
||||||
|
// Helper class for testing the op.
|
||||||
|
class NGramHashModel : public SingleOpModel {
|
||||||
|
public:
|
||||||
|
explicit NGramHashModel(const uint64_t seed,
|
||||||
|
const std::vector<int>& ngram_lengths,
|
||||||
|
const std::vector<int>& vocab_sizes,
|
||||||
|
const absl::optional<int> max_splits = std::nullopt) {
|
||||||
|
// Setup the model inputs.
|
||||||
|
Builder fbb;
|
||||||
|
size_t start = fbb.StartMap();
|
||||||
|
fbb.UInt("seed", seed);
|
||||||
|
{
|
||||||
|
size_t start = fbb.StartVector("ngram_lengths");
|
||||||
|
for (const int& ngram_len : ngram_lengths) {
|
||||||
|
fbb.Int(ngram_len);
|
||||||
|
}
|
||||||
|
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
size_t start = fbb.StartVector("vocab_sizes");
|
||||||
|
for (const int& vocab_size : vocab_sizes) {
|
||||||
|
fbb.Int(vocab_size);
|
||||||
|
}
|
||||||
|
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
|
||||||
|
}
|
||||||
|
if (max_splits) {
|
||||||
|
fbb.Int("max_splits", *max_splits);
|
||||||
|
}
|
||||||
|
fbb.EndMap(start);
|
||||||
|
fbb.Finish();
|
||||||
|
output_ = AddOutput({TensorType_INT32, {}});
|
||||||
|
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
|
||||||
|
BuildInterpreter({GetShape(input_)});
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetupInputTensor(const std::string& input) {
|
||||||
|
PopulateStringTensor(input_, {input});
|
||||||
|
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
|
||||||
|
<< "Cannot allocate tensors";
|
||||||
|
}
|
||||||
|
|
||||||
|
void Invoke(const std::string& input) {
|
||||||
|
SetupInputTensor(input);
|
||||||
|
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
|
||||||
|
}
|
||||||
|
|
||||||
|
TfLiteStatus InvokeUnchecked(const std::string& input) {
|
||||||
|
SetupInputTensor(input);
|
||||||
|
return SingleOpModel::Invoke();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
std::vector<T> GetOutput() {
|
||||||
|
return ExtractVector<T>(output_);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int input_ = AddInput(TensorType_STRING);
|
||||||
|
int output_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) {
|
||||||
|
// Checks that the op returns the expected value when the input is sane.
|
||||||
|
// Also checks that when `max_splits` is not specified, the entire string is
|
||||||
|
// tokenized.
|
||||||
|
const uint64_t kSeed = 123;
|
||||||
|
const std::vector<int> vocab_sizes({100, 200});
|
||||||
|
std::vector<int> ngram_lengths({1, 2});
|
||||||
|
const std::vector<std::string> testcase_inputs({
|
||||||
|
"hi",
|
||||||
|
"wow",
|
||||||
|
"!",
|
||||||
|
"HI",
|
||||||
|
});
|
||||||
|
|
||||||
|
// A hash function that maps the given string to an index in the embedding
|
||||||
|
// table denoted by `vocab_idx`.
|
||||||
|
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
|
||||||
|
const auto hash_value =
|
||||||
|
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
|
||||||
|
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
|
||||||
|
};
|
||||||
|
const std::vector<std::vector<int>> expected_testcase_outputs(
|
||||||
|
{{
|
||||||
|
// Unigram & Bigram output for "hi".
|
||||||
|
hash("^", 0),
|
||||||
|
hash("h", 0),
|
||||||
|
hash("i", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^h", 1),
|
||||||
|
hash("hi", 1),
|
||||||
|
hash("i$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow".
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("ow", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "!" (which will get replaced by " ").
|
||||||
|
hash("^", 0),
|
||||||
|
hash(" ", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^ ", 1),
|
||||||
|
hash(" $", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "HI" (which will get lower-cased).
|
||||||
|
hash("^", 0),
|
||||||
|
hash("h", 0),
|
||||||
|
hash("i", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^h", 1),
|
||||||
|
hash("hi", 1),
|
||||||
|
hash("i$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
}});
|
||||||
|
|
||||||
|
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
|
||||||
|
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
|
||||||
|
const string& testcase_input = testcase_inputs[test_idx];
|
||||||
|
m.Invoke(testcase_input);
|
||||||
|
SCOPED_TRACE(Message() << "Where the testcases' input is: "
|
||||||
|
<< testcase_input);
|
||||||
|
EXPECT_THAT(m.GetOutput<int>(),
|
||||||
|
ElementsAreArray(expected_testcase_outputs[test_idx]));
|
||||||
|
EXPECT_THAT(m.GetOutputShape(),
|
||||||
|
ElementsAreArray(
|
||||||
|
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
|
||||||
|
static_cast<int>(testcase_input.size()) + /*padding*/ 2}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) {
|
||||||
|
// Checks that the op returns the expected value when the input is correct
|
||||||
|
// when `max_splits` is specified.
|
||||||
|
const uint64_t kSeed = 123;
|
||||||
|
const std::vector<int> vocab_sizes({100, 200});
|
||||||
|
std::vector<int> ngram_lengths({1, 2});
|
||||||
|
|
||||||
|
const std::string testcase_input = "wow";
|
||||||
|
const std::vector<int> max_splits({2, 3, 4, 5, 6});
|
||||||
|
|
||||||
|
// A hash function that maps the given string to an index in the embedding
|
||||||
|
// table denoted by `vocab_idx`.
|
||||||
|
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
|
||||||
|
const auto hash_value =
|
||||||
|
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
|
||||||
|
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::vector<std::vector<int>> expected_testcase_outputs(
|
||||||
|
{{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 2.
|
||||||
|
// We cannot include any of the actual tokens, since `max_splits`
|
||||||
|
// only allows enough space for the delimiters.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 3.
|
||||||
|
// We can start to include some tokens from the input string.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 4.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("o$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 5.
|
||||||
|
// We can include the full input string.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("ow", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// Unigram & Bigram output for "wow", when `max_splits` == 6.
|
||||||
|
// `max_splits` is more than the full input string.
|
||||||
|
hash("^", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("o", 0),
|
||||||
|
hash("w", 0),
|
||||||
|
hash("$", 0),
|
||||||
|
hash("^w", 1),
|
||||||
|
hash("wo", 1),
|
||||||
|
hash("ow", 1),
|
||||||
|
hash("w$", 1),
|
||||||
|
hash("$", 1),
|
||||||
|
}});
|
||||||
|
|
||||||
|
for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) {
|
||||||
|
const int testcase_max_splits = max_splits[test_idx];
|
||||||
|
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits);
|
||||||
|
m.Invoke(testcase_input);
|
||||||
|
SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits);
|
||||||
|
EXPECT_THAT(m.GetOutput<int>(),
|
||||||
|
ElementsAreArray(expected_testcase_outputs[test_idx]));
|
||||||
|
EXPECT_THAT(
|
||||||
|
m.GetOutputShape(),
|
||||||
|
ElementsAreArray(
|
||||||
|
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
|
||||||
|
std::min(
|
||||||
|
// Longest possible tokenization when using the entire
|
||||||
|
// input.
|
||||||
|
static_cast<int>(testcase_input.size()) + /*padding*/ 2,
|
||||||
|
// Longest possible string when the `max_splits` value
|
||||||
|
// is < testcase_input.size() + 2 for padding.
|
||||||
|
testcase_max_splits)}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NGramHashTest, InvalidMaxSplitsValue) {
|
||||||
|
// Check that the op errors out when given an invalid max splits value.
|
||||||
|
const std::vector<int> invalid_max_splits({0, -1, -5, -100});
|
||||||
|
for (const int max_splits : invalid_max_splits) {
|
||||||
|
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
|
||||||
|
/*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits);
|
||||||
|
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) {
|
||||||
|
// Check that the op errors out when ngram lengths and vocab sizes mistmatch.
|
||||||
|
{
|
||||||
|
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300},
|
||||||
|
/*vocab_sizes=*/{1, 2});
|
||||||
|
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
|
||||||
|
/*vocab_sizes=*/{1, 2, 3});
|
||||||
|
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tflite::ops::custom
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "ngram_hash_ops_utils",
|
||||||
|
srcs = [
|
||||||
|
"ngram_hash_ops_utils.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"ngram_hash_ops_utils.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "ngram_hash_ops_utils_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"ngram_hash_ops_utils_test.cc",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":ngram_hash_ops_utils",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "murmur",
|
||||||
|
srcs = ["murmur.cc"],
|
||||||
|
hdrs = ["murmur.h"],
|
||||||
|
deps = [
|
||||||
|
"//mediapipe/framework/port:integral_types",
|
||||||
|
"@com_google_absl//absl/base:core_headers",
|
||||||
|
"@com_google_absl//absl/base:endian",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "murmur_test",
|
||||||
|
srcs = ["murmur_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":murmur",
|
||||||
|
"//mediapipe/framework/port:gtest_main",
|
||||||
|
"//mediapipe/framework/port:integral_types",
|
||||||
|
],
|
||||||
|
)
|
|
@ -0,0 +1,95 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Forked from a library written by Austin Appelby and Jyrki Alakuijala.
|
||||||
|
// Original copyright message below.
|
||||||
|
// Copyright 2009 Google Inc. All Rights Reserved.
|
||||||
|
// Author: aappleby@google.com (Austin Appleby)
|
||||||
|
// jyrki@google.com (Jyrki Alakuijala)
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "absl/base/internal/endian.h"
|
||||||
|
#include "absl/base/optimization.h"
|
||||||
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::absl::little_endian::Load64;
|
||||||
|
|
||||||
|
// Murmur 2.0 multiplication constant.
|
||||||
|
static const uint64_t kMul = 0xc6a4a7935bd1e995ULL;
|
||||||
|
|
||||||
|
// We need to mix some of the bits that get propagated and mixed into the
|
||||||
|
// high bits by multiplication back into the low bits. 17 last bits get
|
||||||
|
// a more efficiently mixed with this.
|
||||||
|
inline uint64_t ShiftMix(uint64_t val) { return val ^ (val >> 47); }
|
||||||
|
|
||||||
|
// Accumulate 8 bytes into 64-bit Murmur hash
|
||||||
|
inline uint64_t MurmurStep(uint64_t hash, uint64_t data) {
|
||||||
|
hash ^= ShiftMix(data * kMul) * kMul;
|
||||||
|
hash *= kMul;
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a uint64 from 1-8 bytes.
|
||||||
|
// 8 * len least significant bits are loaded from the memory with
|
||||||
|
// LittleEndian order. The 64 - 8 * len most significant bits are
|
||||||
|
// set all to 0.
|
||||||
|
// In latex-friendly words, this function returns:
|
||||||
|
// $\sum_{i=0}^{len-1} p[i] 256^{i}$, where p[i] is unsigned.
|
||||||
|
//
|
||||||
|
// This function is equivalent to:
|
||||||
|
// uint64 val = 0;
|
||||||
|
// memcpy(&val, p, len);
|
||||||
|
// return ToHost64(val);
|
||||||
|
//
|
||||||
|
// The caller needs to guarantee that 0 <= len <= 8.
|
||||||
|
uint64_t Load64VariableLength(const void* const p, int len) {
|
||||||
|
ABSL_ASSUME(len >= 0 && len <= 8);
|
||||||
|
uint64_t val = 0;
|
||||||
|
const uint8_t* const src = static_cast<const uint8_t*>(p);
|
||||||
|
for (int i = 0; i < len; ++i) {
|
||||||
|
val |= static_cast<uint64_t>(src[i]) << (8 * i);
|
||||||
|
}
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
unsigned long long MurmurHash64WithSeed(const char* buf, // NOLINT
|
||||||
|
const size_t len, const uint64_t seed) {
|
||||||
|
// Let's remove the bytes not divisible by the sizeof(uint64).
|
||||||
|
// This allows the inner loop to process the data as 64 bit integers.
|
||||||
|
const size_t len_aligned = len & ~0x7;
|
||||||
|
const char* const end = buf + len_aligned;
|
||||||
|
uint64_t hash = seed ^ (len * kMul);
|
||||||
|
for (const char* p = buf; p != end; p += 8) {
|
||||||
|
hash = MurmurStep(hash, Load64(p));
|
||||||
|
}
|
||||||
|
if ((len & 0x7) != 0) {
|
||||||
|
const uint64_t data = Load64VariableLength(end, len & 0x7);
|
||||||
|
hash ^= data;
|
||||||
|
hash *= kMul;
|
||||||
|
}
|
||||||
|
hash = ShiftMix(hash) * kMul;
|
||||||
|
hash = ShiftMix(hash);
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash
|
|
@ -0,0 +1,43 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Forked from a library written by Austin Appelby and Jyrki Alakuijala.
|
||||||
|
// Original copyright message below.
|
||||||
|
// Copyright 2009 Google Inc. All Rights Reserved.
|
||||||
|
// Author: aappleby@google.com (Austin Appelby)
|
||||||
|
// jyrki@google.com (Jyrki Alakuijala)
|
||||||
|
//
|
||||||
|
// MurmurHash is a fast multiplication and shifting based algorithm,
|
||||||
|
// based on Austin Appleby's MurmurHash 2.0 algorithm.
|
||||||
|
|
||||||
|
#ifndef UTIL_HASH_MURMUR_H_
|
||||||
|
#define UTIL_HASH_MURMUR_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdlib.h> // for size_t.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
|
||||||
|
|
||||||
|
// Hash function for a byte array. Has a seed which allows this hash function to
|
||||||
|
// be used in algorithms that need a family of parameterized hash functions.
|
||||||
|
// e.g. Minhash.
|
||||||
|
unsigned long long MurmurHash64WithSeed(const char* buf, size_t len, // NOLINT
|
||||||
|
uint64_t seed);
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash
|
||||||
|
|
||||||
|
#endif // UTIL_HASH_MURMUR_H_
|
|
@ -0,0 +1,66 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Forked from a test library written by Jyrki Alakuijala.
|
||||||
|
// Original copyright message below.
|
||||||
|
// Copyright 2009 Google Inc. All Rights Reserved.
|
||||||
|
// Author: jyrki@google.com (Jyrki Alakuijala)
|
||||||
|
//
|
||||||
|
// Tests for the fast hashing algorithm based on Austin Appleby's
|
||||||
|
// MurmurHash 2.0 algorithm. See http://murmurhash.googlepages.com/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
|
||||||
|
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
#include "mediapipe/framework/port/integral_types.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
|
||||||
|
|
||||||
|
TEST(Murmur, EmptyData64) {
|
||||||
|
EXPECT_EQ(uint64_t{0}, MurmurHash64WithSeed(nullptr, uint64_t{0}, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Murmur, VaryWithDifferentSeeds) {
|
||||||
|
// While in theory different seeds could return the same
|
||||||
|
// hash for the same data this is unlikely.
|
||||||
|
char data1 = 'x';
|
||||||
|
EXPECT_NE(MurmurHash64WithSeed(&data1, 1, 100),
|
||||||
|
MurmurHash64WithSeed(&data1, 1, 101));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hashes don't change.
|
||||||
|
TEST(Murmur, Idempotence) {
|
||||||
|
const char data[] = "deadbeef";
|
||||||
|
const size_t dlen = strlen(data);
|
||||||
|
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
EXPECT_EQ(MurmurHash64WithSeed(data, dlen, i),
|
||||||
|
MurmurHash64WithSeed(data, dlen, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
const char next_data[] = "deadbeef000---";
|
||||||
|
const size_t next_dlen = strlen(next_data);
|
||||||
|
|
||||||
|
for (int i = 0; i < 10; i++) {
|
||||||
|
EXPECT_EQ(MurmurHash64WithSeed(next_data, next_dlen, i),
|
||||||
|
MurmurHash64WithSeed(next_data, next_dlen, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash
|
|
@ -0,0 +1,96 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops {
|
||||||
|
|
||||||
|
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
|
||||||
|
bool exclude_nonalphaspace_tokens) {
|
||||||
|
const std::string kPrefix = "^";
|
||||||
|
const std::string kSuffix = "$";
|
||||||
|
const std::string kReplacementToken = " ";
|
||||||
|
|
||||||
|
TokenizedOutput output;
|
||||||
|
|
||||||
|
size_t token_start = 0;
|
||||||
|
output.str.reserve(len + 2);
|
||||||
|
output.tokens.reserve(len + 2);
|
||||||
|
|
||||||
|
output.str.append(kPrefix);
|
||||||
|
output.tokens.push_back(std::make_pair(token_start, kPrefix.size()));
|
||||||
|
token_start += kPrefix.size();
|
||||||
|
|
||||||
|
Rune token;
|
||||||
|
for (int i = 0; i < len && output.tokens.size() + 1 < max_tokens;) {
|
||||||
|
// Use the standard UTF-8 library to find the next token.
|
||||||
|
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
|
||||||
|
|
||||||
|
// Stop processing, if we can't read any more tokens, or we have reached
|
||||||
|
// maximum allowed tokens, allocating one token for the suffix.
|
||||||
|
if (bytes_read == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If `exclude_nonalphaspace_tokens` is set to true, and the token is not
|
||||||
|
// alphanumeric, replace it with a replacement token.
|
||||||
|
if (exclude_nonalphaspace_tokens && !utf_isalpharune(token)) {
|
||||||
|
output.str.append(kReplacementToken);
|
||||||
|
output.tokens.push_back(
|
||||||
|
std::make_pair(token_start, kReplacementToken.size()));
|
||||||
|
token_start += kReplacementToken.size();
|
||||||
|
i += bytes_read;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append the token in the output string, and note its position and the
|
||||||
|
// number of bytes that token consumed.
|
||||||
|
output.str.append(input_str + i, bytes_read);
|
||||||
|
output.tokens.push_back(std::make_pair(token_start, bytes_read));
|
||||||
|
token_start += bytes_read;
|
||||||
|
i += bytes_read;
|
||||||
|
}
|
||||||
|
output.str.append(kSuffix);
|
||||||
|
output.tokens.push_back(std::make_pair(token_start, kSuffix.size()));
|
||||||
|
token_start += kSuffix.size();
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
void LowercaseUnicodeStr(const char* input_str, int len,
|
||||||
|
std::string* output_str) {
|
||||||
|
for (int i = 0; i < len;) {
|
||||||
|
Rune token;
|
||||||
|
|
||||||
|
// Tokenize the given string, and get the appropriate lowercase token.
|
||||||
|
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
|
||||||
|
token = utf_isalpharune(token) ? utf_tolowerrune(token) : token;
|
||||||
|
|
||||||
|
// Write back the token to the output string.
|
||||||
|
char token_buf[UTFmax];
|
||||||
|
size_t bytes_to_write = utf_runetochar(token_buf, &token);
|
||||||
|
output_str->append(token_buf, bytes_to_write);
|
||||||
|
|
||||||
|
i += bytes_read;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops
|
|
@ -0,0 +1,56 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
|
||||||
|
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops {
|
||||||
|
|
||||||
|
struct TokenizedOutput {
|
||||||
|
// The processed string (with necessary prefix, suffix, skipped tokens, etc.).
|
||||||
|
std::string str;
|
||||||
|
|
||||||
|
// This vector contains pairs, where each pair has two members. The first
|
||||||
|
// denoting the starting index of the token in the `str` string, and the
|
||||||
|
// second denoting the length of that token in bytes.
|
||||||
|
std::vector<std::pair<const size_t, const size_t>> tokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Tokenizes the given input string on Unicode token boundaries, with a maximum
|
||||||
|
// of `max_tokens` tokens.
|
||||||
|
//
|
||||||
|
// If `exclude_nonalphaspace_tokens` is enabled, the tokenization ignores
|
||||||
|
// non-alphanumeric tokens, and replaces them with a replacement token (" ").
|
||||||
|
//
|
||||||
|
// The method returns the output in the `TokenizedOutput` struct, which stores
|
||||||
|
// both, the processed input string, and the indices and sizes of each token
|
||||||
|
// within that string.
|
||||||
|
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
|
||||||
|
bool exclude_nonalphaspace_tokens);
|
||||||
|
|
||||||
|
// Converts the given unicode string (`input_str`) with the specified length
|
||||||
|
// (`len`) to a lowercase string.
|
||||||
|
//
|
||||||
|
// The method populates the lowercased string in `output_str`.
|
||||||
|
void LowercaseUnicodeStr(const char* input_str, int len,
|
||||||
|
std::string* output_str);
|
||||||
|
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops
|
||||||
|
|
||||||
|
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
|
|
@ -0,0 +1,135 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "mediapipe/framework/port/gmock.h"
|
||||||
|
#include "mediapipe/framework/port/gtest.h"
|
||||||
|
|
||||||
|
namespace mediapipe::tasks::text::language_detector::custom_ops {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using ::testing::Values;
|
||||||
|
|
||||||
|
std::string ReconstructStringFromTokens(TokenizedOutput output) {
|
||||||
|
std::string reconstructed_str;
|
||||||
|
for (int i = 0; i < output.tokens.size(); i++) {
|
||||||
|
reconstructed_str.append(
|
||||||
|
output.str.c_str() + output.tokens[i].first,
|
||||||
|
output.str.c_str() + output.tokens[i].first + output.tokens[i].second);
|
||||||
|
}
|
||||||
|
return reconstructed_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TokenizeTestParams {
|
||||||
|
std::string input_str;
|
||||||
|
size_t max_tokens;
|
||||||
|
bool exclude_nonalphaspace_tokens;
|
||||||
|
std::string expected_output_str;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TokenizeParameterizedTest
|
||||||
|
: public ::testing::Test,
|
||||||
|
public testing::WithParamInterface<TokenizeTestParams> {};
|
||||||
|
|
||||||
|
TEST_P(TokenizeParameterizedTest, Tokenize) {
|
||||||
|
// Checks that the Tokenize method returns the expected value.
|
||||||
|
const TokenizeTestParams params = TokenizeParameterizedTest::GetParam();
|
||||||
|
const TokenizedOutput output = Tokenize(
|
||||||
|
/*input_str=*/params.input_str.c_str(),
|
||||||
|
/*len=*/params.input_str.size(),
|
||||||
|
/*max_tokens=*/params.max_tokens,
|
||||||
|
/*exclude_nonalphaspace_tokens=*/params.exclude_nonalphaspace_tokens);
|
||||||
|
|
||||||
|
// The output string should have the necessary prefixes, and the "!" token
|
||||||
|
// should have been replaced with a " ".
|
||||||
|
EXPECT_EQ(output.str, params.expected_output_str);
|
||||||
|
EXPECT_EQ(ReconstructStringFromTokens(output), params.expected_output_str);
|
||||||
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
TokenizeParameterizedTests, TokenizeParameterizedTest,
|
||||||
|
Values(
|
||||||
|
// Test including non-alphanumeric characters.
|
||||||
|
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
|
||||||
|
/*exclude_alphanonspace=*/false,
|
||||||
|
/*expected_output_str=*/"^hi!$"}),
|
||||||
|
// Test not including non-alphanumeric characters.
|
||||||
|
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
|
||||||
|
/*exclude_alphanonspace=*/true,
|
||||||
|
/*expected_output_str=*/"^hi $"}),
|
||||||
|
// Test with a maximum of 3 tokens.
|
||||||
|
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/3,
|
||||||
|
/*exclude_alphanonspace=*/true,
|
||||||
|
/*expected_output_str=*/"^h$"}),
|
||||||
|
// Test with non-latin characters.
|
||||||
|
TokenizeTestParams({/*input_str=*/"ありがと", /*max_tokens=*/100,
|
||||||
|
/*exclude_alphanonspace=*/true,
|
||||||
|
/*expected_output_str=*/"^ありがと$"})));
|
||||||
|
|
||||||
|
TEST(LowercaseUnicodeTest, TestLowercaseUnicode) {
|
||||||
|
{
|
||||||
|
// Check that the method is a no-op when the string is lowercase.
|
||||||
|
std::string input_str = "hello";
|
||||||
|
std::string output_str;
|
||||||
|
LowercaseUnicodeStr(
|
||||||
|
/*input_str=*/input_str.c_str(),
|
||||||
|
/*len=*/input_str.size(),
|
||||||
|
/*output_str=*/&output_str);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_str, "hello");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Check that the method has uppercase characters.
|
||||||
|
std::string input_str = "hElLo";
|
||||||
|
std::string output_str;
|
||||||
|
LowercaseUnicodeStr(
|
||||||
|
/*input_str=*/input_str.c_str(),
|
||||||
|
/*len=*/input_str.size(),
|
||||||
|
/*output_str=*/&output_str);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_str, "hello");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Check that the method works with non-latin scripts.
|
||||||
|
// Cyrillic has the concept of cases, so it should change the input.
|
||||||
|
std::string input_str = "БЙп";
|
||||||
|
std::string output_str;
|
||||||
|
LowercaseUnicodeStr(
|
||||||
|
/*input_str=*/input_str.c_str(),
|
||||||
|
/*len=*/input_str.size(),
|
||||||
|
/*output_str=*/&output_str);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_str, "бйп");
|
||||||
|
}
|
||||||
|
{
|
||||||
|
// Check that the method works with non-latin scripts.
|
||||||
|
// Japanese doesn't have the concept of cases, so it should not change.
|
||||||
|
std::string input_str = "ありがと";
|
||||||
|
std::string output_str;
|
||||||
|
LowercaseUnicodeStr(
|
||||||
|
/*input_str=*/input_str.c_str(),
|
||||||
|
/*len=*/input_str.size(),
|
||||||
|
/*output_str=*/&output_str);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_str, "ありがと");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace mediapipe::tasks::text::language_detector::custom_ops
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
package(default_visibility = ["//mediapipe/tasks:internal"])
|
||||||
|
|
||||||
|
licenses(["notice"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "utf",
|
||||||
|
srcs = [
|
||||||
|
"rune.c",
|
||||||
|
"runetype.c",
|
||||||
|
"runetypebody.h",
|
||||||
|
],
|
||||||
|
hdrs = ["utf.h"],
|
||||||
|
)
|
|
@ -0,0 +1,233 @@
|
||||||
|
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
// Forked from a library written by Rob Pike and Ken Thompson. Original
|
||||||
|
// copyright message below.
|
||||||
|
/*
|
||||||
|
* The authors of this software are Rob Pike and Ken Thompson.
|
||||||
|
* Copyright (c) 2002 by Lucent Technologies.
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose without fee is hereby granted, provided that this entire notice
|
||||||
|
* is included in all copies of any software which is or includes a copy
|
||||||
|
* or modification of this software and in all copies of the supporting
|
||||||
|
* documentation for such software.
|
||||||
|
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
|
||||||
|
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
|
||||||
|
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
|
||||||
|
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
|
||||||
|
*/
|
||||||
|
#include <stdarg.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
|
||||||
|
|
||||||
|
enum
|
||||||
|
{
|
||||||
|
Bit1 = 7,
|
||||||
|
Bitx = 6,
|
||||||
|
Bit2 = 5,
|
||||||
|
Bit3 = 4,
|
||||||
|
Bit4 = 3,
|
||||||
|
Bit5 = 2,
|
||||||
|
|
||||||
|
T1 = ((1<<(Bit1+1))-1) ^ 0xFF, /* 0000 0000 */
|
||||||
|
Tx = ((1<<(Bitx+1))-1) ^ 0xFF, /* 1000 0000 */
|
||||||
|
T2 = ((1<<(Bit2+1))-1) ^ 0xFF, /* 1100 0000 */
|
||||||
|
T3 = ((1<<(Bit3+1))-1) ^ 0xFF, /* 1110 0000 */
|
||||||
|
T4 = ((1<<(Bit4+1))-1) ^ 0xFF, /* 1111 0000 */
|
||||||
|
T5 = ((1<<(Bit5+1))-1) ^ 0xFF, /* 1111 1000 */
|
||||||
|
|
||||||
|
Rune1 = (1<<(Bit1+0*Bitx))-1, /* 0000 0000 0111 1111 */
|
||||||
|
Rune2 = (1<<(Bit2+1*Bitx))-1, /* 0000 0111 1111 1111 */
|
||||||
|
Rune3 = (1<<(Bit3+2*Bitx))-1, /* 1111 1111 1111 1111 */
|
||||||
|
Rune4 = (1<<(Bit4+3*Bitx))-1,
|
||||||
|
/* 0001 1111 1111 1111 1111 1111 */
|
||||||
|
|
||||||
|
Maskx = (1<<Bitx)-1, /* 0011 1111 */
|
||||||
|
Testx = Maskx ^ 0xFF, /* 1100 0000 */
|
||||||
|
|
||||||
|
Bad = Runeerror,
|
||||||
|
};
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Modified by Wei-Hwa Huang, Google Inc., on 2004-09-24
|
||||||
|
* This is a slower but "safe" version of the old chartorune
|
||||||
|
* that works on strings that are not necessarily null-terminated.
|
||||||
|
*
|
||||||
|
* If you know for sure that your string is null-terminated,
|
||||||
|
* chartorune will be a bit faster.
|
||||||
|
*
|
||||||
|
* It is guaranteed not to attempt to access "length"
|
||||||
|
* past the incoming pointer. This is to avoid
|
||||||
|
* possible access violations. If the string appears to be
|
||||||
|
* well-formed but incomplete (i.e., to get the whole Rune
|
||||||
|
* we'd need to read past str+length) then we'll set the Rune
|
||||||
|
* to Bad and return 0.
|
||||||
|
*
|
||||||
|
* Note that if we have decoding problems for other
|
||||||
|
* reasons, we return 1 instead of 0.
|
||||||
|
*/
|
||||||
|
int
|
||||||
|
utf_charntorune(Rune *rune, const char *str, int length)
|
||||||
|
{
|
||||||
|
int c, c1, c2, c3;
|
||||||
|
long l;
|
||||||
|
|
||||||
|
/* When we're not allowed to read anything */
|
||||||
|
if(length <= 0) {
|
||||||
|
goto badlen;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* one character sequence (7-bit value)
|
||||||
|
* 00000-0007F => T1
|
||||||
|
*/
|
||||||
|
c = *(uchar*)str;
|
||||||
|
if(c < Tx) {
|
||||||
|
*rune = c;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we can't read more than one character we must stop
|
||||||
|
if(length <= 1) {
|
||||||
|
goto badlen;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* two character sequence (11-bit value)
|
||||||
|
* 0080-07FF => T2 Tx
|
||||||
|
*/
|
||||||
|
c1 = *(uchar*)(str+1) ^ Tx;
|
||||||
|
if(c1 & Testx)
|
||||||
|
goto bad;
|
||||||
|
if(c < T3) {
|
||||||
|
if(c < T2)
|
||||||
|
goto bad;
|
||||||
|
l = ((c << Bitx) | c1) & Rune2;
|
||||||
|
if(l <= Rune1)
|
||||||
|
goto bad;
|
||||||
|
*rune = l;
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we can't read more than two characters we must stop
|
||||||
|
if(length <= 2) {
|
||||||
|
goto badlen;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* three character sequence (16-bit value)
|
||||||
|
* 0800-FFFF => T3 Tx Tx
|
||||||
|
*/
|
||||||
|
c2 = *(uchar*)(str+2) ^ Tx;
|
||||||
|
if(c2 & Testx)
|
||||||
|
goto bad;
|
||||||
|
if(c < T4) {
|
||||||
|
l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3;
|
||||||
|
if(l <= Rune2)
|
||||||
|
goto bad;
|
||||||
|
*rune = l;
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (length <= 3)
|
||||||
|
goto badlen;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* four character sequence (21-bit value)
|
||||||
|
* 10000-1FFFFF => T4 Tx Tx Tx
|
||||||
|
*/
|
||||||
|
c3 = *(uchar*)(str+3) ^ Tx;
|
||||||
|
if (c3 & Testx)
|
||||||
|
goto bad;
|
||||||
|
if (c < T5) {
|
||||||
|
l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4;
|
||||||
|
if (l <= Rune3)
|
||||||
|
goto bad;
|
||||||
|
if (l > Runemax)
|
||||||
|
goto bad;
|
||||||
|
*rune = l;
|
||||||
|
return 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Support for 5-byte or longer UTF-8 would go here, but
|
||||||
|
// since we don't have that, we'll just fall through to bad.
|
||||||
|
|
||||||
|
/*
|
||||||
|
* bad decoding
|
||||||
|
*/
|
||||||
|
bad:
|
||||||
|
*rune = Bad;
|
||||||
|
return 1;
|
||||||
|
badlen:
|
||||||
|
*rune = Bad;
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
int
|
||||||
|
utf_runetochar(char *str, const Rune *rune)
|
||||||
|
{
|
||||||
|
/* Runes are signed, so convert to unsigned for range check. */
|
||||||
|
unsigned long c;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* one character sequence
|
||||||
|
* 00000-0007F => 00-7F
|
||||||
|
*/
|
||||||
|
c = *rune;
|
||||||
|
if(c <= Rune1) {
|
||||||
|
str[0] = c;
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* two character sequence
|
||||||
|
* 0080-07FF => T2 Tx
|
||||||
|
*/
|
||||||
|
if(c <= Rune2) {
|
||||||
|
str[0] = T2 | (c >> 1*Bitx);
|
||||||
|
str[1] = Tx | (c & Maskx);
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* If the Rune is out of range, convert it to the error rune.
|
||||||
|
* Do this test here because the error rune encodes to three bytes.
|
||||||
|
* Doing it earlier would duplicate work, since an out of range
|
||||||
|
* Rune wouldn't have fit in one or two bytes.
|
||||||
|
*/
|
||||||
|
if (c > Runemax)
|
||||||
|
c = Runeerror;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* three character sequence
|
||||||
|
* 0800-FFFF => T3 Tx Tx
|
||||||
|
*/
|
||||||
|
if (c <= Rune3) {
|
||||||
|
str[0] = T3 | (c >> 2*Bitx);
|
||||||
|
str[1] = Tx | ((c >> 1*Bitx) & Maskx);
|
||||||
|
str[2] = Tx | (c & Maskx);
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* four character sequence (21-bit value)
|
||||||
|
* 10000-1FFFFF => T4 Tx Tx Tx
|
||||||
|
*/
|
||||||
|
str[0] = T4 | (c >> 3*Bitx);
|
||||||
|
str[1] = Tx | ((c >> 2*Bitx) & Maskx);
|
||||||
|
str[2] = Tx | ((c >> 1*Bitx) & Maskx);
|
||||||
|
str[3] = Tx | (c & Maskx);
|
||||||
|
return 4;
|
||||||
|
}
|