Project import generated by Copybara.

GitOrigin-RevId: 19a829ffd755edb43e54d20c0e7b9348512d5108
This commit is contained in:
MediaPipe Team 2022-05-03 15:29:57 -07:00 committed by schmidt-sebastian
parent c6c80c3745
commit 7fb37c80e8
136 changed files with 2572 additions and 555 deletions

View File

@ -32,6 +32,9 @@ build:macos --copt=-w
# Sets the default Apple platform to macOS. # Sets the default Apple platform to macOS.
build --apple_platform_type=macos build --apple_platform_type=macos
# Compile ObjC++ files with C++17
build --per_file_copt=.*\.mm\$@-std=c++17
# Allow debugging with XCODE # Allow debugging with XCODE
build --apple_generate_dsym build --apple_generate_dsym
@ -88,6 +91,10 @@ build:darwin_x86_64 --apple_platform_type=macos
build:darwin_x86_64 --macos_minimum_os=10.12 build:darwin_x86_64 --macos_minimum_os=10.12
build:darwin_x86_64 --cpu=darwin_x86_64 build:darwin_x86_64 --cpu=darwin_x86_64
build:darwin_arm64 --apple_platform_type=macos
build:darwin_arm64 --macos_minimum_os=10.16
build:darwin_arm64 --cpu=darwin_arm64
# This bazelrc file is meant to be written by a setup script. # This bazelrc file is meant to be written by a setup script.
try-import %workspace%/.configure.bazelrc try-import %workspace%/.configure.bazelrc

View File

@ -202,7 +202,10 @@ new_local_repository(
new_local_repository( new_local_repository(
name = "macos_opencv", name = "macos_opencv",
build_file = "@//third_party:opencv_macos.BUILD", build_file = "@//third_party:opencv_macos.BUILD",
path = "/usr/local/opt/opencv@3", # For local MacOS builds, the path should point to an opencv@3 installation.
# If you edit the path here, you will also need to update the corresponding
# prefix in "opencv_macos.BUILD".
path = "/usr/local",
) )
new_local_repository( new_local_repository(

View File

@ -53,7 +53,7 @@ the following:
```bash ```bash
$ echo "android_sdk_repository(name = \"androidsdk\")" >> WORKSPACE $ echo "android_sdk_repository(name = \"androidsdk\")" >> WORKSPACE
$ echo "android_ndk_repository(name = \"androidndk\")" >> WORKSPACE $ echo "android_ndk_repository(name = \"androidndk\", api_level=21)" >> WORKSPACE
``` ```
In order to use MediaPipe on earlier Android versions, MediaPipe needs to switch In order to use MediaPipe on earlier Android versions, MediaPipe needs to switch

View File

@ -59,6 +59,21 @@ OpenGL ES profile shading language version string: OpenGL ES GLSL ES 3.20
OpenGL ES profile extensions: OpenGL ES profile extensions:
``` ```
If you have connected to your computer through SSH and find when you probe for
GPU information you see the output:
```bash
glxinfo | grep -i opengl
Error: unable to open display
```
Try re-establishing your SSH connection with the `-X` option and try again. For
example:
```bash
ssh -X <user>@<host>
```
*Notice the ES 3.20 text above.* *Notice the ES 3.20 text above.*
You need to see ES 3.1 or greater printed in order to perform TFLite inference You need to see ES 3.1 or greater printed in order to perform TFLite inference

View File

@ -131,7 +131,7 @@ Create a `BUILD` file in the `$APPLICATION_PATH` and add the following build
rules: rules:
``` ```
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
load( load(
"@build_bazel_rules_apple//apple:ios.bzl", "@build_bazel_rules_apple//apple:ios.bzl",

View File

@ -32,9 +32,14 @@ example apps, start from, start from
xcode-select --install xcode-select --install
``` ```
3. Install [Bazel](https://bazel.build/). 3. Install [Bazelisk](https://github.com/bazelbuild/bazelisk)
.
We recommend using [Homebrew](https://brew.sh/) to get the latest version. We recommend using [Homebrew](https://brew.sh/) to get the latest versions.
```bash
brew install bazelisk
```
4. Set Python 3.7 as the default Python version and install the Python "six" 4. Set Python 3.7 as the default Python version and install the Python "six"
library. This is needed for TensorFlow. library. This is needed for TensorFlow.
@ -187,6 +192,9 @@ Note: When you ask Xcode to run an app, by default it will use the Debug
configuration. Some of our demos are computationally heavy; you may want to use configuration. Some of our demos are computationally heavy; you may want to use
the Release configuration for better performance. the Release configuration for better performance.
Note: Due to an imcoptibility caused by one of our dependencies, MediaPipe
cannot be used for apps running on the iPhone Simulator on Apple Silicon (M1).
Tip: To switch build configuration in Xcode, click on the target menu, choose Tip: To switch build configuration in Xcode, click on the target menu, choose
"Edit Scheme...", select the Run action, and switch the Build Configuration from "Edit Scheme...", select the Run action, and switch the Build Configuration from
Debug to Release. Note that this is set independently for each target. Debug to Release. Note that this is set independently for each target.

View File

@ -258,13 +258,14 @@ Many of the following settings are advanced and not recommended for general
usage. Consult [Enabling tracing and profiling](#enabling-tracing-and-profiling) usage. Consult [Enabling tracing and profiling](#enabling-tracing-and-profiling)
for a friendlier introduction. for a friendlier introduction.
histogram_interval_size_usec :Specifies the size of the runtimes histogram histogram_interval_size_usec
intervals (in microseconds) to generate the histogram of the Process() time. The : Specifies the size of the runtimes histogram intervals (in microseconds) to
last interval extends to +inf. If not specified, the interval is 1000000 usec = generate the histogram of the `Process()` time. The last interval extends to
1 sec. +inf. If not specified, the interval is 1000000 usec = 1 sec.
num_histogram_intervals :Specifies the number of intervals to generate the num_histogram_intervals
histogram of the `Process()` runtime. If not specified, one interval is used. : Specifies the number of intervals to generate the histogram of the
`Process()` runtime. If not specified, one interval is used.
enable_profiler enable_profiler
: If true, the profiler starts profiling when graph is initialized. : If true, the profiler starts profiling when graph is initialized.
@ -288,7 +289,7 @@ trace_event_types_disabled
trace_log_path trace_log_path
: The output directory and base-name prefix for trace log files. Log files are : The output directory and base-name prefix for trace log files. Log files are
written to: StrCat(trace_log_path, index, "`.binarypb`") written to: `StrCat(trace_log_path, index, ".binarypb")`
trace_log_count trace_log_count
: The number of trace log files retained. The trace log files are named : The number of trace log files retained. The trace log files are named
@ -310,8 +311,8 @@ trace_log_instant_events
trace_log_interval_count trace_log_interval_count
: The number of trace log intervals per file. The total log duration is: : The number of trace log intervals per file. The total log duration is:
`trace_log_interval_usec * trace_log_file_count * trace_log_interval_count`. `trace_log_interval_usec * trace_log_count * trace_log_interval_count`. The
The default value specifies 10 intervals per file. default value specifies 10 intervals per file.
trace_log_disabled trace_log_disabled
: An option to turn ON/OFF writing trace files to disk. Saving trace files to : An option to turn ON/OFF writing trace files to disk. Saving trace files to

View File

@ -75,6 +75,7 @@ alias(
actual = select({ actual = select({
":macos_i386": ":macos_i386", ":macos_i386": ":macos_i386",
":macos_x86_64": ":macos_x86_64", ":macos_x86_64": ":macos_x86_64",
":macos_arm64": ":macos_arm64",
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above. "//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
@ -119,6 +120,15 @@ config_setting(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )
config_setting(
name = "macos_arm64",
values = {
"apple_platform_type": "macos",
"cpu": "darwin_arm64",
},
visibility = ["//visibility:public"],
)
[ [
config_setting( config_setting(
name = arch, name = arch,

View File

@ -214,6 +214,7 @@ cc_library(
"//mediapipe/framework:collection_item_id", "//mediapipe/framework:collection_item_id",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework/formats:classification_cc_proto", "//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto", "//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto", "//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
@ -1257,3 +1258,36 @@ cc_test(
"@com_google_absl//absl/time", "@com_google_absl//absl/time",
], ],
) )
cc_library(
name = "get_vector_item_calculator",
srcs = ["get_vector_item_calculator.cc"],
hdrs = ["get_vector_item_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "vector_size_calculator",
srcs = ["vector_size_calculator.cc"],
hdrs = ["vector_size_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)

View File

@ -28,6 +28,10 @@ typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedLandmarkList>>
BeginLoopNormalizedLandmarkListVectorCalculator; BeginLoopNormalizedLandmarkListVectorCalculator;
REGISTER_CALCULATOR(BeginLoopNormalizedLandmarkListVectorCalculator); REGISTER_CALCULATOR(BeginLoopNormalizedLandmarkListVectorCalculator);
// A calculator to process std::vector<int>.
typedef BeginLoopCalculator<std::vector<int>> BeginLoopIntCalculator;
REGISTER_CALCULATOR(BeginLoopIntCalculator);
// A calculator to process std::vector<NormalizedRect>. // A calculator to process std::vector<NormalizedRect>.
typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedRect>> typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
BeginLoopNormalizedRectCalculator; BeginLoopNormalizedRectCalculator;

View File

@ -17,6 +17,7 @@
#include <vector> #include <vector>
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h" #include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/util/render_data.pb.h" #include "mediapipe/util/render_data.pb.h"
@ -50,4 +51,8 @@ REGISTER_CALCULATOR(EndLoopClassificationListCalculator);
typedef EndLoopCalculator<std::vector<TfLiteTensor>> EndLoopTensorCalculator; typedef EndLoopCalculator<std::vector<TfLiteTensor>> EndLoopTensorCalculator;
REGISTER_CALCULATOR(EndLoopTensorCalculator); REGISTER_CALCULATOR(EndLoopTensorCalculator);
typedef EndLoopCalculator<std::vector<::mediapipe::Detection>>
EndLoopDetectionCalculator;
REGISTER_CALCULATOR(EndLoopDetectionCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,32 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/get_vector_item_calculator.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe {
namespace api2 {
using GetLandmarkListVectorItemCalculator =
GetVectorItemCalculator<mediapipe::LandmarkList>;
REGISTER_CALCULATOR(GetLandmarkListVectorItemCalculator);
using GetClassificationListVectorItemCalculator =
GetVectorItemCalculator<mediapipe::ClassificationList>;
REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,77 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_
#include <optional>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// A calcutlator to return an item from the vector by its index.
//
// Inputs:
// VECTOR - std::vector<T>
// Vector to take an item from.
// INDEX - int
// Index of the item to return.
//
// Outputs:
// ITEM - T
// Item from the vector at given index.
//
// Example config:
// node {
// calculator: "Get{SpecificType}VectorItemCalculator"
// input_stream: "VECTOR:vector"
// input_stream: "INDEX:index"
// input_stream: "ITEM:item"
// }
//
template <typename T>
class GetVectorItemCalculator : public Node {
public:
static constexpr Input<std::vector<T>> kIn{"VECTOR"};
static constexpr Input<int> kIdx{"INDEX"};
static constexpr Output<T> kOut{"ITEM"};
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
absl::Status Process(CalculatorContext* cc) final {
if (kIn(cc).IsEmpty() || kIdx(cc).IsEmpty()) {
return absl::OkStatus();
}
const std::vector<T>& items = kIn(cc).Get();
const int idx = kIdx(cc).Get();
RET_CHECK_LT(idx, items.size());
kOut(cc).Send(items[idx]);
return absl::OkStatus();
}
};
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_

View File

@ -83,4 +83,7 @@ REGISTER_CALCULATOR(SplitClassificationListVectorCalculator);
typedef SplitVectorCalculator<uint64_t, false> SplitUint64tVectorCalculator; typedef SplitVectorCalculator<uint64_t, false> SplitUint64tVectorCalculator;
REGISTER_CALCULATOR(SplitUint64tVectorCalculator); REGISTER_CALCULATOR(SplitUint64tVectorCalculator);
typedef SplitVectorCalculator<float, false> SplitFloatVectorCalculator;
REGISTER_CALCULATOR(SplitFloatVectorCalculator);
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,32 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/vector_size_calculator.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe {
namespace api2 {
using LandmarkListVectorSizeCalculator =
VectorSizeCalculator<mediapipe::LandmarkList>;
REGISTER_CALCULATOR(LandmarkListVectorSizeCalculator);
using ClassificationListVectorSizeCalculator =
VectorSizeCalculator<mediapipe::ClassificationList>;
REGISTER_CALCULATOR(ClassificationListVectorSizeCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,64 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_
#include <optional>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// A calcutlator to return vector size.
//
// Inputs:
// VECTOR - std::vector<T>
// Vector which size to return.
//
// Outputs:
// SIZE - int
// Size of the input vector.
//
// Example config:
// node {
// calculator: "{SpecificType}VectorSizeCalculator"
// input_stream: "VECTOR:vector"
// output_stream: "SIZE:vector_size"
// }
//
template <typename T>
class VectorSizeCalculator : public Node {
public:
static constexpr Input<std::vector<T>> kIn{"VECTOR"};
static constexpr Output<int> kOut{"SIZE"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Process(CalculatorContext* cc) final {
if (kIn(cc).IsEmpty()) {
return absl::OkStatus();
}
kOut(cc).Send(kIn(cc).Get().size());
return absl::OkStatus();
}
};
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_

View File

@ -421,6 +421,10 @@ absl::Status ScaleImageCalculator::InitializeFromOptions() {
alignment_boundary_ = options_.alignment_boundary(); alignment_boundary_ = options_.alignment_boundary();
} }
if (options_.has_output_format()) {
output_format_ = options_.output_format();
}
downscaler_.reset(new ImageResizer(options_.post_sharpening_coefficient())); downscaler_.reset(new ImageResizer(options_.post_sharpening_coefficient()));
return absl::OkStatus(); return absl::OkStatus();
@ -433,13 +437,17 @@ absl::Status ScaleImageCalculator::ValidateImageFormats() const {
<< "The output image format was set to UNKNOWN."; << "The output image format was set to UNKNOWN.";
// TODO Remove these conditions. // TODO Remove these conditions.
RET_CHECK(output_format_ == ImageFormat::SRGB || RET_CHECK(output_format_ == ImageFormat::SRGB ||
output_format_ == ImageFormat::SRGBA ||
(input_format_ == output_format_ && (input_format_ == output_format_ &&
output_format_ == ImageFormat::YCBCR420P)) output_format_ == ImageFormat::YCBCR420P))
<< "Outputting YCbCr420P images from SRGB input is not yet supported"; << "Outputting YCbCr420P images from SRGB input is not yet supported";
RET_CHECK(input_format_ == output_format_ || RET_CHECK(input_format_ == output_format_ ||
input_format_ == ImageFormat::YCBCR420P) (input_format_ == ImageFormat::YCBCR420P &&
output_format_ == ImageFormat::SRGB) ||
(input_format_ == ImageFormat::SRGB &&
output_format_ == ImageFormat::SRGBA))
<< "Conversion of the color space (except from " << "Conversion of the color space (except from "
"YCbCr420P to SRGB) is not yet supported."; "YCbCr420P to SRGB or SRGB to SRBGA) is not yet supported.";
return absl::OkStatus(); return absl::OkStatus();
} }
@ -604,6 +612,15 @@ absl::Status ScaleImageCalculator::Process(CalculatorContext* cc) {
.Add(output_image.release(), cc->InputTimestamp()); .Add(output_image.release(), cc->InputTimestamp());
return absl::OkStatus(); return absl::OkStatus();
} }
} else if (input_format_ == ImageFormat::SRGB &&
output_format_ == ImageFormat::SRGBA) {
image_frame = &cc->Inputs().Get(input_data_id_).Get<ImageFrame>();
cv::Mat input_mat = ::mediapipe::formats::MatView(image_frame);
converted_image_frame.Reset(ImageFormat::SRGBA, image_frame->Width(),
image_frame->Height(), alignment_boundary_);
cv::Mat output_mat = ::mediapipe::formats::MatView(&converted_image_frame);
cv::cvtColor(input_mat, output_mat, cv::COLOR_RGB2RGBA, 4);
image_frame = &converted_image_frame;
} else { } else {
image_frame = &cc->Inputs().Get(input_data_id_).Get<ImageFrame>(); image_frame = &cc->Inputs().Get(input_data_id_).Get<ImageFrame>();
MP_RETURN_IF_ERROR(ValidateImageFrame(cc, *image_frame)); MP_RETURN_IF_ERROR(ValidateImageFrame(cc, *image_frame));

View File

@ -28,7 +28,9 @@ package(default_visibility = ["//visibility:private"])
exports_files( exports_files(
glob(["testdata/image_to_tensor/*"]), glob(["testdata/image_to_tensor/*"]),
visibility = ["//mediapipe/calculators/image:__subpackages__"], visibility = [
"//mediapipe/calculators/image:__subpackages__",
],
) )
selects.config_setting_group( selects.config_setting_group(
@ -64,15 +66,16 @@ cc_library(
":inference_calculator_cc_proto", ":inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/tool:subgraph_expansion", "//mediapipe/framework/tool:subgraph_expansion",
"//mediapipe/util/tflite:config",
"//mediapipe/util/tflite:tflite_model_loader", "//mediapipe/util/tflite:tflite_model_loader",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
], ],
alwayslink = 1, alwayslink = 1,
@ -91,6 +94,7 @@ cc_library(
"//mediapipe/util/tflite:tflite_gpu_runner", "//mediapipe/util/tflite:tflite_gpu_runner",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", "@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape",
], ],
@ -142,6 +146,8 @@ cc_library(
":inference_calculator_interface", ":inference_calculator_interface",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
] + select({ ] + select({
"//conditions:default": [ "//conditions:default": [
"//mediapipe/util:cpu_util", "//mediapipe/util:cpu_util",

View File

@ -142,22 +142,35 @@ class ImageToTensorCalculator : public Node {
cc->Options<mediapipe::ImageToTensorCalculatorOptions>(); cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
RET_CHECK(options.has_output_tensor_float_range() || RET_CHECK(options.has_output_tensor_float_range() ||
options.has_output_tensor_int_range()) options.has_output_tensor_int_range() ||
options.has_output_tensor_uint_range())
<< "Output tensor range is required."; << "Output tensor range is required.";
if (options.has_output_tensor_float_range()) { if (options.has_output_tensor_float_range()) {
RET_CHECK_LT(options.output_tensor_float_range().min(), RET_CHECK_LT(options.output_tensor_float_range().min(),
options.output_tensor_float_range().max()) options.output_tensor_float_range().max())
<< "Valid output float tensor range is required."; << "Valid output float tensor range is required.";
} }
if (options.has_output_tensor_uint_range()) {
RET_CHECK_LT(options.output_tensor_uint_range().min(),
options.output_tensor_uint_range().max())
<< "Valid output uint tensor range is required.";
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
<< "The minimum of the output uint tensor range must be "
"non-negative.";
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
<< "The maximum of the output uint tensor range must be less than or "
"equal to 255.";
}
if (options.has_output_tensor_int_range()) { if (options.has_output_tensor_int_range()) {
RET_CHECK_LT(options.output_tensor_int_range().min(), RET_CHECK_LT(options.output_tensor_int_range().min(),
options.output_tensor_int_range().max()) options.output_tensor_int_range().max())
<< "Valid output int tensor range is required."; << "Valid output int tensor range is required.";
RET_CHECK_GE(options.output_tensor_int_range().min(), 0) RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
<< "The minimum of the output int tensor range must be non-negative."; << "The minimum of the output int tensor range must be greater than "
RET_CHECK_LE(options.output_tensor_int_range().max(), 255) "or equal to -128.";
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
<< "The maximum of the output int tensor range must be less than or " << "The maximum of the output int tensor range must be less than or "
"equal to 255."; "equal to 127.";
} }
RET_CHECK_GT(options.output_tensor_width(), 0) RET_CHECK_GT(options.output_tensor_width(), 0)
<< "Valid output tensor width is required."; << "Valid output tensor width is required.";
@ -187,15 +200,19 @@ class ImageToTensorCalculator : public Node {
options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>(); options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
output_width_ = options_.output_tensor_width(); output_width_ = options_.output_tensor_width();
output_height_ = options_.output_tensor_height(); output_height_ = options_.output_tensor_height();
is_int_output_ = options_.has_output_tensor_int_range(); is_float_output_ = options_.has_output_tensor_float_range();
if (options_.has_output_tensor_uint_range()) {
range_min_ = range_min_ =
is_int_output_ static_cast<float>(options_.output_tensor_uint_range().min());
? static_cast<float>(options_.output_tensor_int_range().min())
: options_.output_tensor_float_range().min();
range_max_ = range_max_ =
is_int_output_ static_cast<float>(options_.output_tensor_uint_range().max());
? static_cast<float>(options_.output_tensor_int_range().max()) } else if (options_.has_output_tensor_int_range()) {
: options_.output_tensor_float_range().max(); range_min_ = static_cast<float>(options_.output_tensor_int_range().min());
range_max_ = static_cast<float>(options_.output_tensor_int_range().max());
} else {
range_min_ = options_.output_tensor_float_range().min();
range_max_ = options_.output_tensor_float_range().max();
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -275,6 +292,17 @@ class ImageToTensorCalculator : public Node {
} }
} }
Tensor::ElementType GetOutputTensorType() {
if (is_float_output_) {
return Tensor::ElementType::kFloat32;
}
if (range_min_ < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
}
}
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage( absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
CalculatorContext* cc) { CalculatorContext* cc) {
if (kIn(cc).IsConnected()) { if (kIn(cc).IsConnected()) {
@ -305,7 +333,7 @@ class ImageToTensorCalculator : public Node {
const Image& image) { const Image& image) {
// Lazy initialization of the GPU or CPU converter. // Lazy initialization of the GPU or CPU converter.
if (image.UsesGpu()) { if (image.UsesGpu()) {
if (is_int_output_) { if (!is_float_output_) {
return absl::UnimplementedError( return absl::UnimplementedError(
"ImageToTensorConverter for the input GPU image currently doesn't " "ImageToTensorConverter for the input GPU image currently doesn't "
"support quantization."); "support quantization.");
@ -337,11 +365,9 @@ class ImageToTensorCalculator : public Node {
} else { } else {
if (!cpu_converter_) { if (!cpu_converter_) {
#if !MEDIAPIPE_DISABLE_OPENCV #if !MEDIAPIPE_DISABLE_OPENCV
ASSIGN_OR_RETURN(cpu_converter_, ASSIGN_OR_RETURN(
CreateOpenCvConverter( cpu_converter_,
cc, GetBorderMode(), CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType()));
is_int_output_ ? Tensor::ElementType::kUInt8
: Tensor::ElementType::kFloat32));
#else #else
LOG(FATAL) << "Cannot create image to tensor opencv converter since " LOG(FATAL) << "Cannot create image to tensor opencv converter since "
"MEDIAPIPE_DISABLE_OPENCV is defined."; "MEDIAPIPE_DISABLE_OPENCV is defined.";
@ -356,7 +382,7 @@ class ImageToTensorCalculator : public Node {
mediapipe::ImageToTensorCalculatorOptions options_; mediapipe::ImageToTensorCalculatorOptions options_;
int output_width_ = 0; int output_width_ = 0;
int output_height_ = 0; int output_height_ = 0;
bool is_int_output_ = false; bool is_float_output_ = false;
float range_min_ = 0.0f; float range_min_ = 0.0f;
float range_max_ = 1.0f; float range_max_ = 1.0f;
}; };

View File

@ -39,6 +39,14 @@ message ImageToTensorCalculatorOptions {
optional int64 max = 2; optional int64 max = 2;
} }
// Range of uint values [min, max].
// min, must be strictly less than max.
// Please note that UIntRange is supported for CPU tensors only.
message UIntRange {
optional uint64 min = 1;
optional uint64 max = 2;
}
// Pixel extrapolation methods. See @border_mode. // Pixel extrapolation methods. See @border_mode.
enum BorderMode { enum BorderMode {
BORDER_UNSPECIFIED = 0; BORDER_UNSPECIFIED = 0;
@ -58,6 +66,7 @@ message ImageToTensorCalculatorOptions {
oneof range { oneof range {
FloatRange output_tensor_float_range = 4; FloatRange output_tensor_float_range = 4;
IntRange output_tensor_int_range = 7; IntRange output_tensor_int_range = 7;
UIntRange output_tensor_uint_range = 8;
} }
// For CONVENTIONAL mode for OpenGL, input image starts at bottom and needs // For CONVENTIONAL mode for OpenGL, input image starts at bottom and needs

View File

@ -76,12 +76,21 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet,
} }
std::string output_tensor_range; std::string output_tensor_range;
if (output_int_tensor) { if (output_int_tensor) {
if (range_min < 0) {
output_tensor_range = absl::Substitute(R"(output_tensor_int_range { output_tensor_range = absl::Substitute(R"(output_tensor_int_range {
min: $0 min: $0
max: $1 max: $1
})", })",
static_cast<int>(range_min), static_cast<int>(range_min),
static_cast<int>(range_max)); static_cast<int>(range_max));
} else {
output_tensor_range = absl::Substitute(R"(output_tensor_uint_range {
min: $0
max: $1
})",
static_cast<uint>(range_min),
static_cast<uint>(range_max));
}
} else { } else {
output_tensor_range = absl::Substitute(R"(output_tensor_float_range { output_tensor_range = absl::Substitute(R"(output_tensor_float_range {
min: $0 min: $0
@ -141,9 +150,15 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet,
auto view = tensor.GetCpuReadView(); auto view = tensor.GetCpuReadView();
cv::Mat tensor_mat; cv::Mat tensor_mat;
if (output_int_tensor) { if (output_int_tensor) {
if (range_min < 0) {
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kInt8);
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8SC3,
const_cast<int8*>(view.buffer<int8>()));
} else {
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8); EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8);
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3, tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3,
const_cast<uint8*>(view.buffer<uint8>())); const_cast<uint8*>(view.buffer<uint8>()));
}
} else { } else {
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32); EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32);
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3, tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3,
@ -190,26 +205,29 @@ const std::vector<InputType> kInputTypesToTest = {InputType::kImageFrame,
InputType::kImage}; InputType::kImage};
void RunTest(cv::Mat input, cv::Mat expected_result, void RunTest(cv::Mat input, cv::Mat expected_result,
std::vector<float> float_range, std::vector<int> int_range, std::vector<std::pair<float, float>> float_ranges,
int tensor_width, int tensor_height, bool keep_aspect, std::vector<std::pair<int, int>> int_ranges, int tensor_width,
int tensor_height, bool keep_aspect,
absl::optional<BorderMode> border_mode, absl::optional<BorderMode> border_mode,
const mediapipe::NormalizedRect& roi) { const mediapipe::NormalizedRect& roi) {
ASSERT_EQ(2, float_range.size());
ASSERT_EQ(2, int_range.size());
for (auto input_type : kInputTypesToTest) { for (auto input_type : kInputTypesToTest) {
for (auto float_range : float_ranges) {
RunTestWithInputImagePacket( RunTestWithInputImagePacket(
input_type == InputType::kImageFrame ? MakeImageFramePacket(input) input_type == InputType::kImageFrame ? MakeImageFramePacket(input)
: MakeImagePacket(input), : MakeImagePacket(input),
expected_result, float_range[0], float_range[1], tensor_width, expected_result, float_range.first, float_range.second, tensor_width,
tensor_height, keep_aspect, border_mode, roi, tensor_height, keep_aspect, border_mode, roi,
/*output_int_tensor=*/false); /*output_int_tensor=*/false);
}
for (auto int_range : int_ranges) {
RunTestWithInputImagePacket( RunTestWithInputImagePacket(
input_type == InputType::kImageFrame ? MakeImageFramePacket(input) input_type == InputType::kImageFrame ? MakeImageFramePacket(input)
: MakeImagePacket(input), : MakeImagePacket(input),
expected_result, int_range[0], int_range[1], tensor_width, expected_result, int_range.first, int_range.second, tensor_width,
tensor_height, keep_aspect, border_mode, roi, tensor_height, keep_aspect, border_mode, roi,
/*output_int_tensor=*/true); /*output_int_tensor=*/true);
} }
}
} }
TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) { TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) {
@ -224,8 +242,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) {
"tensor/testdata/image_to_tensor/input.jpg"), "tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"), "tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
/*border mode*/ {}, roi); /*border mode*/ {}, roi);
} }
@ -242,8 +260,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) {
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/" "tensor/testdata/image_to_tensor/"
"medium_sub_rect_keep_aspect_border_zero.png"), "medium_sub_rect_keep_aspect_border_zero.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
BorderMode::kZero, roi); BorderMode::kZero, roi);
} }
@ -260,8 +278,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectWithRotation) {
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/" "tensor/testdata/image_to_tensor/"
"medium_sub_rect_keep_aspect_with_rotation.png"), "medium_sub_rect_keep_aspect_with_rotation.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
BorderMode::kReplicate, roi); BorderMode::kReplicate, roi);
} }
@ -279,8 +297,8 @@ TEST(ImageToTensorCalculatorTest,
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/" "tensor/testdata/image_to_tensor/"
"medium_sub_rect_keep_aspect_with_rotation_border_zero.png"), "medium_sub_rect_keep_aspect_with_rotation_border_zero.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
BorderMode::kZero, roi); BorderMode::kZero, roi);
} }
@ -298,8 +316,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotation) {
GetRgb( GetRgb(
"/mediapipe/calculators/" "/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"), "tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"),
/*float_range=*/{-1.0f, 1.0f}, /*float_ranges=*/{{-1.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false,
BorderMode::kReplicate, roi); BorderMode::kReplicate, roi);
} }
@ -316,8 +334,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) {
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/" "tensor/testdata/image_to_tensor/"
"medium_sub_rect_with_rotation_border_zero.png"), "medium_sub_rect_with_rotation_border_zero.png"),
/*float_range=*/{-1.0f, 1.0f}, /*float_ranges=*/{{-1.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false, /*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false,
BorderMode::kZero, roi); BorderMode::kZero, roi);
} }
@ -333,8 +351,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRect) {
"tensor/testdata/image_to_tensor/input.jpg"), "tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/large_sub_rect.png"), "tensor/testdata/image_to_tensor/large_sub_rect.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false,
BorderMode::kReplicate, roi); BorderMode::kReplicate, roi);
} }
@ -351,8 +369,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectBorderZero) {
"tensor/testdata/image_to_tensor/input.jpg"), "tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"), "tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false,
BorderMode::kZero, roi); BorderMode::kZero, roi);
} }
@ -369,8 +387,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) {
"tensor/testdata/image_to_tensor/input.jpg"), "tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"), "tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
BorderMode::kReplicate, roi); BorderMode::kReplicate, roi);
} }
@ -387,8 +405,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) {
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/" "tensor/testdata/image_to_tensor/"
"large_sub_rect_keep_aspect_border_zero.png"), "large_sub_rect_keep_aspect_border_zero.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
BorderMode::kZero, roi); BorderMode::kZero, roi);
} }
@ -405,8 +423,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) {
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/" "tensor/testdata/image_to_tensor/"
"large_sub_rect_keep_aspect_with_rotation.png"), "large_sub_rect_keep_aspect_with_rotation.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
/*border_mode=*/{}, roi); /*border_mode=*/{}, roi);
} }
@ -424,8 +442,8 @@ TEST(ImageToTensorCalculatorTest,
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/" "tensor/testdata/image_to_tensor/"
"large_sub_rect_keep_aspect_with_rotation_border_zero.png"), "large_sub_rect_keep_aspect_with_rotation_border_zero.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true, /*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
/*border_mode=*/BorderMode::kZero, roi); /*border_mode=*/BorderMode::kZero, roi);
} }
@ -441,8 +459,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRange) {
"tensor/testdata/image_to_tensor/input.jpg"), "tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/noop_except_range.png"), "tensor/testdata/image_to_tensor/noop_except_range.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true,
BorderMode::kReplicate, roi); BorderMode::kReplicate, roi);
} }
@ -458,8 +476,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeBorderZero) {
"tensor/testdata/image_to_tensor/input.jpg"), "tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/" GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/noop_except_range.png"), "tensor/testdata/image_to_tensor/noop_except_range.png"),
/*float_range=*/{0.0f, 1.0f}, /*float_ranges=*/{{0.0f, 1.0f}},
/*int_range=*/{0, 255}, /*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true, /*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true,
BorderMode::kZero, roi); BorderMode::kZero, roi);
} }

View File

@ -268,9 +268,11 @@ class GlProcessor : public ImageToTensorConverter {
const RotatedRect& roi, const RotatedRect& roi,
const Size& output_dims, float range_min, const Size& output_dims, float range_min,
float range_max) override { float range_max) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
return InvalidArgumentError( input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ", input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format()))); static_cast<uint32_t>(input.format())));
} }

View File

@ -172,9 +172,11 @@ class GlProcessor : public ImageToTensorConverter {
const RotatedRect& roi, const RotatedRect& roi,
const Size& output_dims, float range_min, const Size& output_dims, float range_min,
float range_max) override { float range_max) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
return InvalidArgumentError( input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ", input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format()))); static_cast<uint32_t>(input.format())));
} }

View File

@ -352,10 +352,11 @@ class MetalProcessor : public ImageToTensorConverter {
const RotatedRect& roi, const RotatedRect& roi,
const Size& output_dims, float range_min, const Size& output_dims, float range_min,
float range_max) override { float range_max) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) { if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
return InvalidArgumentError( input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
absl::StrCat("Only BGRA/RGBA textures are supported, passed " input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
"format: ", return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format()))); static_cast<uint32_t>(input.format())));
} }

View File

@ -45,7 +45,19 @@ class OpenCvProcessor : public ImageToTensorConverter {
border_mode_ = cv::BORDER_CONSTANT; border_mode_ = cv::BORDER_CONSTANT;
break; break;
} }
mat_type_ = tensor_type == Tensor::ElementType::kUInt8 ? CV_8UC3 : CV_32FC3; switch (tensor_type_) {
case Tensor::ElementType::kInt8:
mat_type_ = CV_8SC3;
break;
case Tensor::ElementType::kFloat32:
mat_type_ = CV_32FC3;
break;
case Tensor::ElementType::kUInt8:
mat_type_ = CV_8UC3;
break;
default:
mat_type_ = -1;
}
} }
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input, absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
@ -65,12 +77,22 @@ class OpenCvProcessor : public ImageToTensorConverter {
output_dims.width, kNumChannels}); output_dims.width, kNumChannels});
auto buffer_view = tensor.GetCpuWriteView(); auto buffer_view = tensor.GetCpuWriteView();
cv::Mat dst; cv::Mat dst;
if (tensor_type_ == Tensor::ElementType::kUInt8) { switch (tensor_type_) {
case Tensor::ElementType::kInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
buffer_view.buffer<uint8>()); buffer_view.buffer<int8>());
} else { break;
case Tensor::ElementType::kFloat32:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_, dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
buffer_view.buffer<float>()); buffer_view.buffer<float>());
break;
case Tensor::ElementType::kUInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
buffer_view.buffer<uint8>());
break;
default:
return InvalidArgumentError(
absl::StrCat("Unsupported tensor type: ", tensor_type_));
} }
const cv::RotatedRect rotated_rect(cv::Point2f(roi.center_x, roi.center_y), const cv::RotatedRect rotated_rect(cv::Point2f(roi.center_x, roi.center_y),
@ -124,6 +146,13 @@ class OpenCvProcessor : public ImageToTensorConverter {
absl::StatusOr<std::unique_ptr<ImageToTensorConverter>> CreateOpenCvConverter( absl::StatusOr<std::unique_ptr<ImageToTensorConverter>> CreateOpenCvConverter(
CalculatorContext* cc, BorderMode border_mode, CalculatorContext* cc, BorderMode border_mode,
Tensor::ElementType tensor_type) { Tensor::ElementType tensor_type) {
if (tensor_type != Tensor::ElementType::kInt8 &&
tensor_type != Tensor::ElementType::kFloat32 &&
tensor_type != Tensor::ElementType::kUInt8) {
return absl::InvalidArgumentError(absl::StrCat(
"Tensor type is currently not supported by OpenCvProcessor, type: ",
tensor_type));
}
return absl::make_unique<OpenCvProcessor>(border_mode, tensor_type); return absl::make_unique<OpenCvProcessor>(border_mode, tensor_type);
} }

View File

@ -21,7 +21,9 @@
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/tool/subgraph_expansion.h" #include "mediapipe/framework/tool/subgraph_expansion.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
@ -67,5 +69,17 @@ absl::StatusOr<Packet<TfLiteModelPtr>> InferenceCalculator::GetModelAsPacket(
"Must specify TFLite model as path or loaded model."); "Must specify TFLite model as path or loaded model.");
} }
absl::StatusOr<Packet<tflite::OpResolver>>
InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
if (kSideInOpResolver(cc).IsConnected()) {
return kSideInOpResolver(cc).As<tflite::OpResolver>();
} else if (kSideInCustomOpResolver(cc).IsConnected()) {
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
}
return PacketAdopting<tflite::OpResolver>(
std::make_unique<
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>());
}
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -27,6 +27,7 @@
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/tflite_model_loader.h" #include "mediapipe/util/tflite/tflite_model_loader.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
@ -55,8 +56,11 @@ namespace api2 {
// TENSORS - Vector of Tensors // TENSORS - Vector of Tensors
// //
// Input side packet: // Input side packet:
// DEPRECATED: Prefer to use the "OP_RESOLVER" input side packet instead.
// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver, // CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver,
// instead of the builtin one. // instead of the builtin one.
// OP_RESOLVER (optional) - Use to provide tflite op resolver
// (tflite::OpResolver)
// MODEL (optional) - Use to specify TfLite model // MODEL (optional) - Use to specify TfLite model
// (std::unique_ptr<tflite::FlatBufferModel, // (std::unique_ptr<tflite::FlatBufferModel,
// std::function<void(tflite::FlatBufferModel*)>>) // std::function<void(tflite::FlatBufferModel*)>>)
@ -95,15 +99,21 @@ namespace api2 {
class InferenceCalculator : public NodeIntf { class InferenceCalculator : public NodeIntf {
public: public:
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"}; static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
// migration.
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"}; kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
"OP_RESOLVER"};
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"}; static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"}; static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
static constexpr SideInput< static constexpr SideInput<
mediapipe::InferenceCalculatorOptions::Delegate>::Optional kDelegate{ mediapipe::InferenceCalculatorOptions::Delegate>::Optional kDelegate{
"DELEGATE"}; "DELEGATE"};
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel, MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver,
kOutTensors, kDelegate); kSideInOpResolver, kSideInModel, kOutTensors,
kDelegate);
protected: protected:
using TfLiteDelegatePtr = using TfLiteDelegatePtr =
@ -111,6 +121,9 @@ class InferenceCalculator : public NodeIntf {
absl::StatusOr<Packet<TfLiteModelPtr>> GetModelAsPacket( absl::StatusOr<Packet<TfLiteModelPtr>> GetModelAsPacket(
CalculatorContext* cc); CalculatorContext* cc);
absl::StatusOr<Packet<tflite::OpResolver>> GetOpResolverAsPacket(
CalculatorContext* cc);
}; };
struct InferenceCalculatorSelector : public InferenceCalculator { struct InferenceCalculatorSelector : public InferenceCalculator {

View File

@ -116,6 +116,9 @@ message InferenceCalculatorOptions {
// to ensure there is no clash of the tokens. If unspecified, NNAPI will // to ensure there is no clash of the tokens. If unspecified, NNAPI will
// not try caching the compilation. // not try caching the compilation.
optional string model_token = 2; optional string model_token = 2;
// The name of an accelerator to be used for NNAPI delegate, e.g.
// "google-edgetpu". When not specified, it will be selected by NNAPI.
optional string accelerator_name = 3;
} }
message Xnnpack { message Xnnpack {
// Number of threads for XNNPACK delegate. (By default, calculator tries // Number of threads for XNNPACK delegate. (By default, calculator tries

View File

@ -19,7 +19,7 @@
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_calculator.h"
#include "tensorflow/lite/interpreter_builder.h"
#if defined(MEDIAPIPE_ANDROID) #if defined(MEDIAPIPE_ANDROID)
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#endif // ANDROID #endif // ANDROID
@ -28,6 +28,7 @@
#include "mediapipe/util/cpu_util.h" #include "mediapipe/util/cpu_util.h"
#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__ #endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
namespace mediapipe { namespace mediapipe {
@ -61,6 +62,17 @@ int GetXnnpackNumThreads(
return GetXnnpackDefaultNumThreads(); return GetXnnpackDefaultNumThreads();
} }
template <typename T>
void CopyTensorBuffer(const Tensor& input_tensor,
tflite::Interpreter* interpreter,
int input_tensor_index) {
auto input_tensor_view = input_tensor.GetCpuReadView();
auto input_tensor_buffer = input_tensor_view.buffer<T>();
T* local_tensor_buffer =
interpreter->typed_input_tensor<T>(input_tensor_index);
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
}
} // namespace } // namespace
class InferenceCalculatorCpuImpl class InferenceCalculatorCpuImpl
@ -73,15 +85,16 @@ class InferenceCalculatorCpuImpl
absl::Status Close(CalculatorContext* cc) override; absl::Status Close(CalculatorContext* cc) override;
private: private:
absl::Status LoadModel(CalculatorContext* cc); absl::Status InitInterpreter(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc,
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); tflite::InterpreterBuilder* interpreter_builder);
absl::Status AllocateTensors();
// TfLite requires us to keep the model alive as long as the interpreter is. // TfLite requires us to keep the model alive as long as the interpreter is.
Packet<TfLiteModelPtr> model_packet_; Packet<TfLiteModelPtr> model_packet_;
std::unique_ptr<tflite::Interpreter> interpreter_; std::unique_ptr<tflite::Interpreter> interpreter_;
TfLiteDelegatePtr delegate_; TfLiteDelegatePtr delegate_;
bool has_quantized_input_; TfLiteType input_tensor_type_ = TfLiteType::kTfLiteNoType;
}; };
absl::Status InferenceCalculatorCpuImpl::UpdateContract( absl::Status InferenceCalculatorCpuImpl::UpdateContract(
@ -94,8 +107,7 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract(
} }
absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) { absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadModel(cc)); return InitInterpreter(cc);
return LoadDelegateAndAllocateTensors(cc);
} }
absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) { absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
@ -108,19 +120,23 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
// Read CPU input into tensors. // Read CPU input into tensors.
for (int i = 0; i < input_tensors.size(); ++i) { for (int i = 0; i < input_tensors.size(); ++i) {
const Tensor* input_tensor = &input_tensors[i]; switch (input_tensor_type_) {
auto input_tensor_view = input_tensor->GetCpuReadView(); case TfLiteType::kTfLiteFloat16:
if (has_quantized_input_) { case TfLiteType::kTfLiteFloat32: {
// TODO: Support more quantized tensor types. CopyTensorBuffer<float>(input_tensors[i], interpreter_.get(), i);
auto input_tensor_buffer = input_tensor_view.buffer<uint8>(); break;
uint8* local_tensor_buffer = interpreter_->typed_input_tensor<uint8>(i); }
std::memcpy(local_tensor_buffer, input_tensor_buffer, case TfLiteType::kTfLiteUInt8: {
input_tensor->bytes()); CopyTensorBuffer<uint8>(input_tensors[i], interpreter_.get(), i);
} else { break;
auto input_tensor_buffer = input_tensor_view.buffer<float>(); }
float* local_tensor_buffer = interpreter_->typed_input_tensor<float>(i); case TfLiteType::kTfLiteInt8: {
std::memcpy(local_tensor_buffer, input_tensor_buffer, CopyTensorBuffer<int8>(input_tensors[i], interpreter_.get(), i);
input_tensor->bytes()); break;
}
default:
return absl::InvalidArgumentError(
absl::StrCat("Unsupported input tensor type:", input_tensor_type_));
} }
} }
@ -150,39 +166,34 @@ absl::Status InferenceCalculatorCpuImpl::Close(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) { absl::Status InferenceCalculatorCpuImpl::InitInterpreter(
CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get(); const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver = ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
kSideInCustomOpResolver(cc).GetOr( const auto& op_resolver = op_resolver_packet.Get();
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
RET_CHECK(interpreter_);
#if defined(__EMSCRIPTEN__) #if defined(__EMSCRIPTEN__)
interpreter_->SetNumThreads(1); interpreter_builder.SetNumThreads(1);
#else #else
interpreter_->SetNumThreads( interpreter_builder.SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread()); cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
#endif // __EMSCRIPTEN__ #endif // __EMSCRIPTEN__
return absl::OkStatus(); RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
RET_CHECK(interpreter_);
return AllocateTensors();
} }
absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors( absl::Status InferenceCalculatorCpuImpl::AllocateTensors() {
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
has_quantized_input_ = input_tensor_type_ = interpreter_->tensor(interpreter_->inputs()[0])->type;
interpreter_->tensor(interpreter_->inputs()[0])->quantization.type ==
kTfLiteAffineQuantization;
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) { absl::Status InferenceCalculatorCpuImpl::LoadDelegate(
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
const auto& calculator_opts = const auto& calculator_opts =
cc->Options<mediapipe::InferenceCalculatorOptions>(); cc->Options<mediapipe::InferenceCalculatorOptions>();
auto opts_delegate = calculator_opts.delegate(); auto opts_delegate = calculator_opts.delegate();
@ -211,18 +222,20 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
if (nnapi_requested) { if (nnapi_requested) {
// Attempt to use NNAPI. // Attempt to use NNAPI.
// If not supported, the default CPU delegate will be created and used. // If not supported, the default CPU delegate will be created and used.
interpreter_->SetAllowFp16PrecisionForFp32(1);
tflite::StatefulNnApiDelegate::Options options; tflite::StatefulNnApiDelegate::Options options;
const auto& nnapi = opts_delegate.nnapi(); const auto& nnapi = opts_delegate.nnapi();
options.allow_fp16 = true;
// Set up cache_dir and model_token for NNAPI compilation cache. // Set up cache_dir and model_token for NNAPI compilation cache.
options.cache_dir = options.cache_dir =
nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr; nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr;
options.model_token = options.model_token =
nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr; nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr;
options.accelerator_name = nnapi.has_accelerator_name()
? nnapi.accelerator_name().c_str()
: nullptr;
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options), delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
[](TfLiteDelegate*) {}); [](TfLiteDelegate*) {});
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), interpreter_builder->AddDelegate(delegate_.get());
kTfLiteOk);
return absl::OkStatus(); return absl::OkStatus();
} }
#endif // MEDIAPIPE_ANDROID #endif // MEDIAPIPE_ANDROID
@ -239,8 +252,7 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
GetXnnpackNumThreads(opts_has_delegate, opts_delegate); GetXnnpackNumThreads(opts_has_delegate, opts_delegate);
delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts), delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
&TfLiteXNNPackDelegateDelete); &TfLiteXNNPackDelegateDelete);
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), interpreter_builder->AddDelegate(delegate_.get());
kTfLiteOk);
} }
return absl::OkStatus(); return absl::OkStatus();

View File

@ -22,6 +22,7 @@
#include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_calculator.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/tflite/config.h" #include "mediapipe/util/tflite/config.h"
#include "tensorflow/lite/interpreter_builder.h"
#if MEDIAPIPE_TFLITE_GL_INFERENCE #if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
@ -52,9 +53,11 @@ class InferenceCalculatorGlImpl
private: private:
absl::Status ReadGpuCaches(); absl::Status ReadGpuCaches();
absl::Status SaveGpuCaches(); absl::Status SaveGpuCaches();
absl::Status LoadModel(CalculatorContext* cc); absl::Status InitInterpreter(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc,
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); tflite::InterpreterBuilder* interpreter_builder);
absl::Status BindBuffersToTensors();
absl::Status AllocateTensors();
absl::Status InitTFLiteGPURunner(CalculatorContext* cc); absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is. // TfLite requires us to keep the model alive as long as the interpreter is.
@ -137,17 +140,11 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
#endif // MEDIAPIPE_ANDROID #endif // MEDIAPIPE_ANDROID
} }
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
// for everything.
if (!use_advanced_gpu_api_) {
MP_RETURN_IF_ERROR(LoadModel(cc));
}
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
: LoadDelegateAndAllocateTensors(cc); : InitInterpreter(cc);
})); }));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -292,12 +289,6 @@ absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner( absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
CalculatorContext* cc) { CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
// Create runner // Create runner
tflite::gpu::InferenceOptions options; tflite::gpu::InferenceOptions options;
options.priority1 = allow_precision_loss_ options.priority1 = allow_precision_loss_
@ -335,6 +326,10 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
break; break;
} }
} }
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
const auto& op_resolver = op_resolver_packet.Get();
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel( MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true)); model, op_resolver, /*allow_quant_ops=*/true));
@ -355,31 +350,27 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { absl::Status InferenceCalculatorGlImpl::InitInterpreter(CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get(); const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver = ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
kSideInCustomOpResolver(cc).GetOr( const auto& op_resolver = op_resolver_packet.Get();
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
RET_CHECK(interpreter_);
#if defined(__EMSCRIPTEN__) #if defined(__EMSCRIPTEN__)
interpreter_->SetNumThreads(1); interpreter_builder.SetNumThreads(1);
#else #else
interpreter_->SetNumThreads( interpreter_builder.SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread()); cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
#endif // __EMSCRIPTEN__ #endif // __EMSCRIPTEN__
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
RET_CHECK(interpreter_);
MP_RETURN_IF_ERROR(BindBuffersToTensors());
MP_RETURN_IF_ERROR(AllocateTensors());
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors( absl::Status InferenceCalculatorGlImpl::AllocateTensors() {
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors. // TODO: Support quantized tensors.
RET_CHECK_NE( RET_CHECK_NE(
@ -388,7 +379,8 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors(
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) { absl::Status InferenceCalculatorGlImpl::LoadDelegate(
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
// Configure and create the delegate. // Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed = options.compile_options.precision_loss_allowed =
@ -399,7 +391,11 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
options.compile_options.inline_parameters = 1; options.compile_options.inline_parameters = 1;
delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options), delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options),
&TfLiteGpuDelegateDelete); &TfLiteGpuDelegateDelete);
interpreter_builder->AddDelegate(delegate_.get());
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() {
// Get input image sizes. // Get input image sizes.
const auto& input_indices = interpreter_->inputs(); const auto& input_indices = interpreter_->inputs();
for (int i = 0; i < input_indices.size(); ++i) { for (int i = 0; i < input_indices.size(); ++i) {
@ -431,11 +427,6 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
output_indices[i]), output_indices[i]),
kTfLiteOk); kTfLiteOk);
} }
// Must call this last.
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -90,9 +90,10 @@ class InferenceCalculatorMetalImpl
absl::Status Close(CalculatorContext* cc) override; absl::Status Close(CalculatorContext* cc) override;
private: private:
absl::Status LoadModel(CalculatorContext* cc); absl::Status InitInterpreter(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc); void AddDelegate(CalculatorContext* cc,
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); tflite::InterpreterBuilder* interpreter_builder);
absl::Status CreateConverters(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is. // TfLite requires us to keep the model alive as long as the interpreter is.
Packet<TfLiteModelPtr> model_packet_; Packet<TfLiteModelPtr> model_packet_;
@ -127,11 +128,9 @@ absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) {
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>(); const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss(); allow_precision_loss_ = options.delegate().gpu().allow_precision_loss();
MP_RETURN_IF_ERROR(LoadModel(cc));
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_); RET_CHECK(gpu_helper_);
return LoadDelegateAndAllocateTensors(cc); return InitInterpreter(cc);
} }
absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) { absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) {
@ -199,27 +198,20 @@ absl::Status InferenceCalculatorMetalImpl::Close(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) { absl::Status InferenceCalculatorMetalImpl::InitInterpreter(
CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc)); ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get(); const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver = ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
kSideInCustomOpResolver(cc).GetOr( const auto& op_resolver = op_resolver_packet.Get();
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates()); tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
AddDelegate(cc, &interpreter_builder);
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_); interpreter_builder.SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
RET_CHECK(interpreter_); RET_CHECK(interpreter_);
interpreter_->SetNumThreads( MP_RETURN_IF_ERROR(CreateConverters(cc));
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
return absl::OkStatus();
}
absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk); RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors. // TODO: Support quantized tensors.
RET_CHECK_NE( RET_CHECK_NE(
@ -228,7 +220,8 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors(
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) { void InferenceCalculatorMetalImpl::AddDelegate(
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
const auto& calculator_opts = const auto& calculator_opts =
cc->Options<mediapipe::InferenceCalculatorOptions>(); cc->Options<mediapipe::InferenceCalculatorOptions>();
@ -242,9 +235,11 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait; options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait;
delegate_ = delegate_ =
TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete); TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete);
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()), interpreter_builder->AddDelegate(delegate_.get());
kTfLiteOk); }
absl::Status InferenceCalculatorMetalImpl::CreateConverters(
CalculatorContext* cc) {
id<MTLDevice> device = gpu_helper_.mtlDevice; id<MTLDevice> device = gpu_helper_.mtlDevice;
// Get input image sizes. // Get input image sizes.

View File

@ -91,6 +91,40 @@ void ConvertAnchorsToRawValues(const std::vector<Anchor>& anchors,
} }
} }
absl::Status CheckCustomTensorMapping(
const TensorsToDetectionsCalculatorOptions::TensorMapping& tensor_mapping) {
RET_CHECK(tensor_mapping.has_detections_tensor_index() &&
tensor_mapping.has_scores_tensor_index());
int bitmap = 0;
bitmap |= 1 << tensor_mapping.detections_tensor_index();
bitmap |= 1 << tensor_mapping.scores_tensor_index();
if (!tensor_mapping.has_num_detections_tensor_index() &&
!tensor_mapping.has_classes_tensor_index() &&
!tensor_mapping.has_anchors_tensor_index()) {
// Only allows the output tensor index 0 and 1 to be occupied.
RET_CHECK_EQ(3, bitmap) << "The custom output tensor indices should only "
"cover index 0 and 1.";
} else if (tensor_mapping.has_anchors_tensor_index()) {
RET_CHECK(!tensor_mapping.has_classes_tensor_index() &&
!tensor_mapping.has_num_detections_tensor_index());
bitmap |= 1 << tensor_mapping.anchors_tensor_index();
// If the"anchors" tensor will be available, only allows the output tensor
// index 0, 1, 2 to be occupied.
RET_CHECK_EQ(7, bitmap) << "The custom output tensor indices should only "
"cover index 0, 1 and 2.";
} else {
RET_CHECK(tensor_mapping.has_classes_tensor_index() &&
tensor_mapping.has_num_detections_tensor_index());
// If the "classes" and the "number of detections" tensors will be
// available, only allows the output tensor index 0, 1, 2, 3 to be occupied.
bitmap |= 1 << tensor_mapping.classes_tensor_index();
bitmap |= 1 << tensor_mapping.num_detections_tensor_index();
RET_CHECK_EQ(15, bitmap) << "The custom output tensor indices should only "
"cover index 0, 1, 2 and 3.";
}
return absl::OkStatus();
}
} // namespace } // namespace
// Convert result Tensors from object detection models into MediaPipe // Convert result Tensors from object detection models into MediaPipe
@ -170,13 +204,27 @@ class TensorsToDetectionsCalculator : public Node {
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax, Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
float box_xmax, float score, int class_id, float box_xmax, float score, int class_id,
bool flip_vertically); bool flip_vertically);
bool IsClassIndexAllowed(int class_index);
int num_classes_ = 0; int num_classes_ = 0;
int num_boxes_ = 0; int num_boxes_ = 0;
int num_coords_ = 0; int num_coords_ = 0;
std::set<int> ignore_classes_; int max_results_ = -1;
::mediapipe::TensorsToDetectionsCalculatorOptions options_; // Set of allowed or ignored class indices.
struct ClassIndexSet {
absl::flat_hash_set<int> values;
bool is_allowlist;
};
// Allowed or ignored class indices based on provided options or side packet.
// These are used to filter out the output detection results.
ClassIndexSet class_index_set_;
TensorsToDetectionsCalculatorOptions options_;
bool scores_tensor_index_is_set_ = false;
TensorsToDetectionsCalculatorOptions::TensorMapping tensor_mapping_;
std::vector<int> box_indices_ = {0, 1, 2, 3};
bool has_custom_box_indices_ = false;
std::vector<Anchor> anchors_; std::vector<Anchor> anchors_;
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE #ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
@ -239,6 +287,21 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
} }
} }
} }
const int num_input_tensors = kInTensors(cc)->size();
if (!scores_tensor_index_is_set_) {
if (num_input_tensors == 2 ||
num_input_tensors == kNumInputTensorsWithAnchors) {
tensor_mapping_.set_scores_tensor_index(1);
} else {
tensor_mapping_.set_scores_tensor_index(2);
}
scores_tensor_index_is_set_ = true;
}
if (gpu_processing || num_input_tensors != 4) {
// Allows custom bounding box indices when receiving 4 cpu tensors.
// Uses the default bbox indices in other cases.
RET_CHECK(!has_custom_box_indices_);
}
if (gpu_processing) { if (gpu_processing) {
if (!gpu_inited_) { if (!gpu_inited_) {
@ -263,13 +326,15 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
// Postprocessing on CPU for model without postprocessing op. E.g. output // Postprocessing on CPU for model without postprocessing op. E.g. output
// raw score tensor and box tensor. Anchor decoding will be handled below. // raw score tensor and box tensor. Anchor decoding will be handled below.
// TODO: Add flexible input tensor size handling. // TODO: Add flexible input tensor size handling.
auto raw_box_tensor = &input_tensors[0]; auto raw_box_tensor =
&input_tensors[tensor_mapping_.detections_tensor_index()];
RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3); RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3);
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1); RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options"; RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options";
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_); RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_); RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
auto raw_score_tensor = &input_tensors[1]; auto raw_score_tensor =
&input_tensors[tensor_mapping_.scores_tensor_index()];
RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3); RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3);
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1); RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_); RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
@ -282,7 +347,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
// TODO: Support other options to load anchors. // TODO: Support other options to load anchors.
if (!anchors_init_) { if (!anchors_init_) {
if (input_tensors.size() == kNumInputTensorsWithAnchors) { if (input_tensors.size() == kNumInputTensorsWithAnchors) {
auto anchor_tensor = &input_tensors[2]; auto anchor_tensor =
&input_tensors[tensor_mapping_.anchors_tensor_index()];
RET_CHECK_EQ(anchor_tensor->shape().dims.size(), 2); RET_CHECK_EQ(anchor_tensor->shape().dims.size(), 2);
RET_CHECK_EQ(anchor_tensor->shape().dims[0], num_boxes_); RET_CHECK_EQ(anchor_tensor->shape().dims[0], num_boxes_);
RET_CHECK_EQ(anchor_tensor->shape().dims[1], kNumCoordsPerBox); RET_CHECK_EQ(anchor_tensor->shape().dims[1], kNumCoordsPerBox);
@ -308,7 +374,7 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
float max_score = -std::numeric_limits<float>::max(); float max_score = -std::numeric_limits<float>::max();
// Find the top score for box i. // Find the top score for box i.
for (int score_idx = 0; score_idx < num_classes_; ++score_idx) { for (int score_idx = 0; score_idx < num_classes_; ++score_idx) {
if (ignore_classes_.find(score_idx) == ignore_classes_.end()) { if (IsClassIndexAllowed(score_idx)) {
auto score = raw_scores[i * num_classes_ + score_idx]; auto score = raw_scores[i * num_classes_ + score_idx];
if (options_.sigmoid_score()) { if (options_.sigmoid_score()) {
if (options_.has_score_clipping_thresh()) { if (options_.has_score_clipping_thresh()) {
@ -338,23 +404,26 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
// Postprocessing on CPU with postprocessing op (e.g. anchor decoding and // Postprocessing on CPU with postprocessing op (e.g. anchor decoding and
// non-maximum suppression) within the model. // non-maximum suppression) within the model.
RET_CHECK_EQ(input_tensors.size(), 4); RET_CHECK_EQ(input_tensors.size(), 4);
auto num_boxes_tensor =
auto num_boxes_tensor = &input_tensors[3]; &input_tensors[tensor_mapping_.num_detections_tensor_index()];
RET_CHECK_EQ(num_boxes_tensor->shape().dims.size(), 1); RET_CHECK_EQ(num_boxes_tensor->shape().dims.size(), 1);
RET_CHECK_EQ(num_boxes_tensor->shape().dims[0], 1); RET_CHECK_EQ(num_boxes_tensor->shape().dims[0], 1);
auto detection_boxes_tensor = &input_tensors[0]; auto detection_boxes_tensor =
&input_tensors[tensor_mapping_.detections_tensor_index()];
RET_CHECK_EQ(detection_boxes_tensor->shape().dims.size(), 3); RET_CHECK_EQ(detection_boxes_tensor->shape().dims.size(), 3);
RET_CHECK_EQ(detection_boxes_tensor->shape().dims[0], 1); RET_CHECK_EQ(detection_boxes_tensor->shape().dims[0], 1);
const int max_detections = detection_boxes_tensor->shape().dims[1]; const int max_detections = detection_boxes_tensor->shape().dims[1];
RET_CHECK_EQ(detection_boxes_tensor->shape().dims[2], num_coords_); RET_CHECK_EQ(detection_boxes_tensor->shape().dims[2], num_coords_);
auto detection_classes_tensor = &input_tensors[1]; auto detection_classes_tensor =
&input_tensors[tensor_mapping_.classes_tensor_index()];
RET_CHECK_EQ(detection_classes_tensor->shape().dims.size(), 2); RET_CHECK_EQ(detection_classes_tensor->shape().dims.size(), 2);
RET_CHECK_EQ(detection_classes_tensor->shape().dims[0], 1); RET_CHECK_EQ(detection_classes_tensor->shape().dims[0], 1);
RET_CHECK_EQ(detection_classes_tensor->shape().dims[1], max_detections); RET_CHECK_EQ(detection_classes_tensor->shape().dims[1], max_detections);
auto detection_scores_tensor = &input_tensors[2]; auto detection_scores_tensor =
&input_tensors[tensor_mapping_.scores_tensor_index()];
RET_CHECK_EQ(detection_scores_tensor->shape().dims.size(), 2); RET_CHECK_EQ(detection_scores_tensor->shape().dims.size(), 2);
RET_CHECK_EQ(detection_scores_tensor->shape().dims[0], 1); RET_CHECK_EQ(detection_scores_tensor->shape().dims[0], 1);
RET_CHECK_EQ(detection_scores_tensor->shape().dims[1], max_detections); RET_CHECK_EQ(detection_scores_tensor->shape().dims[1], max_detections);
@ -394,12 +463,14 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
-> absl::Status { -> absl::Status {
if (!anchors_init_) { if (!anchors_init_) {
if (input_tensors.size() == kNumInputTensorsWithAnchors) { if (input_tensors.size() == kNumInputTensorsWithAnchors) {
auto read_view = input_tensors[2].GetOpenGlBufferReadView(); auto read_view = input_tensors[tensor_mapping_.anchors_tensor_index()]
.GetOpenGlBufferReadView();
glBindBuffer(GL_COPY_READ_BUFFER, read_view.name()); glBindBuffer(GL_COPY_READ_BUFFER, read_view.name());
auto write_view = raw_anchors_buffer_->GetOpenGlBufferWriteView(); auto write_view = raw_anchors_buffer_->GetOpenGlBufferWriteView();
glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name()); glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name());
glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0, glCopyBufferSubData(
input_tensors[2].bytes()); GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0,
input_tensors[tensor_mapping_.anchors_tensor_index()].bytes());
} else if (!kInAnchors(cc).IsEmpty()) { } else if (!kInAnchors(cc).IsEmpty()) {
const auto& anchors = *kInAnchors(cc); const auto& anchors = *kInAnchors(cc);
auto anchors_view = raw_anchors_buffer_->GetCpuWriteView(); auto anchors_view = raw_anchors_buffer_->GetCpuWriteView();
@ -418,7 +489,9 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
auto decoded_boxes_view = auto decoded_boxes_view =
decoded_boxes_buffer_->GetOpenGlBufferWriteView(); decoded_boxes_buffer_->GetOpenGlBufferWriteView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, decoded_boxes_view.name()); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, decoded_boxes_view.name());
auto input0_view = input_tensors[0].GetOpenGlBufferReadView(); auto input0_view =
input_tensors[tensor_mapping_.detections_tensor_index()]
.GetOpenGlBufferReadView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input0_view.name()); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input0_view.name());
auto raw_anchors_view = raw_anchors_buffer_->GetOpenGlBufferReadView(); auto raw_anchors_view = raw_anchors_buffer_->GetOpenGlBufferReadView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, raw_anchors_view.name()); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, raw_anchors_view.name());
@ -427,7 +500,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
// Score boxes. // Score boxes.
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, scored_boxes_view.name()); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, scored_boxes_view.name());
auto input1_view = input_tensors[1].GetOpenGlBufferReadView(); auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()]
.GetOpenGlBufferReadView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input1_view.name()); glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input1_view.name());
glUseProgram(score_program_); glUseProgram(score_program_);
glDispatchCompute(num_boxes_, 1, 1); glDispatchCompute(num_boxes_, 1, 1);
@ -459,7 +533,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
if (input_tensors.size() == kNumInputTensorsWithAnchors) { if (input_tensors.size() == kNumInputTensorsWithAnchors) {
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors); RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
auto command_buffer = [gpu_helper_ commandBuffer]; auto command_buffer = [gpu_helper_ commandBuffer];
auto src_buffer = input_tensors[2].GetMtlBufferReadView(command_buffer); auto src_buffer = input_tensors[tensor_mapping_.anchors_tensor_index()]
.GetMtlBufferReadView(command_buffer);
auto dest_buffer = auto dest_buffer =
raw_anchors_buffer_->GetMtlBufferWriteView(command_buffer); raw_anchors_buffer_->GetMtlBufferWriteView(command_buffer);
id<MTLBlitCommandEncoder> blit_command = id<MTLBlitCommandEncoder> blit_command =
@ -468,7 +543,9 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
sourceOffset:0 sourceOffset:0
toBuffer:dest_buffer.buffer() toBuffer:dest_buffer.buffer()
destinationOffset:0 destinationOffset:0
size:input_tensors[2].bytes()]; size:input_tensors[tensor_mapping_
.anchors_tensor_index()]
.bytes()];
[blit_command endEncoding]; [blit_command endEncoding];
[command_buffer commit]; [command_buffer commit];
} else if (!kInAnchors(cc).IsEmpty()) { } else if (!kInAnchors(cc).IsEmpty()) {
@ -495,7 +572,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
auto decoded_boxes_view = auto decoded_boxes_view =
decoded_boxes_buffer_->GetMtlBufferWriteView(command_buffer); decoded_boxes_buffer_->GetMtlBufferWriteView(command_buffer);
[command_encoder setBuffer:decoded_boxes_view.buffer() offset:0 atIndex:0]; [command_encoder setBuffer:decoded_boxes_view.buffer() offset:0 atIndex:0];
auto input0_view = input_tensors[0].GetMtlBufferReadView(command_buffer); auto input0_view = input_tensors[tensor_mapping_.detections_tensor_index()]
.GetMtlBufferReadView(command_buffer);
[command_encoder setBuffer:input0_view.buffer() offset:0 atIndex:1]; [command_encoder setBuffer:input0_view.buffer() offset:0 atIndex:1];
auto raw_anchors_view = auto raw_anchors_view =
raw_anchors_buffer_->GetMtlBufferReadView(command_buffer); raw_anchors_buffer_->GetMtlBufferReadView(command_buffer);
@ -507,7 +585,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
[command_encoder setComputePipelineState:score_program_]; [command_encoder setComputePipelineState:score_program_];
[command_encoder setBuffer:scored_boxes_view.buffer() offset:0 atIndex:0]; [command_encoder setBuffer:scored_boxes_view.buffer() offset:0 atIndex:0];
auto input1_view = input_tensors[1].GetMtlBufferReadView(command_buffer); auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()]
.GetMtlBufferReadView(command_buffer);
[command_encoder setBuffer:input1_view.buffer() offset:0 atIndex:1]; [command_encoder setBuffer:input1_view.buffer() offset:0 atIndex:1];
MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1); MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1);
MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1); MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1);
@ -570,6 +649,10 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) {
num_classes_ = options_.num_classes(); num_classes_ = options_.num_classes();
num_boxes_ = options_.num_boxes(); num_boxes_ = options_.num_boxes();
num_coords_ = options_.num_coords(); num_coords_ = options_.num_coords();
CHECK_NE(options_.max_results(), 0)
<< "The maximum number of the top-scored detection results must be "
"non-zero.";
max_results_ = options_.max_results();
// Currently only support 2D when num_values_per_keypoint equals to 2. // Currently only support 2D when num_values_per_keypoint equals to 2.
CHECK_EQ(options_.num_values_per_keypoint(), 2); CHECK_EQ(options_.num_values_per_keypoint(), 2);
@ -581,15 +664,55 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) {
if (kSideInIgnoreClasses(cc).IsConnected()) { if (kSideInIgnoreClasses(cc).IsConnected()) {
RET_CHECK(!kSideInIgnoreClasses(cc).IsEmpty()); RET_CHECK(!kSideInIgnoreClasses(cc).IsEmpty());
RET_CHECK(options_.allow_classes().empty());
class_index_set_.is_allowlist = false;
for (int ignore_class : *kSideInIgnoreClasses(cc)) { for (int ignore_class : *kSideInIgnoreClasses(cc)) {
ignore_classes_.insert(ignore_class); class_index_set_.values.insert(ignore_class);
}
} else if (!options_.allow_classes().empty()) {
RET_CHECK(options_.ignore_classes().empty());
class_index_set_.is_allowlist = true;
for (int i = 0; i < options_.allow_classes_size(); ++i) {
class_index_set_.values.insert(options_.allow_classes(i));
} }
} else { } else {
class_index_set_.is_allowlist = false;
for (int i = 0; i < options_.ignore_classes_size(); ++i) { for (int i = 0; i < options_.ignore_classes_size(); ++i) {
ignore_classes_.insert(options_.ignore_classes(i)); class_index_set_.values.insert(options_.ignore_classes(i));
} }
} }
if (options_.has_tensor_mapping()) {
RET_CHECK_OK(CheckCustomTensorMapping(options_.tensor_mapping()));
tensor_mapping_ = options_.tensor_mapping();
scores_tensor_index_is_set_ = true;
} else {
// Assigns the default tensor indices.
tensor_mapping_.set_detections_tensor_index(0);
tensor_mapping_.set_classes_tensor_index(1);
tensor_mapping_.set_anchors_tensor_index(2);
tensor_mapping_.set_num_detections_tensor_index(3);
// The scores tensor index needs to be determined based on the number of
// model's output tensors, which will be available in the first invocation
// of the Process() method.
tensor_mapping_.set_scores_tensor_index(-1);
scores_tensor_index_is_set_ = false;
}
if (options_.has_box_boundaries_indices()) {
box_indices_ = {options_.box_boundaries_indices().ymin(),
options_.box_boundaries_indices().xmin(),
options_.box_boundaries_indices().ymax(),
options_.box_boundaries_indices().xmax()};
int bitmap = 0;
for (int i : box_indices_) {
bitmap |= 1 << i;
}
RET_CHECK_EQ(bitmap, 15) << "The custom box boundaries indices should only "
"cover index 0, 1, 2, and 3.";
has_custom_box_indices_ = true;
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -661,14 +784,22 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
const float* detection_boxes, const float* detection_scores, const float* detection_boxes, const float* detection_scores,
const int* detection_classes, std::vector<Detection>* output_detections) { const int* detection_classes, std::vector<Detection>* output_detections) {
for (int i = 0; i < num_boxes_; ++i) { for (int i = 0; i < num_boxes_; ++i) {
if (max_results_ > 0 && output_detections->size() == max_results_) {
break;
}
if (options_.has_min_score_thresh() && if (options_.has_min_score_thresh() &&
detection_scores[i] < options_.min_score_thresh()) { detection_scores[i] < options_.min_score_thresh()) {
continue; continue;
} }
if (!IsClassIndexAllowed(detection_classes[i])) {
continue;
}
const int box_offset = i * num_coords_; const int box_offset = i * num_coords_;
Detection detection = ConvertToDetection( Detection detection = ConvertToDetection(
detection_boxes[box_offset + 0], detection_boxes[box_offset + 1], /*box_ymin=*/detection_boxes[box_offset + box_indices_[0]],
detection_boxes[box_offset + 2], detection_boxes[box_offset + 3], /*box_xmin=*/detection_boxes[box_offset + box_indices_[1]],
/*box_ymax=*/detection_boxes[box_offset + box_indices_[2]],
/*box_xmax=*/detection_boxes[box_offset + box_indices_[3]],
detection_scores[i], detection_classes[i], options_.flip_vertically()); detection_scores[i], detection_classes[i], options_.flip_vertically());
const auto& bbox = detection.location_data().relative_bounding_box(); const auto& bbox = detection.location_data().relative_bounding_box();
if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) || if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
@ -910,7 +1041,7 @@ void main() {
options_.has_score_clipping_thresh() ? 1 : 0, options_.has_score_clipping_thresh() ? 1 : 0,
options_.has_score_clipping_thresh() ? options_.score_clipping_thresh() options_.has_score_clipping_thresh() ? options_.score_clipping_thresh()
: 0, : 0,
!ignore_classes_.empty() ? 1 : 0); !IsClassIndexAllowed(0));
// # filter classes supported is hardware dependent. // # filter classes supported is hardware dependent.
int max_wg_size; // typically <= 1024 int max_wg_size; // typically <= 1024
@ -919,7 +1050,14 @@ void main() {
CHECK_LT(num_classes_, max_wg_size) CHECK_LT(num_classes_, max_wg_size)
<< "# classes must be < " << max_wg_size; << "# classes must be < " << max_wg_size;
// TODO support better filtering. // TODO support better filtering.
CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; if (class_index_set_.is_allowlist) {
CHECK_EQ(class_index_set_.values.size(),
IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1)
<< "Only all classes >= class 0 or >= class 1";
} else {
CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1)
<< "Only ignore class 0 is allowed";
}
// Shader program // Shader program
{ {
@ -1126,10 +1264,17 @@ kernel void scoreKernel(
options_.has_score_clipping_thresh() ? 1 : 0, options_.has_score_clipping_thresh() ? 1 : 0,
options_.has_score_clipping_thresh() ? options_.score_clipping_thresh() options_.has_score_clipping_thresh() ? options_.score_clipping_thresh()
: 0, : 0,
ignore_classes_.size() ? 1 : 0); !IsClassIndexAllowed(0));
// TODO support better filtering. // TODO support better filtering.
CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; if (class_index_set_.is_allowlist) {
CHECK_EQ(class_index_set_.values.size(),
IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1)
<< "Only all classes >= class 0 or >= class 1";
} else {
CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1)
<< "Only ignore class 0 is allowed";
}
{ {
// Shader program // Shader program
@ -1161,5 +1306,16 @@ kernel void scoreKernel(
return absl::OkStatus(); return absl::OkStatus();
} }
bool TensorsToDetectionsCalculator::IsClassIndexAllowed(int class_index) {
if (class_index_set_.values.empty()) {
return true;
}
if (class_index_set_.is_allowlist) {
return class_index_set_.values.contains(class_index);
} else {
return !class_index_set_.values.contains(class_index);
}
}
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -57,7 +57,12 @@ message TensorsToDetectionsCalculatorOptions {
optional bool reverse_output_order = 14 [default = false]; optional bool reverse_output_order = 14 [default = false];
// The ids of classes that should be ignored during decoding the score for // The ids of classes that should be ignored during decoding the score for
// each predicted box. Can be overridden with IGNORE_CLASSES side packet. // each predicted box. Can be overridden with IGNORE_CLASSES side packet.
// `ignore_classes` and `allow_classes` are mutually exclusive.
repeated int32 ignore_classes = 8; repeated int32 ignore_classes = 8;
// The ids of classes that should be allowed during decoding the score for
// each predicted box. `ignore_classes` and `allow_classes` are mutually
// exclusive.
repeated int32 allow_classes = 21 [packed = true];
optional bool sigmoid_score = 15 [default = false]; optional bool sigmoid_score = 15 [default = false];
optional float score_clipping_thresh = 16; optional float score_clipping_thresh = 16;
@ -71,4 +76,40 @@ message TensorsToDetectionsCalculatorOptions {
// Score threshold for perserving decoded detections. // Score threshold for perserving decoded detections.
optional float min_score_thresh = 19; optional float min_score_thresh = 19;
// The maximum number of the detection results to return. If < 0, all
// available results will be returned.
// For the detection models that have built-in non max suppression op, the
// output detections are the top-scored results. Otherwise, the output
// detections are the first N results that have higher scores than
// `min_score_thresh`.
optional int32 max_results = 20 [default = -1];
// The custom model output tensor mapping.
// The indices of the "detections" tensor and the "scores" tensor are always
// required. If the model outputs an "anchors" tensor, `anchors_tensor_index`
// must be specified. If the model outputs both "classes" tensor and "number
// of detections" tensors, `classes_tensor_index` and
// `num_detections_tensor_index` must be set.
message TensorMapping {
optional int32 detections_tensor_index = 1;
optional int32 classes_tensor_index = 2;
optional int32 scores_tensor_index = 3;
optional int32 num_detections_tensor_index = 4;
optional int32 anchors_tensor_index = 5;
}
optional TensorMapping tensor_mapping = 22;
// Represents the bounding box by using the combination of boundaries,
// {ymin, xmin, ymax, xmax}.
// The default order is {ymin, xmin, ymax, xmax}.
message BoxBoundariesIndices {
optional int32 ymin = 1 [default = 0];
optional int32 xmin = 2 [default = 1];
optional int32 ymax = 3 [default = 2];
optional int32 xmax = 4 [default = 3];
}
oneof box_indices {
BoxBoundariesIndices box_boundaries_indices = 23;
}
} }

View File

@ -121,8 +121,12 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
if (d > 255) d = 255; if (d > 255) d = 255;
buffer[i] = d; buffer[i] = d;
} }
output = ::absl::make_unique<ImageFrame>(format, width, height, output = ::absl::make_unique<ImageFrame>(
width * depth, buffer.release()); format, width, height, width * depth, buffer.release(),
[total_size](uint8* ptr) {
::operator delete[](ptr, total_size,
std::align_val_t(EIGEN_MAX_ALIGN_BYTES));
});
} else if (input_tensor.dtype() == tensorflow::DT_UINT8) { } else if (input_tensor.dtype() == tensorflow::DT_UINT8) {
if (scale_factor_ != 1.0) { if (scale_factor_ != 1.0) {
return absl::InvalidArgumentError("scale_factor_ given for uint8 tensor"); return absl::InvalidArgumentError("scale_factor_ given for uint8 tensor");

View File

@ -121,10 +121,11 @@ cc_library(
deps = [ deps = [
":tflite_custom_op_resolver_calculator_cc_proto", ":tflite_custom_op_resolver_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util/tflite:cpu_op_resolver", "//mediapipe/util/tflite:cpu_op_resolver",
"//mediapipe/util/tflite:op_resolver", "//mediapipe/util/tflite:op_resolver",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
], ],
alwayslink = 1, alwayslink = 1,
) )

View File

@ -12,14 +12,22 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include "mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.pb.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/cpu_op_resolver.h" #include "mediapipe/util/tflite/cpu_op_resolver.h"
#include "mediapipe/util/tflite/op_resolver.h" #include "mediapipe/util/tflite/op_resolver.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe { namespace mediapipe {
namespace {
constexpr char kOpResolverTag[] = "OP_RESOLVER";
} // namespace
// This calculator creates a custom op resolver as a side packet that can be // This calculator creates a custom op resolver as a side packet that can be
// used in TfLiteInferenceCalculator. Current custom op resolver supports the // used in TfLiteInferenceCalculator. Current custom op resolver supports the
// following custom op on CPU and GPU: // following custom op on CPU and GPU:
@ -27,7 +35,9 @@ namespace mediapipe {
// MaxPoolArgmax // MaxPoolArgmax
// MaxUnpooling // MaxUnpooling
// //
// Usage example: // Usage examples:
//
// For using with TfliteInferenceCalculator:
// node { // node {
// calculator: "TfLiteCustomOpResolverCalculator" // calculator: "TfLiteCustomOpResolverCalculator"
// output_side_packet: "op_resolver" // output_side_packet: "op_resolver"
@ -37,12 +47,27 @@ namespace mediapipe {
// } // }
// } // }
// } // }
//
// For using with InferenceCalculator:
// node {
// calculator: "TfLiteCustomOpResolverCalculator"
// output_side_packet: "OP_RESOLVER:op_resolver"
// node_options: {
// [type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] {
// use_gpu: true
// }
// }
// }
class TfLiteCustomOpResolverCalculator : public CalculatorBase { class TfLiteCustomOpResolverCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
if (cc->OutputSidePackets().HasTag(kOpResolverTag)) {
cc->OutputSidePackets().Tag(kOpResolverTag).Set<tflite::OpResolver>();
} else {
cc->OutputSidePackets() cc->OutputSidePackets()
.Index(0) .Index(0)
.Set<tflite::ops::builtin::BuiltinOpResolver>(); .Set<tflite::ops::builtin::BuiltinOpResolver>();
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -59,7 +84,14 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
op_resolver = absl::make_unique<mediapipe::CpuOpResolver>(); op_resolver = absl::make_unique<mediapipe::CpuOpResolver>();
} }
if (cc->OutputSidePackets().HasTag(kOpResolverTag)) {
cc->OutputSidePackets()
.Tag(kOpResolverTag)
.Set(mediapipe::api2::PacketAdopting<tflite::OpResolver>(
std::move(op_resolver)));
} else {
cc->OutputSidePackets().Index(0).Set(Adopt(op_resolver.release())); cc->OutputSidePackets().Index(0).Set(Adopt(op_resolver.release()));
}
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -54,6 +54,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/util:label_map_proto",
], ],
) )
@ -304,6 +305,7 @@ cc_library(
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"//mediapipe/util:label_map_cc_proto",
] + select({ ] + select({
"//mediapipe:android": [ "//mediapipe:android": [
"//mediapipe/util/android/file/base", "//mediapipe/util/android/file/base",
@ -350,6 +352,40 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "detection_transformation_calculator",
srcs = ["detection_transformation_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_test(
name = "detection_transformation_calculator_test",
size = "small",
srcs = ["detection_transformation_calculator_test.cc"],
deps = [
":detection_transformation_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
)
cc_library( cc_library(
name = "non_max_suppression_calculator", name = "non_max_suppression_calculator",
srcs = ["non_max_suppression_calculator.cc"], srcs = ["non_max_suppression_calculator.cc"],

View File

@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "absl/container/node_hash_map.h"
#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h" #include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
#if defined(MEDIAPIPE_MOBILE) #if defined(MEDIAPIPE_MOBILE)
@ -53,8 +53,11 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) override; absl::Status Process(CalculatorContext* cc) override;
private: private:
absl::node_hash_map<int, std::string> label_map_; // Local label map built from the calculator options' `label_map_path` or
::mediapipe::DetectionLabelIdToTextCalculatorOptions options_; // `label` field.
LabelMap local_label_map_;
bool keep_label_id_;
const LabelMap& GetLabelMap(CalculatorContext* cc);
}; };
REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator);
@ -69,13 +72,16 @@ absl::Status DetectionLabelIdToTextCalculator::GetContract(
absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
options_ = const auto& options =
cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>(); cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>();
if (options_.has_label_map_path()) { if (options.has_label_map_path()) {
RET_CHECK(!options.has_label_map() && options.label().empty())
<< "Only can set one of the following fields in the CalculatorOptions: "
"label_map_path, label, and label_map.";
std::string string_path; std::string string_path;
ASSIGN_OR_RETURN(string_path, ASSIGN_OR_RETURN(string_path,
PathToResourceAsFile(options_.label_map_path())); PathToResourceAsFile(options.label_map_path()));
std::string label_map_string; std::string label_map_string;
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string)); MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
@ -83,13 +89,21 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
std::string line; std::string line;
int i = 0; int i = 0;
while (std::getline(stream, line)) { while (std::getline(stream, line)) {
label_map_[i++] = line; LabelMapItem item;
item.set_name(line);
(*local_label_map_.mutable_index_to_item())[i++] = item;
} }
} else { } else if (!options.label().empty()) {
for (int i = 0; i < options_.label_size(); ++i) { RET_CHECK(!options.has_label_map())
label_map_[i] = options_.label(i); << "Only can set one of the following fields in the CalculatorOptions: "
"label_map_path, label, and label_map.";
for (int i = 0; i < options.label_size(); ++i) {
LabelMapItem item;
item.set_name(options.label(i));
(*local_label_map_.mutable_index_to_item())[i] = item;
} }
} }
keep_label_id_ = options.keep_label_id();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -101,13 +115,18 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
Detection& output_detection = output_detections.back(); Detection& output_detection = output_detections.back();
bool has_text_label = false; bool has_text_label = false;
for (const int32 label_id : output_detection.label_id()) { for (const int32 label_id : output_detection.label_id()) {
if (label_map_.find(label_id) != label_map_.end()) { if (GetLabelMap(cc).index_to_item().find(label_id) !=
output_detection.add_label(label_map_[label_id]); GetLabelMap(cc).index_to_item().end()) {
auto item = GetLabelMap(cc).index_to_item().at(label_id);
output_detection.add_label(item.name());
if (item.has_display_name()) {
output_detection.add_display_name(item.display_name());
}
has_text_label = true; has_text_label = true;
} }
} }
// Remove label_id field if text labels exist. // Remove label_id field if text labels exist.
if (has_text_label && !options_.keep_label_id()) { if (has_text_label && !keep_label_id_) {
output_detection.clear_label_id(); output_detection.clear_label_id();
} }
} }
@ -117,4 +136,13 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
const LabelMap& DetectionLabelIdToTextCalculator::GetLabelMap(
CalculatorContext* cc) {
return !local_label_map_.index_to_item().empty()
? local_label_map_
: cc->Options<
::mediapipe::DetectionLabelIdToTextCalculatorOptions>()
.label_map();
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -17,6 +17,7 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/util/label_map.proto";
message DetectionLabelIdToTextCalculatorOptions { message DetectionLabelIdToTextCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
@ -26,7 +27,7 @@ message DetectionLabelIdToTextCalculatorOptions {
// Path to a label map file for getting the actual name of detected classes. // Path to a label map file for getting the actual name of detected classes.
optional string label_map_path = 1; optional string label_map_path = 1;
// Alternative way to specify label map // Alternative way to specify label map.
// label: "label for id 0" // label: "label for id 0"
// label: "label for id 1" // label: "label for id 1"
// ... // ...
@ -36,4 +37,7 @@ message DetectionLabelIdToTextCalculatorOptions {
// could be found. By setting this field to true, it is always copied to the // could be found. By setting this field to true, it is always copied to the
// output detections. // output detections.
optional bool keep_label_id = 3; optional bool keep_label_id = 3;
// Label map.
optional LabelMap label_map = 4;
} }

View File

@ -0,0 +1,298 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
namespace mediapipe {
namespace api2 {
namespace {
template <typename T>
T BoundedValue(T value, T upper_bound) {
T output = std::min(value, upper_bound);
if (output < 0) {
return 0;
}
return output;
}
absl::Status ConvertRelativeBoundingBoxToBoundingBox(
const std::pair<int, int>& image_size, Detection* detection) {
const int image_width = image_size.first;
const int image_height = image_size.second;
const auto& relative_bbox =
detection->location_data().relative_bounding_box();
auto* bbox = detection->mutable_location_data()->mutable_bounding_box();
bbox->set_xmin(
BoundedValue<int>(relative_bbox.xmin() * image_width, image_width));
bbox->set_ymin(
BoundedValue<int>(relative_bbox.ymin() * image_height, image_height));
bbox->set_width(
BoundedValue<int>(relative_bbox.width() * image_width, image_width));
bbox->set_height(
BoundedValue<int>(relative_bbox.height() * image_height, image_height));
detection->mutable_location_data()->set_format(LocationData::BOUNDING_BOX);
detection->mutable_location_data()->clear_relative_bounding_box();
return absl::OkStatus();
}
absl::Status ConvertBoundingBoxToRelativeBoundingBox(
const std::pair<int, int>& image_size, Detection* detection) {
int image_width = image_size.first;
int image_height = image_size.second;
const auto& bbox = detection->location_data().bounding_box();
auto* relative_bbox =
detection->mutable_location_data()->mutable_relative_bounding_box();
relative_bbox->set_xmin(
BoundedValue<float>((float)bbox.xmin() / image_width, 1.0f));
relative_bbox->set_ymin(
BoundedValue<float>((float)bbox.ymin() / image_height, 1.0f));
relative_bbox->set_width(
BoundedValue<float>((float)bbox.width() / image_width, 1.0f));
relative_bbox->set_height(
BoundedValue<float>((float)bbox.height() / image_height, 1.0f));
detection->mutable_location_data()->clear_bounding_box();
detection->mutable_location_data()->set_format(
LocationData::RELATIVE_BOUNDING_BOX);
return absl::OkStatus();
}
absl::StatusOr<LocationData::Format> GetLocationDataFormat(
const Detection& detection) {
if (!detection.has_location_data()) {
return absl::InvalidArgumentError("Detection must have location data.");
}
LocationData::Format format = detection.location_data().format();
RET_CHECK(format == LocationData::RELATIVE_BOUNDING_BOX ||
format == LocationData::BOUNDING_BOX)
<< "Detection's location data format must be either "
"RELATIVE_BOUNDING_BOX or BOUNDING_BOX";
return format;
}
absl::StatusOr<LocationData::Format> GetLocationDataFormat(
std::vector<Detection>& detections) {
RET_CHECK(!detections.empty());
LocationData::Format output_format;
ASSIGN_OR_RETURN(output_format, GetLocationDataFormat(detections[0]));
for (int i = 1; i < detections.size(); ++i) {
ASSIGN_OR_RETURN(LocationData::Format format,
GetLocationDataFormat(detections[i]));
if (output_format != format) {
return absl::InvalidArgumentError(
"Input detections have different location data formats.");
}
}
return output_format;
}
absl::Status ConvertBoundingBox(const std::pair<int, int>& image_size,
Detection* detection) {
if (!detection->has_location_data()) {
return absl::InvalidArgumentError("Detection must have location data.");
}
switch (detection->location_data().format()) {
case LocationData::RELATIVE_BOUNDING_BOX:
return ConvertRelativeBoundingBoxToBoundingBox(image_size, detection);
case LocationData::BOUNDING_BOX:
return ConvertBoundingBoxToRelativeBoundingBox(image_size, detection);
default:
return absl::InvalidArgumentError(
"Detection's location data format must be either "
"RELATIVE_BOUNDING_BOX or BOUNDING_BOX.");
}
}
} // namespace
// Transforms relative bounding box(es) to pixel bounding box(es) in a detection
// proto/detection list/detection vector, or vice versa.
//
// Inputs:
// One of the following:
// DETECTION: A Detection proto.
// DETECTIONS: An std::vector<Detection>/ a DetectionList proto.
// IMAGE_SIZE: A std::pair<int, int> represention image width and height.
//
// Outputs:
// At least one of the following:
// PIXEL_DETECTION: A Detection proto with pixel bounding box.
// PIXEL_DETECTIONS: An std::vector<Detection> with pixel bounding boxes.
// PIXEL_DETECTION_LIST: A DetectionList proto with pixel bounding boxes.
// RELATIVE_DETECTION: A Detection proto with relative bounding box.
// RELATIVE_DETECTIONS: An std::vector<Detection> with relative bounding boxes.
// RELATIVE_DETECTION_LIST: A DetectionList proto with relative bounding boxes.
//
// Example config:
// For input detection(s) with relative bounding box(es):
// node {
// calculator: "DetectionTransformationCalculator"
// input_stream: "DETECTION:input_detection"
// input_stream: "IMAGE_SIZE:image_size"
// output_stream: "PIXEL_DETECTION:output_detection"
// output_stream: "PIXEL_DETECTIONS:output_detections"
// output_stream: "PIXEL_DETECTION_LIST:output_detection_list"
// }
//
// For input detection(s) with pixel bounding box(es):
// node {
// calculator: "DetectionTransformationCalculator"
// input_stream: "DETECTION:input_detection"
// input_stream: "IMAGE_SIZE:image_size"
// output_stream: "RELATIVE_DETECTION:output_detection"
// output_stream: "RELATIVE_DETECTIONS:output_detections"
// output_stream: "RELATIVE_DETECTION_LIST:output_detection_list"
// }
class DetectionTransformationCalculator : public Node {
public:
static constexpr Input<Detection>::Optional kInDetection{"DETECTION"};
static constexpr Input<OneOf<DetectionList, std::vector<Detection>>>::Optional
kInDetections{"DETECTIONS"};
static constexpr Input<std::pair<int, int>> kInImageSize{"IMAGE_SIZE"};
static constexpr Output<Detection>::Optional kOutPixelDetection{
"PIXEL_DETECTION"};
static constexpr Output<std::vector<Detection>>::Optional kOutPixelDetections{
"PIXEL_DETECTIONS"};
static constexpr Output<DetectionList>::Optional kOutPixelDetectionList{
"PIXEL_DETECTION_LIST"};
static constexpr Output<Detection>::Optional kOutRelativeDetection{
"RELATIVE_DETECTION"};
static constexpr Output<std::vector<Detection>>::Optional
kOutRelativeDetections{"RELATIVE_DETECTIONS"};
static constexpr Output<DetectionList>::Optional kOutRelativeDetectionList{
"RELATIVE_DETECTION_LIST"};
MEDIAPIPE_NODE_CONTRACT(kInDetection, kInDetections, kInImageSize,
kOutPixelDetection, kOutPixelDetections,
kOutPixelDetectionList, kOutRelativeDetection,
kOutRelativeDetections, kOutRelativeDetectionList);
static absl::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK(kInImageSize(cc).IsConnected()) << "Image size must be provided.";
RET_CHECK(kInDetections(cc).IsConnected() ^ kInDetection(cc).IsConnected());
if (kInDetections(cc).IsConnected()) {
RET_CHECK(kOutPixelDetections(cc).IsConnected() ||
kOutPixelDetectionList(cc).IsConnected() ||
kOutRelativeDetections(cc).IsConnected() ||
kOutRelativeDetectionList(cc).IsConnected())
<< "Output must be a container of detections.";
}
RET_CHECK(kOutPixelDetections(cc).IsConnected() ||
kOutPixelDetectionList(cc).IsConnected() ||
kOutPixelDetection(cc).IsConnected() ||
kOutRelativeDetections(cc).IsConnected() ||
kOutRelativeDetectionList(cc).IsConnected() ||
kOutRelativeDetection(cc).IsConnected())
<< "Must connect at least one output stream.";
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
output_pixel_bounding_boxes_ = kOutPixelDetections(cc).IsConnected() ||
kOutPixelDetectionList(cc).IsConnected() ||
kOutPixelDetection(cc).IsConnected();
output_relative_bounding_boxes_ =
kOutRelativeDetections(cc).IsConnected() ||
kOutRelativeDetectionList(cc).IsConnected() ||
kOutRelativeDetection(cc).IsConnected();
RET_CHECK(output_pixel_bounding_boxes_ ^ output_relative_bounding_boxes_)
<< "All output streams must have the same stream tag prefix, either "
"\"PIXEL\" or \"RELATIVE_\".";
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
std::pair<int, int> image_size = kInImageSize(cc).Get();
std::vector<Detection> transformed_detections;
LocationData::Format input_location_data_format;
if (kInDetections(cc).IsConnected()) {
transformed_detections = kInDetections(cc).Visit(
[&](const DetectionList& detection_list) {
return std::vector<Detection>(detection_list.detection().begin(),
detection_list.detection().end());
},
[&](const std::vector<Detection>& detection_vector) {
return detection_vector;
});
ASSIGN_OR_RETURN(input_location_data_format,
GetLocationDataFormat(transformed_detections));
for (Detection& detection : transformed_detections) {
MP_RETURN_IF_ERROR(ConvertBoundingBox(image_size, &detection));
}
} else {
ASSIGN_OR_RETURN(input_location_data_format,
GetLocationDataFormat(kInDetection(cc).Get()));
Detection transformed_detection(kInDetection(cc).Get());
MP_RETURN_IF_ERROR(
ConvertBoundingBox(image_size, &transformed_detection));
transformed_detections.push_back(transformed_detection);
}
if (input_location_data_format == LocationData::RELATIVE_BOUNDING_BOX) {
RET_CHECK(!output_relative_bounding_boxes_)
<< "Input detections are with relative bounding box(es), and the "
"output detections must have pixel bounding box(es).";
if (kOutPixelDetection(cc).IsConnected()) {
kOutPixelDetection(cc).Send(transformed_detections[0]);
}
if (kOutPixelDetections(cc).IsConnected()) {
kOutPixelDetections(cc).Send(transformed_detections);
}
if (kOutPixelDetectionList(cc).IsConnected()) {
DetectionList detection_list;
for (const auto& detection : transformed_detections) {
detection_list.add_detection()->CopyFrom(detection);
}
kOutPixelDetectionList(cc).Send(detection_list);
}
} else {
RET_CHECK(!output_pixel_bounding_boxes_)
<< "Input detections are with pixel bounding box(es), and the "
"output detections must have relative bounding box(es).";
if (kOutRelativeDetection(cc).IsConnected()) {
kOutRelativeDetection(cc).Send(transformed_detections[0]);
}
if (kOutRelativeDetections(cc).IsConnected()) {
kOutRelativeDetections(cc).Send(transformed_detections);
}
if (kOutRelativeDetectionList(cc).IsConnected()) {
DetectionList detection_list;
for (const auto& detection : transformed_detections) {
detection_list.add_detection()->CopyFrom(detection);
}
kOutRelativeDetectionList(cc).Send(detection_list);
}
}
return absl::OkStatus();
}
private:
bool output_relative_bounding_boxes_;
bool output_pixel_bounding_boxes_;
};
MEDIAPIPE_REGISTER_NODE(DetectionTransformationCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,287 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <vector>
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
constexpr char kDetectionTag[] = "DETECTION";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kPixelDetectionTag[] = "PIXEL_DETECTION";
constexpr char kPixelDetectionListTag[] = "PIXEL_DETECTION_LIST";
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
constexpr char kRelativeDetectionListTag[] = "RELATIVE_DETECTION_LIST";
constexpr char kRelativeDetectionsTag[] = "RELATIVE_DETECTIONS";
Detection DetectionWithBoundingBox(int32 xmin, int32 ymin, int32 width,
int32 height) {
Detection detection;
LocationData* location_data = detection.mutable_location_data();
location_data->set_format(LocationData::BOUNDING_BOX);
location_data->mutable_bounding_box()->set_xmin(xmin);
location_data->mutable_bounding_box()->set_ymin(ymin);
location_data->mutable_bounding_box()->set_width(width);
location_data->mutable_bounding_box()->set_height(height);
return detection;
}
Detection DetectionWithRelativeBoundingBox(float xmin, float ymin, float width,
float height) {
Detection detection;
LocationData* location_data = detection.mutable_location_data();
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);
location_data->mutable_relative_bounding_box()->set_xmin(xmin);
location_data->mutable_relative_bounding_box()->set_ymin(ymin);
location_data->mutable_relative_bounding_box()->set_width(width);
location_data->mutable_relative_bounding_box()->set_height(height);
return detection;
}
std::vector<Detection> ConvertToDetectionVector(
const DetectionList& detection_list) {
std::vector<Detection> output;
for (const auto& detection : detection_list.detection()) {
output.push_back(detection);
}
return output;
}
void CheckBoundingBox(const Detection& output, const Detection& expected) {
const auto& output_bbox = output.location_data().bounding_box();
const auto& expected_bbox = output.location_data().bounding_box();
EXPECT_THAT(output_bbox.xmin(), testing::Eq(expected_bbox.xmin()));
EXPECT_THAT(output_bbox.ymin(), testing::Eq(expected_bbox.ymin()));
EXPECT_THAT(output_bbox.width(), testing::Eq(expected_bbox.width()));
EXPECT_THAT(output_bbox.height(), testing::Eq(expected_bbox.height()));
}
void CheckRelativeBoundingBox(const Detection& output,
const Detection& expected) {
const auto& output_bbox = output.location_data().relative_bounding_box();
const auto& expected_bbox = output.location_data().relative_bounding_box();
EXPECT_THAT(output_bbox.xmin(), testing::FloatEq(expected_bbox.xmin()));
EXPECT_THAT(output_bbox.ymin(), testing::FloatEq(expected_bbox.ymin()));
EXPECT_THAT(output_bbox.width(), testing::FloatEq(expected_bbox.width()));
EXPECT_THAT(output_bbox.height(), testing::FloatEq(expected_bbox.height()));
}
void CheckOutputDetections(const std::vector<Detection>& expected,
const std::vector<Detection>& output) {
ASSERT_EQ(output.size(), expected.size());
for (int i = 0; i < output.size(); ++i) {
auto output_format = output[i].location_data().format();
ASSERT_TRUE(output_format == LocationData::RELATIVE_BOUNDING_BOX ||
output_format == LocationData::BOUNDING_BOX);
ASSERT_EQ(output_format, expected[i].location_data().format());
if (output_format == LocationData::RELATIVE_BOUNDING_BOX) {
CheckRelativeBoundingBox(output[i], expected[i]);
}
if (output_format == LocationData::BOUNDING_BOX) {
CheckBoundingBox(output[i], expected[i]);
}
}
}
TEST(DetectionsTransformationCalculatorTest, MissingImageSize) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTIONS:detections"
output_stream: "PIXEL_DETECTION:detection"
)pb"));
auto status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("Image size must be provided"));
}
TEST(DetectionsTransformationCalculatorTest, WrongOutputType) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTIONS:detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "PIXEL_DETECTION:detection"
)pb"));
auto status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("Output must be a container of detections"));
}
TEST(DetectionsTransformationCalculatorTest, WrongLocationDataFormat) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTION:input_detection"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "PIXEL_DETECTION:output_detection"
)pb"));
Detection detection;
detection.mutable_location_data()->set_format(LocationData::GLOBAL);
runner.MutableInputs()
->Tag(kDetectionTag)
.packets.push_back(MakePacket<Detection>(detection).At(Timestamp(0)));
std::pair<int, int> image_size({2000, 1000});
runner.MutableInputs()
->Tag(kImageSizeTag)
.packets.push_back(
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
auto status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("location data format must be either "
"RELATIVE_BOUNDING_BOX or BOUNDING_BOX"));
}
TEST(DetectionsTransformationCalculatorTest,
ConvertBoundingBoxToRelativeBoundingBox) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTIONS:input_detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "RELATIVE_DETECTIONS:output_detections"
output_stream: "RELATIVE_DETECTION_LIST:output_detection_list"
)pb"));
auto detections(absl::make_unique<std::vector<Detection>>());
detections->push_back(DetectionWithBoundingBox(100, 200, 400, 300));
detections->push_back(DetectionWithBoundingBox(0, 0, 2000, 1000));
std::pair<int, int> image_size({2000, 1000});
runner.MutableInputs()
->Tag(kDetectionsTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(0)));
runner.MutableInputs()
->Tag(kImageSizeTag)
.packets.push_back(
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
std::vector<Detection> expected(
{DetectionWithRelativeBoundingBox(0.05, 0.2, 0.2, 0.3),
DetectionWithRelativeBoundingBox(0, 0, 1, 1)});
const std::vector<Packet>& detections_output =
runner.Outputs().Tag(kRelativeDetectionsTag).packets;
ASSERT_EQ(1, detections_output.size());
CheckOutputDetections(expected,
detections_output[0].Get<std::vector<Detection>>());
const std::vector<Packet>& detection_list_output =
runner.Outputs().Tag(kRelativeDetectionListTag).packets;
ASSERT_EQ(1, detection_list_output.size());
CheckOutputDetections(
expected,
ConvertToDetectionVector(detection_list_output[0].Get<DetectionList>()));
}
TEST(DetectionsTransformationCalculatorTest,
ConvertRelativeBoundingBoxToBoundingBox) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTIONS:input_detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "PIXEL_DETECTIONS:output_detections"
output_stream: "PIXEL_DETECTION_LIST:output_detection_list"
)pb"));
auto detections(absl::make_unique<std::vector<Detection>>());
detections->push_back(DetectionWithRelativeBoundingBox(0.1, 0.2, 0.3, 0.4));
detections->push_back(DetectionWithRelativeBoundingBox(0, 0, 1, 1));
std::pair<int, int> image_size({2000, 1000});
runner.MutableInputs()
->Tag(kDetectionsTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(0)));
runner.MutableInputs()
->Tag(kImageSizeTag)
.packets.push_back(
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
std::vector<Detection> expected({DetectionWithBoundingBox(100, 200, 400, 300),
DetectionWithBoundingBox(0, 0, 2000, 1000)});
const std::vector<Packet>& detections_output =
runner.Outputs().Tag(kPixelDetectionsTag).packets;
ASSERT_EQ(1, detections_output.size());
CheckOutputDetections(expected,
detections_output[0].Get<std::vector<Detection>>());
const std::vector<Packet>& detection_list_output =
runner.Outputs().Tag(kPixelDetectionListTag).packets;
ASSERT_EQ(1, detection_list_output.size());
CheckOutputDetections(
expected,
ConvertToDetectionVector(detection_list_output[0].Get<DetectionList>()));
}
TEST(DetectionsTransformationCalculatorTest, ConvertSingleDetection) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTION:input_detection"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "PIXEL_DETECTION:outpu_detection"
output_stream: "PIXEL_DETECTIONS:output_detections"
output_stream: "PIXEL_DETECTION_LIST:output_detection_list"
)pb"));
runner.MutableInputs()
->Tag(kDetectionTag)
.packets.push_back(MakePacket<Detection>(DetectionWithRelativeBoundingBox(
0.05, 0.2, 0.2, 0.3))
.At(Timestamp(0)));
std::pair<int, int> image_size({2000, 1000});
runner.MutableInputs()
->Tag(kImageSizeTag)
.packets.push_back(
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
std::vector<Detection> expected(
{DetectionWithBoundingBox(100, 200, 400, 300)});
const std::vector<Packet>& detection_output =
runner.Outputs().Tag(kPixelDetectionTag).packets;
ASSERT_EQ(1, detection_output.size());
CheckOutputDetections(expected, {detection_output[0].Get<Detection>()});
const std::vector<Packet>& detections_output =
runner.Outputs().Tag(kPixelDetectionsTag).packets;
ASSERT_EQ(1, detections_output.size());
CheckOutputDetections(expected,
detections_output[0].Get<std::vector<Detection>>());
const std::vector<Packet>& detection_list_output =
runner.Outputs().Tag(kPixelDetectionListTag).packets;
ASSERT_EQ(1, detection_list_output.size());
CheckOutputDetections(
expected,
ConvertToDetectionVector(detection_list_output[0].Get<DetectionList>()));
}
} // namespace
} // namespace mediapipe

View File

@ -181,7 +181,7 @@ class TrackingGraphTest : public Test {
// Each image is shifted to the right and bottom by kTranslationStep // Each image is shifted to the right and bottom by kTranslationStep
// pixels compared with the previous image. // pixels compared with the previous image.
static constexpr int kTranslationStep = 10; static constexpr int kTranslationStep = 10;
static constexpr float kEqualityTolerance = 3e-4f; static constexpr float kEqualityTolerance = 1e-3f;
}; };
void TrackingGraphTest::ExpectBoxAtFrame(const TimedBoxProto& box, float frame, void TrackingGraphTest::ExpectBoxAtFrame(const TimedBoxProto& box, float frame,

View File

@ -85,7 +85,7 @@ class KinematicPathSolver {
double current_position_px_; double current_position_px_;
double prior_position_px_; double prior_position_px_;
double current_velocity_deg_per_s_; double current_velocity_deg_per_s_;
uint64 current_time_; uint64 current_time_ = 0;
// History of observations (second) and their time (first). // History of observations (second) and their time (first).
std::deque<std::pair<uint64, int>> raw_positions_at_time_; std::deque<std::pair<uint64, int>> raw_positions_at_time_;
// Current target position. // Current target position.

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Example of reading a MediaSequence dataset. """Example of reading a MediaSequence dataset.
""" """

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "facedetectioncpu", name = "facedetectioncpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "facedetectiongpu", name = "facedetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "faceeffect", name = "faceeffect",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "facemeshgpu", name = "facemeshgpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "handdetectiongpu", name = "handdetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "handtrackinggpu", name = "handtrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "helloworld", name = "helloworld",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "holistictrackinggpu", name = "holistictrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "iristrackinggpu", name = "iristrackinggpu",

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""This script is used to set up automatic provisioning for iOS examples. """This script is used to set up automatic provisioning for iOS examples.
It scans the provisioning profiles used by Xcode, looking for one matching the It scans the provisioning profiles used by Xcode, looking for one matching the

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "objectdetectioncpu", name = "objectdetectioncpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "objectdetectiongpu", name = "objectdetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "objectdetectiontrackinggpu", name = "objectdetectiontrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "posetrackinggpu", name = "posetrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) licenses(["notice"])
MIN_IOS_VERSION = "10.0" MIN_IOS_VERSION = "11.0"
alias( alias(
name = "selfiesegmentationgpu", name = "selfiesegmentationgpu",

View File

@ -234,7 +234,9 @@ cc_library(
"//mediapipe/framework/tool:options_map", "//mediapipe/framework/tool:options_map",
"//mediapipe/framework/tool:packet_generator_wrapper_calculator_cc_proto", "//mediapipe/framework/tool:packet_generator_wrapper_calculator_cc_proto",
"//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:tag_map",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
], ],
) )
@ -348,6 +350,7 @@ cc_library(
"//mediapipe/framework/tool:validate", "//mediapipe/framework/tool:validate",
"//mediapipe/framework/tool:validate_name", "//mediapipe/framework/tool:validate_name",
"//mediapipe/gpu:graph_support", "//mediapipe/gpu:graph_support",
"//mediapipe/gpu:gpu_service",
"//mediapipe/util:cpu_util", "//mediapipe/util:cpu_util",
] + select({ ] + select({
"//conditions:default": ["//mediapipe/gpu:gpu_shared_data_internal"], "//conditions:default": ["//mediapipe/gpu:gpu_shared_data_internal"],
@ -416,7 +419,6 @@ cc_library(
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"//mediapipe/framework/tool:tag_map", "//mediapipe/framework/tool:tag_map",
"//mediapipe/framework/tool:validate_name", "//mediapipe/framework/tool:validate_name",
"//mediapipe/gpu:graph_support",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
@ -613,7 +615,11 @@ cc_library(
hdrs = ["graph_service.h"], hdrs = ["graph_service.h"],
visibility = [":mediapipe_internal"], visibility = [":mediapipe_internal"],
deps = [ deps = [
":packet",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -167,7 +167,6 @@ struct IsCompatibleType<V, OneOf<U...>>
template <typename T> template <typename T>
inline Packet<T> PacketBase::As() const { inline Packet<T> PacketBase::As() const {
if (!payload_) return Packet<T>().At(timestamp_); if (!payload_) return Packet<T>().At(timestamp_);
packet_internal::Holder<T>* typed_payload = payload_->As<T>();
internal::CheckCompatibleType(*payload_, internal::Wrap<T>{}); internal::CheckCompatibleType(*payload_, internal::Wrap<T>{});
return Packet<T>(payload_).At(timestamp_); return Packet<T>(payload_).At(timestamp_);
} }
@ -217,8 +216,8 @@ class Packet : public Packet<internal::Generic> {
const T& operator*() const { return Get(); } const T& operator*() const { return Get(); }
const T* operator->() const { return &Get(); } const T* operator->() const { return &Get(); }
template <typename U> template <typename U, typename TT = T>
T GetOr(U&& v) const { std::enable_if_t<!std::is_abstract_v<TT>, TT> GetOr(U&& v) const {
return IsEmpty() ? static_cast<T>(absl::forward<U>(v)) : **this; return IsEmpty() ? static_cast<T>(absl::forward<U>(v)) : **this;
} }

View File

@ -4,11 +4,15 @@ namespace api2 {
namespace { namespace {
#if defined(TEST_NO_ASSIGN_WRONG_PACKET_TYPE) #if defined(TEST_NO_ASSIGN_WRONG_PACKET_TYPE)
void AssignWrongPacketType() { Packet<int> p = MakePacket<float>(1.0); } int AssignWrongPacketType() {
Packet<int> p = MakePacket<float>(1.0);
return *p;
}
#elif defined(TEST_NO_ASSIGN_GENERIC_TO_SPECIFIC) #elif defined(TEST_NO_ASSIGN_GENERIC_TO_SPECIFIC)
void AssignWrongPacketType() { int AssignWrongPacketType() {
Packet<> p = MakePacket<float>(1.0); Packet<> p = MakePacket<float>(1.0);
Packet<int> p2 = p; Packet<int> p2 = p;
return *p2;
} }
#endif #endif

View File

@ -264,6 +264,23 @@ TEST(PacketTest, Polymorphism) {
EXPECT_EQ((**mutable_base).name(), "Derived"); EXPECT_EQ((**mutable_base).name(), "Derived");
} }
class AbstractBase {
public:
virtual ~AbstractBase() = default;
virtual absl::string_view name() const = 0;
};
class ConcreteDerived : public AbstractBase {
public:
absl::string_view name() const override { return "ConcreteDerived"; }
};
TEST(PacketTest, PolymorphismAbstract) {
Packet<AbstractBase> base =
PacketAdopting<AbstractBase>(absl::make_unique<ConcreteDerived>());
EXPECT_EQ(base->name(), "ConcreteDerived");
}
} // namespace } // namespace
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -40,6 +40,17 @@ TEST(PortTest, DeletedCopyConstructorInput) {
EXPECT_EQ(std::string(kSideOutputPort.Tag()), "SIDE_OUTPUT"); EXPECT_EQ(std::string(kSideOutputPort.Tag()), "SIDE_OUTPUT");
} }
class AbstractBase {
public:
virtual ~AbstractBase() = default;
virtual absl::string_view name() const = 0;
};
TEST(PortTest, Abstract) {
static constexpr Input<AbstractBase> kInputPort{"INPUT"};
EXPECT_EQ(std::string(kInputPort.Tag()), "INPUT");
}
} // namespace } // namespace
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -21,6 +21,8 @@
#include <typeindex> #include <typeindex>
// TODO: Move protos in another CL after the C++ code migration. // TODO: Move protos in another CL after the C++ code migration.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/calculator.pb.h" #include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/graph_service.h" #include "mediapipe/framework/graph_service.h"
#include "mediapipe/framework/mediapipe_options.pb.h" #include "mediapipe/framework/mediapipe_options.pb.h"
@ -147,7 +149,7 @@ class CalculatorContract {
bool IsOptional() const { return optional_; } bool IsOptional() const { return optional_; }
private: private:
GraphServiceBase service_; const GraphServiceBase& service_;
bool optional_ = false; bool optional_ = false;
}; };
@ -156,9 +158,12 @@ class CalculatorContract {
return it->second; return it->second;
} }
const std::map<std::string, GraphServiceRequest>& ServiceRequests() const { // A GraphService's key is always a static constant, so we can use string_view
return service_requests_; // as the key type without lifetime issues.
} using ServiceReqMap =
absl::flat_hash_map<absl::string_view, GraphServiceRequest>;
const ServiceReqMap& ServiceRequests() const { return service_requests_; }
private: private:
template <class T> template <class T>
@ -180,7 +185,7 @@ class CalculatorContract {
std::string input_stream_handler_; std::string input_stream_handler_;
MediaPipeOptions input_stream_handler_options_; MediaPipeOptions input_stream_handler_options_;
std::string node_name_; std::string node_name_;
std::map<std::string, GraphServiceRequest> service_requests_; ServiceReqMap service_requests_;
bool process_timestamps_ = false; bool process_timestamps_ = false;
TimestampDiff timestamp_offset_ = TimestampDiff::Unset(); TimestampDiff timestamp_offset_ = TimestampDiff::Unset();

View File

@ -226,6 +226,16 @@ absl::Status CalculatorGraph::InitializeStreams() {
return absl::OkStatus(); return absl::OkStatus();
} }
// Hack for backwards compatibility with ancient GPU calculators. Can it
// be retired yet?
static void MaybeFixupLegacyGpuNodeContract(CalculatorNode& node) {
#if !MEDIAPIPE_DISABLE_GPU
if (node.Contract().InputSidePackets().HasTag(kGpuSharedTagName)) {
const_cast<CalculatorContract&>(node.Contract()).UseService(kGpuService);
}
#endif // !MEDIAPIPE_DISABLE_GPU
}
absl::Status CalculatorGraph::InitializeCalculatorNodes() { absl::Status CalculatorGraph::InitializeCalculatorNodes() {
// Check if the user has specified a maximum queue size for an input stream. // Check if the user has specified a maximum queue size for an input stream.
max_queue_size_ = validated_graph_->Config().max_queue_size(); max_queue_size_ = validated_graph_->Config().max_queue_size();
@ -246,6 +256,7 @@ absl::Status CalculatorGraph::InitializeCalculatorNodes() {
validated_graph_.get(), node_ref, input_stream_managers_.get(), validated_graph_.get(), node_ref, input_stream_managers_.get(),
output_stream_managers_.get(), output_side_packets_.get(), output_stream_managers_.get(), output_side_packets_.get(),
&buffer_size_hint, profiler_); &buffer_size_hint, profiler_);
MaybeFixupLegacyGpuNodeContract(*nodes_.back());
if (buffer_size_hint > 0) { if (buffer_size_hint > 0) {
max_queue_size_ = std::max(max_queue_size_, buffer_size_hint); max_queue_size_ = std::max(max_queue_size_, buffer_size_hint);
} }
@ -283,6 +294,7 @@ absl::Status CalculatorGraph::InitializePacketGeneratorNodes(
validated_graph_.get(), node_ref, input_stream_managers_.get(), validated_graph_.get(), node_ref, input_stream_managers_.get(),
output_stream_managers_.get(), output_side_packets_.get(), output_stream_managers_.get(), output_side_packets_.get(),
&buffer_size_hint, profiler_); &buffer_size_hint, profiler_);
MaybeFixupLegacyGpuNodeContract(*nodes_.back());
if (!result.ok()) { if (!result.ok()) {
// Collect as many errors as we can before failing. // Collect as many errors as we can before failing.
errors.push_back(result); errors.push_back(result);
@ -495,9 +507,8 @@ absl::StatusOr<Packet> CalculatorGraph::GetOutputSidePacket(
<< "\" because it doesn't exist."; << "\" because it doesn't exist.";
} }
Packet output_packet; Packet output_packet;
if (scheduler_.IsTerminated()) { if (!output_side_packets_[side_packet_index].GetPacket().IsEmpty() ||
// Side-packets from calculators can be retrieved only after the graph is scheduler_.IsTerminated()) {
// done.
output_packet = output_side_packets_[side_packet_index].GetPacket(); output_packet = output_side_packets_[side_packet_index].GetPacket();
} }
if (output_packet.IsEmpty()) { if (output_packet.IsEmpty()) {
@ -546,6 +557,7 @@ absl::Status CalculatorGraph::StartRun(
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
absl::Status CalculatorGraph::SetGpuResources( absl::Status CalculatorGraph::SetGpuResources(
std::shared_ptr<::mediapipe::GpuResources> resources) { std::shared_ptr<::mediapipe::GpuResources> resources) {
RET_CHECK_NE(resources, nullptr);
auto gpu_service = service_manager_.GetServiceObject(kGpuService); auto gpu_service = service_manager_.GetServiceObject(kGpuService);
RET_CHECK_EQ(gpu_service, nullptr) RET_CHECK_EQ(gpu_service, nullptr)
<< "The GPU resources have already been configured."; << "The GPU resources have already been configured.";
@ -557,56 +569,56 @@ std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources()
return service_manager_.GetServiceObject(kGpuService); return service_manager_.GetServiceObject(kGpuService);
} }
absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu( static Packet GetLegacyGpuSharedSidePacket(
const std::map<std::string, Packet>& side_packets) { const std::map<std::string, Packet>& side_packets) {
std::map<std::string, Packet> additional_side_packets;
bool update_sp = false;
bool uses_gpu = false;
for (const auto& node : nodes_) {
if (node->UsesGpu()) {
uses_gpu = true;
break;
}
}
if (uses_gpu) {
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName); auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName);
// Workaround for b/116875321: CalculatorRunner provides an empty packet, if (legacy_sp_iter == side_packets.end()) return {};
// instead of just leaving it unset. // Note that, because of b/116875321, the legacy side packet may be set but
bool has_legacy_sp = legacy_sp_iter != side_packets.end() && // empty. But it's ok, because here we return an empty packet to indicate the
!legacy_sp_iter->second.IsEmpty(); // missing case anyway.
return legacy_sp_iter->second;
}
absl::Status CalculatorGraph::MaybeSetUpGpuServiceFromLegacySidePacket(
Packet legacy_sp) {
if (legacy_sp.IsEmpty()) return absl::OkStatus();
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
if (gpu_resources) { if (gpu_resources) {
if (has_legacy_sp) {
LOG(WARNING) LOG(WARNING)
<< "::mediapipe::GpuSharedData provided as a side packet while the " << "::mediapipe::GpuSharedData provided as a side packet while the "
<< "graph already had one; ignoring side packet"; << "graph already had one; ignoring side packet";
return absl::OkStatus();
} }
update_sp = true; gpu_resources = legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources;
} else { return service_manager_.SetServiceObject(kGpuService, gpu_resources);
if (has_legacy_sp) { }
gpu_resources =
legacy_sp_iter->second.Get<::mediapipe::GpuSharedData*>()
->gpu_resources;
} else {
ASSIGN_OR_RETURN(gpu_resources, ::mediapipe::GpuResources::Create());
update_sp = true;
}
MP_RETURN_IF_ERROR(
service_manager_.SetServiceObject(kGpuService, gpu_resources));
}
// Create or replace the legacy side packet if needed. std::map<std::string, Packet> CalculatorGraph::MaybeCreateLegacyGpuSidePacket(
if (update_sp) { Packet legacy_sp) {
legacy_gpu_shared_.reset(new ::mediapipe::GpuSharedData(gpu_resources)); std::map<std::string, Packet> additional_side_packets;
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
if (gpu_resources &&
(legacy_sp.IsEmpty() ||
legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources !=
gpu_resources)) {
legacy_gpu_shared_ =
absl::make_unique<mediapipe::GpuSharedData>(gpu_resources);
additional_side_packets[kGpuSharedSidePacketName] = additional_side_packets[kGpuSharedSidePacketName] =
MakePacket<::mediapipe::GpuSharedData*>(legacy_gpu_shared_.get()); MakePacket<::mediapipe::GpuSharedData*>(legacy_gpu_shared_.get());
} }
return additional_side_packets;
}
static bool UsesGpu(const CalculatorNode& node) {
return node.Contract().ServiceRequests().contains(kGpuService.key);
}
absl::Status CalculatorGraph::PrepareGpu() {
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
if (!gpu_resources) return absl::OkStatus();
// Set up executors. // Set up executors.
for (auto& node : nodes_) { for (auto& node : nodes_) {
if (node->UsesGpu()) { if (UsesGpu(*node)) {
MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(node.get())); MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(node.get()));
} }
} }
@ -614,11 +626,32 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
SetExecutorInternal(name_executor.first, name_executor.second)); SetExecutorInternal(name_executor.first, name_executor.second));
} }
} return absl::OkStatus();
return additional_side_packets;
} }
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
absl::Status CalculatorGraph::PrepareServices() {
for (const auto& node : nodes_) {
for (const auto& [key, request] : node->Contract().ServiceRequests()) {
auto packet = service_manager_.GetServicePacket(request.Service());
if (!packet.IsEmpty()) continue;
auto packet_or = request.Service().CreateDefaultObject();
if (packet_or.ok()) {
MP_RETURN_IF_ERROR(service_manager_.SetServicePacket(
request.Service(), std::move(packet_or).value()));
} else if (request.IsOptional()) {
continue;
} else {
return absl::InternalError(absl::StrCat(
"Service \"", request.Service().key, "\", required by node ",
node->DebugName(), ", was not provided and cannot be created: ",
std::move(packet_or).status().message()));
}
}
}
return absl::OkStatus();
}
absl::Status CalculatorGraph::PrepareForRun( absl::Status CalculatorGraph::PrepareForRun(
const std::map<std::string, Packet>& extra_side_packets, const std::map<std::string, Packet>& extra_side_packets,
const std::map<std::string, Packet>& stream_headers) { const std::map<std::string, Packet>& stream_headers) {
@ -637,7 +670,13 @@ absl::Status CalculatorGraph::PrepareForRun(
std::map<std::string, Packet> additional_side_packets; std::map<std::string, Packet> additional_side_packets;
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
ASSIGN_OR_RETURN(additional_side_packets, PrepareGpu(extra_side_packets)); auto legacy_sp = GetLegacyGpuSharedSidePacket(extra_side_packets);
MP_RETURN_IF_ERROR(MaybeSetUpGpuServiceFromLegacySidePacket(legacy_sp));
#endif // !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(PrepareServices());
#if !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(PrepareGpu());
additional_side_packets = MaybeCreateLegacyGpuSidePacket(legacy_sp);
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
const std::map<std::string, Packet>* input_side_packets; const std::map<std::string, Packet>* input_side_packets;

View File

@ -165,10 +165,13 @@ class CalculatorGraph {
StatusOrPoller AddOutputStreamPoller(const std::string& stream_name, StatusOrPoller AddOutputStreamPoller(const std::string& stream_name,
bool observe_timestamp_bounds = false); bool observe_timestamp_bounds = false);
// Gets output side packet by name after the graph is done. However, base // Gets output side packet by name. The output side packet can be successfully
// packets (generated by PacketGenerators) can be retrieved before // retrevied in one of the following situations:
// graph is done. Returns error if the graph is still running (for non-base // - The graph is done.
// packets) or the output side packet is not found or empty. // - The output side packet has been generated by a calculator and the graph
// is currently idle.
// - The side packet is a base packet generated by a PacketGenerator.
// Returns error if the the output side packet is not found or empty.
absl::StatusOr<Packet> GetOutputSidePacket(const std::string& packet_name); absl::StatusOr<Packet> GetOutputSidePacket(const std::string& packet_name);
// Runs the graph after adding the given extra input side packets. All // Runs the graph after adding the given extra input side packets. All
@ -367,13 +370,8 @@ class CalculatorGraph {
std::shared_ptr<GpuResources> GetGpuResources() const; std::shared_ptr<GpuResources> GetGpuResources() const;
absl::Status SetGpuResources(std::shared_ptr<GpuResources> resources); absl::Status SetGpuResources(std::shared_ptr<GpuResources> resources);
// Helper for PrepareForRun. If it returns a non-empty map, those packets
// must be added to the existing side packets, replacing existing values
// that have the same key.
absl::StatusOr<std::map<std::string, Packet>> PrepareGpu(
const std::map<std::string, Packet>& side_packets);
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
template <typename T> template <typename T>
absl::Status SetServiceObject(const GraphService<T>& service, absl::Status SetServiceObject(const GraphService<T>& service,
std::shared_ptr<T> object) { std::shared_ptr<T> object) {
@ -495,6 +493,18 @@ class CalculatorGraph {
const std::map<std::string, Packet>& extra_side_packets, const std::map<std::string, Packet>& extra_side_packets,
const std::map<std::string, Packet>& stream_headers); const std::map<std::string, Packet>& stream_headers);
absl::Status PrepareServices();
#if !MEDIAPIPE_DISABLE_GPU
absl::Status MaybeSetUpGpuServiceFromLegacySidePacket(Packet legacy_sp);
// Helper for PrepareForRun. If it returns a non-empty map, those packets
// must be added to the existing side packets, replacing existing values
// that have the same key.
std::map<std::string, Packet> MaybeCreateLegacyGpuSidePacket(
Packet legacy_sp);
absl::Status PrepareGpu();
#endif // !MEDIAPIPE_DISABLE_GPU
// Cleans up any remaining state after the run and returns any errors that may // Cleans up any remaining state after the run and returns any errors that may
// have occurred during the run. Called after the scheduler has terminated. // have occurred during the run. Called after the scheduler has terminated.
absl::Status FinishRun(); absl::Status FinishRun();

View File

@ -732,11 +732,12 @@ TEST(CalculatorGraph, GetOutputSidePacket) {
status_or_packet = graph.GetOutputSidePacket("unknown"); status_or_packet = graph.GetOutputSidePacket("unknown");
EXPECT_FALSE(status_or_packet.ok()); EXPECT_FALSE(status_or_packet.ok());
EXPECT_EQ(absl::StatusCode::kNotFound, status_or_packet.status().code()); EXPECT_EQ(absl::StatusCode::kNotFound, status_or_packet.status().code());
// Should return UNAVAILABLE before graph is done for valid non-base // Should return the packet after the graph becomes idle.
// packets. MP_ASSERT_OK(graph.WaitUntilIdle());
status_or_packet = graph.GetOutputSidePacket("num_of_packets"); status_or_packet = graph.GetOutputSidePacket("num_of_packets");
EXPECT_FALSE(status_or_packet.ok()); MP_ASSERT_OK(status_or_packet);
EXPECT_EQ(absl::StatusCode::kUnavailable, status_or_packet.status().code()); EXPECT_EQ(max_count, status_or_packet.value().Get<int>());
EXPECT_EQ(Timestamp::Unset(), status_or_packet.value().Timestamp());
// Should stil return a base even before graph is done. // Should stil return a base even before graph is done.
status_or_packet = graph.GetOutputSidePacket("output_uint64"); status_or_packet = graph.GetOutputSidePacket("output_uint64");
MP_ASSERT_OK(status_or_packet); MP_ASSERT_OK(status_or_packet);
@ -896,5 +897,23 @@ TEST(CalculatorGraph, GeneratorAfterCalculatorProcess) {
} }
} }
TEST(CalculatorGraph, GetOutputSidePacketAfterCalculatorIsOpened) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "IntegerOutputSidePacketCalculator"
output_side_packet: "offset"
}
)pb");
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
// Must be called to ensure that the calculator is opened.
MP_ASSERT_OK(graph.WaitUntilIdle());
absl::StatusOr<Packet> status_or_packet = graph.GetOutputSidePacket("offset");
MP_ASSERT_OK(status_or_packet);
EXPECT_EQ(1, status_or_packet.value().Get<int>());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -46,7 +46,6 @@
#include "mediapipe/framework/tool/status_util.h" #include "mediapipe/framework/tool/status_util.h"
#include "mediapipe/framework/tool/tag_map.h" #include "mediapipe/framework/tool/tag_map.h"
#include "mediapipe/framework/tool/validate_name.h" #include "mediapipe/framework/tool/validate_name.h"
#include "mediapipe/gpu/graph_support.h"
namespace mediapipe { namespace mediapipe {
@ -155,11 +154,6 @@ absl::Status CalculatorNode::Initialize(
const CalculatorContract& contract = node_type_info_->Contract(); const CalculatorContract& contract = node_type_info_->Contract();
uses_gpu_ =
node_type_info_->InputSidePacketTypes().HasTag(kGpuSharedTagName) ||
ContainsKey(node_type_info_->Contract().ServiceRequests(),
kGpuService.key);
// TODO Propagate types between calculators when SetAny is used. // TODO Propagate types between calculators when SetAny is used.
MP_RETURN_IF_ERROR(InitializeOutputSidePackets( MP_RETURN_IF_ERROR(InitializeOutputSidePackets(
@ -397,7 +391,7 @@ absl::Status CalculatorNode::PrepareForRun(
std::move(schedule_callback), error_callback); std::move(schedule_callback), error_callback);
output_stream_handler_->PrepareForRun(error_callback); output_stream_handler_->PrepareForRun(error_callback);
const auto& contract = node_type_info_->Contract(); const auto& contract = Contract();
input_side_packet_types_ = RemoveOmittedPacketTypes( input_side_packet_types_ = RemoveOmittedPacketTypes(
contract.InputSidePackets(), all_side_packets, validated_graph_); contract.InputSidePackets(), all_side_packets, validated_graph_);
MP_RETURN_IF_ERROR(input_side_packet_handler_.PrepareForRun( MP_RETURN_IF_ERROR(input_side_packet_handler_.PrepareForRun(

View File

@ -195,9 +195,6 @@ class CalculatorNode {
// Called by SchedulerQueue when a node is opened. // Called by SchedulerQueue when a node is opened.
void NodeOpened() ABSL_LOCKS_EXCLUDED(status_mutex_); void NodeOpened() ABSL_LOCKS_EXCLUDED(status_mutex_);
// Returns whether this is a GPU calculator node.
bool UsesGpu() const { return uses_gpu_; }
// Returns the scheduler queue the node is assigned to. // Returns the scheduler queue the node is assigned to.
internal::SchedulerQueue* GetSchedulerQueue() const { internal::SchedulerQueue* GetSchedulerQueue() const {
return scheduler_queue_; return scheduler_queue_;
@ -234,6 +231,12 @@ class CalculatorNode {
return *calculator_state_; return *calculator_state_;
} }
// Returns the node's contract.
// Must not be called before the CalculatorNode is initialized.
const CalculatorContract& Contract() const {
return node_type_info_->Contract();
}
private: private:
// Sets up the output side packets from the main flat array. // Sets up the output side packets from the main flat array.
absl::Status InitializeOutputSidePackets( absl::Status InitializeOutputSidePackets(
@ -363,9 +366,6 @@ class CalculatorNode {
std::unique_ptr<OutputStreamHandler> output_stream_handler_; std::unique_ptr<OutputStreamHandler> output_stream_handler_;
// Whether this is a GPU calculator.
bool uses_gpu_ = false;
// True if CleanupAfterRun() needs to call CloseNode(). // True if CleanupAfterRun() needs to call CloseNode().
bool needs_to_close_ = false; bool needs_to_close_ = false;

View File

@ -187,6 +187,21 @@ cc_library(
], ],
) )
config_setting(
name = "opencv",
define_values = {
"use_opencv": "true",
},
)
config_setting(
name = "portable_opencv",
define_values = {
"use_portable_opencv": "true",
"use_opencv": "false",
},
)
cc_library( cc_library(
name = "location", name = "location",
srcs = ["location.cc"], srcs = ["location.cc"],
@ -194,6 +209,8 @@ cc_library(
defines = select({ defines = select({
"//conditions:default": [], "//conditions:default": [],
"//mediapipe:android": ["MEDIAPIPE_ANDROID_OPENCV"], "//mediapipe:android": ["MEDIAPIPE_ANDROID_OPENCV"],
":portable_opencv": ["MEDIAPIPE_ANDROID_OPENCV"],
":opencv": [],
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [

View File

@ -76,7 +76,7 @@ class Tensor {
public: public:
// No resources are allocated here. // No resources are allocated here.
enum class ElementType { kNone, kFloat16, kFloat32, kUInt8 }; enum class ElementType { kNone, kFloat16, kFloat32, kUInt8, kInt8 };
struct Shape { struct Shape {
Shape() = default; Shape() = default;
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {} Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
@ -217,6 +217,8 @@ class Tensor {
return sizeof(float); return sizeof(float);
case ElementType::kUInt8: case ElementType::kUInt8:
return 1; return 1;
case ElementType::kInt8:
return 1;
} }
} }
int bytes() const { return shape_.num_elements() * element_size(); } int bytes() const { return shape_.num_elements() * element_size(); }

View File

@ -16,6 +16,12 @@
#define MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_H_ #define MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_H_
#include <memory> #include <memory>
#include <type_traits>
#include <utility>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe { namespace mediapipe {
@ -27,18 +33,74 @@ namespace mediapipe {
// IMPORTANT: this is an experimental API. Get in touch with the MediaPipe team // IMPORTANT: this is an experimental API. Get in touch with the MediaPipe team
// if you want to use it. In most cases, you should use a side packet instead. // if you want to use it. In most cases, you should use a side packet instead.
struct GraphServiceBase { class GraphServiceBase {
public:
// TODO: fix services for which default init is broken, remove
// this setting.
enum DefaultInitSupport {
kAllowDefaultInitialization,
kDisallowDefaultInitialization
};
constexpr GraphServiceBase(const char* key) : key(key) {} constexpr GraphServiceBase(const char* key) : key(key) {}
virtual ~GraphServiceBase() = default;
inline virtual absl::StatusOr<Packet> CreateDefaultObject() const {
return DefaultInitializationUnsupported();
}
const char* key; const char* key;
protected:
absl::Status DefaultInitializationUnsupported() const {
return absl::UnimplementedError(absl::StrCat(
"Graph service '", key, "' does not support default initialization"));
}
}; };
template <typename T> template <typename T>
struct GraphService : public GraphServiceBase { class GraphService : public GraphServiceBase {
public:
using type = T; using type = T;
using packet_type = std::shared_ptr<T>; using packet_type = std::shared_ptr<T>;
constexpr GraphService(const char* key) : GraphServiceBase(key) {} constexpr GraphService(const char* my_key, DefaultInitSupport default_init =
kDisallowDefaultInitialization)
: GraphServiceBase(my_key), default_init_(default_init) {}
absl::StatusOr<Packet> CreateDefaultObject() const override {
if (default_init_ != kAllowDefaultInitialization) {
return DefaultInitializationUnsupported();
}
auto packet_or = CreateDefaultObjectInternal();
if (packet_or.ok()) {
return MakePacket<std::shared_ptr<T>>(std::move(packet_or).value());
} else {
return packet_or.status();
}
}
private:
absl::StatusOr<std::shared_ptr<T>> CreateDefaultObjectInternal() const {
auto call_create = [](auto x) -> decltype(decltype(x)::type::Create()) {
return decltype(x)::type::Create();
};
if constexpr (std::is_invocable_r_v<absl::StatusOr<std::shared_ptr<T>>,
decltype(call_create), type_tag<T>>) {
return T::Create();
}
if constexpr (std::is_default_constructible_v<T>) {
return std::make_shared<T>();
}
return DefaultInitializationUnsupported();
}
template <class U>
struct type_tag {
using type = U;
};
DefaultInitSupport default_init_;
}; };
template <typename T> template <typename T>

View File

@ -35,6 +35,8 @@ class GraphServiceManager {
Packet GetServicePacket(const GraphServiceBase& service) const; Packet GetServicePacket(const GraphServiceBase& service) const;
std::map<std::string, Packet> service_packets_; std::map<std::string, Packet> service_packets_;
friend class CalculatorGraph;
}; };
} // namespace mediapipe } // namespace mediapipe

View File

@ -6,11 +6,13 @@
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe { namespace mediapipe {
namespace {
const GraphService<int> kIntService("mediapipe::IntService");
} // namespace
TEST(GraphServiceManager, SetGetServiceObject) { TEST(GraphServiceManager, SetGetServiceObject) {
GraphServiceManager service_manager; GraphServiceManager service_manager;
constexpr GraphService<int> kIntService("mediapipe::IntService");
EXPECT_EQ(service_manager.GetServiceObject(kIntService), nullptr); EXPECT_EQ(service_manager.GetServiceObject(kIntService), nullptr);
MP_EXPECT_OK(service_manager.SetServiceObject(kIntService, MP_EXPECT_OK(service_manager.SetServiceObject(kIntService,
@ -22,8 +24,6 @@ TEST(GraphServiceManager, SetGetServiceObject) {
TEST(GraphServiceManager, SetServicePacket) { TEST(GraphServiceManager, SetServicePacket) {
GraphServiceManager service_manager; GraphServiceManager service_manager;
constexpr GraphService<int> kIntService("mediapipe::IntService");
MP_EXPECT_OK(service_manager.SetServicePacket( MP_EXPECT_OK(service_manager.SetServicePacket(
kIntService, kIntService,
mediapipe::MakePacket<std::shared_ptr<int>>(std::make_shared<int>(100)))); mediapipe::MakePacket<std::shared_ptr<int>>(std::make_shared<int>(100))));
@ -36,8 +36,6 @@ TEST(GraphServiceManager, ServicePackets) {
EXPECT_TRUE(service_manager.ServicePackets().empty()); EXPECT_TRUE(service_manager.ServicePackets().empty());
constexpr GraphService<int> kIntService("mediapipe::IntService");
MP_EXPECT_OK(service_manager.SetServiceObject(kIntService, MP_EXPECT_OK(service_manager.SetServiceObject(kIntService,
std::make_shared<int>(100))); std::make_shared<int>(100)));

View File

@ -150,5 +150,12 @@ TEST_F(GraphServiceTest, OptionalIsAvailable) {
EXPECT_EQ(PacketValues<int>(output_packets_), (std::vector<int>{108})); EXPECT_EQ(PacketValues<int>(output_packets_), (std::vector<int>{108}));
} }
TEST_F(GraphServiceTest, CreateDefault) {
EXPECT_FALSE(kTestService.CreateDefaultObject().ok());
MP_EXPECT_OK(kAnotherService.CreateDefaultObject());
EXPECT_FALSE(kNoDefaultService.CreateDefaultObject().ok());
MP_EXPECT_OK(kNeedsCreateService.CreateDefaultObject());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -50,15 +50,18 @@ absl::Status InputStreamHandler::SetupInputShards(
return absl::OkStatus(); return absl::OkStatus();
} }
std::vector<std::pair<std::string, int>> std::vector<std::tuple<std::string, int, int, Timestamp>>
InputStreamHandler::GetMonitoringInfo() { InputStreamHandler::GetMonitoringInfo() {
std::vector<std::pair<std::string, int>> monitoring_info_vector; std::vector<std::tuple<std::string, int, int, Timestamp>>
monitoring_info_vector;
for (auto& stream : input_stream_managers_) { for (auto& stream : input_stream_managers_) {
if (!stream) { if (!stream) {
continue; continue;
} }
monitoring_info_vector.emplace_back( monitoring_info_vector.emplace_back(
std::pair<std::string, int>(stream->Name(), stream->QueueSize())); std::tuple<std::string, int, int, Timestamp>(
stream->Name(), stream->QueueSize(), stream->NumPacketsAdded(),
stream->MinTimestampOrBound(nullptr)));
} }
return monitoring_info_vector; return monitoring_info_vector;
} }

View File

@ -94,7 +94,7 @@ class InputStreamHandler {
// Returns a vector of pairs of stream name and queue size for monitoring // Returns a vector of pairs of stream name and queue size for monitoring
// purpose. // purpose.
std::vector<std::pair<std::string, int>> GetMonitoringInfo(); std::vector<std::tuple<std::string, int, int, Timestamp>> GetMonitoringInfo();
// Resets the input stream handler and its underlying input streams for // Resets the input stream handler and its underlying input streams for
// another run of the graph. // another run of the graph.

View File

@ -329,6 +329,11 @@ Packet InputStreamManager::PopQueueHead(bool* stream_is_done) {
return packet; return packet;
} }
int InputStreamManager::NumPacketsAdded() const {
absl::MutexLock lock(&stream_mutex_);
return num_packets_added_;
}
int InputStreamManager::QueueSize() const { int InputStreamManager::QueueSize() const {
absl::MutexLock lock(&stream_mutex_); absl::MutexLock lock(&stream_mutex_);
return static_cast<int>(queue_.size()); return static_cast<int>(queue_.size());

View File

@ -87,12 +87,14 @@ class InputStreamManager {
// Timestamp::PostStream(), the packet must be the only packet in the // Timestamp::PostStream(), the packet must be the only packet in the
// stream. // stream.
// Violation of any of these conditions causes an error status. // Violation of any of these conditions causes an error status.
absl::Status AddPackets(const std::list<Packet>& container, bool* notify); absl::Status AddPackets(const std::list<Packet>& container, bool* notify)
ABSL_LOCKS_EXCLUDED(stream_mutex_);
// Move a list of timestamped packets. Sets "notify" to true if the queue // Move a list of timestamped packets. Sets "notify" to true if the queue
// becomes non-empty. Does nothing if the input stream is closed. After the // becomes non-empty. Does nothing if the input stream is closed. After the
// move, all packets in the container must be empty. // move, all packets in the container must be empty.
absl::Status MovePackets(std::list<Packet>* container, bool* notify); absl::Status MovePackets(std::list<Packet>* container, bool* notify)
ABSL_LOCKS_EXCLUDED(stream_mutex_);
// Closes the input stream. This function can be called multiple times. // Closes the input stream. This function can be called multiple times.
void Close() ABSL_LOCKS_EXCLUDED(stream_mutex_); void Close() ABSL_LOCKS_EXCLUDED(stream_mutex_);
@ -140,6 +142,9 @@ class InputStreamManager {
// Timestamp::Done() after the pop. // Timestamp::Done() after the pop.
Packet PopQueueHead(bool* stream_is_done) ABSL_LOCKS_EXCLUDED(stream_mutex_); Packet PopQueueHead(bool* stream_is_done) ABSL_LOCKS_EXCLUDED(stream_mutex_);
// Returns the number of packets in the queue.
int NumPacketsAdded() const ABSL_LOCKS_EXCLUDED(stream_mutex_);
// Returns the number of packets in the queue. // Returns the number of packets in the queue.
int QueueSize() const ABSL_LOCKS_EXCLUDED(stream_mutex_); int QueueSize() const ABSL_LOCKS_EXCLUDED(stream_mutex_);

View File

@ -767,6 +767,7 @@ TEST_F(InputStreamManagerTest, QueueSizeTest) {
EXPECT_EQ(3, num_packets_dropped_); EXPECT_EQ(3, num_packets_dropped_);
EXPECT_TRUE(input_stream_manager_->IsEmpty()); EXPECT_TRUE(input_stream_manager_->IsEmpty());
EXPECT_FALSE(stream_is_done_); EXPECT_FALSE(stream_is_done_);
EXPECT_EQ(3, input_stream_manager_->NumPacketsAdded());
packets.clear(); packets.clear();
packets.push_back(MakePacket<std::string>("packet 4").At(Timestamp(60))); packets.push_back(MakePacket<std::string>("packet 4").At(Timestamp(60)));
@ -776,6 +777,7 @@ TEST_F(InputStreamManagerTest, QueueSizeTest) {
input_stream_manager_->AddPackets(packets, &notify_)); // Notification input_stream_manager_->AddPackets(packets, &notify_)); // Notification
EXPECT_FALSE(input_stream_manager_->IsEmpty()); EXPECT_FALSE(input_stream_manager_->IsEmpty());
EXPECT_TRUE(notify_); EXPECT_TRUE(notify_);
EXPECT_EQ(5, input_stream_manager_->NumPacketsAdded());
expected_queue_becomes_full_count_ = 2; expected_queue_becomes_full_count_ = 2;
expected_queue_becomes_not_full_count_ = 1; expected_queue_becomes_not_full_count_ = 1;

View File

@ -12,6 +12,8 @@ def mediapipe_cc_test(
timeout = None, timeout = None,
args = [], args = [],
additional_deps = DEFAULT_ADDITIONAL_TEST_DEPS, additional_deps = DEFAULT_ADDITIONAL_TEST_DEPS,
platforms = ["linux", "android", "ios", "wasm"],
exclude_platforms = None,
# ios_unit_test arguments # ios_unit_test arguments
ios_minimum_os_version = "9.0", ios_minimum_os_version = "9.0",
# android_cc_test arguments # android_cc_test arguments

View File

@ -412,8 +412,7 @@ cc_library(
name = "status_matchers", name = "status_matchers",
testonly = 1, testonly = 1,
hdrs = ["status_matchers.h"], hdrs = ["status_matchers.h"],
# Use this library through "mediapipe/framework/port:gtest_main". visibility = ["//visibility:private"],
visibility = ["//mediapipe/framework/port:__pkg__"],
deps = [ deps = [
":status", ":status",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",

View File

@ -16,8 +16,14 @@
namespace mediapipe { namespace mediapipe {
const GraphService<TestServiceObject> kTestService("test_service"); const GraphService<TestServiceObject> kTestService(
const GraphService<int> kAnotherService("another_service"); "test_service", GraphServiceBase::kDisallowDefaultInitialization);
const GraphService<int> kAnotherService(
"another_service", GraphServiceBase::kAllowDefaultInitialization);
const GraphService<NoDefaultConstructor> kNoDefaultService(
"no_default_service", GraphServiceBase::kAllowDefaultInitialization);
const GraphService<NeedsCreateMethod> kNeedsCreateService(
"needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) { absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>(); cc->Inputs().Index(0).Set<int>();

View File

@ -16,6 +16,7 @@
#define MEDIAPIPE_FRAMEWORK_TEST_SERVICE_H_ #define MEDIAPIPE_FRAMEWORK_TEST_SERVICE_H_
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/graph_service.h"
namespace mediapipe { namespace mediapipe {
@ -24,6 +25,23 @@ using TestServiceObject = std::map<std::string, int>;
extern const GraphService<TestServiceObject> kTestService; extern const GraphService<TestServiceObject> kTestService;
extern const GraphService<int> kAnotherService; extern const GraphService<int> kAnotherService;
class NoDefaultConstructor {
public:
NoDefaultConstructor() = delete;
};
extern const GraphService<NoDefaultConstructor> kNoDefaultService;
class NeedsCreateMethod {
public:
static absl::StatusOr<std::shared_ptr<NeedsCreateMethod>> Create() {
return std::shared_ptr<NeedsCreateMethod>(new NeedsCreateMethod());
}
private:
NeedsCreateMethod() = default;
};
extern const GraphService<NeedsCreateMethod> kNeedsCreateService;
// Use a service. // Use a service.
class TestServiceCalculator : public CalculatorBase { class TestServiceCalculator : public CalculatorBase {
public: public:

View File

@ -134,7 +134,7 @@ cc_library(
name = "name_util", name = "name_util",
srcs = ["name_util.cc"], srcs = ["name_util.cc"],
hdrs = ["name_util.h"], hdrs = ["name_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"], visibility = ["//visibility:public"],
deps = [ deps = [
":validate_name", ":validate_name",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",

View File

@ -225,7 +225,7 @@ std::string GetTestOutputsDir() {
return output_dir; return output_dir;
} }
std::string GetTestDataDir(const std::string& package_base_path) { std::string GetTestDataDir(absl::string_view package_base_path) {
return file::JoinPath(GetTestRootDir(), package_base_path, "testdata/"); return file::JoinPath(GetTestRootDir(), package_base_path, "testdata/");
} }
@ -270,7 +270,7 @@ absl::StatusOr<std::unique_ptr<ImageFrame>> LoadTestImage(
format, width, height, width * output_channels, data, stbi_image_free); format, width, height, width * output_channels, data, stbi_image_free);
} }
std::unique_ptr<ImageFrame> LoadTestPng(const std::string& path, std::unique_ptr<ImageFrame> LoadTestPng(absl::string_view path,
ImageFormat::Format format) { ImageFormat::Format format) {
return nullptr; return nullptr;
} }

View File

@ -63,7 +63,7 @@ std::string GetTestFilePath(absl::string_view relative_path);
// directory. // directory.
// This handles the different paths where test data ends up when using // This handles the different paths where test data ends up when using
// ion_cc_test on various platforms. // ion_cc_test on various platforms.
std::string GetTestDataDir(const std::string& package_base_path); std::string GetTestDataDir(absl::string_view package_base_path);
// Loads a binary graph from path. Returns true iff successful. // Loads a binary graph from path. Returns true iff successful.
bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path); bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path);
@ -75,7 +75,7 @@ absl::StatusOr<std::unique_ptr<ImageFrame>> LoadTestImage(
// Loads a PNG image from path using the given ImageFormat. Returns nullptr in // Loads a PNG image from path using the given ImageFormat. Returns nullptr in
// case of failure. // case of failure.
std::unique_ptr<ImageFrame> LoadTestPng( std::unique_ptr<ImageFrame> LoadTestPng(
const std::string& path, ImageFormat::Format format = ImageFormat::SRGBA); absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA);
// Returns the luminance image of |original_image|. // Returns the luminance image of |original_image|.
// The format of |original_image| must be sRGB or sRGBA. // The format of |original_image| must be sRGB or sRGBA.

View File

@ -38,14 +38,19 @@ cc_library(
srcs = ["gpu_service.cc"], srcs = ["gpu_service.cc"],
hdrs = ["gpu_service.h"], hdrs = ["gpu_service.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:graph_service"], deps = ["//mediapipe/framework:graph_service"] + select({
"//conditions:default": [
":gpu_shared_data_internal",
],
"//mediapipe/gpu:disable_gpu": [],
}),
) )
cc_library( cc_library(
name = "graph_support", name = "graph_support",
hdrs = ["graph_support.h"], hdrs = ["graph_support.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [":gpu_service"], deps = ["//mediapipe/framework:graph_service"],
) )
GL_BASE_LINK_OPTS = select({ GL_BASE_LINK_OPTS = select({
@ -366,7 +371,6 @@ objc_library(
hdrs = ["pixel_buffer_pool_util.h"], hdrs = ["pixel_buffer_pool_util.h"],
copts = [ copts = [
"-Wno-shorten-64-to-32", "-Wno-shorten-64-to-32",
"-std=c++17",
], ],
sdk_frameworks = [ sdk_frameworks = [
"Accelerate", "Accelerate",
@ -389,7 +393,6 @@ objc_library(
copts = [ copts = [
"-x objective-c++", "-x objective-c++",
"-Wno-shorten-64-to-32", "-Wno-shorten-64-to-32",
"-std=c++17",
], ],
features = ["-layering_check"], features = ["-layering_check"],
sdk_frameworks = [ sdk_frameworks = [
@ -425,7 +428,6 @@ objc_library(
copts = [ copts = [
"-x objective-c++", "-x objective-c++",
"-Wno-shorten-64-to-32", "-Wno-shorten-64-to-32",
"-std=c++17",
], ],
sdk_frameworks = [ sdk_frameworks = [
"CoreVideo", "CoreVideo",
@ -691,7 +693,6 @@ objc_library(
name = "gl_calculator_helper_ios", name = "gl_calculator_helper_ios",
copts = [ copts = [
"-Wno-shorten-64-to-32", "-Wno-shorten-64-to-32",
"-std=c++17",
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
@ -707,7 +708,6 @@ objc_library(
hdrs = ["MPPMetalHelper.h"], hdrs = ["MPPMetalHelper.h"],
copts = [ copts = [
"-Wno-shorten-64-to-32", "-Wno-shorten-64-to-32",
"-std=c++17",
], ],
features = ["-layering_check"], features = ["-layering_check"],
sdk_frameworks = [ sdk_frameworks = [
@ -801,7 +801,6 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":gl_calculator_helper", ":gl_calculator_helper",
":gpu_buffer_storage_image_frame",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
@ -927,7 +926,6 @@ mediapipe_cc_proto_library(
objc_library( objc_library(
name = "metal_copy_calculator", name = "metal_copy_calculator",
srcs = ["MetalCopyCalculator.mm"], srcs = ["MetalCopyCalculator.mm"],
copts = ["-std=c++17"],
features = ["-layering_check"], features = ["-layering_check"],
sdk_frameworks = [ sdk_frameworks = [
"CoreVideo", "CoreVideo",
@ -946,7 +944,6 @@ objc_library(
objc_library( objc_library(
name = "metal_rgb_weight_calculator", name = "metal_rgb_weight_calculator",
srcs = ["MetalRgbWeightCalculator.mm"], srcs = ["MetalRgbWeightCalculator.mm"],
copts = ["-std=c++17"],
features = ["-layering_check"], features = ["-layering_check"],
sdk_frameworks = [ sdk_frameworks = [
"CoreVideo", "CoreVideo",
@ -964,7 +961,6 @@ objc_library(
objc_library( objc_library(
name = "metal_sobel_calculator", name = "metal_sobel_calculator",
srcs = ["MetalSobelCalculator.mm"], srcs = ["MetalSobelCalculator.mm"],
copts = ["-std=c++17"],
features = ["-layering_check"], features = ["-layering_check"],
sdk_frameworks = [ sdk_frameworks = [
"CoreVideo", "CoreVideo",
@ -982,7 +978,6 @@ objc_library(
objc_library( objc_library(
name = "metal_sobel_compute_calculator", name = "metal_sobel_compute_calculator",
srcs = ["MetalSobelComputeCalculator.mm"], srcs = ["MetalSobelComputeCalculator.mm"],
copts = ["-std=c++17"],
features = ["-layering_check"], features = ["-layering_check"],
sdk_frameworks = [ sdk_frameworks = [
"CoreVideo", "CoreVideo",
@ -1018,7 +1013,6 @@ objc_library(
objc_library( objc_library(
name = "mps_threshold_calculator", name = "mps_threshold_calculator",
srcs = ["MPSThresholdCalculator.mm"], srcs = ["MPSThresholdCalculator.mm"],
copts = ["-std=c++17"],
sdk_frameworks = [ sdk_frameworks = [
"CoreVideo", "CoreVideo",
"Metal", "Metal",
@ -1053,7 +1047,6 @@ objc_library(
], ],
copts = [ copts = [
"-Wno-shorten-64-to-32", "-Wno-shorten-64-to-32",
"-std=c++17",
], ],
data = [ data = [
"//mediapipe/objc:testdata/googlelogo_color_272x92dp.png", "//mediapipe/objc:testdata/googlelogo_color_272x92dp.png",

View File

@ -23,6 +23,7 @@
#include "absl/base/dynamic_annotations.h" #include "absl/base/dynamic_annotations.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
@ -358,6 +359,7 @@ absl::Status GlContext::FinishInitialization(bool create_thread) {
GlContext::GlContext() {} GlContext::GlContext() {}
GlContext::~GlContext() { GlContext::~GlContext() {
destructing_ = true;
// Note: on Apple platforms, this object contains Objective-C objects. // Note: on Apple platforms, this object contains Objective-C objects.
// The destructor will release them, but ARC must be on. // The destructor will release them, but ARC must be on.
#ifdef __OBJC__ #ifdef __OBJC__
@ -366,11 +368,16 @@ GlContext::~GlContext() {
#endif #endif
#endif // __OBJC__ #endif // __OBJC__
if (thread_) { auto clear_attachments = [this] {
auto status = thread_->Run([this] { attachments_.clear();
if (profiling_helper_) { if (profiling_helper_) {
profiling_helper_->LogAllTimestamps(); profiling_helper_->LogAllTimestamps();
} }
};
if (thread_) {
auto status = thread_->Run([this, clear_attachments] {
clear_attachments();
return ExitContext(nullptr); return ExitContext(nullptr);
}); });
LOG_IF(ERROR, !status.ok()) LOG_IF(ERROR, !status.ok())
@ -378,6 +385,17 @@ GlContext::~GlContext() {
if (thread_->IsCurrentThread()) { if (thread_->IsCurrentThread()) {
thread_.release()->SelfDestruct(); thread_.release()->SelfDestruct();
} }
} else {
if (IsCurrent()) {
clear_attachments();
} else {
ContextBinding saved_context;
auto status = SwitchContextAndRun([&clear_attachments] {
clear_attachments();
return absl::OkStatus();
});
LOG_IF(ERROR, !status.ok()) << status;
}
} }
DestroyContext(); DestroyContext();
} }
@ -501,6 +519,14 @@ absl::Status GlContext::SwitchContext(ContextBinding* saved_context,
} }
} }
GlContext::ContextBinding GlContext::ThisContextBinding() {
GlContext::ContextBinding result = ThisContextBindingPlatform();
if (!destructing_) {
result.context_object = shared_from_this();
}
return result;
}
absl::Status GlContext::EnterContext(ContextBinding* saved_context) { absl::Status GlContext::EnterContext(ContextBinding* saved_context) {
DCHECK(HasContext()); DCHECK(HasContext());
return SwitchContext(saved_context, ThisContextBinding()); return SwitchContext(saved_context, ThisContextBinding());

View File

@ -21,6 +21,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/executor.h" #include "mediapipe/framework/executor.h"
#include "mediapipe/framework/mediapipe_profiling.h" #include "mediapipe/framework/mediapipe_profiling.h"
@ -285,6 +286,48 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
// Sets default texture filtering parameters. // Sets default texture filtering parameters.
void SetStandardTextureParams(GLenum target, GLint internal_format); void SetStandardTextureParams(GLenum target, GLint internal_format);
template <class T>
using AttachmentPtr = std::unique_ptr<T, std::function<void(void*)>>;
template <class T, class... Args>
static std::enable_if_t<!std::is_array<T>::value, AttachmentPtr<T>>
MakeAttachmentPtr(Args&&... args) {
return {new T(std::forward<Args>(args)...),
[](void* ptr) { delete static_cast<T*>(ptr); }};
}
class AttachmentBase {};
template <class T>
class Attachment : public AttachmentBase {
public:
using FactoryT = std::function<AttachmentPtr<T>(GlContext&)>;
Attachment(FactoryT factory) : factory_(factory) {}
Attachment(const Attachment&) = delete;
Attachment(Attachment&&) = delete;
Attachment& operator=(const Attachment&) = delete;
Attachment& operator=(Attachment&&) = delete;
T& Get(GlContext& ctx) const { return ctx.GetCachedAttachment(*this); }
const FactoryT& factory() const { return factory_; }
private:
FactoryT factory_;
};
// TOOD: const result?
template <class T>
T& GetCachedAttachment(const Attachment<T>& attachment) {
DCHECK(IsCurrent());
AttachmentPtr<void>& entry = attachments_[&attachment];
if (entry == nullptr) {
entry = attachment.factory()(*this);
}
return *static_cast<T*>(entry.get());
}
// These are used for testing specific SyncToken implementations. Do not use // These are used for testing specific SyncToken implementations. Do not use
// outside of tests. // outside of tests.
enum class SyncTokenTypeForTest { enum class SyncTokenTypeForTest {
@ -387,6 +430,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
// A binding that can be used to make this GlContext current. // A binding that can be used to make this GlContext current.
ContextBinding ThisContextBinding(); ContextBinding ThisContextBinding();
// Fill in platform-specific fields. Must _not_ set context_obj.
ContextBinding ThisContextBindingPlatform();
// Fills in a ContextBinding with platform-specific information about which // Fills in a ContextBinding with platform-specific information about which
// context is current on this thread. // context is current on this thread.
static void GetCurrentContextBinding(ContextBinding* binding); static void GetCurrentContextBinding(ContextBinding* binding);
@ -409,6 +454,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
// better mechanism? // better mechanism?
bool can_linear_filter_float_textures_; bool can_linear_filter_float_textures_;
absl::flat_hash_map<const AttachmentBase*, AttachmentPtr<void>> attachments_;
// Number of glFinish calls completed on the GL thread. // Number of glFinish calls completed on the GL thread.
// Changes should be guarded by mutex_. However, we use simple atomic // Changes should be guarded by mutex_. However, we use simple atomic
// loads for efficiency on the fast path. // loads for efficiency on the fast path.
@ -428,6 +475,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
absl::CondVar wait_for_gl_finish_cv_ ABSL_GUARDED_BY(mutex_); absl::CondVar wait_for_gl_finish_cv_ ABSL_GUARDED_BY(mutex_);
std::unique_ptr<mediapipe::GlProfilingHelper> profiling_helper_ = nullptr; std::unique_ptr<mediapipe::GlProfilingHelper> profiling_helper_ = nullptr;
bool destructing_ = false;
}; };
// For backward compatibility. TODO: migrate remaining callers. // For backward compatibility. TODO: migrate remaining callers.

View File

@ -84,9 +84,8 @@ void GlContext::DestroyContext() {
} }
} }
GlContext::ContextBinding GlContext::ThisContextBinding() { GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
GlContext::ContextBinding result; GlContext::ContextBinding result;
result.context_object = shared_from_this();
result.context = context_; result.context = context_;
return result; return result;
} }

View File

@ -269,9 +269,8 @@ void GlContext::DestroyContext() {
#endif // __ANDROID__ #endif // __ANDROID__
} }
GlContext::ContextBinding GlContext::ThisContextBinding() { GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
GlContext::ContextBinding result; GlContext::ContextBinding result;
result.context_object = shared_from_this();
result.display = display_; result.display = display_;
result.draw_surface = surface_; result.draw_surface = surface_;
result.read_surface = surface_; result.read_surface = surface_;

View File

@ -134,9 +134,8 @@ void GlContext::DestroyContext() {
} }
} }
GlContext::ContextBinding GlContext::ThisContextBinding() { GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
GlContext::ContextBinding result; GlContext::ContextBinding result;
result.context_object = shared_from_this();
result.context = context_; result.context = context_;
return result; return result;
} }

View File

@ -173,9 +173,8 @@ void GlContext::DestroyContext() {
} }
} }
GlContext::ContextBinding GlContext::ThisContextBinding() { GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
GlContext::ContextBinding result; GlContext::ContextBinding result;
result.context_object = shared_from_this();
result.context = context_; result.context = context_;
return result; return result;
} }

View File

@ -111,7 +111,7 @@ absl::Status QuadRenderer::GlRender(float frame_width, float frame_height,
FrameScaleMode scale_mode, FrameScaleMode scale_mode,
FrameRotation rotation, FrameRotation rotation,
bool flip_horizontal, bool flip_vertical, bool flip_horizontal, bool flip_vertical,
bool flip_texture) { bool flip_texture) const {
RET_CHECK(program_) << "Must setup the program before rendering."; RET_CHECK(program_) << "Must setup the program before rendering.";
glUseProgram(program_); glUseProgram(program_);

View File

@ -72,7 +72,7 @@ class QuadRenderer {
absl::Status GlRender(float frame_width, float frame_height, float view_width, absl::Status GlRender(float frame_width, float frame_height, float view_width,
float view_height, FrameScaleMode scale_mode, float view_height, FrameScaleMode scale_mode,
FrameRotation rotation, bool flip_horizontal, FrameRotation rotation, bool flip_horizontal,
bool flip_vertical, bool flip_texture); bool flip_vertical, bool flip_texture) const;
// Deletes the rendering program. Must be called withn the GL context where // Deletes the rendering program. Must be called withn the GL context where
// it was created. // it was created.
void GlTeardown(); void GlTeardown();

View File

@ -144,7 +144,7 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
}}, }},
{GpuBufferFormat::kRGBAFloat128, {GpuBufferFormat::kRGBAFloat128,
{ {
{GL_RGBA, GL_RGBA, GL_FLOAT, 1}, {GL_RGBA32F, GL_RGBA, GL_FLOAT, 1},
}}, }},
}}; }};

View File

@ -16,6 +16,7 @@
namespace mediapipe { namespace mediapipe {
const GraphService<GpuResources> kGpuService("kGpuService"); const GraphService<GpuResources> kGpuService(
"kGpuService", GraphServiceBase::kAllowDefaultInitialization);
} // namespace mediapipe } // namespace mediapipe

View File

@ -17,9 +17,18 @@
#include "mediapipe/framework/graph_service.h" #include "mediapipe/framework/graph_service.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_shared_data_internal.h"
#endif // !MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
class GpuResources; #if MEDIAPIPE_DISABLE_GPU
class GpuResources {
GpuResources() = delete;
};
#endif // !MEDIAPIPE_DISABLE_GPU
extern const GraphService<GpuResources> kGpuService; extern const GraphService<GpuResources> kGpuService;
} // namespace mediapipe } // namespace mediapipe

Some files were not shown because too many files have changed in this diff Show More