Merge branch 'master' into face-stylizer-python

This commit is contained in:
Kinar R 2023-03-23 09:27:52 +05:30 committed by GitHub
commit 3afe4cafc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
305 changed files with 21319 additions and 1142 deletions

17
LICENSE
View File

@ -199,3 +199,20 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
===========================================================================
For files under tasks/cc/text/language_detector/custom_ops/utils/utf/
===========================================================================
/*
* The authors of this software are Rob Pike and Ken Thompson.
* Copyright (c) 2002 by Lucent Technologies.
* Permission to use, copy, modify, and distribute this software for any
* purpose without fee is hereby granted, provided that this entire notice
* is included in all copies of any software which is or includes a copy
* or modification of this software and in all copies of the supporting
* documentation for such software.
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
*/

View File

@ -270,7 +270,7 @@ new_local_repository(
# For local MacOS builds, the path should point to an opencv@3 installation.
# If you edit the path here, you will also need to update the corresponding
# prefix in "opencv_macos.BUILD".
path = "/usr/local",
path = "/usr/local", # e.g. /usr/local/Cellar for HomeBrew
)
new_local_repository(
@ -499,8 +499,8 @@ cc_crosstool(name = "crosstool")
# Node dependencies
http_archive(
name = "build_bazel_rules_nodejs",
sha256 = "5aae76dced38f784b58d9776e4ab12278bc156a9ed2b1d9fcd3e39921dc88fda",
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.7.1/rules_nodejs-5.7.1.tar.gz"],
sha256 = "94070eff79305be05b7699207fbac5d2608054dd53e6109f7d00d923919ff45a",
urls = ["https://github.com/bazelbuild/rules_nodejs/releases/download/5.8.2/rules_nodejs-5.8.2.tar.gz"],
)
load("@build_bazel_rules_nodejs//:repositories.bzl", "build_bazel_rules_nodejs_dependencies")
@ -543,3 +543,43 @@ external_files()
load("@//third_party:wasm_files.bzl", "wasm_files")
wasm_files()
# Halide
new_local_repository(
name = "halide",
build_file = "@//third_party/halide:BUILD.bazel",
path = "third_party/halide"
)
http_archive(
name = "linux_halide",
sha256 = "f62b2914823d6e33d18693f5b74484f274523bf5402ce51988e24393d123b375",
strip_prefix = "Halide-15.0.0-x86-64-linux",
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-linux-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"],
build_file = "@//third_party:halide.BUILD",
)
http_archive(
name = "macos_x86_64_halide",
sha256 = "3d832aed942080ea89aa832462c68fbb906f3055c440b7b6d35093d7c52f6aab",
strip_prefix = "Halide-15.0.0-x86-64-osx",
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-osx-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"],
build_file = "@//third_party:halide.BUILD",
)
http_archive(
name = "macos_arm_64_halide",
sha256 = "b1fad3c9810122b187303d7031d9e35fb43761f345d18cc4492c00ed5877f641",
strip_prefix = "Halide-15.0.0-arm-64-osx",
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-arm-64-osx-d7651f4b32f9dbd764f243134001f7554378d62d.tar.gz"],
build_file = "@//third_party:halide.BUILD",
)
http_archive(
name = "windows_halide",
sha256 = "5acf6fe161dd375856a2b43f4bb0a32815ba958b0585ee312c44e008aa7b0b64",
strip_prefix = "Halide-15.0.0-x86-64-windows",
urls = ["https://github.com/halide/Halide/releases/download/v15.0.0/Halide-15.0.0-x86-64-windows-d7651f4b32f9dbd764f243134001f7554378d62d.zip"],
build_file = "@//third_party:halide.BUILD",
)

View File

@ -113,14 +113,14 @@ Warning: On the other hand, it is not guaranteed that an input packet will
always be available for all streams.
To explain how it works, we need to introduce the definition of a settled
timestamp. We say that a timestamp in a stream is *settled* if it lower than the
timestamp bound. In other words, a timestamp is settled for a stream once the
state of the input at that timestamp is irrevocably known: either there is a
timestamp. We say that a timestamp in a stream is *settled* if it is lower than
the timestamp bound. In other words, a timestamp is settled for a stream once
the state of the input at that timestamp is irrevocably known: either there is a
packet, or there is the certainty that a packet with that timestamp will not
arrive.
Note: For this reason, MediaPipe also allows a stream producer to explicitly
advance the timestamp bound farther that what the last packet implies, i.e. to
advance the timestamp bound farther than what the last packet implies, i.e. to
provide a tighter bound. This can allow the downstream nodes to settle their
inputs sooner.

View File

@ -108,6 +108,8 @@ one over the other.
* [TFLite model](https://storage.googleapis.com/mediapipe-assets/ssdlite_object_detection.tflite)
* [TFLite model quantized for EdgeTPU/Coral](https://github.com/google/mediapipe/tree/master/mediapipe/examples/coral/models/object-detector-quantized_edgetpu.tflite)
* [TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip)
* [Model information](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md)
### [Objectron](https://google.github.io/mediapipe/solutions/objectron)

View File

@ -118,9 +118,9 @@ on how to build MediaPipe examples.
* With a TensorFlow Model
This uses the
[TensorFlow model](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model)
[TensorFlow model](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/archive.zip)
( see also
[model info](https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md)),
[model info](https://storage.googleapis.com/mediapipe-assets/object_detection_saved_model/README.md)),
and the pipeline is implemented in this
[graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/object_detection/object_detection_mobile_cpu.pbtxt).

View File

@ -0,0 +1,62 @@
## TensorFlow/TFLite Object Detection Model
### TensorFlow model
The model is trained on [MSCOCO 2014](http://cocodataset.org) dataset using [TensorFlow Object Detection API](https://github.com/tensorflow/models/tree/master/research/object_detection). It is a MobileNetV2-based SSD model with 0.5 depth multiplier. Detailed training configuration is in the provided `pipeline.config`. The model is a relatively compact model which has `0.171 mAP` to achieve real-time performance on mobile devices. You can compare it with other models from the [TensorFlow detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md).
### TFLite model
The TFLite model is converted from the TensorFlow above. The steps needed to convert the model are similar to [this tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) with minor modifications. Assuming now we have a trained TensorFlow model which includes the checkpoint files and the training configuration file, for example the files provided in this repo:
* `model.ckpt.index`
* `model.ckpt.meta`
* `model.ckpt.data-00000-of-00001`
* `pipeline.config`
Make sure you have installed these [python libraries](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1.md). Then to get the frozen graph, run the `export_tflite_ssd_graph.py` script from the `models/research` directory with this command:
```bash
$ PATH_TO_MODEL=path/to/the/model
$ bazel run object_detection:export_tflite_ssd_graph -- \
--pipeline_config_path ${PATH_TO_MODEL}/pipeline.config \
--trained_checkpoint_prefix ${PATH_TO_MODEL}/model.ckpt \
--output_directory ${PATH_TO_MODEL} \
--add_postprocessing_op=False
```
The exported model contains two files:
* `tflite_graph.pb`
* `tflite_graph.pbtxt`
The difference between this step and the one in [the tutorial](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193) is that we set `add_postprocessing_op` to False. In MediaPipe, we have provided all the calculators needed for post-processing such that we can exclude the custom TFLite ops for post-processing in the original graph, e.g., non-maximum suppression. This enables the flexibility to integrate with different post-processing algorithms and implementations.
Optional: You can install and use the [graph tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms) to inspect the input/output of the exported model:
```bash
$ bazel run graph_transforms:summarize_graph -- \
--in_graph=${PATH_TO_MODEL}/tflite_graph.pb
```
You should be able to see the input image size of the model is 320x320 and the outputs of the model are:
* `raw_outputs/box_encodings`
* `raw_outputs/class_predictions`
The last step is to convert the model to TFLite. You can look at [this guide](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r1/convert/cmdline_examples.md) for more detail. For this example, you just need to run:
```bash
$ tflite_convert -- \
--graph_def_file=${PATH_TO_MODEL}/tflite_graph.pb \
--output_file=${PATH_TO_MODEL}/model.tflite \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--inference_type=FLOAT \
--input_shapes=1,320,320,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays=raw_outputs/box_encodings,raw_outputs/class_predictions
```
Now you have the TFLite model `model.tflite` ready to use with MediaPipe Object Detection graphs. Please see the examples for more detail.

View File

@ -269,6 +269,7 @@ Supported configuration options:
```python
import cv2
import mediapipe as mp
import numpy as np
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_pose = mp.solutions.pose

View File

@ -156,6 +156,7 @@ cc_library(
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector",
"//mediapipe/framework/port:opencv_imgproc",
] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
@ -168,6 +169,25 @@ cc_library(
alwayslink = 1,
)
cc_test(
name = "set_alpha_calculator_test",
srcs = ["set_alpha_calculator_test.cc"],
deps = [
":set_alpha_calculator",
":set_alpha_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "bilateral_filter_calculator",
srcs = ["bilateral_filter_calculator.cc"],
@ -748,6 +768,7 @@ cc_test(
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_keep_aspect_with_rotation_border_zero.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/medium_sub_rect_with_rotation_border_zero_interp_cubic.png",
"//mediapipe/calculators/tensor:testdata/image_to_tensor/noop_except_range.png",
],
tags = ["desktop_only_test"],

View File

@ -29,6 +29,9 @@ class AffineTransformation {
// pixels will be calculated.
enum class BorderMode { kZero, kReplicate };
// Pixel sampling interpolation method.
enum class Interpolation { kLinear, kCubic };
struct Size {
int width;
int height;

View File

@ -77,8 +77,11 @@ class GlTextureWarpAffineRunner
std::unique_ptr<GpuBuffer>> {
public:
GlTextureWarpAffineRunner(std::shared_ptr<GlCalculatorHelper> gl_helper,
GpuOrigin::Mode gpu_origin)
: gl_helper_(gl_helper), gpu_origin_(gpu_origin) {}
GpuOrigin::Mode gpu_origin,
AffineTransformation::Interpolation interpolation)
: gl_helper_(gl_helper),
gpu_origin_(gpu_origin),
interpolation_(interpolation) {}
absl::Status Init() {
return gl_helper_->RunInGlContext([this]() -> absl::Status {
const GLint attr_location[kNumAttributes] = {
@ -103,10 +106,13 @@ class GlTextureWarpAffineRunner
}
)";
// TODO Move bicubic code to common shared place.
constexpr GLchar kFragShader[] = R"(
DEFAULT_PRECISION(highp, float)
in vec2 sample_coordinate;
uniform sampler2D input_texture;
uniform vec2 input_size;
#ifdef GL_ES
#define fragColor gl_FragColor
@ -114,8 +120,60 @@ class GlTextureWarpAffineRunner
out vec4 fragColor;
#endif // defined(GL_ES);
#ifdef CUBIC_INTERPOLATION
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
const vec2 halve = vec2(0.5,0.5);
const vec2 one = vec2(1.0,1.0);
const vec2 two = vec2(2.0,2.0);
const vec2 three = vec2(3.0,3.0);
const vec2 six = vec2(6.0,6.0);
// Calculate the fraction and integer.
tex_coord = tex_coord * tex_size - halve;
vec2 frac = fract(tex_coord);
vec2 index = tex_coord - frac + halve;
// Calculate weights for Catmull-Rom filter.
vec2 w0 = frac * (-halve + frac * (one - halve * frac));
vec2 w1 = one + frac * frac * (-(two+halve) + three/two * frac);
vec2 w2 = frac * (halve + frac * (two - three/two * frac));
vec2 w3 = frac * frac * (-halve + halve * frac);
// Calculate weights to take advantage of bilinear texture lookup.
vec2 w12 = w1 + w2;
vec2 offset12 = w2 / (w1 + w2);
vec2 index_tl = index - one;
vec2 index_br = index + two;
vec2 index_eq = index + offset12;
index_tl /= tex_size;
index_br /= tex_size;
index_eq /= tex_size;
// 9 texture lookup and linear blending.
vec4 color = vec4(0.0);
color += texture2D(tex, vec2(index_tl.x, index_tl.y)) * w0.x * w0.y;
color += texture2D(tex, vec2(index_eq.x, index_tl.y)) * w12.x *w0.y;
color += texture2D(tex, vec2(index_br.x, index_tl.y)) * w3.x * w0.y;
color += texture2D(tex, vec2(index_tl.x, index_eq.y)) * w0.x * w12.y;
color += texture2D(tex, vec2(index_eq.x, index_eq.y)) * w12.x *w12.y;
color += texture2D(tex, vec2(index_br.x, index_eq.y)) * w3.x * w12.y;
color += texture2D(tex, vec2(index_tl.x, index_br.y)) * w0.x * w3.y;
color += texture2D(tex, vec2(index_eq.x, index_br.y)) * w12.x *w3.y;
color += texture2D(tex, vec2(index_br.x, index_br.y)) * w3.x * w3.y;
return color;
}
#else
vec4 sample(sampler2D tex, vec2 tex_coord, vec2 tex_size) {
return texture2D(tex, tex_coord);
}
#endif // defined(CUBIC_INTERPOLATION)
void main() {
vec4 color = texture2D(input_texture, sample_coordinate);
vec4 color = sample(input_texture, sample_coordinate, input_size);
#ifdef CUSTOM_ZERO_BORDER_MODE
float out_of_bounds =
float(sample_coordinate.x < 0.0 || sample_coordinate.x > 1.0 ||
@ -137,14 +195,28 @@ class GlTextureWarpAffineRunner
glUseProgram(program);
glUniform1i(glGetUniformLocation(program, "input_texture"), 1);
GLint matrix_id = glGetUniformLocation(program, "transform_matrix");
return Program{.id = program, .matrix_id = matrix_id};
GLint size_id = glGetUniformLocation(program, "input_size");
return Program{
.id = program, .matrix_id = matrix_id, .size_id = size_id};
};
const std::string vert_src =
absl::StrCat(mediapipe::kMediaPipeVertexShaderPreamble, kVertShader);
const std::string frag_src = absl::StrCat(
mediapipe::kMediaPipeFragmentShaderPreamble, kFragShader);
std::string interpolation_def;
switch (interpolation_) {
case AffineTransformation::Interpolation::kCubic:
interpolation_def = R"(
#define CUBIC_INTERPOLATION
)";
break;
case AffineTransformation::Interpolation::kLinear:
break;
}
const std::string frag_src =
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
interpolation_def, kFragShader);
ASSIGN_OR_RETURN(program_, create_fn(vert_src, frag_src));
@ -152,9 +224,9 @@ class GlTextureWarpAffineRunner
std::string custom_zero_border_mode_def = R"(
#define CUSTOM_ZERO_BORDER_MODE
)";
const std::string frag_custom_zero_src =
absl::StrCat(mediapipe::kMediaPipeFragmentShaderPreamble,
custom_zero_border_mode_def, kFragShader);
const std::string frag_custom_zero_src = absl::StrCat(
mediapipe::kMediaPipeFragmentShaderPreamble,
custom_zero_border_mode_def, interpolation_def, kFragShader);
return create_fn(vert_src, frag_custom_zero_src);
};
#if GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
@ -256,6 +328,7 @@ class GlTextureWarpAffineRunner
}
glUseProgram(program->id);
// uniforms
Eigen::Matrix<float, 4, 4, Eigen::RowMajor> eigen_mat(matrix.data());
if (IsMatrixVerticalFlipNeeded(gpu_origin_)) {
// @matrix describes affine transformation in terms of TOP LEFT origin, so
@ -275,6 +348,10 @@ class GlTextureWarpAffineRunner
eigen_mat.transposeInPlace();
glUniformMatrix4fv(program->matrix_id, 1, GL_FALSE, eigen_mat.data());
if (interpolation_ == AffineTransformation::Interpolation::kCubic) {
glUniform2f(program->size_id, texture.width(), texture.height());
}
// vao
glBindVertexArray(vao_);
@ -327,6 +404,7 @@ class GlTextureWarpAffineRunner
struct Program {
GLuint id;
GLint matrix_id;
GLint size_id;
};
std::shared_ptr<GlCalculatorHelper> gl_helper_;
GpuOrigin::Mode gpu_origin_;
@ -335,6 +413,8 @@ class GlTextureWarpAffineRunner
Program program_;
std::optional<Program> program_custom_zero_;
GLuint framebuffer_ = 0;
AffineTransformation::Interpolation interpolation_ =
AffineTransformation::Interpolation::kLinear;
};
#undef GL_CLAMP_TO_BORDER_MAY_BE_SUPPORTED
@ -344,9 +424,10 @@ class GlTextureWarpAffineRunner
absl::StatusOr<std::unique_ptr<
AffineTransformation::Runner<GpuBuffer, std::unique_ptr<GpuBuffer>>>>
CreateAffineTransformationGlRunner(
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin) {
auto runner =
absl::make_unique<GlTextureWarpAffineRunner>(gl_helper, gpu_origin);
std::shared_ptr<GlCalculatorHelper> gl_helper, GpuOrigin::Mode gpu_origin,
AffineTransformation::Interpolation interpolation) {
auto runner = absl::make_unique<GlTextureWarpAffineRunner>(
gl_helper, gpu_origin, interpolation);
MP_RETURN_IF_ERROR(runner->Init());
return runner;
}

View File

@ -29,7 +29,8 @@ absl::StatusOr<std::unique_ptr<AffineTransformation::Runner<
mediapipe::GpuBuffer, std::unique_ptr<mediapipe::GpuBuffer>>>>
CreateAffineTransformationGlRunner(
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper,
mediapipe::GpuOrigin::Mode gpu_origin);
mediapipe::GpuOrigin::Mode gpu_origin,
AffineTransformation::Interpolation interpolation);
} // namespace mediapipe

View File

@ -39,9 +39,22 @@ cv::BorderTypes GetBorderModeForOpenCv(
}
}
int GetInterpolationForOpenCv(
AffineTransformation::Interpolation interpolation) {
switch (interpolation) {
case AffineTransformation::Interpolation::kLinear:
return cv::INTER_LINEAR;
case AffineTransformation::Interpolation::kCubic:
return cv::INTER_CUBIC;
}
}
class OpenCvRunner
: public AffineTransformation::Runner<ImageFrame, ImageFrame> {
public:
OpenCvRunner(AffineTransformation::Interpolation interpolation)
: interpolation_(GetInterpolationForOpenCv(interpolation)) {}
absl::StatusOr<ImageFrame> Run(
const ImageFrame& input, const std::array<float, 16>& matrix,
const AffineTransformation::Size& size,
@ -142,19 +155,23 @@ class OpenCvRunner
cv::warpAffine(in_mat, out_mat, cv_affine_transform,
cv::Size(out_mat.cols, out_mat.rows),
/*flags=*/cv::INTER_LINEAR | cv::WARP_INVERSE_MAP,
/*flags=*/interpolation_ | cv::WARP_INVERSE_MAP,
GetBorderModeForOpenCv(border_mode));
return out_image;
}
private:
int interpolation_ = cv::INTER_LINEAR;
};
} // namespace
absl::StatusOr<
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
CreateAffineTransformationOpenCvRunner() {
return absl::make_unique<OpenCvRunner>();
CreateAffineTransformationOpenCvRunner(
AffineTransformation::Interpolation interpolation) {
return absl::make_unique<OpenCvRunner>(interpolation);
}
} // namespace mediapipe

View File

@ -25,7 +25,8 @@ namespace mediapipe {
absl::StatusOr<
std::unique_ptr<AffineTransformation::Runner<ImageFrame, ImageFrame>>>
CreateAffineTransformationOpenCvRunner();
CreateAffineTransformationOpenCvRunner(
AffineTransformation::Interpolation interpolation);
} // namespace mediapipe

View File

@ -81,7 +81,8 @@ class ImageCloneCalculator : public Node {
absl::Status Process(CalculatorContext* cc) override {
std::unique_ptr<Image> output;
const auto& input = *kIn(cc);
if (input.UsesGpu()) {
bool input_on_gpu = input.UsesGpu();
if (input_on_gpu) {
#if !MEDIAPIPE_DISABLE_GPU
// Create an output Image that co-owns the underlying texture buffer as
// the input Image.
@ -97,15 +98,15 @@ class ImageCloneCalculator : public Node {
// Image. This ensures a correct life span of the shared pixel data.
output = std::make_unique<Image>(std::make_unique<mediapipe::ImageFrame>(
input.image_format(), input.width(), input.height(), input.step(),
const_cast<uint8*>(input.GetImageFrameSharedPtr()->PixelData()),
[packet_copy_ptr](uint8*) { delete packet_copy_ptr; }));
const_cast<uint8_t*>(input.GetImageFrameSharedPtr()->PixelData()),
[packet_copy_ptr](uint8_t*) { delete packet_copy_ptr; }));
}
if (output_on_gpu_) {
if (output_on_gpu_ && !input_on_gpu) {
#if !MEDIAPIPE_DISABLE_GPU
gpu_helper_.RunInGlContext([&output]() { output->ConvertToGpu(); });
#endif // !MEDIAPIPE_DISABLE_GPU
} else {
} else if (!output_on_gpu_ && input_on_gpu) {
output->ConvertToCpu();
}
kOut(cc).Send(std::move(output));

View File

@ -22,6 +22,7 @@
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h"
@ -53,24 +54,16 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
// range of [0, 1). Only the first channel of Alpha is used. Input & output Mat
// must be uchar.
template <typename AlphaType>
absl::Status MergeRGBA8Image(const cv::Mat input_mat, const cv::Mat& alpha_mat,
cv::Mat& output_mat) {
RET_CHECK_EQ(input_mat.rows, alpha_mat.rows);
RET_CHECK_EQ(input_mat.cols, alpha_mat.cols);
RET_CHECK_EQ(input_mat.rows, output_mat.rows);
RET_CHECK_EQ(input_mat.cols, output_mat.cols);
absl::Status CopyAlphaImage(const cv::Mat& alpha_mat, cv::Mat& output_mat) {
RET_CHECK_EQ(output_mat.rows, alpha_mat.rows);
RET_CHECK_EQ(output_mat.cols, alpha_mat.cols);
for (int i = 0; i < output_mat.rows; ++i) {
const uchar* in_ptr = input_mat.ptr<uchar>(i);
const AlphaType* alpha_ptr = alpha_mat.ptr<AlphaType>(i);
uchar* out_ptr = output_mat.ptr<uchar>(i);
for (int j = 0; j < output_mat.cols; ++j) {
const int out_idx = j * kNumChannelsRGBA;
const int in_idx = j * input_mat.channels();
const int alpha_idx = j * alpha_mat.channels();
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
if constexpr (std::is_same<AlphaType, uchar>::value) {
out_ptr[out_idx + 3] = alpha_ptr[alpha_idx + 0]; // channel 0 of mask
} else {
@ -273,7 +266,7 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
// Setup source image
const auto& input_frame = cc->Inputs().Tag(kInputFrameTag).Get<ImageFrame>();
const cv::Mat input_mat = mediapipe::formats::MatView(&input_frame);
const cv::Mat input_mat = formats::MatView(&input_frame);
if (!(input_mat.type() == CV_8UC3 || input_mat.type() == CV_8UC4)) {
LOG(ERROR) << "Only 3 or 4 channel 8-bit input image supported";
}
@ -281,38 +274,38 @@ absl::Status SetAlphaCalculator::RenderCpu(CalculatorContext* cc) {
// Setup destination image
auto output_frame = absl::make_unique<ImageFrame>(
ImageFormat::SRGBA, input_mat.cols, input_mat.rows);
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
cv::Mat output_mat = formats::MatView(output_frame.get());
const bool has_alpha_mask = cc->Inputs().HasTag(kInputAlphaTag) &&
!cc->Inputs().Tag(kInputAlphaTag).IsEmpty();
const bool use_alpha_mask = alpha_value_ < 0 && has_alpha_mask;
// Setup alpha image and Update image in CPU.
// Copy rgb part of the image in CPU
if (input_mat.channels() == 3) {
cv::cvtColor(input_mat, output_mat, cv::COLOR_RGB2RGBA);
} else {
input_mat.copyTo(output_mat);
}
// Setup alpha image in CPU.
if (use_alpha_mask) {
const auto& alpha_mask = cc->Inputs().Tag(kInputAlphaTag).Get<ImageFrame>();
cv::Mat alpha_mat = mediapipe::formats::MatView(&alpha_mask);
cv::Mat alpha_mat = formats::MatView(&alpha_mask);
const bool alpha_is_float = CV_MAT_DEPTH(alpha_mat.type()) == CV_32F;
RET_CHECK(alpha_is_float || CV_MAT_DEPTH(alpha_mat.type()) == CV_8U);
if (alpha_is_float) {
MP_RETURN_IF_ERROR(
MergeRGBA8Image<float>(input_mat, alpha_mat, output_mat));
MP_RETURN_IF_ERROR(CopyAlphaImage<float>(alpha_mat, output_mat));
} else {
MP_RETURN_IF_ERROR(
MergeRGBA8Image<uchar>(input_mat, alpha_mat, output_mat));
MP_RETURN_IF_ERROR(CopyAlphaImage<uchar>(alpha_mat, output_mat));
}
} else {
const uchar alpha_value = std::min(std::max(0.0f, alpha_value_), 255.0f);
for (int i = 0; i < output_mat.rows; ++i) {
const uchar* in_ptr = input_mat.ptr<uchar>(i);
uchar* out_ptr = output_mat.ptr<uchar>(i);
for (int j = 0; j < output_mat.cols; ++j) {
const int out_idx = j * kNumChannelsRGBA;
const int in_idx = j * input_mat.channels();
out_ptr[out_idx + 0] = in_ptr[in_idx + 0];
out_ptr[out_idx + 1] = in_ptr[in_idx + 1];
out_ptr[out_idx + 2] = in_ptr[in_idx + 2];
out_ptr[out_idx + 3] = alpha_value; // use value from options
}
}

View File

@ -0,0 +1,156 @@
#include <cstdint>
#include "mediapipe/calculators/image/set_alpha_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "testing/base/public/benchmark.h"
namespace mediapipe {
namespace {
constexpr int input_width = 100;
constexpr int input_height = 100;
std::unique_ptr<ImageFrame> GetInputFrame(int width, int height, int channel) {
const int total_size = width * height * channel;
ImageFormat::Format image_format;
if (channel == 4) {
image_format = ImageFormat::SRGBA;
} else if (channel == 3) {
image_format = ImageFormat::SRGB;
} else {
image_format = ImageFormat::GRAY8;
}
auto input_frame = std::make_unique<ImageFrame>(image_format, width, height,
/*alignment_boundary =*/1);
for (int i = 0; i < total_size; ++i) {
input_frame->MutablePixelData()[i] = i % 256;
}
return input_frame;
}
// Test SetAlphaCalculator with RGB IMAGE input.
TEST(SetAlphaCalculatorTest, CpuRgb) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 3);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
EXPECT_EQ(outputs.NumEntries(), 1);
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
// Generate ground truth (expected_mat).
const auto image = GetInputFrame(input_width, input_height, 3);
const auto input_mat = formats::MatView(image.get());
const auto mask = GetInputFrame(input_width, input_height, 1);
const auto mask_mat = formats::MatView(mask.get());
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 3, 3});
cv::Mat output_mat = formats::MatView(&output_image);
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
EXPECT_FLOAT_EQ(max_diff, 0);
} // TEST
// Test SetAlphaCalculator with RGBA IMAGE input.
TEST(SetAlphaCalculatorTest, CpuRgba) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 4);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
EXPECT_EQ(outputs.NumEntries(), 1);
const auto& output_image = outputs.Tag("IMAGE").packets[0].Get<ImageFrame>();
// Generate ground truth (expected_mat).
const auto image = GetInputFrame(input_width, input_height, 4);
const auto input_mat = formats::MatView(image.get());
const auto mask = GetInputFrame(input_width, input_height, 1);
const auto mask_mat = formats::MatView(mask.get());
const std::array<cv::Mat, 2> input_mats = {input_mat, mask_mat};
cv::Mat expected_mat(input_width, input_height, CV_8UC4);
cv::mixChannels(input_mats, {expected_mat}, {0, 0, 1, 1, 2, 2, 4, 3});
cv::Mat output_mat = formats::MatView(&output_image);
double max_diff = cv::norm(expected_mat, output_mat, cv::NORM_INF);
EXPECT_FLOAT_EQ(max_diff, 0);
} // TEST
static void BM_SetAlpha3ChannelImage(benchmark::State& state) {
auto calculator_node = ParseTextProtoOrDie<CalculatorGraphConfig::Node>(
R"pb(
calculator: "SetAlphaCalculator"
input_stream: "IMAGE:input_frames"
input_stream: "ALPHA:masks"
output_stream: "IMAGE:output_frames"
)pb");
CalculatorRunner runner(calculator_node);
// Input frames.
const auto input_frame = GetInputFrame(input_width, input_height, 3);
const auto mask_frame = GetInputFrame(input_width, input_height, 1);
auto input_frame_packet = MakePacket<ImageFrame>(std::move(*input_frame));
auto mask_frame_packet = MakePacket<ImageFrame>(std::move(*mask_frame));
runner.MutableInputs()->Tag("IMAGE").packets.push_back(
input_frame_packet.At(Timestamp(1)));
runner.MutableInputs()->Tag("ALPHA").packets.push_back(
mask_frame_packet.At(Timestamp(1)));
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs();
ASSERT_EQ(1, outputs.NumEntries());
for (const auto _ : state) {
MP_ASSERT_OK(runner.Run());
}
}
BENCHMARK(BM_SetAlpha3ChannelImage);
} // namespace
} // namespace mediapipe

View File

@ -53,6 +53,17 @@ AffineTransformation::BorderMode GetBorderMode(
}
}
AffineTransformation::Interpolation GetInterpolation(
mediapipe::WarpAffineCalculatorOptions::Interpolation interpolation) {
switch (interpolation) {
case mediapipe::WarpAffineCalculatorOptions::INTER_UNSPECIFIED:
case mediapipe::WarpAffineCalculatorOptions::INTER_LINEAR:
return AffineTransformation::Interpolation::kLinear;
case mediapipe::WarpAffineCalculatorOptions::INTER_CUBIC:
return AffineTransformation::Interpolation::kCubic;
}
}
template <typename ImageT>
class WarpAffineRunnerHolder {};
@ -61,16 +72,22 @@ template <>
class WarpAffineRunnerHolder<ImageFrame> {
public:
using RunnerType = AffineTransformation::Runner<ImageFrame, ImageFrame>;
absl::Status Open(CalculatorContext* cc) { return absl::OkStatus(); }
absl::Status Open(CalculatorContext* cc) {
interpolation_ = GetInterpolation(
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
return absl::OkStatus();
}
absl::StatusOr<RunnerType*> GetRunner() {
if (!runner_) {
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationOpenCvRunner());
ASSIGN_OR_RETURN(runner_,
CreateAffineTransformationOpenCvRunner(interpolation_));
}
return runner_.get();
}
private:
std::unique_ptr<RunnerType> runner_;
AffineTransformation::Interpolation interpolation_;
};
#endif // !MEDIAPIPE_DISABLE_OPENCV
@ -85,12 +102,14 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
gpu_origin_ =
cc->Options<mediapipe::WarpAffineCalculatorOptions>().gpu_origin();
gl_helper_ = std::make_shared<mediapipe::GlCalculatorHelper>();
interpolation_ = GetInterpolation(
cc->Options<mediapipe::WarpAffineCalculatorOptions>().interpolation());
return gl_helper_->Open(cc);
}
absl::StatusOr<RunnerType*> GetRunner() {
if (!runner_) {
ASSIGN_OR_RETURN(
runner_, CreateAffineTransformationGlRunner(gl_helper_, gpu_origin_));
ASSIGN_OR_RETURN(runner_, CreateAffineTransformationGlRunner(
gl_helper_, gpu_origin_, interpolation_));
}
return runner_.get();
}
@ -99,6 +118,7 @@ class WarpAffineRunnerHolder<mediapipe::GpuBuffer> {
mediapipe::GpuOrigin::Mode gpu_origin_;
std::shared_ptr<mediapipe::GlCalculatorHelper> gl_helper_;
std::unique_ptr<RunnerType> runner_;
AffineTransformation::Interpolation interpolation_;
};
#endif // !MEDIAPIPE_DISABLE_GPU

View File

@ -31,6 +31,13 @@ message WarpAffineCalculatorOptions {
BORDER_REPLICATE = 2;
}
// Pixel sampling interpolation methods. See @interpolation.
enum Interpolation {
INTER_UNSPECIFIED = 0;
INTER_LINEAR = 1;
INTER_CUBIC = 2;
}
// Pixel extrapolation method.
// When converting image to tensor it may happen that tensor needs to read
// pixels outside image boundaries. Border mode helps to specify how such
@ -43,4 +50,10 @@ message WarpAffineCalculatorOptions {
// to be flipped vertically as tensors are expected to start at top.
// (DEFAULT or unset interpreted as CONVENTIONAL.)
optional GpuOrigin.Mode gpu_origin = 2;
// Sampling method for neighboring pixels.
// INTER_LINEAR (bilinear) linearly interpolates from the nearest 4 neighbors.
// INTER_CUBIC (bicubic) interpolates a small neighborhood with cubic weights.
// INTER_UNSPECIFIED or unset interpreted as INTER_LINEAR.
optional Interpolation interpolation = 3;
}

View File

@ -63,7 +63,8 @@ void RunTest(const std::string& graph_text, const std::string& tag,
const cv::Mat& input, cv::Mat expected_result,
float similarity_threshold, std::array<float, 16> matrix,
int out_width, int out_height,
absl::optional<AffineTransformation::BorderMode> border_mode) {
std::optional<AffineTransformation::BorderMode> border_mode,
std::optional<AffineTransformation::Interpolation> interpolation) {
std::string border_mode_str;
if (border_mode) {
switch (*border_mode) {
@ -75,8 +76,20 @@ void RunTest(const std::string& graph_text, const std::string& tag,
break;
}
}
std::string interpolation_str;
if (interpolation) {
switch (*interpolation) {
case AffineTransformation::Interpolation::kLinear:
interpolation_str = "interpolation: INTER_LINEAR";
break;
case AffineTransformation::Interpolation::kCubic:
interpolation_str = "interpolation: INTER_CUBIC";
break;
}
}
auto graph_config = mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(graph_text, /*$0=*/border_mode_str));
absl::Substitute(graph_text, /*$0=*/border_mode_str,
/*$1=*/interpolation_str));
std::vector<Packet> output_packets;
tool::AddVectorSink("output_image", &graph_config, &output_packets);
@ -132,7 +145,8 @@ struct SimilarityConfig {
void RunTest(cv::Mat input, cv::Mat expected_result,
const SimilarityConfig& similarity, std::array<float, 16> matrix,
int out_width, int out_height,
absl::optional<AffineTransformation::BorderMode> border_mode) {
std::optional<AffineTransformation::BorderMode> border_mode,
std::optional<AffineTransformation::Interpolation> interpolation) {
RunTest(R"(
input_stream: "input_image"
input_stream: "output_size"
@ -146,12 +160,13 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options {
[mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode
$1 # interpolation
}
}
}
)",
"cpu", input, expected_result, similarity.threshold_on_cpu, matrix,
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
RunTest(R"(
input_stream: "input_image"
@ -171,6 +186,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options {
[mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode
$1 # interpolation
}
}
}
@ -181,7 +197,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
}
)",
"cpu_image", input, expected_result, similarity.threshold_on_cpu,
matrix, out_width, out_height, border_mode);
matrix, out_width, out_height, border_mode, interpolation);
RunTest(R"(
input_stream: "input_image"
@ -201,6 +217,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options {
[mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode
$1 # interpolation
gpu_origin: TOP_LEFT
}
}
@ -212,7 +229,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
}
)",
"gpu", input, expected_result, similarity.threshold_on_gpu, matrix,
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
RunTest(R"(
input_stream: "input_image"
@ -237,6 +254,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
options {
[mediapipe.WarpAffineCalculatorOptions.ext] {
$0 # border mode
$1 # interpolation
gpu_origin: TOP_LEFT
}
}
@ -253,7 +271,7 @@ void RunTest(cv::Mat input, cv::Mat expected_result,
}
)",
"gpu_image", input, expected_result, similarity.threshold_on_gpu,
matrix, out_width, out_height, border_mode);
matrix, out_width, out_height, border_mode, interpolation);
}
std::array<float, 16> GetMatrix(cv::Mat input, mediapipe::NormalizedRect roi,
@ -287,10 +305,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspect) {
int out_height = 256;
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = {};
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.82},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
@ -312,10 +331,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectBorderZero) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
@ -337,10 +357,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotation) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.77},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
@ -362,10 +383,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectKeepAspectWithRotationBorderZero) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.75},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
@ -386,10 +408,11 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotation) {
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.81},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
@ -411,10 +434,38 @@ TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZero) {
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.80},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, MediumSubRectWithRotationBorderZeroInterpCubic) {
mediapipe::NormalizedRect roi;
roi.set_x_center(0.65f);
roi.set_y_center(0.4f);
roi.set_width(0.5f);
roi.set_height(0.5f);
roi.set_rotation(M_PI * -45.0f / 180.0f);
auto input = GetRgb(
"/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/input.jpg");
auto expected_output = GetRgb(
"/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"medium_sub_rect_with_rotation_border_zero_interp_cubic.png");
int out_width = 256;
int out_height = 256;
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation =
AffineTransformation::Interpolation::kCubic;
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.78},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRect) {
@ -435,10 +486,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRect) {
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.95},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
@ -459,10 +511,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectBorderZero) {
bool keep_aspect_ratio = false;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.92},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
@ -483,10 +536,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspect) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
@ -508,10 +562,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectBorderZero) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.97},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
@ -532,10 +587,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotation) {
int out_height = 128;
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode = {};
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.91},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
@ -557,10 +613,11 @@ TEST(WarpAffineCalculatorTest, LargeSubRectKeepAspectWithRotationBorderZero) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.88},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, NoOp) {
@ -581,10 +638,11 @@ TEST(WarpAffineCalculatorTest, NoOp) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kReplicate;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
@ -605,10 +663,11 @@ TEST(WarpAffineCalculatorTest, NoOpBorderZero) {
bool keep_aspect_ratio = true;
std::optional<AffineTransformation::BorderMode> border_mode =
AffineTransformation::BorderMode::kZero;
std::optional<AffineTransformation::Interpolation> interpolation = {};
RunTest(input, expected_output,
{.threshold_on_cpu = 0.99, .threshold_on_gpu = 0.99},
GetMatrix(input, roi, keep_aspect_ratio, out_width, out_height),
out_width, out_height, border_mode);
out_width, out_height, border_mode, interpolation);
}
} // namespace

View File

@ -997,17 +997,20 @@ cc_library(
":image_to_tensor_converter_gl_buffer",
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
],
"//mediapipe:apple": [
":image_to_tensor_converter_metal",
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
],
"//conditions:default": [
":image_to_tensor_converter_gl_buffer",
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_service",
],
}),
)
@ -1045,6 +1048,10 @@ cc_test(
":image_to_tensor_calculator",
":image_to_tensor_converter",
":image_to_tensor_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework/deps:file_path",
@ -1061,11 +1068,10 @@ cc_test(
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/util:image_test_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
] + select({
"//mediapipe:apple": [],
"//conditions:default": ["//mediapipe/gpu:gl_context"],
}),
)
cc_library(

View File

@ -45,9 +45,11 @@
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_buffer.h"
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_service.h"
#else
#include "mediapipe/calculators/tensor/image_to_tensor_converter_gl_texture.h"
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_service.h"
#endif // MEDIAPIPE_METAL_ENABLED
#endif // !MEDIAPIPE_DISABLE_GPU
@ -147,7 +149,7 @@ class ImageToTensorCalculator : public Node {
#if MEDIAPIPE_METAL_ENABLED
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#else
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
cc->UseService(kGpuService).Optional();
#endif // MEDIAPIPE_METAL_ENABLED
#endif // MEDIAPIPE_DISABLE_GPU

View File

@ -41,6 +41,10 @@
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/util/image_test_utils.h"
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
#include "mediapipe/gpu/gl_context.h"
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
namespace mediapipe {
namespace {
@ -507,5 +511,79 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeAndUseInputImageDims) {
/*tensor_width=*/std::nullopt, /*tensor_height=*/std::nullopt,
/*keep_aspect=*/false, BorderMode::kZero, roi);
}
TEST(ImageToTensorCalculatorTest, CanBeUsedWithoutGpuServiceSet) {
auto graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_image"
node {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.ImageToTensorCalculatorOptions.ext] {
output_tensor_float_range { min: 0.0f max: 1.0f }
}
}
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
MP_ASSERT_OK(graph.StartRun({}));
auto image_frame =
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
Image image = Image(std::move(image_frame));
Packet packet = MakePacket<Image>(std::move(image));
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
MP_ASSERT_OK(graph.WaitUntilIdle());
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
#if !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
TEST(ImageToTensorCalculatorTest,
FailsGracefullyWhenGpuServiceNeededButNotAvailable) {
auto graph_config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: "input_image"
node {
calculator: "ImageToTensorCalculator"
input_stream: "IMAGE:input_image"
output_stream: "TENSORS:tensor"
options {
[mediapipe.ImageToTensorCalculatorOptions.ext] {
output_tensor_float_range { min: 0.0f max: 1.0f }
}
}
}
)pb");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config));
MP_ASSERT_OK(graph.DisallowServiceDefaultInitialization());
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK_AND_ASSIGN(auto context,
GlContext::Create(nullptr, /*create_thread=*/true));
Packet packet;
context->Run([&packet]() {
auto image_frame =
std::make_shared<ImageFrame>(ImageFormat::SRGBA, 128, 256, 4);
Image image = Image(std::move(image_frame));
// Ensure image is available on GPU to force ImageToTensorCalculator to
// run on GPU.
ASSERT_TRUE(image.ConvertToGpu());
packet = MakePacket<Image>(std::move(image));
});
MP_ASSERT_OK(
graph.AddPacketToInputStream("input_image", packet.At(Timestamp(1))));
EXPECT_THAT(graph.WaitUntilIdle(),
StatusIs(absl::StatusCode::kInternal,
HasSubstr("GPU service not available")));
}
#endif // !MEDIAPIPE_DISABLE_GPU && !MEDIAPIPE_METAL_ENABLED
} // namespace
} // namespace mediapipe

View File

@ -141,7 +141,7 @@ InferenceCalculatorGlAdvancedImpl::GpuInferenceRunner::Process(
}
// Run inference.
{
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE, cc);
MEDIAPIPE_PROFILING(GPU_TASK_INVOKE_ADVANCED, cc);
return tflite_gpu_runner_->Invoke();
}
}));

View File

@ -0,0 +1,41 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The option proto for the TensorsReadbackCalculator.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TensorsReadbackCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional TensorsReadbackCalculatorOptions ext = 514750372;
}
// Expected shapes of the input tensors.
// The calculator uses these shape to build the GPU programs during
// initialization, and check the actual tensor shapes against the expected
// shapes during runtime.
// Batch size of the tensor is set to be 1. `TensorShape` here can be C, WC,
// or HWC.
// For example {dims: 1 dims: 2} represents a tensor with batch_size = 1,
// width = 1, and num_channels = 2.
message TensorShape {
repeated int32 dims = 1 [packed = true];
}
// tensor_shape specifies the shape of each input tensors.
repeated TensorShape tensor_shape = 1;
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

View File

@ -14,6 +14,7 @@
#
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library")
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")
load("@bazel_skylib//lib:selects.bzl", "selects")
licenses(["notice"])
@ -312,15 +313,19 @@ cc_library(
alwayslink = 1,
)
cc_library(
# TODO: Re-evaluate which of these libraries we can avoid making
# cc_library_with_tflite and can be changed back to cc_library.
cc_library_with_tflite(
name = "tflite_model_calculator",
srcs = ["tflite_model_calculator.cc"],
tflite_deps = [
"@org_tensorflow//tensorflow/lite:framework_stable",
],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite/core/shims:framework_stable",
],
alwayslink = 1,
)

View File

@ -66,7 +66,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
} else {
cc->OutputSidePackets()
.Index(0)
.Set<tflite_shims::ops::builtin::BuiltinOpResolver>();
.Set<tflite::ops::builtin::BuiltinOpResolver>();
}
return absl::OkStatus();
}
@ -77,7 +77,7 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
const TfLiteCustomOpResolverCalculatorOptions& options =
cc->Options<TfLiteCustomOpResolverCalculatorOptions>();
std::unique_ptr<tflite_shims::ops::builtin::BuiltinOpResolver> op_resolver;
std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> op_resolver;
if (options.use_gpu()) {
op_resolver = absl::make_unique<mediapipe::OpResolver>();
} else {

View File

@ -21,7 +21,7 @@
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/allocation.h"
#include "tensorflow/lite/core/shims/cc/model.h"
#include "tensorflow/lite/model.h"
namespace mediapipe {
@ -82,7 +82,7 @@ class TfLiteModelCalculator : public CalculatorBase {
}
if (cc->InputSidePackets().HasTag("MODEL_FD")) {
#ifdef ABSL_HAVE_MMAP
#if defined(ABSL_HAVE_MMAP) && !TFLITE_WITH_STABLE_ABI
model_packet = cc->InputSidePackets().Tag("MODEL_FD");
const auto& model_fd =
model_packet.Get<std::tuple<int, size_t, size_t>>();

View File

@ -1270,6 +1270,50 @@ cc_library(
alwayslink = 1,
)
mediapipe_proto_library(
name = "flat_color_image_calculator_proto",
srcs = ["flat_color_image_calculator.proto"],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/util:color_proto",
],
)
cc_library(
name = "flat_color_image_calculator",
srcs = ["flat_color_image_calculator.cc"],
deps = [
":flat_color_image_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/util:color_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
alwayslink = 1,
)
cc_test(
name = "flat_color_image_calculator_test",
srcs = ["flat_color_image_calculator_test.cc"],
deps = [
":flat_color_image_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:gtest",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/util:color_cc_proto",
],
)
cc_library(
name = "from_image_calculator",
srcs = ["from_image_calculator.cc"],

View File

@ -0,0 +1,138 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mediapipe/calculators/util/flat_color_image_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/util/color.pb.h"
namespace mediapipe {
namespace {
using ::mediapipe::api2::Input;
using ::mediapipe::api2::Node;
using ::mediapipe::api2::Output;
} // namespace
// A calculator for generating an image filled with a single color.
//
// Inputs:
// IMAGE (Image, optional)
// If provided, the output will have the same size
// COLOR (Color proto, optional)
// Color to paint the output with. Takes precedence over the equivalent
// calculator options.
//
// Outputs:
// IMAGE (Image)
// Image filled with the requested color.
//
// Example useage:
// node {
// calculator: "FlatColorImageCalculator"
// input_stream: "IMAGE:image"
// input_stream: "COLOR:color"
// output_stream: "IMAGE:blank_image"
// options {
// [mediapipe.FlatColorImageCalculatorOptions.ext] {
// color: {
// r: 255
// g: 255
// b: 255
// }
// }
// }
// }
class FlatColorImageCalculator : public Node {
public:
static constexpr Input<Image>::Optional kInImage{"IMAGE"};
static constexpr Input<Color>::Optional kInColor{"COLOR"};
static constexpr Output<Image> kOutImage{"IMAGE"};
MEDIAPIPE_NODE_CONTRACT(kInImage, kInColor, kOutImage);
static absl::Status UpdateContract(CalculatorContract* cc) {
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
RET_CHECK(kInImage(cc).IsConnected() ^
(options.has_output_height() || options.has_output_width()))
<< "Either set IMAGE input stream, or set through options";
RET_CHECK(kInColor(cc).IsConnected() ^ options.has_color())
<< "Either set COLOR input stream, or set through options";
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
bool use_dimension_from_option_ = false;
bool use_color_from_option_ = false;
};
MEDIAPIPE_REGISTER_NODE(FlatColorImageCalculator);
absl::Status FlatColorImageCalculator::Open(CalculatorContext* cc) {
use_dimension_from_option_ = !kInImage(cc).IsConnected();
use_color_from_option_ = !kInColor(cc).IsConnected();
return absl::OkStatus();
}
absl::Status FlatColorImageCalculator::Process(CalculatorContext* cc) {
const auto& options = cc->Options<FlatColorImageCalculatorOptions>();
int output_height = -1;
int output_width = -1;
if (use_dimension_from_option_) {
output_height = options.output_height();
output_width = options.output_width();
} else if (!kInImage(cc).IsEmpty()) {
const Image& input_image = kInImage(cc).Get();
output_height = input_image.height();
output_width = input_image.width();
} else {
return absl::OkStatus();
}
Color color;
if (use_color_from_option_) {
color = options.color();
} else if (!kInColor(cc).IsEmpty()) {
color = kInColor(cc).Get();
} else {
return absl::OkStatus();
}
auto output_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
output_width, output_height);
cv::Mat output_mat = mediapipe::formats::MatView(output_frame.get());
output_mat.setTo(cv::Scalar(color.r(), color.g(), color.b()));
kOutImage(cc).Send(Image(output_frame));
return absl::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,32 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/util/color.proto";
message FlatColorImageCalculatorOptions {
extend CalculatorOptions {
optional FlatColorImageCalculatorOptions ext = 515548435;
}
// Output dimensions.
optional int32 output_width = 1;
optional int32 output_height = 2;
// The color to fill with in the output image.
optional Color color = 3;
}

View File

@ -0,0 +1,210 @@
// Copyright 2023 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/util/color.pb.h"
namespace mediapipe {
namespace {
using ::testing::HasSubstr;
constexpr char kImageTag[] = "IMAGE";
constexpr char kColorTag[] = "COLOR";
constexpr int kImageWidth = 256;
constexpr int kImageHeight = 256;
TEST(FlatColorImageCalculatorTest, SpecifyColorThroughOptions) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
color: {
r: 100,
g: 200,
b: 255,
}
}
}
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
ASSERT_EQ(outputs.size(), 3);
for (const auto& packet : outputs) {
const auto& image = packet.Get<Image>();
EXPECT_EQ(image.width(), kImageWidth);
EXPECT_EQ(image.height(), kImageHeight);
auto image_frame = image.GetImageFrameSharedPtr();
auto* pixel_data = image_frame->PixelData();
EXPECT_EQ(pixel_data[0], 100);
EXPECT_EQ(pixel_data[1], 200);
EXPECT_EQ(pixel_data[2], 255);
}
}
TEST(FlatColorImageCalculatorTest, SpecifyDimensionThroughOptions) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
output_width: 7,
output_height: 13,
}
}
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
MP_ASSERT_OK(runner.Run());
const auto& outputs = runner.Outputs().Tag(kImageTag).packets;
ASSERT_EQ(outputs.size(), 3);
for (const auto& packet : outputs) {
const auto& image = packet.Get<Image>();
EXPECT_EQ(image.width(), 7);
EXPECT_EQ(image.height(), 13);
auto image_frame = image.GetImageFrameSharedPtr();
const uint8_t* pixel_data = image_frame->PixelData();
EXPECT_EQ(pixel_data[0], 0);
EXPECT_EQ(pixel_data[1], 5);
EXPECT_EQ(pixel_data[2], 0);
}
}
TEST(FlatColorImageCalculatorTest, FailureMissingDimension) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set IMAGE input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureMissingColor) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
output_stream: "IMAGE:out_image"
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set COLOR input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureDuplicateDimension) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
output_width: 7,
output_height: 13,
}
}
)pb");
auto image_frame = std::make_shared<ImageFrame>(ImageFormat::SRGB,
kImageWidth, kImageHeight);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kImageTag).packets.push_back(
MakePacket<Image>(image_frame).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set IMAGE input stream"));
}
TEST(FlatColorImageCalculatorTest, FailureDuplicateColor) {
CalculatorRunner runner(R"pb(
calculator: "FlatColorImageCalculator"
input_stream: "IMAGE:image"
input_stream: "COLOR:color"
output_stream: "IMAGE:out_image"
options {
[mediapipe.FlatColorImageCalculatorOptions.ext] {
color: {
r: 100,
g: 200,
b: 255,
}
}
}
)pb");
Color color;
color.set_r(0);
color.set_g(5);
color.set_b(0);
for (int ts = 0; ts < 3; ++ts) {
runner.MutableInputs()->Tag(kColorTag).packets.push_back(
MakePacket<Color>(color).At(Timestamp(ts)));
}
ASSERT_THAT(runner.Run().message(),
HasSubstr("Either set COLOR input stream"));
}
} // namespace
} // namespace mediapipe

View File

@ -1,5 +1,6 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip
networkTimeout=10000
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

View File

@ -1,7 +1,7 @@
#!/usr/bin/env sh
#!/bin/sh
#
# Copyright 2015 the original author or authors.
# Copyright © 2015-2021 the original authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,67 +17,101 @@
#
##############################################################################
##
## Gradle start up script for UN*X
##
#
# Gradle start up script for POSIX generated by Gradle.
#
# Important for running:
#
# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is
# noncompliant, but you have some other compliant shell such as ksh or
# bash, then to run this script, type that shell name before the whole
# command line, like:
#
# ksh Gradle
#
# Busybox and similar reduced shells will NOT work, because this script
# requires all of these POSIX shell features:
# * functions;
# * expansions «$var», «${var}», «${var:-default}», «${var+SET}»,
# «${var#prefix}», «${var%suffix}», and «$( cmd )»;
# * compound commands having a testable exit status, especially «case»;
# * various built-in commands including «command», «set», and «ulimit».
#
# Important for patching:
#
# (2) This script targets any POSIX shell, so it avoids extensions provided
# by Bash, Ksh, etc; in particular arrays are avoided.
#
# The "traditional" practice of packing multiple parameters into a
# space-separated string is a well documented source of bugs and security
# problems, so this is (mostly) avoided, by progressively accumulating
# options in "$@", and eventually passing that to Java.
#
# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS,
# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly;
# see the in-line comments for details.
#
# There are tweaks for specific operating systems such as AIX, CygWin,
# Darwin, MinGW, and NonStop.
#
# (3) This script is generated from the Groovy template
# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# within the Gradle project.
#
# You can find Gradle at https://github.com/gradle/gradle/.
#
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Resolve links: $0 may be a link
app_path=$0
# Need this for daisy-chained symlinks.
while
APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path
[ -h "$app_path" ]
do
ls=$( ls -ld "$app_path" )
link=${ls#*' -> '}
case $link in #(
/*) app_path=$link ;; #(
*) app_path=$APP_HOME$link ;;
esac
done
# This is normally unused
# shellcheck disable=SC2034
APP_BASE_NAME=${0##*/}
APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
MAX_FD=maximum
warn () {
echo "$*"
}
} >&2
die () {
echo
echo "$*"
echo
exit 1
}
} >&2
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
case "$( uname )" in #(
CYGWIN* ) cygwin=true ;; #(
Darwin* ) darwin=true ;; #(
MSYS* | MINGW* ) msys=true ;; #(
NONSTOP* ) nonstop=true ;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
@ -87,9 +121,9 @@ CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
JAVACMD=$JAVA_HOME/jre/sh/java
else
JAVACMD="$JAVA_HOME/bin/java"
JAVACMD=$JAVA_HOME/bin/java
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
@ -98,7 +132,7 @@ Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
JAVACMD=java
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
@ -106,80 +140,105 @@ location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin or MSYS, switch paths to Windows format before running java
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=`expr $i + 1`
done
case $i in
0) set -- ;;
1) set -- "$args0" ;;
2) set -- "$args0" "$args1" ;;
3) set -- "$args0" "$args1" "$args2" ;;
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #(
max*)
# In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045
MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit"
esac
case $MAX_FD in #(
'' | soft) :;; #(
*)
# In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
# shellcheck disable=SC3045
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=`save "$@"`
# Collect all arguments for the java command, stacking in reverse order:
# * args from the command line
# * the main class name
# * -classpath
# * -D...appname settings
# * --module-path (only if needed)
# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables.
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
# For Cygwin or MSYS, switch paths to Windows format before running java
if "$cygwin" || "$msys" ; then
APP_HOME=$( cygpath --path --mixed "$APP_HOME" )
CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" )
JAVACMD=$( cygpath --unix "$JAVACMD" )
# Now convert the arguments - kludge to limit ourselves to /bin/sh
for arg do
if
case $arg in #(
-*) false ;; # don't mess with options #(
/?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath
[ -e "$t" ] ;; #(
*) false ;;
esac
then
arg=$( cygpath --path --ignore --mixed "$arg" )
fi
# Roll the args list around exactly as many times as the number of
# args, so each arg winds up back in the position where it started, but
# possibly modified.
#
# NB: a `for` loop captures its iteration list before it begins, so
# changing the positional parameters here affects neither the number of
# iterations, nor the values presented in `arg`.
shift # remove old arg
set -- "$@" "$arg" # push replacement arg
done
fi
# Collect all arguments for the java command;
# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
# shell script including quotes and variable substitutions, so put them in
# double quotes to make sure that they get re-expanded; and
# * put everything else in single quotes, so that it's not re-expanded.
set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \
-classpath "$CLASSPATH" \
org.gradle.wrapper.GradleWrapperMain \
"$@"
# Stop when "xargs" is not available.
if ! command -v xargs >/dev/null 2>&1
then
die "xargs is not available"
fi
# Use "xargs" to parse quoted args.
#
# With -n1 it outputs one arg per line, with the quotes and backslashes removed.
#
# In Bash we could simply go:
#
# readarray ARGS < <( xargs -n1 <<<"$var" ) &&
# set -- "${ARGS[@]}" "$@"
#
# but POSIX shell has neither arrays nor command substitution, so instead we
# post-process each arg (as a line of input to sed) to backslash-escape any
# character that might be a shell metacharacter, then use eval to reverse
# that process (while maintaining the separation between arguments), and wrap
# the whole thing up as a single "set" statement.
#
# This will of course break if any of these variables contains a newline or
# an unmatched quote.
#
eval "set -- $(
printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" |
xargs -n1 |
sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' |
tr '\n' ' '
)" '"$@"'
exec "$JAVACMD" "$@"

View File

@ -26,6 +26,7 @@ if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%"=="" set DIRNAME=.
@rem This is normally unused
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@ -40,7 +41,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome
set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if "%ERRORLEVEL%" == "0" goto execute
if %ERRORLEVEL% equ 0 goto execute
echo.
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
@ -75,13 +76,15 @@ set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
:end
@rem End local scope for the variables with windows NT shell
if "%ERRORLEVEL%"=="0" goto mainEnd
if %ERRORLEVEL% equ 0 goto mainEnd
:fail
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
rem the _cmd.exe /c_ return code!
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
exit /b 1
set EXIT_CODE=%ERRORLEVEL%
if %EXIT_CODE% equ 0 set EXIT_CODE=1
if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE%
exit /b %EXIT_CODE%
:mainEnd
if "%OS%"=="Windows_NT" endlocal

View File

@ -138,7 +138,23 @@ void TestWithAspectRatio(const double aspect_ratio,
std::string result_image;
MP_ASSERT_OK(
mediapipe::file::GetContents(result_string_path, &result_image));
EXPECT_EQ(result_image, output_string);
if (result_image != output_string) {
// There may be slight differences due to the way the JPEG was encoded or
// the OpenCV version used to generate the reference files. Compare
// pixel-by-pixel using the Peak Signal-to-Noise Ratio instead.
cv::Mat result_mat =
cv::imdecode(cv::Mat(1, result_image.size(), CV_8UC1,
const_cast<char*>(result_image.data())),
cv::IMREAD_UNCHANGED);
cv::Mat output_mat =
cv::imdecode(cv::Mat(1, output_string.size(), CV_8UC1,
const_cast<char*>(output_string.data())),
cv::IMREAD_UNCHANGED);
ASSERT_EQ(result_mat.rows, output_mat.rows);
ASSERT_EQ(result_mat.cols, output_mat.cols);
ASSERT_EQ(result_mat.type(), output_mat.type());
EXPECT_GT(cv::PSNR(result_mat, output_mat), 45.0);
}
} else {
std::string output_string_path = mediapipe::file::JoinPath(
absl::GetFlag(FLAGS_output_folder),

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.1 KiB

After

Width:  |  Height:  |  Size: 6.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.2 KiB

After

Width:  |  Height:  |  Size: 8.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.6 KiB

After

Width:  |  Height:  |  Size: 7.6 KiB

View File

@ -56,5 +56,6 @@ objc_library(
deps = [
"//mediapipe/examples/ios/common:CommonMediaPipeAppLibrary",
"//mediapipe/graphs/edge_detection:mobile_calculators",
"//third_party/apple_frameworks:Metal",
],
)

View File

@ -631,7 +631,13 @@ absl::Status CalculatorGraph::PrepareServices() {
for (const auto& [key, request] : node->Contract().ServiceRequests()) {
auto packet = service_manager_.GetServicePacket(request.Service());
if (!packet.IsEmpty()) continue;
auto packet_or = request.Service().CreateDefaultObject();
absl::StatusOr<Packet> packet_or;
if (allow_service_default_initialization_) {
packet_or = request.Service().CreateDefaultObject();
} else {
packet_or = absl::FailedPreconditionError(
"Service default initialization is disallowed.");
}
if (packet_or.ok()) {
MP_RETURN_IF_ERROR(service_manager_.SetServicePacket(
request.Service(), std::move(packet_or).value()));

View File

@ -405,6 +405,34 @@ class CalculatorGraph {
return service_manager_.GetServiceObject(service);
}
// Disallows/disables default initialization of MediaPipe graph services.
//
// IMPORTANT: MediaPipe graph serices, essentially a graph-level singletons,
// are designed in the way, so they may provide default initialization. For
// example, this allows to run OpenGL processing wihtin the graph without
// provinging a praticular OpenGL context as it can be provided by
// default-initializable `kGpuService`. (One caveat here, you may still need
// to initialize it manually to share graph context with external context.)
//
// Even if calculators require some service optionally
// (`calculator_contract->UseService(kSomeService).Optional()`), it will be
// still initialized if it allows default initialization.
//
// So far, in rare cases, this may be unwanted and strict control of what
// services are allowed in the graph can be achieved by calling this method,
// following `SetServiceObject` call for services which are allowed in the
// graph.
//
// Recommendation: do not use unless you have to (for example, default
// initialization has side effects)
//
// NOTE: must be called before `StartRun`/`Run`, where services are checked
// and can be default-initialized.
absl::Status DisallowServiceDefaultInitialization() {
allow_service_default_initialization_ = false;
return absl::OkStatus();
}
// Sets a service object, essentially a graph-level singleton, which can be
// accessed by calculators and subgraphs without requiring an explicit
// connection.
@ -644,6 +672,9 @@ class CalculatorGraph {
// Object to manage graph services.
GraphServiceManager service_manager_;
// Indicates whether service default initialization is allowed.
bool allow_service_default_initialization_ = true;
// Vector of errors encountered while running graph. Always use RecordError()
// to add an error to this vector.
std::vector<absl::Status> errors_ ABSL_GUARDED_BY(error_mutex_);

View File

@ -136,6 +136,8 @@ message GraphTrace {
GPU_TASK_INVOKE = 16;
TPU_TASK_INVOKE = 17;
CPU_TASK_INVOKE = 18;
GPU_TASK_INVOKE_ADVANCED = 19;
TPU_TASK_INVOKE_ASYNC = 20;
}
// The timing for one packet set being processed at one caclulator node.

View File

@ -315,11 +315,11 @@ cc_library(
visibility = ["//visibility:private"],
deps = [
"@com_google_absl//absl/flags:flag",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:statusor",
"//mediapipe/framework/port:status",
"//mediapipe/framework/deps:file_path",
"//mediapipe/framework/port:statusor",
] + select({
"//conditions:default": [
"//mediapipe/framework/port:file_helpers",
@ -328,8 +328,8 @@ cc_library(
"//mediapipe/framework/port:file_helpers",
],
"//mediapipe:android": [
"//mediapipe/java/com/google/mediapipe/framework/jni:jni_util",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/java/com/google/mediapipe/framework/jni:jni_util",
],
"//mediapipe:apple": [
"//mediapipe/framework/port:file_helpers",

View File

@ -112,6 +112,10 @@ struct TraceEvent {
static constexpr EventType GPU_TASK_INVOKE = GraphTrace::GPU_TASK_INVOKE;
static constexpr EventType TPU_TASK_INVOKE = GraphTrace::TPU_TASK_INVOKE;
static constexpr EventType CPU_TASK_INVOKE = GraphTrace::CPU_TASK_INVOKE;
static constexpr EventType GPU_TASK_INVOKE_ADVANCED =
GraphTrace::GPU_TASK_INVOKE_ADVANCED;
static constexpr EventType TPU_TASK_INVOKE_ASYNC =
GraphTrace::TPU_TASK_INVOKE_ASYNC;
};
// Packet trace log buffer.

View File

@ -57,7 +57,6 @@ struct hash<mediapipe::TaskId> {
namespace mediapipe {
namespace {
void BasicTraceEventTypes(TraceEventRegistry* result) {
// The initializer arguments below are: event_type, description,
// is_packet_event, is_stream_event, id_event_data.
@ -84,6 +83,15 @@ void BasicTraceEventTypes(TraceEventRegistry* result) {
"A time measured by GPU clock and by CPU clock.", true, false},
{TraceEvent::PACKET_QUEUED, "An input queue size when a packet arrives.",
true, true, false},
{TraceEvent::GPU_TASK_INVOKE, "CPU timing for initiating a GPU task."},
{TraceEvent::TPU_TASK_INVOKE, "CPU timing for initiating a TPU task."},
{TraceEvent::CPU_TASK_INVOKE, "CPU timing for initiating a CPU task."},
{TraceEvent::GPU_TASK_INVOKE_ADVANCED,
"CPU timing for initiating a GPU task bypassing the TFLite "
"interpreter."},
{TraceEvent::TPU_TASK_INVOKE_ASYNC,
"CPU timing for async initiation of a TPU task."},
};
for (const TraceEventType& t : basic_types) {
(*result)[t.event_type()] = t;

View File

@ -77,7 +77,6 @@ mediapipe_proto_library(
name = "calculator_graph_template_proto",
srcs = ["calculator_graph_template.proto"],
def_options_lib = False,
def_py_proto = False,
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_options_proto",

View File

@ -204,7 +204,7 @@ def rewrite_mediapipe_proto(name, rewrite_proto, source_proto, **kwargs):
'import public "' + join_path + '";',
)
rewrite_ref = SubsituteCommand(
r"mediapipe\\.(" + rewrite_message_regex + ")",
r"mediapipe\.(" + rewrite_message_regex + ")",
r"mediapipe.\\1",
)
rewrite_objc = SubsituteCommand(
@ -284,7 +284,7 @@ def mediapipe_proto_library(
def_jspb_proto: define the jspb_proto_library target
def_go_proto: define the go_proto_library target
def_options_lib: define the mediapipe_options_library target
def_rewrite: define a sibbling mediapipe_proto_library with package "mediapipe"
def_rewrite: define a sibling mediapipe_proto_library with package "mediapipe"
"""
mediapipe_proto_library_impl(

View File

@ -183,12 +183,13 @@ absl::Status FindCorrespondingStreams(
// name, calculator, input_stream, output_stream, input_side_packet,
// output_side_packet, options.
// All other fields are only applicable to calculators.
// TODO: Check whether executor is not set in the subgraph node
// after this issues is properly solved.
absl::Status ValidateSubgraphFields(
const CalculatorGraphConfig::Node& subgraph_node) {
if (subgraph_node.source_layer() || subgraph_node.buffer_size_hint() ||
subgraph_node.has_output_stream_handler() ||
subgraph_node.input_stream_info_size() != 0 ||
!subgraph_node.executor().empty()) {
subgraph_node.input_stream_info_size() != 0) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Subgraph \"" << subgraph_node.name()
<< "\" has a field that is only applicable to calculators.";

View File

@ -272,14 +272,6 @@ selects.config_setting_group(
],
)
selects.config_setting_group(
name = "platform_ios_without_gpu",
match_all = [
":disable_gpu",
"//mediapipe:ios",
],
)
selects.config_setting_group(
name = "platform_macos_with_gpu",
match_all = [
@ -296,32 +288,33 @@ cc_library(
deps = [
":gpu_buffer_format",
":gpu_buffer_storage",
":gpu_buffer_storage_image_frame",
"@com_google_absl//absl/functional:bind_front",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:logging",
":gpu_buffer_storage_image_frame",
] + select({
"//conditions:default": [
":gl_texture_view",
":gl_texture_buffer",
":gl_texture_view",
],
":platform_ios_with_gpu": [
":gl_texture_view",
":gpu_buffer_storage_cv_pixel_buffer",
"//mediapipe/objc:util",
"//mediapipe/objc:CFHolder",
],
":platform_macos_with_gpu": [
"//mediapipe/objc:CFHolder",
":gl_texture_view",
":gl_texture_buffer",
],
":platform_ios_without_gpu": [
"//mediapipe/objc:util",
":gl_texture_view",
"//mediapipe/objc:CFHolder",
],
":disable_gpu": [],
}) + select({
"//conditions:default": [],
"//mediapipe:ios": [
"//mediapipe/objc:util",
],
}),
)
@ -331,9 +324,9 @@ cc_library(
hdrs = ["gpu_buffer_format.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework/formats:image_format_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/port:logging",
] + select({
"//conditions:default": [
@ -474,6 +467,7 @@ cc_library(
"//mediapipe/framework/formats:frame_buffer",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:yuv_image",
"//mediapipe/util/frame_buffer:frame_buffer_util",
"//third_party/libyuv",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
@ -619,22 +613,22 @@ cc_library(
}),
visibility = ["//visibility:private"],
deps = [
":gl_context_options_cc_proto",
":graph_support",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:executor",
"//mediapipe/framework:calculator_node",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/deps:no_destructor",
":gl_base",
":gl_context",
":gl_context_options_cc_proto",
":gpu_buffer_multi_pool",
":gpu_shared_data_header",
":graph_support",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_node",
"//mediapipe/framework:executor",
"//mediapipe/framework/deps:no_destructor",
"//mediapipe/framework/port:ret_check",
] + select({
"//conditions:default": [],
"//mediapipe:apple": [
":metal_shared_resources",
":cv_texture_cache_manager",
":metal_shared_resources",
],
}),
)
@ -703,13 +697,13 @@ cc_library(
":gpu_buffer",
":gpu_shared_data_header",
":multi_pool",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_node",
"//mediapipe/framework/port:logging",
"//mediapipe/util:resource_cache",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
] + select({
"//conditions:default": [
":gl_texture_buffer",
@ -725,9 +719,9 @@ cc_library(
"//mediapipe:macos": [
":cv_pixel_buffer_pool_wrapper",
":cv_texture_cache_manager",
":pixel_buffer_pool_util",
":gl_texture_buffer",
":gl_texture_buffer_pool",
":pixel_buffer_pool_util",
],
}),
)
@ -795,31 +789,31 @@ cc_library(
":gpu_buffer",
":gpu_buffer_format",
":gpu_buffer_multi_pool",
":gpu_shared_data_internal",
":gpu_service",
":gpu_shared_data_internal",
":graph_support",
":image_frame_view",
":shader_util",
"//mediapipe/framework:calculator_framework",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_node",
"//mediapipe/framework:calculator_contract",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_node",
"//mediapipe/framework:demangle",
"//mediapipe/framework:legacy_calculator_support",
"//mediapipe/framework:packet",
"//mediapipe/framework:packet_set",
"//mediapipe/framework:packet_type",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/synchronization",
"//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:map_util",
] + select({
"//conditions:default": [
],
@ -918,8 +912,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":gl_calculator_helper",
":gpu_buffer_storage_image_frame",
"//mediapipe/framework/api2:node",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:status",
@ -941,7 +933,7 @@ mediapipe_proto_library(
],
)
proto_library(
mediapipe_proto_library(
name = "gl_scaler_calculator_proto",
srcs = ["gl_scaler_calculator.proto"],
visibility = ["//visibility:public"],
@ -951,17 +943,6 @@ proto_library(
],
)
mediapipe_cc_proto_library(
name = "gl_scaler_calculator_cc_proto",
srcs = ["gl_scaler_calculator.proto"],
cc_deps = [
":scale_mode_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//visibility:public"],
deps = [":gl_scaler_calculator_proto"],
)
cc_library(
name = "gl_scaler_calculator",
srcs = ["gl_scaler_calculator.cc"],

View File

@ -12,63 +12,73 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/gpu/gl_calculator_helper.h"
#ifdef __APPLE__
#include "mediapipe/objc/util.h"
#endif
namespace mediapipe {
namespace api2 {
class ImageFrameToGpuBufferCalculator
: public RegisteredNode<ImageFrameToGpuBufferCalculator> {
// Convert ImageFrame to GpuBuffer.
class ImageFrameToGpuBufferCalculator : public CalculatorBase {
public:
static constexpr Input<ImageFrame> kIn{""};
static constexpr Output<GpuBuffer> kOut{""};
ImageFrameToGpuBufferCalculator() {}
MEDIAPIPE_NODE_INTERFACE(ImageFrameToGpuBufferCalculator, kIn, kOut);
static absl::Status UpdateContract(CalculatorContract* cc);
static absl::Status GetContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
private:
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
GlCalculatorHelper helper_;
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
};
REGISTER_CALCULATOR(ImageFrameToGpuBufferCalculator);
// static
absl::Status ImageFrameToGpuBufferCalculator::UpdateContract(
absl::Status ImageFrameToGpuBufferCalculator::GetContract(
CalculatorContract* cc) {
cc->Inputs().Index(0).Set<ImageFrame>();
cc->Outputs().Index(0).Set<GpuBuffer>();
// Note: we call this method even on platforms where we don't use the helper,
// to ensure the calculator's contract is the same. In particular, the helper
// enables support for the legacy side packet, which several graphs still use.
return GlCalculatorHelper::UpdateContract(cc);
MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
return absl::OkStatus();
}
absl::Status ImageFrameToGpuBufferCalculator::Open(CalculatorContext* cc) {
// Inform the framework that we always output at the same timestamp
// as we receive a packet at.
cc->SetOffset(TimestampDiff(0));
#if !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
MP_RETURN_IF_ERROR(helper_.Open(cc));
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
return absl::OkStatus();
}
absl::Status ImageFrameToGpuBufferCalculator::Process(CalculatorContext* cc) {
auto image_frame = std::const_pointer_cast<ImageFrame>(
mediapipe::SharedPtrWithPacket<ImageFrame>(kIn(cc).packet()));
auto gpu_buffer = api2::MakePacket<GpuBuffer>(
std::make_shared<mediapipe::GpuBufferStorageImageFrame>(
std::move(image_frame)))
.At(cc->InputTimestamp());
// This calculator's behavior has been to do the texture upload eagerly, and
// some graphs may rely on running this on a separate GL context to avoid
// blocking another context with the read operation. So let's request GPU
// access here to ensure that the behavior stays the same.
// TODO: have a better way to do this, or defer until later.
helper_.RunInGlContext(
[&gpu_buffer] { auto view = gpu_buffer->GetReadView<GlTextureView>(0); });
kOut(cc).Send(std::move(gpu_buffer));
#if MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
CFHolder<CVPixelBufferRef> buffer;
MP_RETURN_IF_ERROR(CreateCVPixelBufferForImageFramePacket(
cc->Inputs().Index(0).Value(), &buffer));
cc->Outputs().Index(0).Add(new GpuBuffer(buffer), cc->InputTimestamp());
#else
const auto& input = cc->Inputs().Index(0).Get<ImageFrame>();
helper_.RunInGlContext([this, &input, &cc]() {
auto src = helper_.CreateSourceTexture(input);
auto output = src.GetFrame<GpuBuffer>();
glFlush();
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
src.Release();
});
#endif // MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
return absl::OkStatus();
}
} // namespace api2
} // namespace mediapipe

View File

@ -66,7 +66,8 @@ public class GraphTextureFrame implements TextureFrame {
if (nativeBufferHandle == 0) {
return 0;
}
if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) {
long contextHandle = nativeGetCurrentExternalContextHandle();
if (contextHandle != 0 && activeConsumerContextHandleSet.add(contextHandle)) {
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
// PacketGetter.getTextureFrameDeferredSync().
if (deferredSync) {
@ -116,7 +117,14 @@ public class GraphTextureFrame implements TextureFrame {
GlSyncToken consumerToken = null;
// Note that this remove should be moved to the other overload of release when b/68808951 is
// addressed.
if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) {
final long contextHandle = nativeGetCurrentExternalContextHandle();
if (contextHandle == 0 && !activeConsumerContextHandleSet.isEmpty()) {
logger.atWarning().log(
"GraphTextureFrame is being released on non GL thread while having active consumers,"
+ " which may lead to external / internal GL contexts synchronization issues.");
}
if (contextHandle != 0 && activeConsumerContextHandleSet.remove(contextHandle)) {
consumerToken =
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
}
@ -169,7 +177,9 @@ public class GraphTextureFrame implements TextureFrame {
private native void nativeReleaseBuffer(long nativeHandle);
private native int nativeGetTextureName(long nativeHandle);
private native int nativeGetWidth(long nativeHandle);
private native int nativeGetHeight(long nativeHandle);
private native void nativeGpuWait(long nativeHandle);

View File

@ -30,11 +30,11 @@ cc_library(
"compat_jni.cc",
"graph.cc",
"graph_jni.cc",
"graph_profiler_jni.cc",
"graph_service_jni.cc",
"packet_context_jni.cc",
"packet_creator_jni.cc",
"packet_getter_jni.cc",
"graph_profiler_jni.cc",
] + select({
"//conditions:default": [],
"//mediapipe:android": [
@ -54,11 +54,11 @@ cc_library(
"compat_jni.h",
"graph.h",
"graph_jni.h",
"graph_profiler_jni.h",
"graph_service_jni.h",
"packet_context_jni.h",
"packet_creator_jni.h",
"packet_getter_jni.h",
"graph_profiler_jni.h",
] + select({
"//conditions:default": [],
"//mediapipe:android": [
@ -84,40 +84,40 @@ cc_library(
deps = [
":class_registry",
":jni_util",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework:calculator_framework",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@eigen_archive//:eigen3",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_profile_cc_proto",
"//mediapipe/framework:camera_intrinsics",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/formats:video_stream_header",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/tool:name_util",
"//mediapipe/framework/tool:executor_util",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:threadpool",
"//mediapipe/framework/port:singleton",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:threadpool",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/tool:executor_util",
"//mediapipe/framework/tool:name_util",
] + select({
"//conditions:default": [
"//mediapipe/framework/port:file_helpers",
],
"//mediapipe:android": [
"//mediapipe/util/android/file/base",
"//mediapipe/util/android:asset_manager_util",
"//mediapipe/util/android/file/base",
],
}) + select({
"//conditions:default": [
"//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:gl_surface_sink_calculator",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_shared_data_internal",
@ -153,9 +153,9 @@ cc_library(
srcs = ["class_registry.cc"],
hdrs = ["class_registry.h"],
deps = [
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/container:node_hash_map",
] + select({
"//conditions:default": [
],
@ -172,9 +172,9 @@ cc_library(
":class_registry",
":loose_headers",
":mediapipe_framework_jni",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/container:node_hash_map",
"//mediapipe/framework/port:logging",
] + select({
"//conditions:default": [

View File

@ -357,6 +357,22 @@ def mediapipe_java_proto_srcs(name = ""):
target = "//mediapipe/framework/formats:rect_java_proto_lite",
src_out = "com/google/mediapipe/formats/proto/RectProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:color_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/ColorProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:label_map_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/LabelMapProto.java",
))
proto_src_list.append(mediapipe_java_proto_src_extractor(
target = "//mediapipe/util:render_data_java_proto_lite",
src_out = "com/google/mediapipe/util/proto/RenderDataProto.java",
))
return proto_src_list
def mediapipe_logging_java_proto_srcs(name = ""):

View File

@ -0,0 +1,24 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"])
package(default_visibility = ["//mediapipe/model_maker/python/vision/face_stylizer:__subpackages__"])
filegroup(
name = "models",
srcs = glob([
"**",
]),
)

View File

@ -18,7 +18,7 @@ from __future__ import division
from __future__ import print_function
import functools
from typing import Callable, Optional, Tuple, TypeVar
from typing import Any, Callable, Optional, Tuple, TypeVar
# Dependency imports
import tensorflow as tf
@ -66,12 +66,14 @@ class Dataset(object):
"""
return self._size
def gen_tf_dataset(self,
def gen_tf_dataset(
self,
batch_size: int = 1,
is_training: bool = False,
shuffle: bool = False,
preprocess: Optional[Callable[..., bool]] = None,
drop_remainder: bool = False) -> tf.data.Dataset:
preprocess: Optional[Callable[..., Any]] = None,
drop_remainder: bool = False,
) -> tf.data.Dataset:
"""Generates a batched tf.data.Dataset for training/evaluation.
Args:

View File

@ -48,11 +48,13 @@ class Classifier(custom_model.CustomModel):
self._hparams: hp.BaseHParams = None
self._history: tf.keras.callbacks.History = None
def _train_model(self,
def _train_model(
self,
train_data: classification_ds.ClassificationDataset,
validation_data: classification_ds.ClassificationDataset,
preprocessor: Optional[Callable[..., bool]] = None,
checkpoint_path: Optional[str] = None):
preprocessor: Optional[Callable[..., Any]] = None,
checkpoint_path: Optional[str] = None,
):
"""Trains the classifier model.
Compiles and fits the tf.keras `_model` and records the `_history`.

View File

@ -115,9 +115,11 @@ def get_steps_per_epoch(steps_per_epoch: Optional[int] = None,
def convert_to_tflite(
model: tf.keras.Model,
quantization_config: Optional[quantization.QuantizationConfig] = None,
supported_ops: Tuple[tf.lite.OpsSet,
...] = (tf.lite.OpsSet.TFLITE_BUILTINS,),
preprocess: Optional[Callable[..., bool]] = None) -> bytearray:
supported_ops: Tuple[tf.lite.OpsSet, ...] = (
tf.lite.OpsSet.TFLITE_BUILTINS,
),
preprocess: Optional[Callable[..., Any]] = None,
) -> bytearray:
"""Converts the input Keras model to TFLite format.
Args:

View File

@ -0,0 +1,48 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Placeholder for internal Python strict test compatibility macro.
# Placeholder for internal Python strict library and test compatibility macro.
licenses(["notice"])
package(default_visibility = ["//mediapipe:__subpackages__"])
filegroup(
name = "testdata",
srcs = glob([
"testdata/**",
]),
)
py_library(
name = "dataset",
srcs = ["dataset.py"],
deps = [
"//mediapipe/model_maker/python/core/data:classification_dataset",
"//mediapipe/model_maker/python/vision/core:image_utils",
],
)
py_test(
name = "dataset_test",
srcs = ["dataset_test.py"],
data = [
":testdata",
],
deps = [
":dataset",
"//mediapipe/tasks/python/test:test_utils",
],
)

View File

@ -0,0 +1,14 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaPipe Model Maker Python Public API For Face Stylization."""

View File

@ -0,0 +1,98 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Face stylizer dataset library."""
import logging
import os
import tensorflow as tf
from mediapipe.model_maker.python.core.data import classification_dataset
from mediapipe.model_maker.python.vision.core import image_utils
# TODO: Change to a unlabeled dataset if it makes sense.
class Dataset(classification_dataset.ClassificationDataset):
"""Dataset library for face stylizer fine tuning."""
@classmethod
def from_folder(
cls, dirname: str
) -> classification_dataset.ClassificationDataset:
"""Loads images from the given directory.
The style image dataset directory is expected to contain one subdirectory
whose name represents the label of the style. There can be one or multiple
images of the same style in that subdirectory. Supported input image formats
include 'jpg', 'jpeg', 'png'.
Args:
dirname: Name of the directory containing the image files.
Returns:
Dataset containing images and labels and other related info.
Raises:
ValueError: if the input data directory is empty.
"""
data_root = os.path.abspath(dirname)
# Assumes the image data of the same label are in the same subdirectory,
# gets image path and label names.
all_image_paths = list(tf.io.gfile.glob(data_root + r'/*/*'))
all_image_size = len(all_image_paths)
if all_image_size == 0:
raise ValueError('Invalid input data directory')
if not any(
fname.endswith(('.jpg', '.jpeg', '.png')) for fname in all_image_paths
):
raise ValueError('No images found under given directory')
label_names = sorted(
name
for name in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, name))
)
all_label_size = len(label_names)
index_by_label = dict(
(name, index) for index, name in enumerate(label_names)
)
# Get the style label from the subdirectory name.
all_image_labels = [
index_by_label[os.path.basename(os.path.dirname(path))]
for path in all_image_paths
]
path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
image_ds = path_ds.map(
image_utils.load_image, num_parallel_calls=tf.data.AUTOTUNE
)
# Load label
label_ds = tf.data.Dataset.from_tensor_slices(
tf.cast(all_image_labels, tf.int64)
)
# Create a dataset of (image, label) pairs
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
logging.info(
'Load images dataset with size: %d, num_label: %d, labels: %s.',
all_image_size,
all_label_size,
', '.join(label_names),
)
return Dataset(
dataset=image_label_ds, size=all_image_size, label_names=label_names
)

View File

@ -0,0 +1,48 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow as tf
from mediapipe.model_maker.python.vision.face_stylizer import dataset
from mediapipe.tasks.python.test import test_utils
class DatasetTest(tf.test.TestCase):
def setUp(self):
super().setUp()
# TODO: Replace the stylize image dataset with licensed images.
self._test_data_dirname = 'testdata'
def test_from_folder(self):
input_data_dir = test_utils.get_test_data_path(self._test_data_dirname)
data = dataset.Dataset.from_folder(dirname=input_data_dir)
self.assertEqual(data.num_classes, 2)
self.assertEqual(data.label_names, ['cartoon', 'sketch'])
self.assertLen(data, 2)
def test_from_folder_raise_value_error_for_invalid_path(self):
with self.assertRaisesRegex(ValueError, 'Invalid input data directory'):
dataset.Dataset.from_folder(dirname='invalid')
def test_from_folder_raise_value_error_for_valid_no_data_path(self):
input_data_dir = test_utils.get_test_data_path('face_stylizer')
with self.assertRaisesRegex(
ValueError, 'No images found under given directory'
):
dataset.Dataset.from_folder(dirname=input_data_dir)
if __name__ == '__main__':
tf.test.main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 347 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 336 KiB

View File

@ -15,6 +15,7 @@
import io
import os
import tempfile
import unittest
from unittest import mock as unittest_mock
import zipfile
@ -31,6 +32,7 @@ _TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdat
tf.keras.backend.experimental.enable_tf_random_generator()
@unittest.skip('b/273818271')
class GestureRecognizerTest(tf.test.TestCase):
def _load_data(self):
@ -72,8 +74,10 @@ class GestureRecognizerTest(tf.test.TestCase):
self._test_accuracy(model)
@unittest.skip('b/273818271')
@unittest_mock.patch.object(
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense)
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense
)
def test_gesture_recognizer_model_layer_widths(self, mock_dense):
layer_widths = [64, 32]
mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
@ -143,12 +147,14 @@ class GestureRecognizerTest(tf.test.TestCase):
hyperparameters,
'HParams',
autospec=True,
return_value=gesture_recognizer.HParams(epochs=1))
return_value=gesture_recognizer.HParams(epochs=1),
)
@unittest_mock.patch.object(
model_options,
'GestureRecognizerModelOptions',
autospec=True,
return_value=gesture_recognizer.ModelOptions())
return_value=gesture_recognizer.ModelOptions(),
)
def test_create_hparams_and_model_options_if_none_in_gesture_recognizer_options(
self, mock_hparams, mock_model_options):
options = gesture_recognizer.GestureRecognizerOptions()

View File

@ -28,7 +28,7 @@ class ModelSpec(object):
uri: str,
input_image_shape: Optional[List[int]] = None,
name: str = ''):
"""Initializes a new instance of the `ImageModelSpec` class.
"""Initializes a new instance of the image classifier `ModelSpec` class.
Args:
uri: str, URI to the pretrained model.

View File

@ -5,4 +5,4 @@ opencv-python
tensorflow>=2.10
tensorflow-datasets
tensorflow-hub
tf-models-official>=2.10.1
tf-models-official>=2.11.4

View File

@ -37,6 +37,7 @@ constexpr char kDetectionTag[] = "DETECTION";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kLabelsTag[] = "LABELS";
constexpr char kLabelsCsvTag[] = "LABELS_CSV";
constexpr char kLabelMapTag[] = "LABEL_MAP";
using mediapipe::RE2;
using Detections = std::vector<Detection>;
@ -151,6 +152,11 @@ absl::Status FilterDetectionCalculator::GetContract(CalculatorContract* cc) {
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
cc->InputSidePackets().Tag(kLabelsCsvTag).Set<std::string>();
}
if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
cc->InputSidePackets()
.Tag(kLabelMapTag)
.Set<std::unique_ptr<std::map<int, std::string>>>();
}
return absl::OkStatus();
}
@ -158,7 +164,8 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
options_ = cc->Options<FilterDetectionCalculatorOptions>();
limit_labels_ = cc->InputSidePackets().HasTag(kLabelsTag) ||
cc->InputSidePackets().HasTag(kLabelsCsvTag);
cc->InputSidePackets().HasTag(kLabelsCsvTag) ||
cc->InputSidePackets().HasTag(kLabelMapTag);
if (limit_labels_) {
Strings allowlist_labels;
if (cc->InputSidePackets().HasTag(kLabelsCsvTag)) {
@ -168,8 +175,16 @@ absl::Status FilterDetectionCalculator::Open(CalculatorContext* cc) {
for (auto& e : allowlist_labels) {
absl::StripAsciiWhitespace(&e);
}
} else {
} else if (cc->InputSidePackets().HasTag(kLabelsTag)) {
allowlist_labels = cc->InputSidePackets().Tag(kLabelsTag).Get<Strings>();
} else if (cc->InputSidePackets().HasTag(kLabelMapTag)) {
auto label_map = cc->InputSidePackets()
.Tag(kLabelMapTag)
.Get<std::unique_ptr<std::map<int, std::string>>>()
.get();
for (const auto& [_, v] : *label_map) {
allowlist_labels.push_back(v);
}
}
allowed_labels_.insert(allowlist_labels.begin(), allowlist_labels.end());
}

View File

@ -67,5 +67,68 @@ TEST(FilterDetectionCalculatorTest, DetectionFilterTest) {
));
}
TEST(FilterDetectionCalculatorTest, DetectionFilterLabelMapTest) {
auto runner = std::make_unique<CalculatorRunner>(
ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "FilterDetectionCalculator"
input_stream: "DETECTION:input"
input_side_packet: "LABEL_MAP:input_map"
output_stream: "DETECTION:output"
options {
[mediapipe.FilterDetectionCalculatorOptions.ext]: { min_score: 0.6 }
}
)pb"));
runner->MutableInputs()->Tag("DETECTION").packets = {
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
label: "a"
label: "b"
label: "c"
label: "d"
score: 1
score: 0.8
score: 0.3
score: 0.9
)pb"))
.At(Timestamp(20)),
MakePacket<Detection>(ParseTextProtoOrDie<Detection>(R"pb(
label: "a"
label: "b"
label: "c"
label: "e"
score: 0.6
score: 0.4
score: 0.2
score: 0.7
)pb"))
.At(Timestamp(40)),
};
auto label_map = std::make_unique<std::map<int, std::string>>();
(*label_map)[0] = "a";
(*label_map)[1] = "b";
(*label_map)[2] = "c";
runner->MutableSidePackets()->Tag("LABEL_MAP") =
AdoptAsUniquePtr(label_map.release());
// Run graph.
MP_ASSERT_OK(runner->Run());
// Check output.
EXPECT_THAT(
runner->Outputs().Tag("DETECTION").packets,
ElementsAre(PacketContainsTimestampAndPayload<Detection>(
Eq(Timestamp(20)),
EqualsProto(R"pb(
label: "a" label: "b" score: 1 score: 0.8
)pb")), // Packet 1 at timestamp 20.
PacketContainsTimestampAndPayload<Detection>(
Eq(Timestamp(40)),
EqualsProto(R"pb(
label: "a" score: 0.6
)pb")) // Packet 2 at timestamp 40.
));
}
} // namespace
} // namespace mediapipe

View File

@ -57,6 +57,7 @@ pybind_extension(
"//mediapipe/framework/formats:landmark_registration",
"//mediapipe/framework/formats:rect_registration",
"//mediapipe/modules/objectron/calculators:annotation_registration",
"//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_registration",
],
)
@ -95,6 +96,8 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_segmenter:image_segmenter_graph",
"//mediapipe/tasks/cc/vision/object_detector:object_detector_graph",
"//mediapipe/tasks/cc/vision/face_stylizer:face_stylizer_graph",
"//mediapipe/tasks/cc/vision/face_detector:face_detector_graph",
"//mediapipe/tasks/cc/vision/face_landmarker:face_landmarker_graph",
] + select({
# TODO: Build text_classifier_graph and text_embedder_graph on Windows.
"//mediapipe:windows": [],

View File

@ -30,7 +30,7 @@ constexpr absl::string_view kMediaPipeTasksPayload = "MediaPipeTasksStatus";
//
// At runtime, such codes are meant to be attached (where applicable) to a
// `absl::Status` in a key-value manner with `kMediaPipeTasksPayload` as key and
// stringifed error code as value (aka payload). This logic is encapsulated in
// stringified error code as value (aka payload). This logic is encapsulated in
// the `CreateStatusWithPayload` helper below for convenience.
//
// The returned status includes:

View File

@ -51,12 +51,11 @@ ModelAssetBundleResources::Create(
auto model_bundle_resources = absl::WrapUnique(
new ModelAssetBundleResources(tag, std::move(model_asset_bundle_file)));
MP_RETURN_IF_ERROR(
model_bundle_resources->ExtractModelFilesFromExternalFileProto());
model_bundle_resources->ExtractFilesFromExternalFileProto());
return model_bundle_resources;
}
absl::Status
ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
absl::Status ModelAssetBundleResources::ExtractFilesFromExternalFileProto() {
if (model_asset_bundle_file_->has_file_name()) {
// If the model asset bundle file name is a relative path, searches the file
// in a platform-specific location and returns the absolute path on success.
@ -72,34 +71,32 @@ ModelAssetBundleResources::ExtractModelFilesFromExternalFileProto() {
model_asset_bundle_file_handler_->GetFileContent().data();
size_t buffer_size =
model_asset_bundle_file_handler_->GetFileContent().size();
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size,
&model_files_);
return metadata::ExtractFilesfromZipFile(buffer_data, buffer_size, &files_);
}
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetModelFile(
absl::StatusOr<absl::string_view> ModelAssetBundleResources::GetFile(
const std::string& filename) const {
auto it = model_files_.find(filename);
if (it == model_files_.end()) {
auto model_files = ListModelFiles();
std::string all_model_files =
absl::StrJoin(model_files.begin(), model_files.end(), ", ");
auto it = files_.find(filename);
if (it == files_.end()) {
auto files = ListFiles();
std::string all_files = absl::StrJoin(files.begin(), files.end(), ", ");
return CreateStatusWithPayload(
StatusCode::kNotFound,
absl::StrFormat("No model file with name: %s. All model files in the "
"model asset bundle are: %s.",
filename, all_model_files),
absl::StrFormat("No file with name: %s. All files in the model asset "
"bundle are: %s.",
filename, all_files),
MediaPipeTasksStatus::kFileNotFoundError);
}
return it->second;
}
std::vector<std::string> ModelAssetBundleResources::ListModelFiles() const {
std::vector<std::string> model_names;
for (const auto& [model_name, _] : model_files_) {
model_names.push_back(model_name);
std::vector<std::string> ModelAssetBundleResources::ListFiles() const {
std::vector<std::string> file_names;
for (const auto& [file_name, _] : files_) {
file_names.push_back(file_name);
}
return model_names;
return file_names;
}
} // namespace core

View File

@ -28,8 +28,8 @@ namespace core {
// The mediapipe task model asset bundle resources class.
// A ModelAssetBundleResources object, created from an external file proto,
// contains model asset bundle related resources and the method to extract the
// tflite models or model asset bundles for the mediapipe sub-tasks. As the
// resources are owned by the ModelAssetBundleResources object
// tflite models, resource files or model asset bundles for the mediapipe
// sub-tasks. As the resources are owned by the ModelAssetBundleResources object
// callers must keep ModelAssetBundleResources alive while using any of the
// resources.
class ModelAssetBundleResources {
@ -50,14 +50,13 @@ class ModelAssetBundleResources {
// Returns the model asset bundle resources tag.
std::string GetTag() const { return tag_; }
// Gets the contents of the model file (either tflite model file or model
// bundle file) with the provided name. An error is returned if there is no
// such model file.
absl::StatusOr<absl::string_view> GetModelFile(
const std::string& filename) const;
// Gets the contents of the model file (either tflite model file, resource
// file or model bundle file) with the provided name. An error is returned if
// there is no such model file.
absl::StatusOr<absl::string_view> GetFile(const std::string& filename) const;
// Lists all the model file names in the model asset model.
std::vector<std::string> ListModelFiles() const;
// Lists all the file names in the model asset model.
std::vector<std::string> ListFiles() const;
private:
// Constructor.
@ -65,9 +64,9 @@ class ModelAssetBundleResources {
const std::string& tag,
std::unique_ptr<proto::ExternalFile> model_asset_bundle_file);
// Extracts the model files (either tflite model file or model bundle file)
// from the external file proto.
absl::Status ExtractModelFilesFromExternalFileProto();
// Extracts the model files (either tflite model file, resource file or model
// bundle file) from the external file proto.
absl::Status ExtractFilesFromExternalFileProto();
// The model asset bundle resources tag.
const std::string tag_;
@ -78,11 +77,11 @@ class ModelAssetBundleResources {
// The ExternalFileHandler for the model asset bundle.
std::unique_ptr<ExternalFileHandler> model_asset_bundle_file_handler_;
// The model files bundled in model asset bundle, as a map with the filename
// The files bundled in model asset bundle, as a map with the filename
// (corresponding to a basename, e.g. "hand_detector.tflite") as key and
// a pointer to the file contents as value. Each model file can be either
// a TFLite model file or a model bundle file for sub-task.
absl::flat_hash_map<std::string, absl::string_view> model_files_;
// a pointer to the file contents as value. Each file can be either a TFLite
// model file, resource file or a model bundle file for sub-task.
absl::flat_hash_map<std::string, absl::string_view> files_;
};
} // namespace core

View File

@ -66,10 +66,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromBinaryContent) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status());
}
@ -81,10 +80,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status());
}
@ -98,10 +96,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFileDescriptor) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status());
}
#endif // _WIN32
@ -115,10 +112,9 @@ TEST(ModelAssetBundleResourcesTest, CreateFromFilePointer) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task")
.status());
model_bundle_resources->GetFile("dummy_hand_landmarker.task").status());
MP_EXPECT_OK(
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite")
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite")
.status());
}
@ -147,7 +143,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_hand_landmarker.task");
model_bundle_resources->GetFile("dummy_hand_landmarker.task");
MP_EXPECT_OK(status_or_model_bundle_file.status());
// Creates sub-task model asset bundle resources.
@ -159,10 +155,10 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidModelBundleFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(hand_landmaker_model_file)));
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_detector.tflite")
->GetFile("dummy_hand_detector.tflite")
.status());
MP_EXPECT_OK(hand_landmaker_model_bundle_resources
->GetModelFile("dummy_hand_landmarker.tflite")
->GetFile("dummy_hand_landmarker.tflite")
.status());
}
@ -175,7 +171,7 @@ TEST(ModelAssetBundleResourcesTest, ExtractValidTFLiteModelFile) {
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto status_or_model_bundle_file =
model_bundle_resources->GetModelFile("dummy_gesture_recognizer.tflite");
model_bundle_resources->GetFile("dummy_gesture_recognizer.tflite");
MP_EXPECT_OK(status_or_model_bundle_file.status());
// Verify tflite model works.
@ -200,11 +196,11 @@ TEST(ModelAssetBundleResourcesTest, ExtractInvalidModelFile) {
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto status = model_bundle_resources->GetModelFile("not_found.task").status();
auto status = model_bundle_resources->GetFile("not_found.task").status();
EXPECT_EQ(status.code(), absl::StatusCode::kNotFound);
EXPECT_THAT(status.message(),
testing::HasSubstr(
"No model file with name: not_found.task. All model files in "
EXPECT_THAT(
status.message(),
testing::HasSubstr("No file with name: not_found.task. All files in "
"the model asset bundle are: "));
EXPECT_THAT(status.GetPayload(kMediaPipeTasksPayload),
testing::Optional(absl::Cord(
@ -219,7 +215,7 @@ TEST(ModelAssetBundleResourcesTest, ListModelFiles) {
auto model_bundle_resources,
ModelAssetBundleResources::Create(kTestModelBundleResourcesTag,
std::move(model_file)));
auto model_files = model_bundle_resources->ListModelFiles();
auto model_files = model_bundle_resources->ListFiles();
std::vector<std::string> expected_model_files = {
"dummy_gesture_recognizer.tflite", "dummy_hand_landmarker.task"};
std::sort(model_files.begin(), model_files.end());

View File

@ -77,9 +77,11 @@ class ModelResourcesCalculator : public api2::Node {
if (options.has_model_file()) {
RET_CHECK(options.model_file().has_file_content() ||
options.model_file().has_file_descriptor_meta() ||
options.model_file().has_file_name())
options.model_file().has_file_name() ||
options.model_file().has_file_pointer_meta())
<< "'model_file' must specify at least one of "
"'file_content', 'file_descriptor_meta', or 'file_name'";
"'file_content', 'file_descriptor_meta', 'file_name', or "
"'file_pointer_meta'";
}
return absl::OkStatus();
}

View File

@ -179,9 +179,9 @@ TEST_F(ModelResourcesCalculatorTest, EmptyExternalFileProto) {
auto status = graph.Initialize(graph_config);
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr(
"'model_file' must specify at least one of "
"'file_content', 'file_descriptor_meta', or 'file_name'"));
testing::HasSubstr("'model_file' must specify at least one of "
"'file_content', 'file_descriptor_meta', "
"'file_name', or 'file_pointer_meta'"));
}
TEST_F(ModelResourcesCalculatorTest, GraphServiceNotAvailable) {

View File

@ -138,7 +138,7 @@ class InferenceSubgraph : public Subgraph {
delegate.mutable_tflite()->CopyFrom(acceleration.tflite());
break;
case Acceleration::DELEGATE_NOT_SET:
// Deafult inference calculator setting.
// Default inference calculator setting.
break;
}
return delegate;

View File

@ -124,10 +124,10 @@ class ModelTaskGraph : public Subgraph {
// Inserts a mediapipe task inference subgraph into the provided
// GraphBuilder. The returned node provides the following interfaces to the
// the rest of the graph:
// - a tensor vector (std::vector<MeidaPipe::Tensor>) input stream with tag
// - a tensor vector (std::vector<mediapipe::Tensor>) input stream with tag
// "TENSORS", representing the input tensors to be consumed by the
// inference engine.
// - a tensor vector (std::vector<MeidaPipe::Tensor>) output stream with tag
// - a tensor vector (std::vector<mediapipe::Tensor>) output stream with tag
// "TENSORS", representing the output tensors generated by the inference
// engine.
// - a MetadataExtractor output side packet with tag "METADATA_EXTRACTOR".

View File

@ -301,7 +301,7 @@ absl::Status TaskRunner::Close() {
}
is_running_ = false;
MP_RETURN_IF_ERROR(
AddPayload(graph_.CloseAllInputStreams(), "Fail to close intput streams",
AddPayload(graph_.CloseAllInputStreams(), "Fail to close input streams",
MediaPipeTasksStatus::kRunnerFailsToCloseError));
MP_RETURN_IF_ERROR(AddPayload(
graph_.WaitUntilDone(), "Fail to shutdown the MediaPipe graph.",

View File

@ -65,7 +65,7 @@ class TaskRunner {
// Creates the task runner with a CalculatorGraphConfig proto.
// If a tflite op resolver object is provided, the task runner will take
// it as the global op resolver for all models running within this task.
// The op resolver's owernship will be transferred into the pipeleine runner.
// The op resolver's ownership will be transferred into the pipeleine runner.
// When a user-defined PacketsCallback is provided, clients must use the
// asynchronous method, Send(), to provide the input packets. If the packets
// callback is absent, clients must use the synchronous method, Process(), to
@ -84,7 +84,7 @@ class TaskRunner {
// frames from a video file and an audio file. The call blocks the current
// thread until a failure status or a successful result is returned.
// If the input packets have no timestamp, an internal timestamp will be
// assigend per invocation. Otherwise, when the timestamp is set in the
// assigned per invocation. Otherwise, when the timestamp is set in the
// input packets, the caller must ensure that the input packet timestamps are
// greater than the timestamps of the previous invocation. This method is
// thread-unsafe and it is the caller's responsibility to synchronize access

View File

@ -64,7 +64,7 @@ class ModelMetadataPopulator {
// Loads associated files into the TFLite FlatBuffer model. The input is a map
// of {filename, file contents}.
//
// Warning: this method removes any previoulsy present associated files.
// Warning: this method removes any previously present associated files.
// Calling this method multiple time removes any associated files from
// previous calls, so this method should usually be called only once.
void LoadAssociatedFiles(

View File

@ -213,7 +213,7 @@ void UpdateMinimumVersionForTable<tflite::Content>(const tflite::Content* table,
Version* min_version) {
if (table == nullptr) return;
// Checks the ContenProperties field.
// Checks the ContentProperties field.
if (table->content_properties_type() == ContentProperties_AudioProperties) {
UpdateMinimumVersion(
GetMemberVersion(SchemaMembers::kContentPropertiesAudioProperties),

View File

@ -31,8 +31,8 @@ PYBIND11_MODULE(_pywrap_metadata_version, m) {
// Using pybind11 type conversions to convert between Python and native
// C++ types. There are other options to provide access to native Python types
// in C++ and vice versa. See the pybind 11 instrcution [1] for more details.
// Type converstions is recommended by pybind11, though the main downside
// in C++ and vice versa. See the pybind 11 instruction [1] for more details.
// Type conversions is recommended by pybind11, though the main downside
// is that a copy of the data must be made on every Python to C++ transition:
// this is needed since the C++ and Python versions of the same type generally
// wont have the same memory layout.

View File

@ -79,7 +79,7 @@ TEST(MetadataVersionTest,
auto metadata = metadata_builder.Finish();
FinishModelMetadataBuffer(builder, metadata);
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -100,7 +100,7 @@ TEST(MetadataVersionTest,
auto metadata = metadata_builder.Finish();
builder.Finish(metadata);
// Gets the mimimum metadata parser version and triggers error.
// Gets the minimum metadata parser version and triggers error.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -121,7 +121,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -147,7 +147,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -172,7 +172,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -203,7 +203,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -234,7 +234,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -265,7 +265,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -294,7 +294,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -323,7 +323,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -348,7 +348,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -373,7 +373,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -404,7 +404,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -431,7 +431,7 @@ TEST(MetadataVersionTest,
std::vector<Offset<TensorMetadata>>{tensor_builder.Finish()});
CreateModelWithMetadata(tensors, builder);
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -453,7 +453,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -476,7 +476,7 @@ TEST(MetadataVersionTest,
metadata_builder.add_associated_files(associated_files);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),
@ -504,7 +504,7 @@ TEST(MetadataVersionTest, GetMinimumMetadataParserVersionForOptions) {
metadata_builder.add_subgraph_metadata(subgraphs);
FinishModelMetadataBuffer(builder, metadata_builder.Finish());
// Gets the mimimum metadata parser version.
// Gets the minimum metadata parser version.
std::string min_version;
EXPECT_EQ(GetMinimumMetadataParserVersion(builder.GetBufferPointer(),
builder.GetSize(), &min_version),

View File

@ -42,3 +42,36 @@ cc_test(
"@org_tensorflow//tensorflow/lite/kernels:test_util",
],
)
cc_library(
name = "ngram_hash",
srcs = ["ngram_hash.cc"],
hdrs = ["ngram_hash.h"],
copts = tflite_copts(),
deps = [
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils:ngram_hash_ops_utils",
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:kernel_util",
],
alwayslink = 1,
)
cc_test(
name = "ngram_hash_test",
srcs = ["ngram_hash_test.cc"],
deps = [
":ngram_hash",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash:murmur",
"@com_google_absl//absl/types:optional",
"@flatbuffers",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite:string_util",
"@org_tensorflow//tensorflow/lite/c:common",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
"@org_tensorflow//tensorflow/lite/kernels:test_util",
],
)

View File

@ -0,0 +1,264 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include <cstdint>
#include <string>
#include <vector>
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/string_util.h"
namespace tflite::ops::custom {
namespace ngram_op {
namespace {
using ::flexbuffers::GetRoot;
using ::flexbuffers::Map;
using ::flexbuffers::TypedVector;
using ::mediapipe::tasks::text::language_detector::custom_ops::
LowercaseUnicodeStr;
using ::mediapipe::tasks::text::language_detector::custom_ops::Tokenize;
using ::mediapipe::tasks::text::language_detector::custom_ops::TokenizedOutput;
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
MurmurHash64WithSeed;
using ::tflite::GetString;
using ::tflite::StringRef;
constexpr int kInputMessage = 0;
constexpr int kOutputLabel = 0;
constexpr int kDefaultMaxSplits = 128;
// This op takes in a string, finds the character ngrams for it and then
// maps each of these ngrams to an index using the specified vocabulary sizes.
// Input(s):
// - input: Input string.
// - seeds: Seed for the random number generator.
// - ngram_lengths: Lengths of each of the ngrams. For example [1, 2, 3] would
// be interpreted as generating unigrams, bigrams, and trigrams.
// - vocab_sizes: Size of the vocabulary for each of the ngram features
// respectively. The op would generate vocab ids to be less than or equal to
// the vocab size. The index 0 implies an invalid ngram.
// - max_splits: Maximum number of tokens in the output. If this is unset, the
// limit is `kDefaultMaxSplits`.
// - lower_case_input: If this is set to true, the input string would be
// lower-cased before any processing.
// Output(s):
// - output: A tensor of size [number of ngrams, number of tokens + 2],
// where 2 tokens are reserved for the padding. If `max_splits` is set, this
// length is <= max_splits, otherwise it is <= `kDefaultMaxSplits`.
// Helper class used for pre-processing the input.
class NGramHashParams {
public:
NGramHashParams(const uint64_t seed, const std::vector<int>& ngram_lengths,
const std::vector<int>& vocab_sizes, int max_splits,
bool lower_case_input)
: seed_(seed),
ngram_lengths_(ngram_lengths),
vocab_sizes_(vocab_sizes),
max_splits_(max_splits),
lower_case_input_(lower_case_input) {}
TfLiteStatus PreprocessInput(const TfLiteTensor* input_t,
TfLiteContext* context) {
if (input_t->bytes == 0) {
context->ReportError(context, "Empty input not supported.");
return kTfLiteError;
}
// Do sanity checks on the input.
if (ngram_lengths_.empty()) {
context->ReportError(context, "`ngram_lengths` must be non-empty.");
return kTfLiteError;
}
if (vocab_sizes_.empty()) {
context->ReportError(context, "`vocab_sizes` must be non-empty.");
return kTfLiteError;
}
if (ngram_lengths_.size() != vocab_sizes_.size()) {
context->ReportError(
context,
"Sizes of `ngram_lengths` and `vocab_sizes` must be the same.");
return kTfLiteError;
}
if (max_splits_ <= 0) {
context->ReportError(context, "`max_splits` must be > 0.");
return kTfLiteError;
}
// Obtain and tokenize the input.
StringRef inputref = GetString(input_t, /*string_index=*/0);
if (lower_case_input_) {
std::string lower_cased_str;
LowercaseUnicodeStr(inputref.str, inputref.len, &lower_cased_str);
tokenized_output_ =
Tokenize(lower_cased_str.c_str(), inputref.len, max_splits_,
/*exclude_nonalphaspace_tokens=*/true);
} else {
tokenized_output_ = Tokenize(inputref.str, inputref.len, max_splits_,
/*exclude_nonalphaspace_tokens=*/true);
}
return kTfLiteOk;
}
uint64_t GetSeed() const { return seed_; }
int GetNumTokens() const { return tokenized_output_.tokens.size(); }
int GetNumNGrams() const { return ngram_lengths_.size(); }
std::vector<int> GetNGramLengths() const { return ngram_lengths_; }
std::vector<int> GetVocabSizes() const { return vocab_sizes_; }
const TokenizedOutput& GetTokenizedOutput() const {
return tokenized_output_;
}
TokenizedOutput tokenized_output_;
private:
const uint64_t seed_;
std::vector<int> ngram_lengths_;
std::vector<int> vocab_sizes_;
const int max_splits_;
const bool lower_case_input_;
};
// Convert the TypedVector into a regular std::vector.
std::vector<int> GetIntVector(TypedVector typed_vec) {
std::vector<int> vec(typed_vec.size());
for (int j = 0; j < typed_vec.size(); j++) {
vec[j] = typed_vec[j].AsInt32();
}
return vec;
}
void GetNGramHashIndices(NGramHashParams* params, int32_t* data) {
const int max_unicode_length = params->GetNumTokens();
const auto ngram_lengths = params->GetNGramLengths();
const auto vocab_sizes = params->GetVocabSizes();
const auto& tokenized_output = params->GetTokenizedOutput();
const auto seed = params->GetSeed();
// Compute for each ngram.
for (int ngram = 0; ngram < ngram_lengths.size(); ngram++) {
const int vocab_size = vocab_sizes[ngram];
const int ngram_length = ngram_lengths[ngram];
// Compute for each token within the input.
for (int start = 0; start < tokenized_output.tokens.size(); start++) {
// Compute the number of bytes for the ngram starting at the given
// token.
int num_bytes = 0;
for (int i = start;
i < tokenized_output.tokens.size() && i < (start + ngram_length);
i++) {
num_bytes += tokenized_output.tokens[i].second;
}
// Compute the hash for the ngram starting at the token.
const auto str_hash = MurmurHash64WithSeed(
tokenized_output.str.c_str() + tokenized_output.tokens[start].first,
num_bytes, seed);
// Map the hash to an index in the vocab.
data[ngram * max_unicode_length + start] = (str_hash % vocab_size) + 1;
}
}
}
} // namespace
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
const Map& m = GetRoot(buffer_t, length).AsMap();
const uint64_t seed = m["seed"].AsUInt64();
const std::vector<int> ngram_lengths =
GetIntVector(m["ngram_lengths"].AsTypedVector());
const std::vector<int> vocab_sizes =
GetIntVector(m["vocab_sizes"].AsTypedVector());
const int max_splits =
m["max_splits"].IsNull() ? kDefaultMaxSplits : m["max_splits"].AsInt32();
const bool lowercase_input =
m["lowercase_input"].IsNull() ? true : m["lowercase_input"].AsBool();
return new NGramHashParams(seed, ngram_lengths, vocab_sizes, max_splits,
lowercase_input);
}
void Free(TfLiteContext* context, void* buffer) {
delete reinterpret_cast<NGramHashParams*>(buffer);
}
TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
SetTensorToDynamic(output);
return kTfLiteOk;
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
NGramHashParams* params = reinterpret_cast<NGramHashParams*>(node->user_data);
TF_LITE_ENSURE_OK(
context,
params->PreprocessInput(GetInput(context, node, kInputMessage), context));
TfLiteTensor* output = GetOutput(context, node, kOutputLabel);
TF_LITE_ENSURE(context, output != nullptr);
if (IsDynamicTensor(output)) {
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = 1;
output_size->data[1] = params->GetNumNGrams();
output_size->data[2] = params->GetNumTokens();
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
} else {
context->ReportError(context, "Output must by dynamic.");
return kTfLiteError;
}
if (output->type == kTfLiteInt32) {
GetNGramHashIndices(params, output->data.i32);
} else {
context->ReportError(context, "Output type must be Int32.");
return kTfLiteError;
}
return kTfLiteOk;
}
} // namespace ngram_op
TfLiteRegistration* Register_NGRAM_HASH() {
static TfLiteRegistration r = {ngram_op::Init, ngram_op::Free,
ngram_op::Resize, ngram_op::Eval};
return &r;
}
} // namespace tflite::ops::custom

View File

@ -0,0 +1,27 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_
#include "tensorflow/lite/kernels/register.h"
namespace tflite::ops::custom {
TfLiteRegistration* Register_NGRAM_HASH();
} // namespace tflite::ops::custom
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_NGRAM_HASH_H_

View File

@ -0,0 +1,313 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/ngram_hash.h"
#include <cstdint>
#include <optional>
#include <string>
#include <vector>
#include "absl/types/optional.h"
#include "flatbuffers/flexbuffers.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/string_util.h"
namespace tflite::ops::custom {
namespace {
using ::flexbuffers::Builder;
using ::mediapipe::tasks::text::language_detector::custom_ops::hash::
MurmurHash64WithSeed;
using ::testing::ElementsAreArray;
using ::testing::Message;
// Helper class for testing the op.
class NGramHashModel : public SingleOpModel {
public:
explicit NGramHashModel(const uint64_t seed,
const std::vector<int>& ngram_lengths,
const std::vector<int>& vocab_sizes,
const absl::optional<int> max_splits = std::nullopt) {
// Setup the model inputs.
Builder fbb;
size_t start = fbb.StartMap();
fbb.UInt("seed", seed);
{
size_t start = fbb.StartVector("ngram_lengths");
for (const int& ngram_len : ngram_lengths) {
fbb.Int(ngram_len);
}
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
}
{
size_t start = fbb.StartVector("vocab_sizes");
for (const int& vocab_size : vocab_sizes) {
fbb.Int(vocab_size);
}
fbb.EndVector(start, /*typed=*/true, /*fixed=*/false);
}
if (max_splits) {
fbb.Int("max_splits", *max_splits);
}
fbb.EndMap(start);
fbb.Finish();
output_ = AddOutput({TensorType_INT32, {}});
SetCustomOp("NGramHash", fbb.GetBuffer(), Register_NGRAM_HASH);
BuildInterpreter({GetShape(input_)});
}
void SetupInputTensor(const std::string& input) {
PopulateStringTensor(input_, {input});
CHECK(interpreter_->AllocateTensors() == kTfLiteOk)
<< "Cannot allocate tensors";
}
void Invoke(const std::string& input) {
SetupInputTensor(input);
CHECK_EQ(SingleOpModel::Invoke(), kTfLiteOk);
}
TfLiteStatus InvokeUnchecked(const std::string& input) {
SetupInputTensor(input);
return SingleOpModel::Invoke();
}
template <typename T>
std::vector<T> GetOutput() {
return ExtractVector<T>(output_);
}
std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
private:
int input_ = AddInput(TensorType_STRING);
int output_;
};
TEST(NGramHashTest, ReturnsExpectedValueWhenInputIsSane) {
// Checks that the op returns the expected value when the input is sane.
// Also checks that when `max_splits` is not specified, the entire string is
// tokenized.
const uint64_t kSeed = 123;
const std::vector<int> vocab_sizes({100, 200});
std::vector<int> ngram_lengths({1, 2});
const std::vector<std::string> testcase_inputs({
"hi",
"wow",
"!",
"HI",
});
// A hash function that maps the given string to an index in the embedding
// table denoted by `vocab_idx`.
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
const auto hash_value =
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
};
const std::vector<std::vector<int>> expected_testcase_outputs(
{{
// Unigram & Bigram output for "hi".
hash("^", 0),
hash("h", 0),
hash("i", 0),
hash("$", 0),
hash("^h", 1),
hash("hi", 1),
hash("i$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow".
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "!" (which will get replaced by " ").
hash("^", 0),
hash(" ", 0),
hash("$", 0),
hash("^ ", 1),
hash(" $", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "HI" (which will get lower-cased).
hash("^", 0),
hash("h", 0),
hash("i", 0),
hash("$", 0),
hash("^h", 1),
hash("hi", 1),
hash("i$", 1),
hash("$", 1),
}});
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes);
for (int test_idx = 0; test_idx < testcase_inputs.size(); test_idx++) {
const string& testcase_input = testcase_inputs[test_idx];
m.Invoke(testcase_input);
SCOPED_TRACE(Message() << "Where the testcases' input is: "
<< testcase_input);
EXPECT_THAT(m.GetOutput<int>(),
ElementsAreArray(expected_testcase_outputs[test_idx]));
EXPECT_THAT(m.GetOutputShape(),
ElementsAreArray(
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
static_cast<int>(testcase_input.size()) + /*padding*/ 2}));
}
}
TEST(NGramHashTest, ReturnsExpectedValueWhenMaxSplitsIsSpecified) {
// Checks that the op returns the expected value when the input is correct
// when `max_splits` is specified.
const uint64_t kSeed = 123;
const std::vector<int> vocab_sizes({100, 200});
std::vector<int> ngram_lengths({1, 2});
const std::string testcase_input = "wow";
const std::vector<int> max_splits({2, 3, 4, 5, 6});
// A hash function that maps the given string to an index in the embedding
// table denoted by `vocab_idx`.
auto hash = [vocab_sizes](std::string str, const int vocab_idx) {
const auto hash_value =
MurmurHash64WithSeed(str.c_str(), str.size(), kSeed);
return static_cast<int>((hash_value % vocab_sizes[vocab_idx]) + 1);
};
const std::vector<std::vector<int>> expected_testcase_outputs(
{{
// Unigram & Bigram output for "wow", when `max_splits` == 2.
// We cannot include any of the actual tokens, since `max_splits`
// only allows enough space for the delimiters.
hash("^", 0),
hash("$", 0),
hash("^$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 3.
// We can start to include some tokens from the input string.
hash("^", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 4.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("o$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 5.
// We can include the full input string.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
},
{
// Unigram & Bigram output for "wow", when `max_splits` == 6.
// `max_splits` is more than the full input string.
hash("^", 0),
hash("w", 0),
hash("o", 0),
hash("w", 0),
hash("$", 0),
hash("^w", 1),
hash("wo", 1),
hash("ow", 1),
hash("w$", 1),
hash("$", 1),
}});
for (int test_idx = 0; test_idx < max_splits.size(); test_idx++) {
const int testcase_max_splits = max_splits[test_idx];
NGramHashModel m(kSeed, ngram_lengths, vocab_sizes, testcase_max_splits);
m.Invoke(testcase_input);
SCOPED_TRACE(Message() << "Where `max_splits` is: " << testcase_max_splits);
EXPECT_THAT(m.GetOutput<int>(),
ElementsAreArray(expected_testcase_outputs[test_idx]));
EXPECT_THAT(
m.GetOutputShape(),
ElementsAreArray(
{/*batch_size=*/1, static_cast<int>(ngram_lengths.size()),
std::min(
// Longest possible tokenization when using the entire
// input.
static_cast<int>(testcase_input.size()) + /*padding*/ 2,
// Longest possible string when the `max_splits` value
// is < testcase_input.size() + 2 for padding.
testcase_max_splits)}));
}
}
TEST(NGramHashTest, InvalidMaxSplitsValue) {
// Check that the op errors out when given an invalid max splits value.
const std::vector<int> invalid_max_splits({0, -1, -5, -100});
for (const int max_splits : invalid_max_splits) {
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
/*vocab_sizes=*/{1, 2}, /*max_splits=*/max_splits);
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
}
TEST(NGramHashTest, MismatchNgramLengthsAndVocabSizes) {
// Check that the op errors out when ngram lengths and vocab sizes mistmatch.
{
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200, 300},
/*vocab_sizes=*/{1, 2});
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
{
NGramHashModel m(/*seed=*/123, /*ngram_lengths=*/{100, 200},
/*vocab_sizes=*/{1, 2, 3});
EXPECT_EQ(m.InvokeUnchecked("hi"), kTfLiteError);
}
}
} // namespace
} // namespace tflite::ops::custom

View File

@ -0,0 +1,42 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "ngram_hash_ops_utils",
srcs = [
"ngram_hash_ops_utils.cc",
],
hdrs = [
"ngram_hash_ops_utils.h",
],
deps = [
"//mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf",
],
)
cc_test(
name = "ngram_hash_ops_utils_test",
size = "small",
srcs = [
"ngram_hash_ops_utils_test.cc",
],
deps = [
":ngram_hash_ops_utils",
"//mediapipe/framework/port:gtest_main",
],
)

View File

@ -0,0 +1,38 @@
# Copyright 2023 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "murmur",
srcs = ["murmur.cc"],
hdrs = ["murmur.h"],
deps = [
"//mediapipe/framework/port:integral_types",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:endian",
],
)
cc_test(
name = "murmur_test",
srcs = ["murmur_test.cc"],
deps = [
":murmur",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:integral_types",
],
)

View File

@ -0,0 +1,95 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Austin Appelby and Jyrki Alakuijala.
// Original copyright message below.
// Copyright 2009 Google Inc. All Rights Reserved.
// Author: aappleby@google.com (Austin Appleby)
// jyrki@google.com (Jyrki Alakuijala)
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include <cstdint>
#include "absl/base/internal/endian.h"
#include "absl/base/optimization.h"
#include "mediapipe/framework/port/integral_types.h"
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
namespace {
using ::absl::little_endian::Load64;
// Murmur 2.0 multiplication constant.
static const uint64_t kMul = 0xc6a4a7935bd1e995ULL;
// We need to mix some of the bits that get propagated and mixed into the
// high bits by multiplication back into the low bits. 17 last bits get
// a more efficiently mixed with this.
inline uint64_t ShiftMix(uint64_t val) { return val ^ (val >> 47); }
// Accumulate 8 bytes into 64-bit Murmur hash
inline uint64_t MurmurStep(uint64_t hash, uint64_t data) {
hash ^= ShiftMix(data * kMul) * kMul;
hash *= kMul;
return hash;
}
// Build a uint64 from 1-8 bytes.
// 8 * len least significant bits are loaded from the memory with
// LittleEndian order. The 64 - 8 * len most significant bits are
// set all to 0.
// In latex-friendly words, this function returns:
// $\sum_{i=0}^{len-1} p[i] 256^{i}$, where p[i] is unsigned.
//
// This function is equivalent to:
// uint64 val = 0;
// memcpy(&val, p, len);
// return ToHost64(val);
//
// The caller needs to guarantee that 0 <= len <= 8.
uint64_t Load64VariableLength(const void* const p, int len) {
ABSL_ASSUME(len >= 0 && len <= 8);
uint64_t val = 0;
const uint8_t* const src = static_cast<const uint8_t*>(p);
for (int i = 0; i < len; ++i) {
val |= static_cast<uint64_t>(src[i]) << (8 * i);
}
return val;
}
} // namespace
unsigned long long MurmurHash64WithSeed(const char* buf, // NOLINT
const size_t len, const uint64_t seed) {
// Let's remove the bytes not divisible by the sizeof(uint64).
// This allows the inner loop to process the data as 64 bit integers.
const size_t len_aligned = len & ~0x7;
const char* const end = buf + len_aligned;
uint64_t hash = seed ^ (len * kMul);
for (const char* p = buf; p != end; p += 8) {
hash = MurmurStep(hash, Load64(p));
}
if ((len & 0x7) != 0) {
const uint64_t data = Load64VariableLength(end, len & 0x7);
hash ^= data;
hash *= kMul;
}
hash = ShiftMix(hash) * kMul;
hash = ShiftMix(hash);
return hash;
}
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash

View File

@ -0,0 +1,43 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Austin Appelby and Jyrki Alakuijala.
// Original copyright message below.
// Copyright 2009 Google Inc. All Rights Reserved.
// Author: aappleby@google.com (Austin Appelby)
// jyrki@google.com (Jyrki Alakuijala)
//
// MurmurHash is a fast multiplication and shifting based algorithm,
// based on Austin Appleby's MurmurHash 2.0 algorithm.
#ifndef UTIL_HASH_MURMUR_H_
#define UTIL_HASH_MURMUR_H_
#include <stddef.h>
#include <stdlib.h> // for size_t.
#include <cstdint>
#include "mediapipe/framework/port/integral_types.h"
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
// Hash function for a byte array. Has a seed which allows this hash function to
// be used in algorithms that need a family of parameterized hash functions.
// e.g. Minhash.
unsigned long long MurmurHash64WithSeed(const char* buf, size_t len, // NOLINT
uint64_t seed);
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash
#endif // UTIL_HASH_MURMUR_H_

View File

@ -0,0 +1,66 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a test library written by Jyrki Alakuijala.
// Original copyright message below.
// Copyright 2009 Google Inc. All Rights Reserved.
// Author: jyrki@google.com (Jyrki Alakuijala)
//
// Tests for the fast hashing algorithm based on Austin Appleby's
// MurmurHash 2.0 algorithm. See http://murmurhash.googlepages.com/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/hash/murmur.h"
#include <string.h>
#include <cstdint>
#include <string>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/integral_types.h"
namespace mediapipe::tasks::text::language_detector::custom_ops::hash {
TEST(Murmur, EmptyData64) {
EXPECT_EQ(uint64_t{0}, MurmurHash64WithSeed(nullptr, uint64_t{0}, 0));
}
TEST(Murmur, VaryWithDifferentSeeds) {
// While in theory different seeds could return the same
// hash for the same data this is unlikely.
char data1 = 'x';
EXPECT_NE(MurmurHash64WithSeed(&data1, 1, 100),
MurmurHash64WithSeed(&data1, 1, 101));
}
// Hashes don't change.
TEST(Murmur, Idempotence) {
const char data[] = "deadbeef";
const size_t dlen = strlen(data);
for (int i = 0; i < 10; i++) {
EXPECT_EQ(MurmurHash64WithSeed(data, dlen, i),
MurmurHash64WithSeed(data, dlen, i));
}
const char next_data[] = "deadbeef000---";
const size_t next_dlen = strlen(next_data);
for (int i = 0; i < 10; i++) {
EXPECT_EQ(MurmurHash64WithSeed(next_data, next_dlen, i),
MurmurHash64WithSeed(next_data, next_dlen, i));
}
}
} // namespace mediapipe::tasks::text::language_detector::custom_ops::hash

View File

@ -0,0 +1,96 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include <string>
#include <utility>
#include <vector>
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
namespace mediapipe::tasks::text::language_detector::custom_ops {
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
bool exclude_nonalphaspace_tokens) {
const std::string kPrefix = "^";
const std::string kSuffix = "$";
const std::string kReplacementToken = " ";
TokenizedOutput output;
size_t token_start = 0;
output.str.reserve(len + 2);
output.tokens.reserve(len + 2);
output.str.append(kPrefix);
output.tokens.push_back(std::make_pair(token_start, kPrefix.size()));
token_start += kPrefix.size();
Rune token;
for (int i = 0; i < len && output.tokens.size() + 1 < max_tokens;) {
// Use the standard UTF-8 library to find the next token.
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
// Stop processing, if we can't read any more tokens, or we have reached
// maximum allowed tokens, allocating one token for the suffix.
if (bytes_read == 0) {
break;
}
// If `exclude_nonalphaspace_tokens` is set to true, and the token is not
// alphanumeric, replace it with a replacement token.
if (exclude_nonalphaspace_tokens && !utf_isalpharune(token)) {
output.str.append(kReplacementToken);
output.tokens.push_back(
std::make_pair(token_start, kReplacementToken.size()));
token_start += kReplacementToken.size();
i += bytes_read;
continue;
}
// Append the token in the output string, and note its position and the
// number of bytes that token consumed.
output.str.append(input_str + i, bytes_read);
output.tokens.push_back(std::make_pair(token_start, bytes_read));
token_start += bytes_read;
i += bytes_read;
}
output.str.append(kSuffix);
output.tokens.push_back(std::make_pair(token_start, kSuffix.size()));
token_start += kSuffix.size();
return output;
}
void LowercaseUnicodeStr(const char* input_str, int len,
std::string* output_str) {
for (int i = 0; i < len;) {
Rune token;
// Tokenize the given string, and get the appropriate lowercase token.
size_t bytes_read = utf_charntorune(&token, input_str + i, len - i);
token = utf_isalpharune(token) ? utf_tolowerrune(token) : token;
// Write back the token to the output string.
char token_buf[UTFmax];
size_t bytes_to_write = utf_runetochar(token_buf, &token);
output_str->append(token_buf, bytes_to_write);
i += bytes_read;
}
}
} // namespace mediapipe::tasks::text::language_detector::custom_ops

View File

@ -0,0 +1,56 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
#define MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_
#include <string>
#include <utility>
#include <vector>
namespace mediapipe::tasks::text::language_detector::custom_ops {
struct TokenizedOutput {
// The processed string (with necessary prefix, suffix, skipped tokens, etc.).
std::string str;
// This vector contains pairs, where each pair has two members. The first
// denoting the starting index of the token in the `str` string, and the
// second denoting the length of that token in bytes.
std::vector<std::pair<const size_t, const size_t>> tokens;
};
// Tokenizes the given input string on Unicode token boundaries, with a maximum
// of `max_tokens` tokens.
//
// If `exclude_nonalphaspace_tokens` is enabled, the tokenization ignores
// non-alphanumeric tokens, and replaces them with a replacement token (" ").
//
// The method returns the output in the `TokenizedOutput` struct, which stores
// both, the processed input string, and the indices and sizes of each token
// within that string.
TokenizedOutput Tokenize(const char* input_str, int len, int max_tokens,
bool exclude_nonalphaspace_tokens);
// Converts the given unicode string (`input_str`) with the specified length
// (`len`) to a lowercase string.
//
// The method populates the lowercased string in `output_str`.
void LowercaseUnicodeStr(const char* input_str, int len,
std::string* output_str);
} // namespace mediapipe::tasks::text::language_detector::custom_ops
#endif // MEDIAPIPE_TASKS_CC_TEXT_LANGUAGE_DETECTOR_CUSTOM_OPS_UTILS_NGRAM_HASH_OPS_UTILS_H_

View File

@ -0,0 +1,135 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/ngram_hash_ops_utils.h"
#include <string>
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
namespace mediapipe::tasks::text::language_detector::custom_ops {
namespace {
using ::testing::Values;
std::string ReconstructStringFromTokens(TokenizedOutput output) {
std::string reconstructed_str;
for (int i = 0; i < output.tokens.size(); i++) {
reconstructed_str.append(
output.str.c_str() + output.tokens[i].first,
output.str.c_str() + output.tokens[i].first + output.tokens[i].second);
}
return reconstructed_str;
}
struct TokenizeTestParams {
std::string input_str;
size_t max_tokens;
bool exclude_nonalphaspace_tokens;
std::string expected_output_str;
};
class TokenizeParameterizedTest
: public ::testing::Test,
public testing::WithParamInterface<TokenizeTestParams> {};
TEST_P(TokenizeParameterizedTest, Tokenize) {
// Checks that the Tokenize method returns the expected value.
const TokenizeTestParams params = TokenizeParameterizedTest::GetParam();
const TokenizedOutput output = Tokenize(
/*input_str=*/params.input_str.c_str(),
/*len=*/params.input_str.size(),
/*max_tokens=*/params.max_tokens,
/*exclude_nonalphaspace_tokens=*/params.exclude_nonalphaspace_tokens);
// The output string should have the necessary prefixes, and the "!" token
// should have been replaced with a " ".
EXPECT_EQ(output.str, params.expected_output_str);
EXPECT_EQ(ReconstructStringFromTokens(output), params.expected_output_str);
}
INSTANTIATE_TEST_SUITE_P(
TokenizeParameterizedTests, TokenizeParameterizedTest,
Values(
// Test including non-alphanumeric characters.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
/*exclude_alphanonspace=*/false,
/*expected_output_str=*/"^hi!$"}),
// Test not including non-alphanumeric characters.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/100,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^hi $"}),
// Test with a maximum of 3 tokens.
TokenizeTestParams({/*input_str=*/"hi!", /*max_tokens=*/3,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^h$"}),
// Test with non-latin characters.
TokenizeTestParams({/*input_str=*/"ありがと", /*max_tokens=*/100,
/*exclude_alphanonspace=*/true,
/*expected_output_str=*/"^ありがと$"})));
TEST(LowercaseUnicodeTest, TestLowercaseUnicode) {
{
// Check that the method is a no-op when the string is lowercase.
std::string input_str = "hello";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "hello");
}
{
// Check that the method has uppercase characters.
std::string input_str = "hElLo";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "hello");
}
{
// Check that the method works with non-latin scripts.
// Cyrillic has the concept of cases, so it should change the input.
std::string input_str = "БЙп";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "бйп");
}
{
// Check that the method works with non-latin scripts.
// Japanese doesn't have the concept of cases, so it should not change.
std::string input_str = "ありがと";
std::string output_str;
LowercaseUnicodeStr(
/*input_str=*/input_str.c_str(),
/*len=*/input_str.size(),
/*output_str=*/&output_str);
EXPECT_EQ(output_str, "ありがと");
}
}
} // namespace
} // namespace mediapipe::tasks::text::language_detector::custom_ops

View File

@ -0,0 +1,27 @@
# Copyright 2022 The MediaPipe Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//mediapipe/tasks:internal"])
licenses(["notice"])
cc_library(
name = "utf",
srcs = [
"rune.c",
"runetype.c",
"runetypebody.h",
],
hdrs = ["utf.h"],
)

View File

@ -0,0 +1,233 @@
/* Copyright 2023 The MediaPipe Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Forked from a library written by Rob Pike and Ken Thompson. Original
// copyright message below.
/*
* The authors of this software are Rob Pike and Ken Thompson.
* Copyright (c) 2002 by Lucent Technologies.
* Permission to use, copy, modify, and distribute this software for any
* purpose without fee is hereby granted, provided that this entire notice
* is included in all copies of any software which is or includes a copy
* or modification of this software and in all copies of the supporting
* documentation for such software.
* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
*/
#include <stdarg.h>
#include <string.h>
#include "mediapipe/tasks/cc/text/language_detector/custom_ops/utils/utf/utf.h"
enum
{
Bit1 = 7,
Bitx = 6,
Bit2 = 5,
Bit3 = 4,
Bit4 = 3,
Bit5 = 2,
T1 = ((1<<(Bit1+1))-1) ^ 0xFF, /* 0000 0000 */
Tx = ((1<<(Bitx+1))-1) ^ 0xFF, /* 1000 0000 */
T2 = ((1<<(Bit2+1))-1) ^ 0xFF, /* 1100 0000 */
T3 = ((1<<(Bit3+1))-1) ^ 0xFF, /* 1110 0000 */
T4 = ((1<<(Bit4+1))-1) ^ 0xFF, /* 1111 0000 */
T5 = ((1<<(Bit5+1))-1) ^ 0xFF, /* 1111 1000 */
Rune1 = (1<<(Bit1+0*Bitx))-1, /* 0000 0000 0111 1111 */
Rune2 = (1<<(Bit2+1*Bitx))-1, /* 0000 0111 1111 1111 */
Rune3 = (1<<(Bit3+2*Bitx))-1, /* 1111 1111 1111 1111 */
Rune4 = (1<<(Bit4+3*Bitx))-1,
/* 0001 1111 1111 1111 1111 1111 */
Maskx = (1<<Bitx)-1, /* 0011 1111 */
Testx = Maskx ^ 0xFF, /* 1100 0000 */
Bad = Runeerror,
};
/*
* Modified by Wei-Hwa Huang, Google Inc., on 2004-09-24
* This is a slower but "safe" version of the old chartorune
* that works on strings that are not necessarily null-terminated.
*
* If you know for sure that your string is null-terminated,
* chartorune will be a bit faster.
*
* It is guaranteed not to attempt to access "length"
* past the incoming pointer. This is to avoid
* possible access violations. If the string appears to be
* well-formed but incomplete (i.e., to get the whole Rune
* we'd need to read past str+length) then we'll set the Rune
* to Bad and return 0.
*
* Note that if we have decoding problems for other
* reasons, we return 1 instead of 0.
*/
int
utf_charntorune(Rune *rune, const char *str, int length)
{
int c, c1, c2, c3;
long l;
/* When we're not allowed to read anything */
if(length <= 0) {
goto badlen;
}
/*
* one character sequence (7-bit value)
* 00000-0007F => T1
*/
c = *(uchar*)str;
if(c < Tx) {
*rune = c;
return 1;
}
// If we can't read more than one character we must stop
if(length <= 1) {
goto badlen;
}
/*
* two character sequence (11-bit value)
* 0080-07FF => T2 Tx
*/
c1 = *(uchar*)(str+1) ^ Tx;
if(c1 & Testx)
goto bad;
if(c < T3) {
if(c < T2)
goto bad;
l = ((c << Bitx) | c1) & Rune2;
if(l <= Rune1)
goto bad;
*rune = l;
return 2;
}
// If we can't read more than two characters we must stop
if(length <= 2) {
goto badlen;
}
/*
* three character sequence (16-bit value)
* 0800-FFFF => T3 Tx Tx
*/
c2 = *(uchar*)(str+2) ^ Tx;
if(c2 & Testx)
goto bad;
if(c < T4) {
l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3;
if(l <= Rune2)
goto bad;
*rune = l;
return 3;
}
if (length <= 3)
goto badlen;
/*
* four character sequence (21-bit value)
* 10000-1FFFFF => T4 Tx Tx Tx
*/
c3 = *(uchar*)(str+3) ^ Tx;
if (c3 & Testx)
goto bad;
if (c < T5) {
l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4;
if (l <= Rune3)
goto bad;
if (l > Runemax)
goto bad;
*rune = l;
return 4;
}
// Support for 5-byte or longer UTF-8 would go here, but
// since we don't have that, we'll just fall through to bad.
/*
* bad decoding
*/
bad:
*rune = Bad;
return 1;
badlen:
*rune = Bad;
return 0;
}
int
utf_runetochar(char *str, const Rune *rune)
{
/* Runes are signed, so convert to unsigned for range check. */
unsigned long c;
/*
* one character sequence
* 00000-0007F => 00-7F
*/
c = *rune;
if(c <= Rune1) {
str[0] = c;
return 1;
}
/*
* two character sequence
* 0080-07FF => T2 Tx
*/
if(c <= Rune2) {
str[0] = T2 | (c >> 1*Bitx);
str[1] = Tx | (c & Maskx);
return 2;
}
/*
* If the Rune is out of range, convert it to the error rune.
* Do this test here because the error rune encodes to three bytes.
* Doing it earlier would duplicate work, since an out of range
* Rune wouldn't have fit in one or two bytes.
*/
if (c > Runemax)
c = Runeerror;
/*
* three character sequence
* 0800-FFFF => T3 Tx Tx
*/
if (c <= Rune3) {
str[0] = T3 | (c >> 2*Bitx);
str[1] = Tx | ((c >> 1*Bitx) & Maskx);
str[2] = Tx | (c & Maskx);
return 3;
}
/*
* four character sequence (21-bit value)
* 10000-1FFFFF => T4 Tx Tx Tx
*/
str[0] = T4 | (c >> 3*Bitx);
str[1] = Tx | ((c >> 2*Bitx) & Maskx);
str[2] = Tx | ((c >> 1*Bitx) & Maskx);
str[3] = Tx | (c & Maskx);
return 4;
}

Some files were not shown because too many files have changed in this diff Show More