Project import generated by Copybara.

GitOrigin-RevId: 19a829ffd755edb43e54d20c0e7b9348512d5108
This commit is contained in:
MediaPipe Team 2022-05-03 15:29:57 -07:00 committed by schmidt-sebastian
parent c6c80c3745
commit 7fb37c80e8
136 changed files with 2572 additions and 555 deletions

View File

@ -32,6 +32,9 @@ build:macos --copt=-w
# Sets the default Apple platform to macOS.
build --apple_platform_type=macos
# Compile ObjC++ files with C++17
build --per_file_copt=.*\.mm\$@-std=c++17
# Allow debugging with XCODE
build --apple_generate_dsym
@ -88,6 +91,10 @@ build:darwin_x86_64 --apple_platform_type=macos
build:darwin_x86_64 --macos_minimum_os=10.12
build:darwin_x86_64 --cpu=darwin_x86_64
build:darwin_arm64 --apple_platform_type=macos
build:darwin_arm64 --macos_minimum_os=10.16
build:darwin_arm64 --cpu=darwin_arm64
# This bazelrc file is meant to be written by a setup script.
try-import %workspace%/.configure.bazelrc

View File

@ -202,7 +202,10 @@ new_local_repository(
new_local_repository(
name = "macos_opencv",
build_file = "@//third_party:opencv_macos.BUILD",
path = "/usr/local/opt/opencv@3",
# For local MacOS builds, the path should point to an opencv@3 installation.
# If you edit the path here, you will also need to update the corresponding
# prefix in "opencv_macos.BUILD".
path = "/usr/local",
)
new_local_repository(

View File

@ -53,7 +53,7 @@ the following:
```bash
$ echo "android_sdk_repository(name = \"androidsdk\")" >> WORKSPACE
$ echo "android_ndk_repository(name = \"androidndk\")" >> WORKSPACE
$ echo "android_ndk_repository(name = \"androidndk\", api_level=21)" >> WORKSPACE
```
In order to use MediaPipe on earlier Android versions, MediaPipe needs to switch

View File

@ -59,6 +59,21 @@ OpenGL ES profile shading language version string: OpenGL ES GLSL ES 3.20
OpenGL ES profile extensions:
```
If you have connected to your computer through SSH and find when you probe for
GPU information you see the output:
```bash
glxinfo | grep -i opengl
Error: unable to open display
```
Try re-establishing your SSH connection with the `-X` option and try again. For
example:
```bash
ssh -X <user>@<host>
```
*Notice the ES 3.20 text above.*
You need to see ES 3.1 or greater printed in order to perform TFLite inference

View File

@ -131,7 +131,7 @@ Create a `BUILD` file in the `$APPLICATION_PATH` and add the following build
rules:
```
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
load(
"@build_bazel_rules_apple//apple:ios.bzl",

View File

@ -32,9 +32,14 @@ example apps, start from, start from
xcode-select --install
```
3. Install [Bazel](https://bazel.build/).
3. Install [Bazelisk](https://github.com/bazelbuild/bazelisk)
.
We recommend using [Homebrew](https://brew.sh/) to get the latest version.
We recommend using [Homebrew](https://brew.sh/) to get the latest versions.
```bash
brew install bazelisk
```
4. Set Python 3.7 as the default Python version and install the Python "six"
library. This is needed for TensorFlow.
@ -187,6 +192,9 @@ Note: When you ask Xcode to run an app, by default it will use the Debug
configuration. Some of our demos are computationally heavy; you may want to use
the Release configuration for better performance.
Note: Due to an imcoptibility caused by one of our dependencies, MediaPipe
cannot be used for apps running on the iPhone Simulator on Apple Silicon (M1).
Tip: To switch build configuration in Xcode, click on the target menu, choose
"Edit Scheme...", select the Run action, and switch the Build Configuration from
Debug to Release. Note that this is set independently for each target.

View File

@ -258,13 +258,14 @@ Many of the following settings are advanced and not recommended for general
usage. Consult [Enabling tracing and profiling](#enabling-tracing-and-profiling)
for a friendlier introduction.
histogram_interval_size_usec :Specifies the size of the runtimes histogram
intervals (in microseconds) to generate the histogram of the Process() time. The
last interval extends to +inf. If not specified, the interval is 1000000 usec =
1 sec.
histogram_interval_size_usec
: Specifies the size of the runtimes histogram intervals (in microseconds) to
generate the histogram of the `Process()` time. The last interval extends to
+inf. If not specified, the interval is 1000000 usec = 1 sec.
num_histogram_intervals :Specifies the number of intervals to generate the
histogram of the `Process()` runtime. If not specified, one interval is used.
num_histogram_intervals
: Specifies the number of intervals to generate the histogram of the
`Process()` runtime. If not specified, one interval is used.
enable_profiler
: If true, the profiler starts profiling when graph is initialized.
@ -288,7 +289,7 @@ trace_event_types_disabled
trace_log_path
: The output directory and base-name prefix for trace log files. Log files are
written to: StrCat(trace_log_path, index, "`.binarypb`")
written to: `StrCat(trace_log_path, index, ".binarypb")`
trace_log_count
: The number of trace log files retained. The trace log files are named
@ -310,8 +311,8 @@ trace_log_instant_events
trace_log_interval_count
: The number of trace log intervals per file. The total log duration is:
`trace_log_interval_usec * trace_log_file_count * trace_log_interval_count`.
The default value specifies 10 intervals per file.
`trace_log_interval_usec * trace_log_count * trace_log_interval_count`. The
default value specifies 10 intervals per file.
trace_log_disabled
: An option to turn ON/OFF writing trace files to disk. Saving trace files to

View File

@ -75,6 +75,7 @@ alias(
actual = select({
":macos_i386": ":macos_i386",
":macos_x86_64": ":macos_x86_64",
":macos_arm64": ":macos_arm64",
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
}),
visibility = ["//visibility:public"],
@ -119,6 +120,15 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "macos_arm64",
values = {
"apple_platform_type": "macos",
"cpu": "darwin_arm64",
},
visibility = ["//visibility:public"],
)
[
config_setting(
name = arch,

View File

@ -214,6 +214,7 @@ cc_library(
"//mediapipe/framework:collection_item_id",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/formats:rect_cc_proto",
"//mediapipe/framework/port:integral_types",
@ -1257,3 +1258,36 @@ cc_test(
"@com_google_absl//absl/time",
],
)
cc_library(
name = "get_vector_item_calculator",
srcs = ["get_vector_item_calculator.cc"],
hdrs = ["get_vector_item_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)
cc_library(
name = "vector_size_calculator",
srcs = ["vector_size_calculator.cc"],
hdrs = ["vector_size_calculator.h"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:classification_cc_proto",
"//mediapipe/framework/formats:landmark_cc_proto",
"//mediapipe/framework/port:status",
],
alwayslink = 1,
)

View File

@ -28,6 +28,10 @@ typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedLandmarkList>>
BeginLoopNormalizedLandmarkListVectorCalculator;
REGISTER_CALCULATOR(BeginLoopNormalizedLandmarkListVectorCalculator);
// A calculator to process std::vector<int>.
typedef BeginLoopCalculator<std::vector<int>> BeginLoopIntCalculator;
REGISTER_CALCULATOR(BeginLoopIntCalculator);
// A calculator to process std::vector<NormalizedRect>.
typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
BeginLoopNormalizedRectCalculator;

View File

@ -17,6 +17,7 @@
#include <vector>
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
#include "mediapipe/framework/formats/rect.pb.h"
#include "mediapipe/util/render_data.pb.h"
@ -50,4 +51,8 @@ REGISTER_CALCULATOR(EndLoopClassificationListCalculator);
typedef EndLoopCalculator<std::vector<TfLiteTensor>> EndLoopTensorCalculator;
REGISTER_CALCULATOR(EndLoopTensorCalculator);
typedef EndLoopCalculator<std::vector<::mediapipe::Detection>>
EndLoopDetectionCalculator;
REGISTER_CALCULATOR(EndLoopDetectionCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,32 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/get_vector_item_calculator.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe {
namespace api2 {
using GetLandmarkListVectorItemCalculator =
GetVectorItemCalculator<mediapipe::LandmarkList>;
REGISTER_CALCULATOR(GetLandmarkListVectorItemCalculator);
using GetClassificationListVectorItemCalculator =
GetVectorItemCalculator<mediapipe::ClassificationList>;
REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,77 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_
#include <optional>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// A calcutlator to return an item from the vector by its index.
//
// Inputs:
// VECTOR - std::vector<T>
// Vector to take an item from.
// INDEX - int
// Index of the item to return.
//
// Outputs:
// ITEM - T
// Item from the vector at given index.
//
// Example config:
// node {
// calculator: "Get{SpecificType}VectorItemCalculator"
// input_stream: "VECTOR:vector"
// input_stream: "INDEX:index"
// input_stream: "ITEM:item"
// }
//
template <typename T>
class GetVectorItemCalculator : public Node {
public:
static constexpr Input<std::vector<T>> kIn{"VECTOR"};
static constexpr Input<int> kIdx{"INDEX"};
static constexpr Output<T> kOut{"ITEM"};
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
absl::Status Process(CalculatorContext* cc) final {
if (kIn(cc).IsEmpty() || kIdx(cc).IsEmpty()) {
return absl::OkStatus();
}
const std::vector<T>& items = kIn(cc).Get();
const int idx = kIdx(cc).Get();
RET_CHECK_LT(idx, items.size());
kOut(cc).Send(items[idx]);
return absl::OkStatus();
}
};
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_

View File

@ -83,4 +83,7 @@ REGISTER_CALCULATOR(SplitClassificationListVectorCalculator);
typedef SplitVectorCalculator<uint64_t, false> SplitUint64tVectorCalculator;
REGISTER_CALCULATOR(SplitUint64tVectorCalculator);
typedef SplitVectorCalculator<float, false> SplitFloatVectorCalculator;
REGISTER_CALCULATOR(SplitFloatVectorCalculator);
} // namespace mediapipe

View File

@ -0,0 +1,32 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/calculators/core/vector_size_calculator.h"
#include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/landmark.pb.h"
namespace mediapipe {
namespace api2 {
using LandmarkListVectorSizeCalculator =
VectorSizeCalculator<mediapipe::LandmarkList>;
REGISTER_CALCULATOR(LandmarkListVectorSizeCalculator);
using ClassificationListVectorSizeCalculator =
VectorSizeCalculator<mediapipe::ClassificationList>;
REGISTER_CALCULATOR(ClassificationListVectorSizeCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,64 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_
#include <optional>
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
namespace api2 {
// A calcutlator to return vector size.
//
// Inputs:
// VECTOR - std::vector<T>
// Vector which size to return.
//
// Outputs:
// SIZE - int
// Size of the input vector.
//
// Example config:
// node {
// calculator: "{SpecificType}VectorSizeCalculator"
// input_stream: "VECTOR:vector"
// output_stream: "SIZE:vector_size"
// }
//
template <typename T>
class VectorSizeCalculator : public Node {
public:
static constexpr Input<std::vector<T>> kIn{"VECTOR"};
static constexpr Output<int> kOut{"SIZE"};
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
absl::Status Process(CalculatorContext* cc) final {
if (kIn(cc).IsEmpty()) {
return absl::OkStatus();
}
kOut(cc).Send(kIn(cc).Get().size());
return absl::OkStatus();
}
};
} // namespace api2
} // namespace mediapipe
#endif // MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_

View File

@ -421,6 +421,10 @@ absl::Status ScaleImageCalculator::InitializeFromOptions() {
alignment_boundary_ = options_.alignment_boundary();
}
if (options_.has_output_format()) {
output_format_ = options_.output_format();
}
downscaler_.reset(new ImageResizer(options_.post_sharpening_coefficient()));
return absl::OkStatus();
@ -433,13 +437,17 @@ absl::Status ScaleImageCalculator::ValidateImageFormats() const {
<< "The output image format was set to UNKNOWN.";
// TODO Remove these conditions.
RET_CHECK(output_format_ == ImageFormat::SRGB ||
output_format_ == ImageFormat::SRGBA ||
(input_format_ == output_format_ &&
output_format_ == ImageFormat::YCBCR420P))
<< "Outputting YCbCr420P images from SRGB input is not yet supported";
RET_CHECK(input_format_ == output_format_ ||
input_format_ == ImageFormat::YCBCR420P)
(input_format_ == ImageFormat::YCBCR420P &&
output_format_ == ImageFormat::SRGB) ||
(input_format_ == ImageFormat::SRGB &&
output_format_ == ImageFormat::SRGBA))
<< "Conversion of the color space (except from "
"YCbCr420P to SRGB) is not yet supported.";
"YCbCr420P to SRGB or SRGB to SRBGA) is not yet supported.";
return absl::OkStatus();
}
@ -604,6 +612,15 @@ absl::Status ScaleImageCalculator::Process(CalculatorContext* cc) {
.Add(output_image.release(), cc->InputTimestamp());
return absl::OkStatus();
}
} else if (input_format_ == ImageFormat::SRGB &&
output_format_ == ImageFormat::SRGBA) {
image_frame = &cc->Inputs().Get(input_data_id_).Get<ImageFrame>();
cv::Mat input_mat = ::mediapipe::formats::MatView(image_frame);
converted_image_frame.Reset(ImageFormat::SRGBA, image_frame->Width(),
image_frame->Height(), alignment_boundary_);
cv::Mat output_mat = ::mediapipe::formats::MatView(&converted_image_frame);
cv::cvtColor(input_mat, output_mat, cv::COLOR_RGB2RGBA, 4);
image_frame = &converted_image_frame;
} else {
image_frame = &cc->Inputs().Get(input_data_id_).Get<ImageFrame>();
MP_RETURN_IF_ERROR(ValidateImageFrame(cc, *image_frame));

View File

@ -28,7 +28,9 @@ package(default_visibility = ["//visibility:private"])
exports_files(
glob(["testdata/image_to_tensor/*"]),
visibility = ["//mediapipe/calculators/image:__subpackages__"],
visibility = [
"//mediapipe/calculators/image:__subpackages__",
],
)
selects.config_setting_group(
@ -64,15 +66,16 @@ cc_library(
":inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/tool:subgraph_expansion",
"//mediapipe/util/tflite:config",
"//mediapipe/util/tflite:tflite_model_loader",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
],
alwayslink = 1,
@ -91,6 +94,7 @@ cc_library(
"//mediapipe/util/tflite:tflite_gpu_runner",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape",
],
@ -142,6 +146,8 @@ cc_library(
":inference_calculator_interface",
"@com_google_absl//absl/memory",
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
"@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/c:c_api_types",
] + select({
"//conditions:default": [
"//mediapipe/util:cpu_util",

View File

@ -142,22 +142,35 @@ class ImageToTensorCalculator : public Node {
cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
RET_CHECK(options.has_output_tensor_float_range() ||
options.has_output_tensor_int_range())
options.has_output_tensor_int_range() ||
options.has_output_tensor_uint_range())
<< "Output tensor range is required.";
if (options.has_output_tensor_float_range()) {
RET_CHECK_LT(options.output_tensor_float_range().min(),
options.output_tensor_float_range().max())
<< "Valid output float tensor range is required.";
}
if (options.has_output_tensor_uint_range()) {
RET_CHECK_LT(options.output_tensor_uint_range().min(),
options.output_tensor_uint_range().max())
<< "Valid output uint tensor range is required.";
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
<< "The minimum of the output uint tensor range must be "
"non-negative.";
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
<< "The maximum of the output uint tensor range must be less than or "
"equal to 255.";
}
if (options.has_output_tensor_int_range()) {
RET_CHECK_LT(options.output_tensor_int_range().min(),
options.output_tensor_int_range().max())
<< "Valid output int tensor range is required.";
RET_CHECK_GE(options.output_tensor_int_range().min(), 0)
<< "The minimum of the output int tensor range must be non-negative.";
RET_CHECK_LE(options.output_tensor_int_range().max(), 255)
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
<< "The minimum of the output int tensor range must be greater than "
"or equal to -128.";
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
<< "The maximum of the output int tensor range must be less than or "
"equal to 255.";
"equal to 127.";
}
RET_CHECK_GT(options.output_tensor_width(), 0)
<< "Valid output tensor width is required.";
@ -187,15 +200,19 @@ class ImageToTensorCalculator : public Node {
options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
output_width_ = options_.output_tensor_width();
output_height_ = options_.output_tensor_height();
is_int_output_ = options_.has_output_tensor_int_range();
range_min_ =
is_int_output_
? static_cast<float>(options_.output_tensor_int_range().min())
: options_.output_tensor_float_range().min();
range_max_ =
is_int_output_
? static_cast<float>(options_.output_tensor_int_range().max())
: options_.output_tensor_float_range().max();
is_float_output_ = options_.has_output_tensor_float_range();
if (options_.has_output_tensor_uint_range()) {
range_min_ =
static_cast<float>(options_.output_tensor_uint_range().min());
range_max_ =
static_cast<float>(options_.output_tensor_uint_range().max());
} else if (options_.has_output_tensor_int_range()) {
range_min_ = static_cast<float>(options_.output_tensor_int_range().min());
range_max_ = static_cast<float>(options_.output_tensor_int_range().max());
} else {
range_min_ = options_.output_tensor_float_range().min();
range_max_ = options_.output_tensor_float_range().max();
}
return absl::OkStatus();
}
@ -275,6 +292,17 @@ class ImageToTensorCalculator : public Node {
}
}
Tensor::ElementType GetOutputTensorType() {
if (is_float_output_) {
return Tensor::ElementType::kFloat32;
}
if (range_min_ < 0) {
return Tensor::ElementType::kInt8;
} else {
return Tensor::ElementType::kUInt8;
}
}
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
CalculatorContext* cc) {
if (kIn(cc).IsConnected()) {
@ -305,7 +333,7 @@ class ImageToTensorCalculator : public Node {
const Image& image) {
// Lazy initialization of the GPU or CPU converter.
if (image.UsesGpu()) {
if (is_int_output_) {
if (!is_float_output_) {
return absl::UnimplementedError(
"ImageToTensorConverter for the input GPU image currently doesn't "
"support quantization.");
@ -337,11 +365,9 @@ class ImageToTensorCalculator : public Node {
} else {
if (!cpu_converter_) {
#if !MEDIAPIPE_DISABLE_OPENCV
ASSIGN_OR_RETURN(cpu_converter_,
CreateOpenCvConverter(
cc, GetBorderMode(),
is_int_output_ ? Tensor::ElementType::kUInt8
: Tensor::ElementType::kFloat32));
ASSIGN_OR_RETURN(
cpu_converter_,
CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType()));
#else
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
"MEDIAPIPE_DISABLE_OPENCV is defined.";
@ -356,7 +382,7 @@ class ImageToTensorCalculator : public Node {
mediapipe::ImageToTensorCalculatorOptions options_;
int output_width_ = 0;
int output_height_ = 0;
bool is_int_output_ = false;
bool is_float_output_ = false;
float range_min_ = 0.0f;
float range_max_ = 1.0f;
};

View File

@ -39,6 +39,14 @@ message ImageToTensorCalculatorOptions {
optional int64 max = 2;
}
// Range of uint values [min, max].
// min, must be strictly less than max.
// Please note that UIntRange is supported for CPU tensors only.
message UIntRange {
optional uint64 min = 1;
optional uint64 max = 2;
}
// Pixel extrapolation methods. See @border_mode.
enum BorderMode {
BORDER_UNSPECIFIED = 0;
@ -58,6 +66,7 @@ message ImageToTensorCalculatorOptions {
oneof range {
FloatRange output_tensor_float_range = 4;
IntRange output_tensor_int_range = 7;
UIntRange output_tensor_uint_range = 8;
}
// For CONVENTIONAL mode for OpenGL, input image starts at bottom and needs

View File

@ -76,12 +76,21 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet,
}
std::string output_tensor_range;
if (output_int_tensor) {
output_tensor_range = absl::Substitute(R"(output_tensor_int_range {
if (range_min < 0) {
output_tensor_range = absl::Substitute(R"(output_tensor_int_range {
min: $0
max: $1
})",
static_cast<int>(range_min),
static_cast<int>(range_max));
static_cast<int>(range_min),
static_cast<int>(range_max));
} else {
output_tensor_range = absl::Substitute(R"(output_tensor_uint_range {
min: $0
max: $1
})",
static_cast<uint>(range_min),
static_cast<uint>(range_max));
}
} else {
output_tensor_range = absl::Substitute(R"(output_tensor_float_range {
min: $0
@ -141,9 +150,15 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet,
auto view = tensor.GetCpuReadView();
cv::Mat tensor_mat;
if (output_int_tensor) {
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8);
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3,
const_cast<uint8*>(view.buffer<uint8>()));
if (range_min < 0) {
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kInt8);
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8SC3,
const_cast<int8*>(view.buffer<int8>()));
} else {
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8);
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3,
const_cast<uint8*>(view.buffer<uint8>()));
}
} else {
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32);
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3,
@ -190,25 +205,28 @@ const std::vector<InputType> kInputTypesToTest = {InputType::kImageFrame,
InputType::kImage};
void RunTest(cv::Mat input, cv::Mat expected_result,
std::vector<float> float_range, std::vector<int> int_range,
int tensor_width, int tensor_height, bool keep_aspect,
std::vector<std::pair<float, float>> float_ranges,
std::vector<std::pair<int, int>> int_ranges, int tensor_width,
int tensor_height, bool keep_aspect,
absl::optional<BorderMode> border_mode,
const mediapipe::NormalizedRect& roi) {
ASSERT_EQ(2, float_range.size());
ASSERT_EQ(2, int_range.size());
for (auto input_type : kInputTypesToTest) {
RunTestWithInputImagePacket(
input_type == InputType::kImageFrame ? MakeImageFramePacket(input)
: MakeImagePacket(input),
expected_result, float_range[0], float_range[1], tensor_width,
tensor_height, keep_aspect, border_mode, roi,
/*output_int_tensor=*/false);
RunTestWithInputImagePacket(
input_type == InputType::kImageFrame ? MakeImageFramePacket(input)
: MakeImagePacket(input),
expected_result, int_range[0], int_range[1], tensor_width,
tensor_height, keep_aspect, border_mode, roi,
/*output_int_tensor=*/true);
for (auto float_range : float_ranges) {
RunTestWithInputImagePacket(
input_type == InputType::kImageFrame ? MakeImageFramePacket(input)
: MakeImagePacket(input),
expected_result, float_range.first, float_range.second, tensor_width,
tensor_height, keep_aspect, border_mode, roi,
/*output_int_tensor=*/false);
}
for (auto int_range : int_ranges) {
RunTestWithInputImagePacket(
input_type == InputType::kImageFrame ? MakeImageFramePacket(input)
: MakeImagePacket(input),
expected_result, int_range.first, int_range.second, tensor_width,
tensor_height, keep_aspect, border_mode, roi,
/*output_int_tensor=*/true);
}
}
}
@ -224,8 +242,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) {
"tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
/*border mode*/ {}, roi);
}
@ -242,8 +260,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) {
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"medium_sub_rect_keep_aspect_border_zero.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
BorderMode::kZero, roi);
}
@ -260,8 +278,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectWithRotation) {
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"medium_sub_rect_keep_aspect_with_rotation.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
BorderMode::kReplicate, roi);
}
@ -279,8 +297,8 @@ TEST(ImageToTensorCalculatorTest,
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"medium_sub_rect_keep_aspect_with_rotation_border_zero.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
BorderMode::kZero, roi);
}
@ -298,8 +316,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotation) {
GetRgb(
"/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"),
/*float_range=*/{-1.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{-1.0f, 1.0f}},
/*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false,
BorderMode::kReplicate, roi);
}
@ -316,8 +334,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) {
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"medium_sub_rect_with_rotation_border_zero.png"),
/*float_range=*/{-1.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{-1.0f, 1.0f}},
/*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false,
BorderMode::kZero, roi);
}
@ -333,8 +351,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRect) {
"tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/large_sub_rect.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false,
BorderMode::kReplicate, roi);
}
@ -351,8 +369,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectBorderZero) {
"tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false,
BorderMode::kZero, roi);
}
@ -369,8 +387,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) {
"tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
BorderMode::kReplicate, roi);
}
@ -387,8 +405,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) {
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"large_sub_rect_keep_aspect_border_zero.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
BorderMode::kZero, roi);
}
@ -405,8 +423,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) {
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"large_sub_rect_keep_aspect_with_rotation.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
/*border_mode=*/{}, roi);
}
@ -424,8 +442,8 @@ TEST(ImageToTensorCalculatorTest,
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/"
"large_sub_rect_keep_aspect_with_rotation_border_zero.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}},
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
/*border_mode=*/BorderMode::kZero, roi);
}
@ -441,8 +459,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRange) {
"tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/noop_except_range.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true,
BorderMode::kReplicate, roi);
}
@ -458,8 +476,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeBorderZero) {
"tensor/testdata/image_to_tensor/input.jpg"),
GetRgb("/mediapipe/calculators/"
"tensor/testdata/image_to_tensor/noop_except_range.png"),
/*float_range=*/{0.0f, 1.0f},
/*int_range=*/{0, 255},
/*float_ranges=*/{{0.0f, 1.0f}},
/*int_ranges=*/{{0, 255}, {-128, 127}},
/*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true,
BorderMode::kZero, roi);
}

View File

@ -268,10 +268,12 @@ class GlProcessor : public ImageToTensorConverter {
const RotatedRect& roi,
const Size& output_dims, float range_min,
float range_max) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) {
return InvalidArgumentError(
absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ",
static_cast<uint32_t>(input.format())));
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format())));
}
constexpr int kNumChannels = 3;

View File

@ -172,10 +172,12 @@ class GlProcessor : public ImageToTensorConverter {
const RotatedRect& roi,
const Size& output_dims, float range_min,
float range_max) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) {
return InvalidArgumentError(
absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ",
static_cast<uint32_t>(input.format())));
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format())));
}
constexpr int kNumChannels = 3;

View File

@ -352,11 +352,12 @@ class MetalProcessor : public ImageToTensorConverter {
const RotatedRect& roi,
const Size& output_dims, float range_min,
float range_max) override {
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) {
return InvalidArgumentError(
absl::StrCat("Only BGRA/RGBA textures are supported, passed "
"format: ",
static_cast<uint32_t>(input.format())));
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
return InvalidArgumentError(absl::StrCat(
"Only 4-channel texture input formats are supported, passed format: ",
static_cast<uint32_t>(input.format())));
}
@autoreleasepool {

View File

@ -45,7 +45,19 @@ class OpenCvProcessor : public ImageToTensorConverter {
border_mode_ = cv::BORDER_CONSTANT;
break;
}
mat_type_ = tensor_type == Tensor::ElementType::kUInt8 ? CV_8UC3 : CV_32FC3;
switch (tensor_type_) {
case Tensor::ElementType::kInt8:
mat_type_ = CV_8SC3;
break;
case Tensor::ElementType::kFloat32:
mat_type_ = CV_32FC3;
break;
case Tensor::ElementType::kUInt8:
mat_type_ = CV_8UC3;
break;
default:
mat_type_ = -1;
}
}
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
@ -65,12 +77,22 @@ class OpenCvProcessor : public ImageToTensorConverter {
output_dims.width, kNumChannels});
auto buffer_view = tensor.GetCpuWriteView();
cv::Mat dst;
if (tensor_type_ == Tensor::ElementType::kUInt8) {
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
buffer_view.buffer<uint8>());
} else {
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
buffer_view.buffer<float>());
switch (tensor_type_) {
case Tensor::ElementType::kInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
buffer_view.buffer<int8>());
break;
case Tensor::ElementType::kFloat32:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
buffer_view.buffer<float>());
break;
case Tensor::ElementType::kUInt8:
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
buffer_view.buffer<uint8>());
break;
default:
return InvalidArgumentError(
absl::StrCat("Unsupported tensor type: ", tensor_type_));
}
const cv::RotatedRect rotated_rect(cv::Point2f(roi.center_x, roi.center_y),
@ -124,6 +146,13 @@ class OpenCvProcessor : public ImageToTensorConverter {
absl::StatusOr<std::unique_ptr<ImageToTensorConverter>> CreateOpenCvConverter(
CalculatorContext* cc, BorderMode border_mode,
Tensor::ElementType tensor_type) {
if (tensor_type != Tensor::ElementType::kInt8 &&
tensor_type != Tensor::ElementType::kFloat32 &&
tensor_type != Tensor::ElementType::kUInt8) {
return absl::InvalidArgumentError(absl::StrCat(
"Tensor type is currently not supported by OpenCvProcessor, type: ",
tensor_type));
}
return absl::make_unique<OpenCvProcessor>(border_mode, tensor_type);
}

View File

@ -21,7 +21,9 @@
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/tool/subgraph_expansion.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe {
namespace api2 {
@ -67,5 +69,17 @@ absl::StatusOr<Packet<TfLiteModelPtr>> InferenceCalculator::GetModelAsPacket(
"Must specify TFLite model as path or loaded model.");
}
absl::StatusOr<Packet<tflite::OpResolver>>
InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
if (kSideInOpResolver(cc).IsConnected()) {
return kSideInOpResolver(cc).As<tflite::OpResolver>();
} else if (kSideInCustomOpResolver(cc).IsConnected()) {
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
}
return PacketAdopting<tflite::OpResolver>(
std::make_unique<
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>());
}
} // namespace api2
} // namespace mediapipe

View File

@ -27,6 +27,7 @@
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/tflite_model_loader.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
@ -55,8 +56,11 @@ namespace api2 {
// TENSORS - Vector of Tensors
//
// Input side packet:
// DEPRECATED: Prefer to use the "OP_RESOLVER" input side packet instead.
// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver,
// instead of the builtin one.
// OP_RESOLVER (optional) - Use to provide tflite op resolver
// (tflite::OpResolver)
// MODEL (optional) - Use to specify TfLite model
// (std::unique_ptr<tflite::FlatBufferModel,
// std::function<void(tflite::FlatBufferModel*)>>)
@ -95,15 +99,21 @@ namespace api2 {
class InferenceCalculator : public NodeIntf {
public:
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
// migration.
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
"OP_RESOLVER"};
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
static constexpr SideInput<
mediapipe::InferenceCalculatorOptions::Delegate>::Optional kDelegate{
"DELEGATE"};
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel,
kOutTensors, kDelegate);
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver,
kSideInOpResolver, kSideInModel, kOutTensors,
kDelegate);
protected:
using TfLiteDelegatePtr =
@ -111,6 +121,9 @@ class InferenceCalculator : public NodeIntf {
absl::StatusOr<Packet<TfLiteModelPtr>> GetModelAsPacket(
CalculatorContext* cc);
absl::StatusOr<Packet<tflite::OpResolver>> GetOpResolverAsPacket(
CalculatorContext* cc);
};
struct InferenceCalculatorSelector : public InferenceCalculator {

View File

@ -116,6 +116,9 @@ message InferenceCalculatorOptions {
// to ensure there is no clash of the tokens. If unspecified, NNAPI will
// not try caching the compilation.
optional string model_token = 2;
// The name of an accelerator to be used for NNAPI delegate, e.g.
// "google-edgetpu". When not specified, it will be selected by NNAPI.
optional string accelerator_name = 3;
}
message Xnnpack {
// Number of threads for XNNPACK delegate. (By default, calculator tries

View File

@ -19,7 +19,7 @@
#include "absl/memory/memory.h"
#include "mediapipe/calculators/tensor/inference_calculator.h"
#include "tensorflow/lite/interpreter_builder.h"
#if defined(MEDIAPIPE_ANDROID)
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
#endif // ANDROID
@ -28,6 +28,7 @@
#include "mediapipe/util/cpu_util.h"
#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
namespace mediapipe {
@ -61,6 +62,17 @@ int GetXnnpackNumThreads(
return GetXnnpackDefaultNumThreads();
}
template <typename T>
void CopyTensorBuffer(const Tensor& input_tensor,
tflite::Interpreter* interpreter,
int input_tensor_index) {
auto input_tensor_view = input_tensor.GetCpuReadView();
auto input_tensor_buffer = input_tensor_view.buffer<T>();
T* local_tensor_buffer =
interpreter->typed_input_tensor<T>(input_tensor_index);
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
}
} // namespace
class InferenceCalculatorCpuImpl
@ -73,15 +85,16 @@ class InferenceCalculatorCpuImpl
absl::Status Close(CalculatorContext* cc) override;
private:
absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
absl::Status InitInterpreter(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc,
tflite::InterpreterBuilder* interpreter_builder);
absl::Status AllocateTensors();
// TfLite requires us to keep the model alive as long as the interpreter is.
Packet<TfLiteModelPtr> model_packet_;
std::unique_ptr<tflite::Interpreter> interpreter_;
TfLiteDelegatePtr delegate_;
bool has_quantized_input_;
TfLiteType input_tensor_type_ = TfLiteType::kTfLiteNoType;
};
absl::Status InferenceCalculatorCpuImpl::UpdateContract(
@ -94,8 +107,7 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract(
}
absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadModel(cc));
return LoadDelegateAndAllocateTensors(cc);
return InitInterpreter(cc);
}
absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
@ -108,19 +120,23 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
// Read CPU input into tensors.
for (int i = 0; i < input_tensors.size(); ++i) {
const Tensor* input_tensor = &input_tensors[i];
auto input_tensor_view = input_tensor->GetCpuReadView();
if (has_quantized_input_) {
// TODO: Support more quantized tensor types.
auto input_tensor_buffer = input_tensor_view.buffer<uint8>();
uint8* local_tensor_buffer = interpreter_->typed_input_tensor<uint8>(i);
std::memcpy(local_tensor_buffer, input_tensor_buffer,
input_tensor->bytes());
} else {
auto input_tensor_buffer = input_tensor_view.buffer<float>();
float* local_tensor_buffer = interpreter_->typed_input_tensor<float>(i);
std::memcpy(local_tensor_buffer, input_tensor_buffer,
input_tensor->bytes());
switch (input_tensor_type_) {
case TfLiteType::kTfLiteFloat16:
case TfLiteType::kTfLiteFloat32: {
CopyTensorBuffer<float>(input_tensors[i], interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteUInt8: {
CopyTensorBuffer<uint8>(input_tensors[i], interpreter_.get(), i);
break;
}
case TfLiteType::kTfLiteInt8: {
CopyTensorBuffer<int8>(input_tensors[i], interpreter_.get(), i);
break;
}
default:
return absl::InvalidArgumentError(
absl::StrCat("Unsupported input tensor type:", input_tensor_type_));
}
}
@ -150,39 +166,34 @@ absl::Status InferenceCalculatorCpuImpl::Close(CalculatorContext* cc) {
return absl::OkStatus();
}
absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) {
absl::Status InferenceCalculatorCpuImpl::InitInterpreter(
CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
RET_CHECK(interpreter_);
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
const auto& op_resolver = op_resolver_packet.Get();
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
#if defined(__EMSCRIPTEN__)
interpreter_->SetNumThreads(1);
interpreter_builder.SetNumThreads(1);
#else
interpreter_->SetNumThreads(
interpreter_builder.SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
#endif // __EMSCRIPTEN__
return absl::OkStatus();
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
RET_CHECK(interpreter_);
return AllocateTensors();
}
absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
absl::Status InferenceCalculatorCpuImpl::AllocateTensors() {
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
has_quantized_input_ =
interpreter_->tensor(interpreter_->inputs()[0])->quantization.type ==
kTfLiteAffineQuantization;
input_tensor_type_ = interpreter_->tensor(interpreter_->inputs()[0])->type;
return absl::OkStatus();
}
absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
absl::Status InferenceCalculatorCpuImpl::LoadDelegate(
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
const auto& calculator_opts =
cc->Options<mediapipe::InferenceCalculatorOptions>();
auto opts_delegate = calculator_opts.delegate();
@ -211,18 +222,20 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
if (nnapi_requested) {
// Attempt to use NNAPI.
// If not supported, the default CPU delegate will be created and used.
interpreter_->SetAllowFp16PrecisionForFp32(1);
tflite::StatefulNnApiDelegate::Options options;
const auto& nnapi = opts_delegate.nnapi();
options.allow_fp16 = true;
// Set up cache_dir and model_token for NNAPI compilation cache.
options.cache_dir =
nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr;
options.model_token =
nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr;
options.accelerator_name = nnapi.has_accelerator_name()
? nnapi.accelerator_name().c_str()
: nullptr;
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
[](TfLiteDelegate*) {});
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);
interpreter_builder->AddDelegate(delegate_.get());
return absl::OkStatus();
}
#endif // MEDIAPIPE_ANDROID
@ -239,8 +252,7 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
GetXnnpackNumThreads(opts_has_delegate, opts_delegate);
delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
&TfLiteXNNPackDelegateDelete);
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);
interpreter_builder->AddDelegate(delegate_.get());
}
return absl::OkStatus();

View File

@ -22,6 +22,7 @@
#include "mediapipe/calculators/tensor/inference_calculator.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/tflite/config.h"
#include "tensorflow/lite/interpreter_builder.h"
#if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h"
@ -52,9 +53,11 @@ class InferenceCalculatorGlImpl
private:
absl::Status ReadGpuCaches();
absl::Status SaveGpuCaches();
absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
absl::Status InitInterpreter(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc,
tflite::InterpreterBuilder* interpreter_builder);
absl::Status BindBuffersToTensors();
absl::Status AllocateTensors();
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is.
@ -137,17 +140,11 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
#endif // MEDIAPIPE_ANDROID
}
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
// for everything.
if (!use_advanced_gpu_api_) {
MP_RETURN_IF_ERROR(LoadModel(cc));
}
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
: LoadDelegateAndAllocateTensors(cc);
: InitInterpreter(cc);
}));
return absl::OkStatus();
}
@ -292,12 +289,6 @@ absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
// Create runner
tflite::gpu::InferenceOptions options;
options.priority1 = allow_precision_loss_
@ -335,6 +326,10 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
break;
}
}
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
const auto& op_resolver = op_resolver_packet.Get();
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true));
@ -355,31 +350,27 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
absl::Status InferenceCalculatorGlImpl::InitInterpreter(CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
RET_CHECK(interpreter_);
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
const auto& op_resolver = op_resolver_packet.Get();
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
#if defined(__EMSCRIPTEN__)
interpreter_->SetNumThreads(1);
interpreter_builder.SetNumThreads(1);
#else
interpreter_->SetNumThreads(
interpreter_builder.SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
#endif // __EMSCRIPTEN__
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
RET_CHECK(interpreter_);
MP_RETURN_IF_ERROR(BindBuffersToTensors());
MP_RETURN_IF_ERROR(AllocateTensors());
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
absl::Status InferenceCalculatorGlImpl::AllocateTensors() {
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors.
RET_CHECK_NE(
@ -388,7 +379,8 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors(
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
absl::Status InferenceCalculatorGlImpl::LoadDelegate(
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
// Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed =
@ -399,7 +391,11 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
options.compile_options.inline_parameters = 1;
delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options),
&TfLiteGpuDelegateDelete);
interpreter_builder->AddDelegate(delegate_.get());
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() {
// Get input image sizes.
const auto& input_indices = interpreter_->inputs();
for (int i = 0; i < input_indices.size(); ++i) {
@ -431,11 +427,6 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
output_indices[i]),
kTfLiteOk);
}
// Must call this last.
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);
return absl::OkStatus();
}

View File

@ -90,9 +90,10 @@ class InferenceCalculatorMetalImpl
absl::Status Close(CalculatorContext* cc) override;
private:
absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
absl::Status InitInterpreter(CalculatorContext* cc);
void AddDelegate(CalculatorContext* cc,
tflite::InterpreterBuilder* interpreter_builder);
absl::Status CreateConverters(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is.
Packet<TfLiteModelPtr> model_packet_;
@ -127,11 +128,9 @@ absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) {
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss();
MP_RETURN_IF_ERROR(LoadModel(cc));
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
return LoadDelegateAndAllocateTensors(cc);
return InitInterpreter(cc);
}
absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) {
@ -199,27 +198,20 @@ absl::Status InferenceCalculatorMetalImpl::Close(CalculatorContext* cc) {
return absl::OkStatus();
}
absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) {
absl::Status InferenceCalculatorMetalImpl::InitInterpreter(
CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
const auto& op_resolver = op_resolver_packet.Get();
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
AddDelegate(cc, &interpreter_builder);
interpreter_builder.SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
RET_CHECK(interpreter_);
interpreter_->SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
return absl::OkStatus();
}
absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors(
CalculatorContext* cc) {
MP_RETURN_IF_ERROR(LoadDelegate(cc));
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
MP_RETURN_IF_ERROR(CreateConverters(cc));
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
// TODO: Support quantized tensors.
RET_CHECK_NE(
@ -228,7 +220,8 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors(
return absl::OkStatus();
}
absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
void InferenceCalculatorMetalImpl::AddDelegate(
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
const auto& calculator_opts =
cc->Options<mediapipe::InferenceCalculatorOptions>();
@ -242,9 +235,11 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait;
delegate_ =
TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete);
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
kTfLiteOk);
interpreter_builder->AddDelegate(delegate_.get());
}
absl::Status InferenceCalculatorMetalImpl::CreateConverters(
CalculatorContext* cc) {
id<MTLDevice> device = gpu_helper_.mtlDevice;
// Get input image sizes.

View File

@ -91,6 +91,40 @@ void ConvertAnchorsToRawValues(const std::vector<Anchor>& anchors,
}
}
absl::Status CheckCustomTensorMapping(
const TensorsToDetectionsCalculatorOptions::TensorMapping& tensor_mapping) {
RET_CHECK(tensor_mapping.has_detections_tensor_index() &&
tensor_mapping.has_scores_tensor_index());
int bitmap = 0;
bitmap |= 1 << tensor_mapping.detections_tensor_index();
bitmap |= 1 << tensor_mapping.scores_tensor_index();
if (!tensor_mapping.has_num_detections_tensor_index() &&
!tensor_mapping.has_classes_tensor_index() &&
!tensor_mapping.has_anchors_tensor_index()) {
// Only allows the output tensor index 0 and 1 to be occupied.
RET_CHECK_EQ(3, bitmap) << "The custom output tensor indices should only "
"cover index 0 and 1.";
} else if (tensor_mapping.has_anchors_tensor_index()) {
RET_CHECK(!tensor_mapping.has_classes_tensor_index() &&
!tensor_mapping.has_num_detections_tensor_index());
bitmap |= 1 << tensor_mapping.anchors_tensor_index();
// If the"anchors" tensor will be available, only allows the output tensor
// index 0, 1, 2 to be occupied.
RET_CHECK_EQ(7, bitmap) << "The custom output tensor indices should only "
"cover index 0, 1 and 2.";
} else {
RET_CHECK(tensor_mapping.has_classes_tensor_index() &&
tensor_mapping.has_num_detections_tensor_index());
// If the "classes" and the "number of detections" tensors will be
// available, only allows the output tensor index 0, 1, 2, 3 to be occupied.
bitmap |= 1 << tensor_mapping.classes_tensor_index();
bitmap |= 1 << tensor_mapping.num_detections_tensor_index();
RET_CHECK_EQ(15, bitmap) << "The custom output tensor indices should only "
"cover index 0, 1, 2 and 3.";
}
return absl::OkStatus();
}
} // namespace
// Convert result Tensors from object detection models into MediaPipe
@ -170,13 +204,27 @@ class TensorsToDetectionsCalculator : public Node {
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
float box_xmax, float score, int class_id,
bool flip_vertically);
bool IsClassIndexAllowed(int class_index);
int num_classes_ = 0;
int num_boxes_ = 0;
int num_coords_ = 0;
std::set<int> ignore_classes_;
int max_results_ = -1;
::mediapipe::TensorsToDetectionsCalculatorOptions options_;
// Set of allowed or ignored class indices.
struct ClassIndexSet {
absl::flat_hash_set<int> values;
bool is_allowlist;
};
// Allowed or ignored class indices based on provided options or side packet.
// These are used to filter out the output detection results.
ClassIndexSet class_index_set_;
TensorsToDetectionsCalculatorOptions options_;
bool scores_tensor_index_is_set_ = false;
TensorsToDetectionsCalculatorOptions::TensorMapping tensor_mapping_;
std::vector<int> box_indices_ = {0, 1, 2, 3};
bool has_custom_box_indices_ = false;
std::vector<Anchor> anchors_;
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
@ -239,6 +287,21 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
}
}
}
const int num_input_tensors = kInTensors(cc)->size();
if (!scores_tensor_index_is_set_) {
if (num_input_tensors == 2 ||
num_input_tensors == kNumInputTensorsWithAnchors) {
tensor_mapping_.set_scores_tensor_index(1);
} else {
tensor_mapping_.set_scores_tensor_index(2);
}
scores_tensor_index_is_set_ = true;
}
if (gpu_processing || num_input_tensors != 4) {
// Allows custom bounding box indices when receiving 4 cpu tensors.
// Uses the default bbox indices in other cases.
RET_CHECK(!has_custom_box_indices_);
}
if (gpu_processing) {
if (!gpu_inited_) {
@ -263,13 +326,15 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
// Postprocessing on CPU for model without postprocessing op. E.g. output
// raw score tensor and box tensor. Anchor decoding will be handled below.
// TODO: Add flexible input tensor size handling.
auto raw_box_tensor = &input_tensors[0];
auto raw_box_tensor =
&input_tensors[tensor_mapping_.detections_tensor_index()];
RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3);
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options";
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
auto raw_score_tensor = &input_tensors[1];
auto raw_score_tensor =
&input_tensors[tensor_mapping_.scores_tensor_index()];
RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3);
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
@ -282,7 +347,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
// TODO: Support other options to load anchors.
if (!anchors_init_) {
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
auto anchor_tensor = &input_tensors[2];
auto anchor_tensor =
&input_tensors[tensor_mapping_.anchors_tensor_index()];
RET_CHECK_EQ(anchor_tensor->shape().dims.size(), 2);
RET_CHECK_EQ(anchor_tensor->shape().dims[0], num_boxes_);
RET_CHECK_EQ(anchor_tensor->shape().dims[1], kNumCoordsPerBox);
@ -308,7 +374,7 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
float max_score = -std::numeric_limits<float>::max();
// Find the top score for box i.
for (int score_idx = 0; score_idx < num_classes_; ++score_idx) {
if (ignore_classes_.find(score_idx) == ignore_classes_.end()) {
if (IsClassIndexAllowed(score_idx)) {
auto score = raw_scores[i * num_classes_ + score_idx];
if (options_.sigmoid_score()) {
if (options_.has_score_clipping_thresh()) {
@ -338,23 +404,26 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
// Postprocessing on CPU with postprocessing op (e.g. anchor decoding and
// non-maximum suppression) within the model.
RET_CHECK_EQ(input_tensors.size(), 4);
auto num_boxes_tensor = &input_tensors[3];
auto num_boxes_tensor =
&input_tensors[tensor_mapping_.num_detections_tensor_index()];
RET_CHECK_EQ(num_boxes_tensor->shape().dims.size(), 1);
RET_CHECK_EQ(num_boxes_tensor->shape().dims[0], 1);
auto detection_boxes_tensor = &input_tensors[0];
auto detection_boxes_tensor =
&input_tensors[tensor_mapping_.detections_tensor_index()];
RET_CHECK_EQ(detection_boxes_tensor->shape().dims.size(), 3);
RET_CHECK_EQ(detection_boxes_tensor->shape().dims[0], 1);
const int max_detections = detection_boxes_tensor->shape().dims[1];
RET_CHECK_EQ(detection_boxes_tensor->shape().dims[2], num_coords_);
auto detection_classes_tensor = &input_tensors[1];
auto detection_classes_tensor =
&input_tensors[tensor_mapping_.classes_tensor_index()];
RET_CHECK_EQ(detection_classes_tensor->shape().dims.size(), 2);
RET_CHECK_EQ(detection_classes_tensor->shape().dims[0], 1);
RET_CHECK_EQ(detection_classes_tensor->shape().dims[1], max_detections);
auto detection_scores_tensor = &input_tensors[2];
auto detection_scores_tensor =
&input_tensors[tensor_mapping_.scores_tensor_index()];
RET_CHECK_EQ(detection_scores_tensor->shape().dims.size(), 2);
RET_CHECK_EQ(detection_scores_tensor->shape().dims[0], 1);
RET_CHECK_EQ(detection_scores_tensor->shape().dims[1], max_detections);
@ -394,12 +463,14 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
-> absl::Status {
if (!anchors_init_) {
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
auto read_view = input_tensors[2].GetOpenGlBufferReadView();
auto read_view = input_tensors[tensor_mapping_.anchors_tensor_index()]
.GetOpenGlBufferReadView();
glBindBuffer(GL_COPY_READ_BUFFER, read_view.name());
auto write_view = raw_anchors_buffer_->GetOpenGlBufferWriteView();
glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name());
glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0,
input_tensors[2].bytes());
glCopyBufferSubData(
GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0,
input_tensors[tensor_mapping_.anchors_tensor_index()].bytes());
} else if (!kInAnchors(cc).IsEmpty()) {
const auto& anchors = *kInAnchors(cc);
auto anchors_view = raw_anchors_buffer_->GetCpuWriteView();
@ -418,7 +489,9 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
auto decoded_boxes_view =
decoded_boxes_buffer_->GetOpenGlBufferWriteView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, decoded_boxes_view.name());
auto input0_view = input_tensors[0].GetOpenGlBufferReadView();
auto input0_view =
input_tensors[tensor_mapping_.detections_tensor_index()]
.GetOpenGlBufferReadView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input0_view.name());
auto raw_anchors_view = raw_anchors_buffer_->GetOpenGlBufferReadView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, raw_anchors_view.name());
@ -427,7 +500,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
// Score boxes.
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, scored_boxes_view.name());
auto input1_view = input_tensors[1].GetOpenGlBufferReadView();
auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()]
.GetOpenGlBufferReadView();
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input1_view.name());
glUseProgram(score_program_);
glDispatchCompute(num_boxes_, 1, 1);
@ -459,7 +533,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
auto command_buffer = [gpu_helper_ commandBuffer];
auto src_buffer = input_tensors[2].GetMtlBufferReadView(command_buffer);
auto src_buffer = input_tensors[tensor_mapping_.anchors_tensor_index()]
.GetMtlBufferReadView(command_buffer);
auto dest_buffer =
raw_anchors_buffer_->GetMtlBufferWriteView(command_buffer);
id<MTLBlitCommandEncoder> blit_command =
@ -468,7 +543,9 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
sourceOffset:0
toBuffer:dest_buffer.buffer()
destinationOffset:0
size:input_tensors[2].bytes()];
size:input_tensors[tensor_mapping_
.anchors_tensor_index()]
.bytes()];
[blit_command endEncoding];
[command_buffer commit];
} else if (!kInAnchors(cc).IsEmpty()) {
@ -495,7 +572,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
auto decoded_boxes_view =
decoded_boxes_buffer_->GetMtlBufferWriteView(command_buffer);
[command_encoder setBuffer:decoded_boxes_view.buffer() offset:0 atIndex:0];
auto input0_view = input_tensors[0].GetMtlBufferReadView(command_buffer);
auto input0_view = input_tensors[tensor_mapping_.detections_tensor_index()]
.GetMtlBufferReadView(command_buffer);
[command_encoder setBuffer:input0_view.buffer() offset:0 atIndex:1];
auto raw_anchors_view =
raw_anchors_buffer_->GetMtlBufferReadView(command_buffer);
@ -507,7 +585,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
[command_encoder setComputePipelineState:score_program_];
[command_encoder setBuffer:scored_boxes_view.buffer() offset:0 atIndex:0];
auto input1_view = input_tensors[1].GetMtlBufferReadView(command_buffer);
auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()]
.GetMtlBufferReadView(command_buffer);
[command_encoder setBuffer:input1_view.buffer() offset:0 atIndex:1];
MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1);
MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1);
@ -570,6 +649,10 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) {
num_classes_ = options_.num_classes();
num_boxes_ = options_.num_boxes();
num_coords_ = options_.num_coords();
CHECK_NE(options_.max_results(), 0)
<< "The maximum number of the top-scored detection results must be "
"non-zero.";
max_results_ = options_.max_results();
// Currently only support 2D when num_values_per_keypoint equals to 2.
CHECK_EQ(options_.num_values_per_keypoint(), 2);
@ -581,15 +664,55 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) {
if (kSideInIgnoreClasses(cc).IsConnected()) {
RET_CHECK(!kSideInIgnoreClasses(cc).IsEmpty());
RET_CHECK(options_.allow_classes().empty());
class_index_set_.is_allowlist = false;
for (int ignore_class : *kSideInIgnoreClasses(cc)) {
ignore_classes_.insert(ignore_class);
class_index_set_.values.insert(ignore_class);
}
} else if (!options_.allow_classes().empty()) {
RET_CHECK(options_.ignore_classes().empty());
class_index_set_.is_allowlist = true;
for (int i = 0; i < options_.allow_classes_size(); ++i) {
class_index_set_.values.insert(options_.allow_classes(i));
}
} else {
class_index_set_.is_allowlist = false;
for (int i = 0; i < options_.ignore_classes_size(); ++i) {
ignore_classes_.insert(options_.ignore_classes(i));
class_index_set_.values.insert(options_.ignore_classes(i));
}
}
if (options_.has_tensor_mapping()) {
RET_CHECK_OK(CheckCustomTensorMapping(options_.tensor_mapping()));
tensor_mapping_ = options_.tensor_mapping();
scores_tensor_index_is_set_ = true;
} else {
// Assigns the default tensor indices.
tensor_mapping_.set_detections_tensor_index(0);
tensor_mapping_.set_classes_tensor_index(1);
tensor_mapping_.set_anchors_tensor_index(2);
tensor_mapping_.set_num_detections_tensor_index(3);
// The scores tensor index needs to be determined based on the number of
// model's output tensors, which will be available in the first invocation
// of the Process() method.
tensor_mapping_.set_scores_tensor_index(-1);
scores_tensor_index_is_set_ = false;
}
if (options_.has_box_boundaries_indices()) {
box_indices_ = {options_.box_boundaries_indices().ymin(),
options_.box_boundaries_indices().xmin(),
options_.box_boundaries_indices().ymax(),
options_.box_boundaries_indices().xmax()};
int bitmap = 0;
for (int i : box_indices_) {
bitmap |= 1 << i;
}
RET_CHECK_EQ(bitmap, 15) << "The custom box boundaries indices should only "
"cover index 0, 1, 2, and 3.";
has_custom_box_indices_ = true;
}
return absl::OkStatus();
}
@ -661,14 +784,22 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
const float* detection_boxes, const float* detection_scores,
const int* detection_classes, std::vector<Detection>* output_detections) {
for (int i = 0; i < num_boxes_; ++i) {
if (max_results_ > 0 && output_detections->size() == max_results_) {
break;
}
if (options_.has_min_score_thresh() &&
detection_scores[i] < options_.min_score_thresh()) {
continue;
}
if (!IsClassIndexAllowed(detection_classes[i])) {
continue;
}
const int box_offset = i * num_coords_;
Detection detection = ConvertToDetection(
detection_boxes[box_offset + 0], detection_boxes[box_offset + 1],
detection_boxes[box_offset + 2], detection_boxes[box_offset + 3],
/*box_ymin=*/detection_boxes[box_offset + box_indices_[0]],
/*box_xmin=*/detection_boxes[box_offset + box_indices_[1]],
/*box_ymax=*/detection_boxes[box_offset + box_indices_[2]],
/*box_xmax=*/detection_boxes[box_offset + box_indices_[3]],
detection_scores[i], detection_classes[i], options_.flip_vertically());
const auto& bbox = detection.location_data().relative_bounding_box();
if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
@ -910,7 +1041,7 @@ void main() {
options_.has_score_clipping_thresh() ? 1 : 0,
options_.has_score_clipping_thresh() ? options_.score_clipping_thresh()
: 0,
!ignore_classes_.empty() ? 1 : 0);
!IsClassIndexAllowed(0));
// # filter classes supported is hardware dependent.
int max_wg_size; // typically <= 1024
@ -919,7 +1050,14 @@ void main() {
CHECK_LT(num_classes_, max_wg_size)
<< "# classes must be < " << max_wg_size;
// TODO support better filtering.
CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed";
if (class_index_set_.is_allowlist) {
CHECK_EQ(class_index_set_.values.size(),
IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1)
<< "Only all classes >= class 0 or >= class 1";
} else {
CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1)
<< "Only ignore class 0 is allowed";
}
// Shader program
{
@ -1126,10 +1264,17 @@ kernel void scoreKernel(
options_.has_score_clipping_thresh() ? 1 : 0,
options_.has_score_clipping_thresh() ? options_.score_clipping_thresh()
: 0,
ignore_classes_.size() ? 1 : 0);
!IsClassIndexAllowed(0));
// TODO support better filtering.
CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed";
if (class_index_set_.is_allowlist) {
CHECK_EQ(class_index_set_.values.size(),
IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1)
<< "Only all classes >= class 0 or >= class 1";
} else {
CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1)
<< "Only ignore class 0 is allowed";
}
{
// Shader program
@ -1161,5 +1306,16 @@ kernel void scoreKernel(
return absl::OkStatus();
}
bool TensorsToDetectionsCalculator::IsClassIndexAllowed(int class_index) {
if (class_index_set_.values.empty()) {
return true;
}
if (class_index_set_.is_allowlist) {
return class_index_set_.values.contains(class_index);
} else {
return !class_index_set_.values.contains(class_index);
}
}
} // namespace api2
} // namespace mediapipe

View File

@ -57,7 +57,12 @@ message TensorsToDetectionsCalculatorOptions {
optional bool reverse_output_order = 14 [default = false];
// The ids of classes that should be ignored during decoding the score for
// each predicted box. Can be overridden with IGNORE_CLASSES side packet.
// `ignore_classes` and `allow_classes` are mutually exclusive.
repeated int32 ignore_classes = 8;
// The ids of classes that should be allowed during decoding the score for
// each predicted box. `ignore_classes` and `allow_classes` are mutually
// exclusive.
repeated int32 allow_classes = 21 [packed = true];
optional bool sigmoid_score = 15 [default = false];
optional float score_clipping_thresh = 16;
@ -71,4 +76,40 @@ message TensorsToDetectionsCalculatorOptions {
// Score threshold for perserving decoded detections.
optional float min_score_thresh = 19;
// The maximum number of the detection results to return. If < 0, all
// available results will be returned.
// For the detection models that have built-in non max suppression op, the
// output detections are the top-scored results. Otherwise, the output
// detections are the first N results that have higher scores than
// `min_score_thresh`.
optional int32 max_results = 20 [default = -1];
// The custom model output tensor mapping.
// The indices of the "detections" tensor and the "scores" tensor are always
// required. If the model outputs an "anchors" tensor, `anchors_tensor_index`
// must be specified. If the model outputs both "classes" tensor and "number
// of detections" tensors, `classes_tensor_index` and
// `num_detections_tensor_index` must be set.
message TensorMapping {
optional int32 detections_tensor_index = 1;
optional int32 classes_tensor_index = 2;
optional int32 scores_tensor_index = 3;
optional int32 num_detections_tensor_index = 4;
optional int32 anchors_tensor_index = 5;
}
optional TensorMapping tensor_mapping = 22;
// Represents the bounding box by using the combination of boundaries,
// {ymin, xmin, ymax, xmax}.
// The default order is {ymin, xmin, ymax, xmax}.
message BoxBoundariesIndices {
optional int32 ymin = 1 [default = 0];
optional int32 xmin = 2 [default = 1];
optional int32 ymax = 3 [default = 2];
optional int32 xmax = 4 [default = 3];
}
oneof box_indices {
BoxBoundariesIndices box_boundaries_indices = 23;
}
}

View File

@ -121,8 +121,12 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
if (d > 255) d = 255;
buffer[i] = d;
}
output = ::absl::make_unique<ImageFrame>(format, width, height,
width * depth, buffer.release());
output = ::absl::make_unique<ImageFrame>(
format, width, height, width * depth, buffer.release(),
[total_size](uint8* ptr) {
::operator delete[](ptr, total_size,
std::align_val_t(EIGEN_MAX_ALIGN_BYTES));
});
} else if (input_tensor.dtype() == tensorflow::DT_UINT8) {
if (scale_factor_ != 1.0) {
return absl::InvalidArgumentError("scale_factor_ given for uint8 tensor");

View File

@ -121,10 +121,11 @@ cc_library(
deps = [
":tflite_custom_op_resolver_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/util/tflite:cpu_op_resolver",
"//mediapipe/util/tflite:op_resolver",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
],
alwayslink = 1,
)

View File

@ -12,14 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.pb.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/tflite/cpu_op_resolver.h"
#include "mediapipe/util/tflite/op_resolver.h"
#include "tensorflow/lite/core/api/op_resolver.h"
namespace mediapipe {
namespace {
constexpr char kOpResolverTag[] = "OP_RESOLVER";
} // namespace
// This calculator creates a custom op resolver as a side packet that can be
// used in TfLiteInferenceCalculator. Current custom op resolver supports the
// following custom op on CPU and GPU:
@ -27,7 +35,9 @@ namespace mediapipe {
// MaxPoolArgmax
// MaxUnpooling
//
// Usage example:
// Usage examples:
//
// For using with TfliteInferenceCalculator:
// node {
// calculator: "TfLiteCustomOpResolverCalculator"
// output_side_packet: "op_resolver"
@ -37,12 +47,27 @@ namespace mediapipe {
// }
// }
// }
//
// For using with InferenceCalculator:
// node {
// calculator: "TfLiteCustomOpResolverCalculator"
// output_side_packet: "OP_RESOLVER:op_resolver"
// node_options: {
// [type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] {
// use_gpu: true
// }
// }
// }
class TfLiteCustomOpResolverCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->OutputSidePackets()
.Index(0)
.Set<tflite::ops::builtin::BuiltinOpResolver>();
if (cc->OutputSidePackets().HasTag(kOpResolverTag)) {
cc->OutputSidePackets().Tag(kOpResolverTag).Set<tflite::OpResolver>();
} else {
cc->OutputSidePackets()
.Index(0)
.Set<tflite::ops::builtin::BuiltinOpResolver>();
}
return absl::OkStatus();
}
@ -59,7 +84,14 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
op_resolver = absl::make_unique<mediapipe::CpuOpResolver>();
}
cc->OutputSidePackets().Index(0).Set(Adopt(op_resolver.release()));
if (cc->OutputSidePackets().HasTag(kOpResolverTag)) {
cc->OutputSidePackets()
.Tag(kOpResolverTag)
.Set(mediapipe::api2::PacketAdopting<tflite::OpResolver>(
std::move(op_resolver)));
} else {
cc->OutputSidePackets().Index(0).Set(Adopt(op_resolver.release()));
}
return absl::OkStatus();
}

View File

@ -54,6 +54,7 @@ mediapipe_proto_library(
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
"//mediapipe/util:label_map_proto",
],
)
@ -304,6 +305,7 @@ cc_library(
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet",
"//mediapipe/util:resource_util",
"//mediapipe/util:label_map_cc_proto",
] + select({
"//mediapipe:android": [
"//mediapipe/util/android/file/base",
@ -350,6 +352,40 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "detection_transformation_calculator",
srcs = ["detection_transformation_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
],
alwayslink = 1,
)
cc_test(
name = "detection_transformation_calculator_test",
size = "small",
srcs = ["detection_transformation_calculator_test.cc"],
deps = [
":detection_transformation_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework/formats:detection_cc_proto",
"//mediapipe/framework/formats:location_data_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
],
)
cc_library(
name = "non_max_suppression_calculator",
srcs = ["non_max_suppression_calculator.cc"],

View File

@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "absl/container/node_hash_map.h"
#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/resource_util.h"
#if defined(MEDIAPIPE_MOBILE)
@ -53,8 +53,11 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) override;
private:
absl::node_hash_map<int, std::string> label_map_;
::mediapipe::DetectionLabelIdToTextCalculatorOptions options_;
// Local label map built from the calculator options' `label_map_path` or
// `label` field.
LabelMap local_label_map_;
bool keep_label_id_;
const LabelMap& GetLabelMap(CalculatorContext* cc);
};
REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator);
@ -69,13 +72,16 @@ absl::Status DetectionLabelIdToTextCalculator::GetContract(
absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0));
options_ =
const auto& options =
cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>();
if (options_.has_label_map_path()) {
if (options.has_label_map_path()) {
RET_CHECK(!options.has_label_map() && options.label().empty())
<< "Only can set one of the following fields in the CalculatorOptions: "
"label_map_path, label, and label_map.";
std::string string_path;
ASSIGN_OR_RETURN(string_path,
PathToResourceAsFile(options_.label_map_path()));
PathToResourceAsFile(options.label_map_path()));
std::string label_map_string;
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
@ -83,13 +89,21 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
std::string line;
int i = 0;
while (std::getline(stream, line)) {
label_map_[i++] = line;
LabelMapItem item;
item.set_name(line);
(*local_label_map_.mutable_index_to_item())[i++] = item;
}
} else {
for (int i = 0; i < options_.label_size(); ++i) {
label_map_[i] = options_.label(i);
} else if (!options.label().empty()) {
RET_CHECK(!options.has_label_map())
<< "Only can set one of the following fields in the CalculatorOptions: "
"label_map_path, label, and label_map.";
for (int i = 0; i < options.label_size(); ++i) {
LabelMapItem item;
item.set_name(options.label(i));
(*local_label_map_.mutable_index_to_item())[i] = item;
}
}
keep_label_id_ = options.keep_label_id();
return absl::OkStatus();
}
@ -101,13 +115,18 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
Detection& output_detection = output_detections.back();
bool has_text_label = false;
for (const int32 label_id : output_detection.label_id()) {
if (label_map_.find(label_id) != label_map_.end()) {
output_detection.add_label(label_map_[label_id]);
if (GetLabelMap(cc).index_to_item().find(label_id) !=
GetLabelMap(cc).index_to_item().end()) {
auto item = GetLabelMap(cc).index_to_item().at(label_id);
output_detection.add_label(item.name());
if (item.has_display_name()) {
output_detection.add_display_name(item.display_name());
}
has_text_label = true;
}
}
// Remove label_id field if text labels exist.
if (has_text_label && !options_.keep_label_id()) {
if (has_text_label && !keep_label_id_) {
output_detection.clear_label_id();
}
}
@ -117,4 +136,13 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus();
}
const LabelMap& DetectionLabelIdToTextCalculator::GetLabelMap(
CalculatorContext* cc) {
return !local_label_map_.index_to_item().empty()
? local_label_map_
: cc->Options<
::mediapipe::DetectionLabelIdToTextCalculatorOptions>()
.label_map();
}
} // namespace mediapipe

View File

@ -17,6 +17,7 @@ syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
import "mediapipe/util/label_map.proto";
message DetectionLabelIdToTextCalculatorOptions {
extend mediapipe.CalculatorOptions {
@ -26,7 +27,7 @@ message DetectionLabelIdToTextCalculatorOptions {
// Path to a label map file for getting the actual name of detected classes.
optional string label_map_path = 1;
// Alternative way to specify label map
// Alternative way to specify label map.
// label: "label for id 0"
// label: "label for id 1"
// ...
@ -36,4 +37,7 @@ message DetectionLabelIdToTextCalculatorOptions {
// could be found. By setting this field to true, it is always copied to the
// output detections.
optional bool keep_label_id = 3;
// Label map.
optional LabelMap label_map = 4;
}

View File

@ -0,0 +1,298 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
namespace mediapipe {
namespace api2 {
namespace {
template <typename T>
T BoundedValue(T value, T upper_bound) {
T output = std::min(value, upper_bound);
if (output < 0) {
return 0;
}
return output;
}
absl::Status ConvertRelativeBoundingBoxToBoundingBox(
const std::pair<int, int>& image_size, Detection* detection) {
const int image_width = image_size.first;
const int image_height = image_size.second;
const auto& relative_bbox =
detection->location_data().relative_bounding_box();
auto* bbox = detection->mutable_location_data()->mutable_bounding_box();
bbox->set_xmin(
BoundedValue<int>(relative_bbox.xmin() * image_width, image_width));
bbox->set_ymin(
BoundedValue<int>(relative_bbox.ymin() * image_height, image_height));
bbox->set_width(
BoundedValue<int>(relative_bbox.width() * image_width, image_width));
bbox->set_height(
BoundedValue<int>(relative_bbox.height() * image_height, image_height));
detection->mutable_location_data()->set_format(LocationData::BOUNDING_BOX);
detection->mutable_location_data()->clear_relative_bounding_box();
return absl::OkStatus();
}
absl::Status ConvertBoundingBoxToRelativeBoundingBox(
const std::pair<int, int>& image_size, Detection* detection) {
int image_width = image_size.first;
int image_height = image_size.second;
const auto& bbox = detection->location_data().bounding_box();
auto* relative_bbox =
detection->mutable_location_data()->mutable_relative_bounding_box();
relative_bbox->set_xmin(
BoundedValue<float>((float)bbox.xmin() / image_width, 1.0f));
relative_bbox->set_ymin(
BoundedValue<float>((float)bbox.ymin() / image_height, 1.0f));
relative_bbox->set_width(
BoundedValue<float>((float)bbox.width() / image_width, 1.0f));
relative_bbox->set_height(
BoundedValue<float>((float)bbox.height() / image_height, 1.0f));
detection->mutable_location_data()->clear_bounding_box();
detection->mutable_location_data()->set_format(
LocationData::RELATIVE_BOUNDING_BOX);
return absl::OkStatus();
}
absl::StatusOr<LocationData::Format> GetLocationDataFormat(
const Detection& detection) {
if (!detection.has_location_data()) {
return absl::InvalidArgumentError("Detection must have location data.");
}
LocationData::Format format = detection.location_data().format();
RET_CHECK(format == LocationData::RELATIVE_BOUNDING_BOX ||
format == LocationData::BOUNDING_BOX)
<< "Detection's location data format must be either "
"RELATIVE_BOUNDING_BOX or BOUNDING_BOX";
return format;
}
absl::StatusOr<LocationData::Format> GetLocationDataFormat(
std::vector<Detection>& detections) {
RET_CHECK(!detections.empty());
LocationData::Format output_format;
ASSIGN_OR_RETURN(output_format, GetLocationDataFormat(detections[0]));
for (int i = 1; i < detections.size(); ++i) {
ASSIGN_OR_RETURN(LocationData::Format format,
GetLocationDataFormat(detections[i]));
if (output_format != format) {
return absl::InvalidArgumentError(
"Input detections have different location data formats.");
}
}
return output_format;
}
absl::Status ConvertBoundingBox(const std::pair<int, int>& image_size,
Detection* detection) {
if (!detection->has_location_data()) {
return absl::InvalidArgumentError("Detection must have location data.");
}
switch (detection->location_data().format()) {
case LocationData::RELATIVE_BOUNDING_BOX:
return ConvertRelativeBoundingBoxToBoundingBox(image_size, detection);
case LocationData::BOUNDING_BOX:
return ConvertBoundingBoxToRelativeBoundingBox(image_size, detection);
default:
return absl::InvalidArgumentError(
"Detection's location data format must be either "
"RELATIVE_BOUNDING_BOX or BOUNDING_BOX.");
}
}
} // namespace
// Transforms relative bounding box(es) to pixel bounding box(es) in a detection
// proto/detection list/detection vector, or vice versa.
//
// Inputs:
// One of the following:
// DETECTION: A Detection proto.
// DETECTIONS: An std::vector<Detection>/ a DetectionList proto.
// IMAGE_SIZE: A std::pair<int, int> represention image width and height.
//
// Outputs:
// At least one of the following:
// PIXEL_DETECTION: A Detection proto with pixel bounding box.
// PIXEL_DETECTIONS: An std::vector<Detection> with pixel bounding boxes.
// PIXEL_DETECTION_LIST: A DetectionList proto with pixel bounding boxes.
// RELATIVE_DETECTION: A Detection proto with relative bounding box.
// RELATIVE_DETECTIONS: An std::vector<Detection> with relative bounding boxes.
// RELATIVE_DETECTION_LIST: A DetectionList proto with relative bounding boxes.
//
// Example config:
// For input detection(s) with relative bounding box(es):
// node {
// calculator: "DetectionTransformationCalculator"
// input_stream: "DETECTION:input_detection"
// input_stream: "IMAGE_SIZE:image_size"
// output_stream: "PIXEL_DETECTION:output_detection"
// output_stream: "PIXEL_DETECTIONS:output_detections"
// output_stream: "PIXEL_DETECTION_LIST:output_detection_list"
// }
//
// For input detection(s) with pixel bounding box(es):
// node {
// calculator: "DetectionTransformationCalculator"
// input_stream: "DETECTION:input_detection"
// input_stream: "IMAGE_SIZE:image_size"
// output_stream: "RELATIVE_DETECTION:output_detection"
// output_stream: "RELATIVE_DETECTIONS:output_detections"
// output_stream: "RELATIVE_DETECTION_LIST:output_detection_list"
// }
class DetectionTransformationCalculator : public Node {
public:
static constexpr Input<Detection>::Optional kInDetection{"DETECTION"};
static constexpr Input<OneOf<DetectionList, std::vector<Detection>>>::Optional
kInDetections{"DETECTIONS"};
static constexpr Input<std::pair<int, int>> kInImageSize{"IMAGE_SIZE"};
static constexpr Output<Detection>::Optional kOutPixelDetection{
"PIXEL_DETECTION"};
static constexpr Output<std::vector<Detection>>::Optional kOutPixelDetections{
"PIXEL_DETECTIONS"};
static constexpr Output<DetectionList>::Optional kOutPixelDetectionList{
"PIXEL_DETECTION_LIST"};
static constexpr Output<Detection>::Optional kOutRelativeDetection{
"RELATIVE_DETECTION"};
static constexpr Output<std::vector<Detection>>::Optional
kOutRelativeDetections{"RELATIVE_DETECTIONS"};
static constexpr Output<DetectionList>::Optional kOutRelativeDetectionList{
"RELATIVE_DETECTION_LIST"};
MEDIAPIPE_NODE_CONTRACT(kInDetection, kInDetections, kInImageSize,
kOutPixelDetection, kOutPixelDetections,
kOutPixelDetectionList, kOutRelativeDetection,
kOutRelativeDetections, kOutRelativeDetectionList);
static absl::Status UpdateContract(CalculatorContract* cc) {
RET_CHECK(kInImageSize(cc).IsConnected()) << "Image size must be provided.";
RET_CHECK(kInDetections(cc).IsConnected() ^ kInDetection(cc).IsConnected());
if (kInDetections(cc).IsConnected()) {
RET_CHECK(kOutPixelDetections(cc).IsConnected() ||
kOutPixelDetectionList(cc).IsConnected() ||
kOutRelativeDetections(cc).IsConnected() ||
kOutRelativeDetectionList(cc).IsConnected())
<< "Output must be a container of detections.";
}
RET_CHECK(kOutPixelDetections(cc).IsConnected() ||
kOutPixelDetectionList(cc).IsConnected() ||
kOutPixelDetection(cc).IsConnected() ||
kOutRelativeDetections(cc).IsConnected() ||
kOutRelativeDetectionList(cc).IsConnected() ||
kOutRelativeDetection(cc).IsConnected())
<< "Must connect at least one output stream.";
return absl::OkStatus();
}
absl::Status Open(CalculatorContext* cc) override {
output_pixel_bounding_boxes_ = kOutPixelDetections(cc).IsConnected() ||
kOutPixelDetectionList(cc).IsConnected() ||
kOutPixelDetection(cc).IsConnected();
output_relative_bounding_boxes_ =
kOutRelativeDetections(cc).IsConnected() ||
kOutRelativeDetectionList(cc).IsConnected() ||
kOutRelativeDetection(cc).IsConnected();
RET_CHECK(output_pixel_bounding_boxes_ ^ output_relative_bounding_boxes_)
<< "All output streams must have the same stream tag prefix, either "
"\"PIXEL\" or \"RELATIVE_\".";
return absl::OkStatus();
}
absl::Status Process(CalculatorContext* cc) override {
std::pair<int, int> image_size = kInImageSize(cc).Get();
std::vector<Detection> transformed_detections;
LocationData::Format input_location_data_format;
if (kInDetections(cc).IsConnected()) {
transformed_detections = kInDetections(cc).Visit(
[&](const DetectionList& detection_list) {
return std::vector<Detection>(detection_list.detection().begin(),
detection_list.detection().end());
},
[&](const std::vector<Detection>& detection_vector) {
return detection_vector;
});
ASSIGN_OR_RETURN(input_location_data_format,
GetLocationDataFormat(transformed_detections));
for (Detection& detection : transformed_detections) {
MP_RETURN_IF_ERROR(ConvertBoundingBox(image_size, &detection));
}
} else {
ASSIGN_OR_RETURN(input_location_data_format,
GetLocationDataFormat(kInDetection(cc).Get()));
Detection transformed_detection(kInDetection(cc).Get());
MP_RETURN_IF_ERROR(
ConvertBoundingBox(image_size, &transformed_detection));
transformed_detections.push_back(transformed_detection);
}
if (input_location_data_format == LocationData::RELATIVE_BOUNDING_BOX) {
RET_CHECK(!output_relative_bounding_boxes_)
<< "Input detections are with relative bounding box(es), and the "
"output detections must have pixel bounding box(es).";
if (kOutPixelDetection(cc).IsConnected()) {
kOutPixelDetection(cc).Send(transformed_detections[0]);
}
if (kOutPixelDetections(cc).IsConnected()) {
kOutPixelDetections(cc).Send(transformed_detections);
}
if (kOutPixelDetectionList(cc).IsConnected()) {
DetectionList detection_list;
for (const auto& detection : transformed_detections) {
detection_list.add_detection()->CopyFrom(detection);
}
kOutPixelDetectionList(cc).Send(detection_list);
}
} else {
RET_CHECK(!output_pixel_bounding_boxes_)
<< "Input detections are with pixel bounding box(es), and the "
"output detections must have relative bounding box(es).";
if (kOutRelativeDetection(cc).IsConnected()) {
kOutRelativeDetection(cc).Send(transformed_detections[0]);
}
if (kOutRelativeDetections(cc).IsConnected()) {
kOutRelativeDetections(cc).Send(transformed_detections);
}
if (kOutRelativeDetectionList(cc).IsConnected()) {
DetectionList detection_list;
for (const auto& detection : transformed_detections) {
detection_list.add_detection()->CopyFrom(detection);
}
kOutRelativeDetectionList(cc).Send(detection_list);
}
}
return absl::OkStatus();
}
private:
bool output_relative_bounding_boxes_;
bool output_pixel_bounding_boxes_;
};
MEDIAPIPE_REGISTER_NODE(DetectionTransformationCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,287 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <vector>
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/formats/location_data.pb.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
constexpr char kDetectionTag[] = "DETECTION";
constexpr char kDetectionsTag[] = "DETECTIONS";
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
constexpr char kPixelDetectionTag[] = "PIXEL_DETECTION";
constexpr char kPixelDetectionListTag[] = "PIXEL_DETECTION_LIST";
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
constexpr char kRelativeDetectionListTag[] = "RELATIVE_DETECTION_LIST";
constexpr char kRelativeDetectionsTag[] = "RELATIVE_DETECTIONS";
Detection DetectionWithBoundingBox(int32 xmin, int32 ymin, int32 width,
int32 height) {
Detection detection;
LocationData* location_data = detection.mutable_location_data();
location_data->set_format(LocationData::BOUNDING_BOX);
location_data->mutable_bounding_box()->set_xmin(xmin);
location_data->mutable_bounding_box()->set_ymin(ymin);
location_data->mutable_bounding_box()->set_width(width);
location_data->mutable_bounding_box()->set_height(height);
return detection;
}
Detection DetectionWithRelativeBoundingBox(float xmin, float ymin, float width,
float height) {
Detection detection;
LocationData* location_data = detection.mutable_location_data();
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);
location_data->mutable_relative_bounding_box()->set_xmin(xmin);
location_data->mutable_relative_bounding_box()->set_ymin(ymin);
location_data->mutable_relative_bounding_box()->set_width(width);
location_data->mutable_relative_bounding_box()->set_height(height);
return detection;
}
std::vector<Detection> ConvertToDetectionVector(
const DetectionList& detection_list) {
std::vector<Detection> output;
for (const auto& detection : detection_list.detection()) {
output.push_back(detection);
}
return output;
}
void CheckBoundingBox(const Detection& output, const Detection& expected) {
const auto& output_bbox = output.location_data().bounding_box();
const auto& expected_bbox = output.location_data().bounding_box();
EXPECT_THAT(output_bbox.xmin(), testing::Eq(expected_bbox.xmin()));
EXPECT_THAT(output_bbox.ymin(), testing::Eq(expected_bbox.ymin()));
EXPECT_THAT(output_bbox.width(), testing::Eq(expected_bbox.width()));
EXPECT_THAT(output_bbox.height(), testing::Eq(expected_bbox.height()));
}
void CheckRelativeBoundingBox(const Detection& output,
const Detection& expected) {
const auto& output_bbox = output.location_data().relative_bounding_box();
const auto& expected_bbox = output.location_data().relative_bounding_box();
EXPECT_THAT(output_bbox.xmin(), testing::FloatEq(expected_bbox.xmin()));
EXPECT_THAT(output_bbox.ymin(), testing::FloatEq(expected_bbox.ymin()));
EXPECT_THAT(output_bbox.width(), testing::FloatEq(expected_bbox.width()));
EXPECT_THAT(output_bbox.height(), testing::FloatEq(expected_bbox.height()));
}
void CheckOutputDetections(const std::vector<Detection>& expected,
const std::vector<Detection>& output) {
ASSERT_EQ(output.size(), expected.size());
for (int i = 0; i < output.size(); ++i) {
auto output_format = output[i].location_data().format();
ASSERT_TRUE(output_format == LocationData::RELATIVE_BOUNDING_BOX ||
output_format == LocationData::BOUNDING_BOX);
ASSERT_EQ(output_format, expected[i].location_data().format());
if (output_format == LocationData::RELATIVE_BOUNDING_BOX) {
CheckRelativeBoundingBox(output[i], expected[i]);
}
if (output_format == LocationData::BOUNDING_BOX) {
CheckBoundingBox(output[i], expected[i]);
}
}
}
TEST(DetectionsTransformationCalculatorTest, MissingImageSize) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTIONS:detections"
output_stream: "PIXEL_DETECTION:detection"
)pb"));
auto status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("Image size must be provided"));
}
TEST(DetectionsTransformationCalculatorTest, WrongOutputType) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTIONS:detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "PIXEL_DETECTION:detection"
)pb"));
auto status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("Output must be a container of detections"));
}
TEST(DetectionsTransformationCalculatorTest, WrongLocationDataFormat) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTION:input_detection"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "PIXEL_DETECTION:output_detection"
)pb"));
Detection detection;
detection.mutable_location_data()->set_format(LocationData::GLOBAL);
runner.MutableInputs()
->Tag(kDetectionTag)
.packets.push_back(MakePacket<Detection>(detection).At(Timestamp(0)));
std::pair<int, int> image_size({2000, 1000});
runner.MutableInputs()
->Tag(kImageSizeTag)
.packets.push_back(
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
auto status = runner.Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.message(),
testing::HasSubstr("location data format must be either "
"RELATIVE_BOUNDING_BOX or BOUNDING_BOX"));
}
TEST(DetectionsTransformationCalculatorTest,
ConvertBoundingBoxToRelativeBoundingBox) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTIONS:input_detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "RELATIVE_DETECTIONS:output_detections"
output_stream: "RELATIVE_DETECTION_LIST:output_detection_list"
)pb"));
auto detections(absl::make_unique<std::vector<Detection>>());
detections->push_back(DetectionWithBoundingBox(100, 200, 400, 300));
detections->push_back(DetectionWithBoundingBox(0, 0, 2000, 1000));
std::pair<int, int> image_size({2000, 1000});
runner.MutableInputs()
->Tag(kDetectionsTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(0)));
runner.MutableInputs()
->Tag(kImageSizeTag)
.packets.push_back(
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
std::vector<Detection> expected(
{DetectionWithRelativeBoundingBox(0.05, 0.2, 0.2, 0.3),
DetectionWithRelativeBoundingBox(0, 0, 1, 1)});
const std::vector<Packet>& detections_output =
runner.Outputs().Tag(kRelativeDetectionsTag).packets;
ASSERT_EQ(1, detections_output.size());
CheckOutputDetections(expected,
detections_output[0].Get<std::vector<Detection>>());
const std::vector<Packet>& detection_list_output =
runner.Outputs().Tag(kRelativeDetectionListTag).packets;
ASSERT_EQ(1, detection_list_output.size());
CheckOutputDetections(
expected,
ConvertToDetectionVector(detection_list_output[0].Get<DetectionList>()));
}
TEST(DetectionsTransformationCalculatorTest,
ConvertRelativeBoundingBoxToBoundingBox) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTIONS:input_detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "PIXEL_DETECTIONS:output_detections"
output_stream: "PIXEL_DETECTION_LIST:output_detection_list"
)pb"));
auto detections(absl::make_unique<std::vector<Detection>>());
detections->push_back(DetectionWithRelativeBoundingBox(0.1, 0.2, 0.3, 0.4));
detections->push_back(DetectionWithRelativeBoundingBox(0, 0, 1, 1));
std::pair<int, int> image_size({2000, 1000});
runner.MutableInputs()
->Tag(kDetectionsTag)
.packets.push_back(Adopt(detections.release()).At(Timestamp(0)));
runner.MutableInputs()
->Tag(kImageSizeTag)
.packets.push_back(
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
std::vector<Detection> expected({DetectionWithBoundingBox(100, 200, 400, 300),
DetectionWithBoundingBox(0, 0, 2000, 1000)});
const std::vector<Packet>& detections_output =
runner.Outputs().Tag(kPixelDetectionsTag).packets;
ASSERT_EQ(1, detections_output.size());
CheckOutputDetections(expected,
detections_output[0].Get<std::vector<Detection>>());
const std::vector<Packet>& detection_list_output =
runner.Outputs().Tag(kPixelDetectionListTag).packets;
ASSERT_EQ(1, detection_list_output.size());
CheckOutputDetections(
expected,
ConvertToDetectionVector(detection_list_output[0].Get<DetectionList>()));
}
TEST(DetectionsTransformationCalculatorTest, ConvertSingleDetection) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
calculator: "DetectionTransformationCalculator"
input_stream: "DETECTION:input_detection"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "PIXEL_DETECTION:outpu_detection"
output_stream: "PIXEL_DETECTIONS:output_detections"
output_stream: "PIXEL_DETECTION_LIST:output_detection_list"
)pb"));
runner.MutableInputs()
->Tag(kDetectionTag)
.packets.push_back(MakePacket<Detection>(DetectionWithRelativeBoundingBox(
0.05, 0.2, 0.2, 0.3))
.At(Timestamp(0)));
std::pair<int, int> image_size({2000, 1000});
runner.MutableInputs()
->Tag(kImageSizeTag)
.packets.push_back(
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
std::vector<Detection> expected(
{DetectionWithBoundingBox(100, 200, 400, 300)});
const std::vector<Packet>& detection_output =
runner.Outputs().Tag(kPixelDetectionTag).packets;
ASSERT_EQ(1, detection_output.size());
CheckOutputDetections(expected, {detection_output[0].Get<Detection>()});
const std::vector<Packet>& detections_output =
runner.Outputs().Tag(kPixelDetectionsTag).packets;
ASSERT_EQ(1, detections_output.size());
CheckOutputDetections(expected,
detections_output[0].Get<std::vector<Detection>>());
const std::vector<Packet>& detection_list_output =
runner.Outputs().Tag(kPixelDetectionListTag).packets;
ASSERT_EQ(1, detection_list_output.size());
CheckOutputDetections(
expected,
ConvertToDetectionVector(detection_list_output[0].Get<DetectionList>()));
}
} // namespace
} // namespace mediapipe

View File

@ -181,7 +181,7 @@ class TrackingGraphTest : public Test {
// Each image is shifted to the right and bottom by kTranslationStep
// pixels compared with the previous image.
static constexpr int kTranslationStep = 10;
static constexpr float kEqualityTolerance = 3e-4f;
static constexpr float kEqualityTolerance = 1e-3f;
};
void TrackingGraphTest::ExpectBoxAtFrame(const TimedBoxProto& box, float frame,

View File

@ -85,7 +85,7 @@ class KinematicPathSolver {
double current_position_px_;
double prior_position_px_;
double current_velocity_deg_per_s_;
uint64 current_time_;
uint64 current_time_ = 0;
// History of observations (second) and their time (first).
std::deque<std::pair<uint64, int>> raw_positions_at_time_;
// Current target position.

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Example of reading a MediaSequence dataset.
"""

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "facedetectioncpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "facedetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "faceeffect",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "facemeshgpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "handdetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "handtrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "helloworld",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "holistictrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "iristrackinggpu",

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""This script is used to set up automatic provisioning for iOS examples.
It scans the provisioning profiles used by Xcode, looking for one matching the

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "objectdetectioncpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"]) # Apache 2.0
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "objectdetectiongpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "objectdetectiontrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "posetrackinggpu",

View File

@ -24,7 +24,7 @@ load(
licenses(["notice"])
MIN_IOS_VERSION = "10.0"
MIN_IOS_VERSION = "11.0"
alias(
name = "selfiesegmentationgpu",

View File

@ -234,7 +234,9 @@ cc_library(
"//mediapipe/framework/tool:options_map",
"//mediapipe/framework/tool:packet_generator_wrapper_calculator_cc_proto",
"//mediapipe/framework/tool:tag_map",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
)
@ -348,6 +350,7 @@ cc_library(
"//mediapipe/framework/tool:validate",
"//mediapipe/framework/tool:validate_name",
"//mediapipe/gpu:graph_support",
"//mediapipe/gpu:gpu_service",
"//mediapipe/util:cpu_util",
] + select({
"//conditions:default": ["//mediapipe/gpu:gpu_shared_data_internal"],
@ -416,7 +419,6 @@ cc_library(
"//mediapipe/framework/tool:status_util",
"//mediapipe/framework/tool:tag_map",
"//mediapipe/framework/tool:validate_name",
"//mediapipe/gpu:graph_support",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
@ -613,7 +615,11 @@ cc_library(
hdrs = ["graph_service.h"],
visibility = [":mediapipe_internal"],
deps = [
":packet",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/strings",
],
)

View File

@ -167,7 +167,6 @@ struct IsCompatibleType<V, OneOf<U...>>
template <typename T>
inline Packet<T> PacketBase::As() const {
if (!payload_) return Packet<T>().At(timestamp_);
packet_internal::Holder<T>* typed_payload = payload_->As<T>();
internal::CheckCompatibleType(*payload_, internal::Wrap<T>{});
return Packet<T>(payload_).At(timestamp_);
}
@ -217,8 +216,8 @@ class Packet : public Packet<internal::Generic> {
const T& operator*() const { return Get(); }
const T* operator->() const { return &Get(); }
template <typename U>
T GetOr(U&& v) const {
template <typename U, typename TT = T>
std::enable_if_t<!std::is_abstract_v<TT>, TT> GetOr(U&& v) const {
return IsEmpty() ? static_cast<T>(absl::forward<U>(v)) : **this;
}

View File

@ -4,11 +4,15 @@ namespace api2 {
namespace {
#if defined(TEST_NO_ASSIGN_WRONG_PACKET_TYPE)
void AssignWrongPacketType() { Packet<int> p = MakePacket<float>(1.0); }
int AssignWrongPacketType() {
Packet<int> p = MakePacket<float>(1.0);
return *p;
}
#elif defined(TEST_NO_ASSIGN_GENERIC_TO_SPECIFIC)
void AssignWrongPacketType() {
int AssignWrongPacketType() {
Packet<> p = MakePacket<float>(1.0);
Packet<int> p2 = p;
return *p2;
}
#endif

View File

@ -264,6 +264,23 @@ TEST(PacketTest, Polymorphism) {
EXPECT_EQ((**mutable_base).name(), "Derived");
}
class AbstractBase {
public:
virtual ~AbstractBase() = default;
virtual absl::string_view name() const = 0;
};
class ConcreteDerived : public AbstractBase {
public:
absl::string_view name() const override { return "ConcreteDerived"; }
};
TEST(PacketTest, PolymorphismAbstract) {
Packet<AbstractBase> base =
PacketAdopting<AbstractBase>(absl::make_unique<ConcreteDerived>());
EXPECT_EQ(base->name(), "ConcreteDerived");
}
} // namespace
} // namespace api2
} // namespace mediapipe

View File

@ -40,6 +40,17 @@ TEST(PortTest, DeletedCopyConstructorInput) {
EXPECT_EQ(std::string(kSideOutputPort.Tag()), "SIDE_OUTPUT");
}
class AbstractBase {
public:
virtual ~AbstractBase() = default;
virtual absl::string_view name() const = 0;
};
TEST(PortTest, Abstract) {
static constexpr Input<AbstractBase> kInputPort{"INPUT"};
EXPECT_EQ(std::string(kInputPort.Tag()), "INPUT");
}
} // namespace
} // namespace api2
} // namespace mediapipe

View File

@ -21,6 +21,8 @@
#include <typeindex>
// TODO: Move protos in another CL after the C++ code migration.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/graph_service.h"
#include "mediapipe/framework/mediapipe_options.pb.h"
@ -147,7 +149,7 @@ class CalculatorContract {
bool IsOptional() const { return optional_; }
private:
GraphServiceBase service_;
const GraphServiceBase& service_;
bool optional_ = false;
};
@ -156,9 +158,12 @@ class CalculatorContract {
return it->second;
}
const std::map<std::string, GraphServiceRequest>& ServiceRequests() const {
return service_requests_;
}
// A GraphService's key is always a static constant, so we can use string_view
// as the key type without lifetime issues.
using ServiceReqMap =
absl::flat_hash_map<absl::string_view, GraphServiceRequest>;
const ServiceReqMap& ServiceRequests() const { return service_requests_; }
private:
template <class T>
@ -180,7 +185,7 @@ class CalculatorContract {
std::string input_stream_handler_;
MediaPipeOptions input_stream_handler_options_;
std::string node_name_;
std::map<std::string, GraphServiceRequest> service_requests_;
ServiceReqMap service_requests_;
bool process_timestamps_ = false;
TimestampDiff timestamp_offset_ = TimestampDiff::Unset();

View File

@ -226,6 +226,16 @@ absl::Status CalculatorGraph::InitializeStreams() {
return absl::OkStatus();
}
// Hack for backwards compatibility with ancient GPU calculators. Can it
// be retired yet?
static void MaybeFixupLegacyGpuNodeContract(CalculatorNode& node) {
#if !MEDIAPIPE_DISABLE_GPU
if (node.Contract().InputSidePackets().HasTag(kGpuSharedTagName)) {
const_cast<CalculatorContract&>(node.Contract()).UseService(kGpuService);
}
#endif // !MEDIAPIPE_DISABLE_GPU
}
absl::Status CalculatorGraph::InitializeCalculatorNodes() {
// Check if the user has specified a maximum queue size for an input stream.
max_queue_size_ = validated_graph_->Config().max_queue_size();
@ -246,6 +256,7 @@ absl::Status CalculatorGraph::InitializeCalculatorNodes() {
validated_graph_.get(), node_ref, input_stream_managers_.get(),
output_stream_managers_.get(), output_side_packets_.get(),
&buffer_size_hint, profiler_);
MaybeFixupLegacyGpuNodeContract(*nodes_.back());
if (buffer_size_hint > 0) {
max_queue_size_ = std::max(max_queue_size_, buffer_size_hint);
}
@ -283,6 +294,7 @@ absl::Status CalculatorGraph::InitializePacketGeneratorNodes(
validated_graph_.get(), node_ref, input_stream_managers_.get(),
output_stream_managers_.get(), output_side_packets_.get(),
&buffer_size_hint, profiler_);
MaybeFixupLegacyGpuNodeContract(*nodes_.back());
if (!result.ok()) {
// Collect as many errors as we can before failing.
errors.push_back(result);
@ -495,9 +507,8 @@ absl::StatusOr<Packet> CalculatorGraph::GetOutputSidePacket(
<< "\" because it doesn't exist.";
}
Packet output_packet;
if (scheduler_.IsTerminated()) {
// Side-packets from calculators can be retrieved only after the graph is
// done.
if (!output_side_packets_[side_packet_index].GetPacket().IsEmpty() ||
scheduler_.IsTerminated()) {
output_packet = output_side_packets_[side_packet_index].GetPacket();
}
if (output_packet.IsEmpty()) {
@ -546,6 +557,7 @@ absl::Status CalculatorGraph::StartRun(
#if !MEDIAPIPE_DISABLE_GPU
absl::Status CalculatorGraph::SetGpuResources(
std::shared_ptr<::mediapipe::GpuResources> resources) {
RET_CHECK_NE(resources, nullptr);
auto gpu_service = service_manager_.GetServiceObject(kGpuService);
RET_CHECK_EQ(gpu_service, nullptr)
<< "The GPU resources have already been configured.";
@ -557,68 +569,89 @@ std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources()
return service_manager_.GetServiceObject(kGpuService);
}
absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
static Packet GetLegacyGpuSharedSidePacket(
const std::map<std::string, Packet>& side_packets) {
std::map<std::string, Packet> additional_side_packets;
bool update_sp = false;
bool uses_gpu = false;
for (const auto& node : nodes_) {
if (node->UsesGpu()) {
uses_gpu = true;
break;
}
auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName);
if (legacy_sp_iter == side_packets.end()) return {};
// Note that, because of b/116875321, the legacy side packet may be set but
// empty. But it's ok, because here we return an empty packet to indicate the
// missing case anyway.
return legacy_sp_iter->second;
}
absl::Status CalculatorGraph::MaybeSetUpGpuServiceFromLegacySidePacket(
Packet legacy_sp) {
if (legacy_sp.IsEmpty()) return absl::OkStatus();
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
if (gpu_resources) {
LOG(WARNING)
<< "::mediapipe::GpuSharedData provided as a side packet while the "
<< "graph already had one; ignoring side packet";
return absl::OkStatus();
}
if (uses_gpu) {
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
gpu_resources = legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources;
return service_manager_.SetServiceObject(kGpuService, gpu_resources);
}
auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName);
// Workaround for b/116875321: CalculatorRunner provides an empty packet,
// instead of just leaving it unset.
bool has_legacy_sp = legacy_sp_iter != side_packets.end() &&
!legacy_sp_iter->second.IsEmpty();
if (gpu_resources) {
if (has_legacy_sp) {
LOG(WARNING)
<< "::mediapipe::GpuSharedData provided as a side packet while the "
<< "graph already had one; ignoring side packet";
}
update_sp = true;
} else {
if (has_legacy_sp) {
gpu_resources =
legacy_sp_iter->second.Get<::mediapipe::GpuSharedData*>()
->gpu_resources;
} else {
ASSIGN_OR_RETURN(gpu_resources, ::mediapipe::GpuResources::Create());
update_sp = true;
}
MP_RETURN_IF_ERROR(
service_manager_.SetServiceObject(kGpuService, gpu_resources));
}
// Create or replace the legacy side packet if needed.
if (update_sp) {
legacy_gpu_shared_.reset(new ::mediapipe::GpuSharedData(gpu_resources));
additional_side_packets[kGpuSharedSidePacketName] =
MakePacket<::mediapipe::GpuSharedData*>(legacy_gpu_shared_.get());
}
// Set up executors.
for (auto& node : nodes_) {
if (node->UsesGpu()) {
MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(node.get()));
}
}
for (const auto& name_executor : gpu_resources->GetGpuExecutors()) {
MP_RETURN_IF_ERROR(
SetExecutorInternal(name_executor.first, name_executor.second));
}
std::map<std::string, Packet> CalculatorGraph::MaybeCreateLegacyGpuSidePacket(
Packet legacy_sp) {
std::map<std::string, Packet> additional_side_packets;
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
if (gpu_resources &&
(legacy_sp.IsEmpty() ||
legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources !=
gpu_resources)) {
legacy_gpu_shared_ =
absl::make_unique<mediapipe::GpuSharedData>(gpu_resources);
additional_side_packets[kGpuSharedSidePacketName] =
MakePacket<::mediapipe::GpuSharedData*>(legacy_gpu_shared_.get());
}
return additional_side_packets;
}
static bool UsesGpu(const CalculatorNode& node) {
return node.Contract().ServiceRequests().contains(kGpuService.key);
}
absl::Status CalculatorGraph::PrepareGpu() {
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
if (!gpu_resources) return absl::OkStatus();
// Set up executors.
for (auto& node : nodes_) {
if (UsesGpu(*node)) {
MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(node.get()));
}
}
for (const auto& name_executor : gpu_resources->GetGpuExecutors()) {
MP_RETURN_IF_ERROR(
SetExecutorInternal(name_executor.first, name_executor.second));
}
return absl::OkStatus();
}
#endif // !MEDIAPIPE_DISABLE_GPU
absl::Status CalculatorGraph::PrepareServices() {
for (const auto& node : nodes_) {
for (const auto& [key, request] : node->Contract().ServiceRequests()) {
auto packet = service_manager_.GetServicePacket(request.Service());
if (!packet.IsEmpty()) continue;
auto packet_or = request.Service().CreateDefaultObject();
if (packet_or.ok()) {
MP_RETURN_IF_ERROR(service_manager_.SetServicePacket(
request.Service(), std::move(packet_or).value()));
} else if (request.IsOptional()) {
continue;
} else {
return absl::InternalError(absl::StrCat(
"Service \"", request.Service().key, "\", required by node ",
node->DebugName(), ", was not provided and cannot be created: ",
std::move(packet_or).status().message()));
}
}
}
return absl::OkStatus();
}
absl::Status CalculatorGraph::PrepareForRun(
const std::map<std::string, Packet>& extra_side_packets,
const std::map<std::string, Packet>& stream_headers) {
@ -637,7 +670,13 @@ absl::Status CalculatorGraph::PrepareForRun(
std::map<std::string, Packet> additional_side_packets;
#if !MEDIAPIPE_DISABLE_GPU
ASSIGN_OR_RETURN(additional_side_packets, PrepareGpu(extra_side_packets));
auto legacy_sp = GetLegacyGpuSharedSidePacket(extra_side_packets);
MP_RETURN_IF_ERROR(MaybeSetUpGpuServiceFromLegacySidePacket(legacy_sp));
#endif // !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(PrepareServices());
#if !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(PrepareGpu());
additional_side_packets = MaybeCreateLegacyGpuSidePacket(legacy_sp);
#endif // !MEDIAPIPE_DISABLE_GPU
const std::map<std::string, Packet>* input_side_packets;

View File

@ -165,10 +165,13 @@ class CalculatorGraph {
StatusOrPoller AddOutputStreamPoller(const std::string& stream_name,
bool observe_timestamp_bounds = false);
// Gets output side packet by name after the graph is done. However, base
// packets (generated by PacketGenerators) can be retrieved before
// graph is done. Returns error if the graph is still running (for non-base
// packets) or the output side packet is not found or empty.
// Gets output side packet by name. The output side packet can be successfully
// retrevied in one of the following situations:
// - The graph is done.
// - The output side packet has been generated by a calculator and the graph
// is currently idle.
// - The side packet is a base packet generated by a PacketGenerator.
// Returns error if the the output side packet is not found or empty.
absl::StatusOr<Packet> GetOutputSidePacket(const std::string& packet_name);
// Runs the graph after adding the given extra input side packets. All
@ -367,13 +370,8 @@ class CalculatorGraph {
std::shared_ptr<GpuResources> GetGpuResources() const;
absl::Status SetGpuResources(std::shared_ptr<GpuResources> resources);
// Helper for PrepareForRun. If it returns a non-empty map, those packets
// must be added to the existing side packets, replacing existing values
// that have the same key.
absl::StatusOr<std::map<std::string, Packet>> PrepareGpu(
const std::map<std::string, Packet>& side_packets);
#endif // !MEDIAPIPE_DISABLE_GPU
template <typename T>
absl::Status SetServiceObject(const GraphService<T>& service,
std::shared_ptr<T> object) {
@ -495,6 +493,18 @@ class CalculatorGraph {
const std::map<std::string, Packet>& extra_side_packets,
const std::map<std::string, Packet>& stream_headers);
absl::Status PrepareServices();
#if !MEDIAPIPE_DISABLE_GPU
absl::Status MaybeSetUpGpuServiceFromLegacySidePacket(Packet legacy_sp);
// Helper for PrepareForRun. If it returns a non-empty map, those packets
// must be added to the existing side packets, replacing existing values
// that have the same key.
std::map<std::string, Packet> MaybeCreateLegacyGpuSidePacket(
Packet legacy_sp);
absl::Status PrepareGpu();
#endif // !MEDIAPIPE_DISABLE_GPU
// Cleans up any remaining state after the run and returns any errors that may
// have occurred during the run. Called after the scheduler has terminated.
absl::Status FinishRun();

View File

@ -732,11 +732,12 @@ TEST(CalculatorGraph, GetOutputSidePacket) {
status_or_packet = graph.GetOutputSidePacket("unknown");
EXPECT_FALSE(status_or_packet.ok());
EXPECT_EQ(absl::StatusCode::kNotFound, status_or_packet.status().code());
// Should return UNAVAILABLE before graph is done for valid non-base
// packets.
// Should return the packet after the graph becomes idle.
MP_ASSERT_OK(graph.WaitUntilIdle());
status_or_packet = graph.GetOutputSidePacket("num_of_packets");
EXPECT_FALSE(status_or_packet.ok());
EXPECT_EQ(absl::StatusCode::kUnavailable, status_or_packet.status().code());
MP_ASSERT_OK(status_or_packet);
EXPECT_EQ(max_count, status_or_packet.value().Get<int>());
EXPECT_EQ(Timestamp::Unset(), status_or_packet.value().Timestamp());
// Should stil return a base even before graph is done.
status_or_packet = graph.GetOutputSidePacket("output_uint64");
MP_ASSERT_OK(status_or_packet);
@ -896,5 +897,23 @@ TEST(CalculatorGraph, GeneratorAfterCalculatorProcess) {
}
}
TEST(CalculatorGraph, GetOutputSidePacketAfterCalculatorIsOpened) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
node {
calculator: "IntegerOutputSidePacketCalculator"
output_side_packet: "offset"
}
)pb");
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
// Must be called to ensure that the calculator is opened.
MP_ASSERT_OK(graph.WaitUntilIdle());
absl::StatusOr<Packet> status_or_packet = graph.GetOutputSidePacket("offset");
MP_ASSERT_OK(status_or_packet);
EXPECT_EQ(1, status_or_packet.value().Get<int>());
}
} // namespace
} // namespace mediapipe

View File

@ -46,7 +46,6 @@
#include "mediapipe/framework/tool/status_util.h"
#include "mediapipe/framework/tool/tag_map.h"
#include "mediapipe/framework/tool/validate_name.h"
#include "mediapipe/gpu/graph_support.h"
namespace mediapipe {
@ -155,11 +154,6 @@ absl::Status CalculatorNode::Initialize(
const CalculatorContract& contract = node_type_info_->Contract();
uses_gpu_ =
node_type_info_->InputSidePacketTypes().HasTag(kGpuSharedTagName) ||
ContainsKey(node_type_info_->Contract().ServiceRequests(),
kGpuService.key);
// TODO Propagate types between calculators when SetAny is used.
MP_RETURN_IF_ERROR(InitializeOutputSidePackets(
@ -397,7 +391,7 @@ absl::Status CalculatorNode::PrepareForRun(
std::move(schedule_callback), error_callback);
output_stream_handler_->PrepareForRun(error_callback);
const auto& contract = node_type_info_->Contract();
const auto& contract = Contract();
input_side_packet_types_ = RemoveOmittedPacketTypes(
contract.InputSidePackets(), all_side_packets, validated_graph_);
MP_RETURN_IF_ERROR(input_side_packet_handler_.PrepareForRun(

View File

@ -195,9 +195,6 @@ class CalculatorNode {
// Called by SchedulerQueue when a node is opened.
void NodeOpened() ABSL_LOCKS_EXCLUDED(status_mutex_);
// Returns whether this is a GPU calculator node.
bool UsesGpu() const { return uses_gpu_; }
// Returns the scheduler queue the node is assigned to.
internal::SchedulerQueue* GetSchedulerQueue() const {
return scheduler_queue_;
@ -234,6 +231,12 @@ class CalculatorNode {
return *calculator_state_;
}
// Returns the node's contract.
// Must not be called before the CalculatorNode is initialized.
const CalculatorContract& Contract() const {
return node_type_info_->Contract();
}
private:
// Sets up the output side packets from the main flat array.
absl::Status InitializeOutputSidePackets(
@ -363,9 +366,6 @@ class CalculatorNode {
std::unique_ptr<OutputStreamHandler> output_stream_handler_;
// Whether this is a GPU calculator.
bool uses_gpu_ = false;
// True if CleanupAfterRun() needs to call CloseNode().
bool needs_to_close_ = false;

View File

@ -187,6 +187,21 @@ cc_library(
],
)
config_setting(
name = "opencv",
define_values = {
"use_opencv": "true",
},
)
config_setting(
name = "portable_opencv",
define_values = {
"use_portable_opencv": "true",
"use_opencv": "false",
},
)
cc_library(
name = "location",
srcs = ["location.cc"],
@ -194,6 +209,8 @@ cc_library(
defines = select({
"//conditions:default": [],
"//mediapipe:android": ["MEDIAPIPE_ANDROID_OPENCV"],
":portable_opencv": ["MEDIAPIPE_ANDROID_OPENCV"],
":opencv": [],
}),
visibility = ["//visibility:public"],
deps = [

View File

@ -76,7 +76,7 @@ class Tensor {
public:
// No resources are allocated here.
enum class ElementType { kNone, kFloat16, kFloat32, kUInt8 };
enum class ElementType { kNone, kFloat16, kFloat32, kUInt8, kInt8 };
struct Shape {
Shape() = default;
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
@ -217,6 +217,8 @@ class Tensor {
return sizeof(float);
case ElementType::kUInt8:
return 1;
case ElementType::kInt8:
return 1;
}
}
int bytes() const { return shape_.num_elements() * element_size(); }

View File

@ -16,6 +16,12 @@
#define MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_H_
#include <memory>
#include <type_traits>
#include <utility>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/status.h"
namespace mediapipe {
@ -27,18 +33,74 @@ namespace mediapipe {
// IMPORTANT: this is an experimental API. Get in touch with the MediaPipe team
// if you want to use it. In most cases, you should use a side packet instead.
struct GraphServiceBase {
class GraphServiceBase {
public:
// TODO: fix services for which default init is broken, remove
// this setting.
enum DefaultInitSupport {
kAllowDefaultInitialization,
kDisallowDefaultInitialization
};
constexpr GraphServiceBase(const char* key) : key(key) {}
virtual ~GraphServiceBase() = default;
inline virtual absl::StatusOr<Packet> CreateDefaultObject() const {
return DefaultInitializationUnsupported();
}
const char* key;
protected:
absl::Status DefaultInitializationUnsupported() const {
return absl::UnimplementedError(absl::StrCat(
"Graph service '", key, "' does not support default initialization"));
}
};
template <typename T>
struct GraphService : public GraphServiceBase {
class GraphService : public GraphServiceBase {
public:
using type = T;
using packet_type = std::shared_ptr<T>;
constexpr GraphService(const char* key) : GraphServiceBase(key) {}
constexpr GraphService(const char* my_key, DefaultInitSupport default_init =
kDisallowDefaultInitialization)
: GraphServiceBase(my_key), default_init_(default_init) {}
absl::StatusOr<Packet> CreateDefaultObject() const override {
if (default_init_ != kAllowDefaultInitialization) {
return DefaultInitializationUnsupported();
}
auto packet_or = CreateDefaultObjectInternal();
if (packet_or.ok()) {
return MakePacket<std::shared_ptr<T>>(std::move(packet_or).value());
} else {
return packet_or.status();
}
}
private:
absl::StatusOr<std::shared_ptr<T>> CreateDefaultObjectInternal() const {
auto call_create = [](auto x) -> decltype(decltype(x)::type::Create()) {
return decltype(x)::type::Create();
};
if constexpr (std::is_invocable_r_v<absl::StatusOr<std::shared_ptr<T>>,
decltype(call_create), type_tag<T>>) {
return T::Create();
}
if constexpr (std::is_default_constructible_v<T>) {
return std::make_shared<T>();
}
return DefaultInitializationUnsupported();
}
template <class U>
struct type_tag {
using type = U;
};
DefaultInitSupport default_init_;
};
template <typename T>

View File

@ -35,6 +35,8 @@ class GraphServiceManager {
Packet GetServicePacket(const GraphServiceBase& service) const;
std::map<std::string, Packet> service_packets_;
friend class CalculatorGraph;
};
} // namespace mediapipe

View File

@ -6,11 +6,13 @@
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
const GraphService<int> kIntService("mediapipe::IntService");
} // namespace
TEST(GraphServiceManager, SetGetServiceObject) {
GraphServiceManager service_manager;
constexpr GraphService<int> kIntService("mediapipe::IntService");
EXPECT_EQ(service_manager.GetServiceObject(kIntService), nullptr);
MP_EXPECT_OK(service_manager.SetServiceObject(kIntService,
@ -22,8 +24,6 @@ TEST(GraphServiceManager, SetGetServiceObject) {
TEST(GraphServiceManager, SetServicePacket) {
GraphServiceManager service_manager;
constexpr GraphService<int> kIntService("mediapipe::IntService");
MP_EXPECT_OK(service_manager.SetServicePacket(
kIntService,
mediapipe::MakePacket<std::shared_ptr<int>>(std::make_shared<int>(100))));
@ -36,8 +36,6 @@ TEST(GraphServiceManager, ServicePackets) {
EXPECT_TRUE(service_manager.ServicePackets().empty());
constexpr GraphService<int> kIntService("mediapipe::IntService");
MP_EXPECT_OK(service_manager.SetServiceObject(kIntService,
std::make_shared<int>(100)));

View File

@ -150,5 +150,12 @@ TEST_F(GraphServiceTest, OptionalIsAvailable) {
EXPECT_EQ(PacketValues<int>(output_packets_), (std::vector<int>{108}));
}
TEST_F(GraphServiceTest, CreateDefault) {
EXPECT_FALSE(kTestService.CreateDefaultObject().ok());
MP_EXPECT_OK(kAnotherService.CreateDefaultObject());
EXPECT_FALSE(kNoDefaultService.CreateDefaultObject().ok());
MP_EXPECT_OK(kNeedsCreateService.CreateDefaultObject());
}
} // namespace
} // namespace mediapipe

View File

@ -50,15 +50,18 @@ absl::Status InputStreamHandler::SetupInputShards(
return absl::OkStatus();
}
std::vector<std::pair<std::string, int>>
std::vector<std::tuple<std::string, int, int, Timestamp>>
InputStreamHandler::GetMonitoringInfo() {
std::vector<std::pair<std::string, int>> monitoring_info_vector;
std::vector<std::tuple<std::string, int, int, Timestamp>>
monitoring_info_vector;
for (auto& stream : input_stream_managers_) {
if (!stream) {
continue;
}
monitoring_info_vector.emplace_back(
std::pair<std::string, int>(stream->Name(), stream->QueueSize()));
std::tuple<std::string, int, int, Timestamp>(
stream->Name(), stream->QueueSize(), stream->NumPacketsAdded(),
stream->MinTimestampOrBound(nullptr)));
}
return monitoring_info_vector;
}

View File

@ -94,7 +94,7 @@ class InputStreamHandler {
// Returns a vector of pairs of stream name and queue size for monitoring
// purpose.
std::vector<std::pair<std::string, int>> GetMonitoringInfo();
std::vector<std::tuple<std::string, int, int, Timestamp>> GetMonitoringInfo();
// Resets the input stream handler and its underlying input streams for
// another run of the graph.

View File

@ -329,6 +329,11 @@ Packet InputStreamManager::PopQueueHead(bool* stream_is_done) {
return packet;
}
int InputStreamManager::NumPacketsAdded() const {
absl::MutexLock lock(&stream_mutex_);
return num_packets_added_;
}
int InputStreamManager::QueueSize() const {
absl::MutexLock lock(&stream_mutex_);
return static_cast<int>(queue_.size());

View File

@ -87,12 +87,14 @@ class InputStreamManager {
// Timestamp::PostStream(), the packet must be the only packet in the
// stream.
// Violation of any of these conditions causes an error status.
absl::Status AddPackets(const std::list<Packet>& container, bool* notify);
absl::Status AddPackets(const std::list<Packet>& container, bool* notify)
ABSL_LOCKS_EXCLUDED(stream_mutex_);
// Move a list of timestamped packets. Sets "notify" to true if the queue
// becomes non-empty. Does nothing if the input stream is closed. After the
// move, all packets in the container must be empty.
absl::Status MovePackets(std::list<Packet>* container, bool* notify);
absl::Status MovePackets(std::list<Packet>* container, bool* notify)
ABSL_LOCKS_EXCLUDED(stream_mutex_);
// Closes the input stream. This function can be called multiple times.
void Close() ABSL_LOCKS_EXCLUDED(stream_mutex_);
@ -140,6 +142,9 @@ class InputStreamManager {
// Timestamp::Done() after the pop.
Packet PopQueueHead(bool* stream_is_done) ABSL_LOCKS_EXCLUDED(stream_mutex_);
// Returns the number of packets in the queue.
int NumPacketsAdded() const ABSL_LOCKS_EXCLUDED(stream_mutex_);
// Returns the number of packets in the queue.
int QueueSize() const ABSL_LOCKS_EXCLUDED(stream_mutex_);

View File

@ -767,6 +767,7 @@ TEST_F(InputStreamManagerTest, QueueSizeTest) {
EXPECT_EQ(3, num_packets_dropped_);
EXPECT_TRUE(input_stream_manager_->IsEmpty());
EXPECT_FALSE(stream_is_done_);
EXPECT_EQ(3, input_stream_manager_->NumPacketsAdded());
packets.clear();
packets.push_back(MakePacket<std::string>("packet 4").At(Timestamp(60)));
@ -776,6 +777,7 @@ TEST_F(InputStreamManagerTest, QueueSizeTest) {
input_stream_manager_->AddPackets(packets, &notify_)); // Notification
EXPECT_FALSE(input_stream_manager_->IsEmpty());
EXPECT_TRUE(notify_);
EXPECT_EQ(5, input_stream_manager_->NumPacketsAdded());
expected_queue_becomes_full_count_ = 2;
expected_queue_becomes_not_full_count_ = 1;

View File

@ -12,6 +12,8 @@ def mediapipe_cc_test(
timeout = None,
args = [],
additional_deps = DEFAULT_ADDITIONAL_TEST_DEPS,
platforms = ["linux", "android", "ios", "wasm"],
exclude_platforms = None,
# ios_unit_test arguments
ios_minimum_os_version = "9.0",
# android_cc_test arguments

View File

@ -412,8 +412,7 @@ cc_library(
name = "status_matchers",
testonly = 1,
hdrs = ["status_matchers.h"],
# Use this library through "mediapipe/framework/port:gtest_main".
visibility = ["//mediapipe/framework/port:__pkg__"],
visibility = ["//visibility:private"],
deps = [
":status",
"@com_google_googletest//:gtest",

View File

@ -16,8 +16,14 @@
namespace mediapipe {
const GraphService<TestServiceObject> kTestService("test_service");
const GraphService<int> kAnotherService("another_service");
const GraphService<TestServiceObject> kTestService(
"test_service", GraphServiceBase::kDisallowDefaultInitialization);
const GraphService<int> kAnotherService(
"another_service", GraphServiceBase::kAllowDefaultInitialization);
const GraphService<NoDefaultConstructor> kNoDefaultService(
"no_default_service", GraphServiceBase::kAllowDefaultInitialization);
const GraphService<NeedsCreateMethod> kNeedsCreateService(
"needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();

View File

@ -16,6 +16,7 @@
#define MEDIAPIPE_FRAMEWORK_TEST_SERVICE_H_
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/graph_service.h"
namespace mediapipe {
@ -24,6 +25,23 @@ using TestServiceObject = std::map<std::string, int>;
extern const GraphService<TestServiceObject> kTestService;
extern const GraphService<int> kAnotherService;
class NoDefaultConstructor {
public:
NoDefaultConstructor() = delete;
};
extern const GraphService<NoDefaultConstructor> kNoDefaultService;
class NeedsCreateMethod {
public:
static absl::StatusOr<std::shared_ptr<NeedsCreateMethod>> Create() {
return std::shared_ptr<NeedsCreateMethod>(new NeedsCreateMethod());
}
private:
NeedsCreateMethod() = default;
};
extern const GraphService<NeedsCreateMethod> kNeedsCreateService;
// Use a service.
class TestServiceCalculator : public CalculatorBase {
public:

View File

@ -134,7 +134,7 @@ cc_library(
name = "name_util",
srcs = ["name_util.cc"],
hdrs = ["name_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"],
visibility = ["//visibility:public"],
deps = [
":validate_name",
"//mediapipe/framework:calculator_cc_proto",

View File

@ -225,7 +225,7 @@ std::string GetTestOutputsDir() {
return output_dir;
}
std::string GetTestDataDir(const std::string& package_base_path) {
std::string GetTestDataDir(absl::string_view package_base_path) {
return file::JoinPath(GetTestRootDir(), package_base_path, "testdata/");
}
@ -270,7 +270,7 @@ absl::StatusOr<std::unique_ptr<ImageFrame>> LoadTestImage(
format, width, height, width * output_channels, data, stbi_image_free);
}
std::unique_ptr<ImageFrame> LoadTestPng(const std::string& path,
std::unique_ptr<ImageFrame> LoadTestPng(absl::string_view path,
ImageFormat::Format format) {
return nullptr;
}

View File

@ -63,7 +63,7 @@ std::string GetTestFilePath(absl::string_view relative_path);
// directory.
// This handles the different paths where test data ends up when using
// ion_cc_test on various platforms.
std::string GetTestDataDir(const std::string& package_base_path);
std::string GetTestDataDir(absl::string_view package_base_path);
// Loads a binary graph from path. Returns true iff successful.
bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path);
@ -75,7 +75,7 @@ absl::StatusOr<std::unique_ptr<ImageFrame>> LoadTestImage(
// Loads a PNG image from path using the given ImageFormat. Returns nullptr in
// case of failure.
std::unique_ptr<ImageFrame> LoadTestPng(
const std::string& path, ImageFormat::Format format = ImageFormat::SRGBA);
absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA);
// Returns the luminance image of |original_image|.
// The format of |original_image| must be sRGB or sRGBA.

View File

@ -38,14 +38,19 @@ cc_library(
srcs = ["gpu_service.cc"],
hdrs = ["gpu_service.h"],
visibility = ["//visibility:public"],
deps = ["//mediapipe/framework:graph_service"],
deps = ["//mediapipe/framework:graph_service"] + select({
"//conditions:default": [
":gpu_shared_data_internal",
],
"//mediapipe/gpu:disable_gpu": [],
}),
)
cc_library(
name = "graph_support",
hdrs = ["graph_support.h"],
visibility = ["//visibility:public"],
deps = [":gpu_service"],
deps = ["//mediapipe/framework:graph_service"],
)
GL_BASE_LINK_OPTS = select({
@ -366,7 +371,6 @@ objc_library(
hdrs = ["pixel_buffer_pool_util.h"],
copts = [
"-Wno-shorten-64-to-32",
"-std=c++17",
],
sdk_frameworks = [
"Accelerate",
@ -389,7 +393,6 @@ objc_library(
copts = [
"-x objective-c++",
"-Wno-shorten-64-to-32",
"-std=c++17",
],
features = ["-layering_check"],
sdk_frameworks = [
@ -425,7 +428,6 @@ objc_library(
copts = [
"-x objective-c++",
"-Wno-shorten-64-to-32",
"-std=c++17",
],
sdk_frameworks = [
"CoreVideo",
@ -691,7 +693,6 @@ objc_library(
name = "gl_calculator_helper_ios",
copts = [
"-Wno-shorten-64-to-32",
"-std=c++17",
],
visibility = ["//visibility:public"],
deps = [
@ -707,7 +708,6 @@ objc_library(
hdrs = ["MPPMetalHelper.h"],
copts = [
"-Wno-shorten-64-to-32",
"-std=c++17",
],
features = ["-layering_check"],
sdk_frameworks = [
@ -801,7 +801,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":gl_calculator_helper",
":gpu_buffer_storage_image_frame",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:status",
@ -927,7 +926,6 @@ mediapipe_cc_proto_library(
objc_library(
name = "metal_copy_calculator",
srcs = ["MetalCopyCalculator.mm"],
copts = ["-std=c++17"],
features = ["-layering_check"],
sdk_frameworks = [
"CoreVideo",
@ -946,7 +944,6 @@ objc_library(
objc_library(
name = "metal_rgb_weight_calculator",
srcs = ["MetalRgbWeightCalculator.mm"],
copts = ["-std=c++17"],
features = ["-layering_check"],
sdk_frameworks = [
"CoreVideo",
@ -964,7 +961,6 @@ objc_library(
objc_library(
name = "metal_sobel_calculator",
srcs = ["MetalSobelCalculator.mm"],
copts = ["-std=c++17"],
features = ["-layering_check"],
sdk_frameworks = [
"CoreVideo",
@ -982,7 +978,6 @@ objc_library(
objc_library(
name = "metal_sobel_compute_calculator",
srcs = ["MetalSobelComputeCalculator.mm"],
copts = ["-std=c++17"],
features = ["-layering_check"],
sdk_frameworks = [
"CoreVideo",
@ -1018,7 +1013,6 @@ objc_library(
objc_library(
name = "mps_threshold_calculator",
srcs = ["MPSThresholdCalculator.mm"],
copts = ["-std=c++17"],
sdk_frameworks = [
"CoreVideo",
"Metal",
@ -1053,7 +1047,6 @@ objc_library(
],
copts = [
"-Wno-shorten-64-to-32",
"-std=c++17",
],
data = [
"//mediapipe/objc:testdata/googlelogo_color_272x92dp.png",

View File

@ -23,6 +23,7 @@
#include "absl/base/dynamic_annotations.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
@ -358,6 +359,7 @@ absl::Status GlContext::FinishInitialization(bool create_thread) {
GlContext::GlContext() {}
GlContext::~GlContext() {
destructing_ = true;
// Note: on Apple platforms, this object contains Objective-C objects.
// The destructor will release them, but ARC must be on.
#ifdef __OBJC__
@ -366,11 +368,16 @@ GlContext::~GlContext() {
#endif
#endif // __OBJC__
auto clear_attachments = [this] {
attachments_.clear();
if (profiling_helper_) {
profiling_helper_->LogAllTimestamps();
}
};
if (thread_) {
auto status = thread_->Run([this] {
if (profiling_helper_) {
profiling_helper_->LogAllTimestamps();
}
auto status = thread_->Run([this, clear_attachments] {
clear_attachments();
return ExitContext(nullptr);
});
LOG_IF(ERROR, !status.ok())
@ -378,6 +385,17 @@ GlContext::~GlContext() {
if (thread_->IsCurrentThread()) {
thread_.release()->SelfDestruct();
}
} else {
if (IsCurrent()) {
clear_attachments();
} else {
ContextBinding saved_context;
auto status = SwitchContextAndRun([&clear_attachments] {
clear_attachments();
return absl::OkStatus();
});
LOG_IF(ERROR, !status.ok()) << status;
}
}
DestroyContext();
}
@ -501,6 +519,14 @@ absl::Status GlContext::SwitchContext(ContextBinding* saved_context,
}
}
GlContext::ContextBinding GlContext::ThisContextBinding() {
GlContext::ContextBinding result = ThisContextBindingPlatform();
if (!destructing_) {
result.context_object = shared_from_this();
}
return result;
}
absl::Status GlContext::EnterContext(ContextBinding* saved_context) {
DCHECK(HasContext());
return SwitchContext(saved_context, ThisContextBinding());

View File

@ -21,6 +21,7 @@
#include <functional>
#include <memory>
#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/executor.h"
#include "mediapipe/framework/mediapipe_profiling.h"
@ -285,6 +286,48 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
// Sets default texture filtering parameters.
void SetStandardTextureParams(GLenum target, GLint internal_format);
template <class T>
using AttachmentPtr = std::unique_ptr<T, std::function<void(void*)>>;
template <class T, class... Args>
static std::enable_if_t<!std::is_array<T>::value, AttachmentPtr<T>>
MakeAttachmentPtr(Args&&... args) {
return {new T(std::forward<Args>(args)...),
[](void* ptr) { delete static_cast<T*>(ptr); }};
}
class AttachmentBase {};
template <class T>
class Attachment : public AttachmentBase {
public:
using FactoryT = std::function<AttachmentPtr<T>(GlContext&)>;
Attachment(FactoryT factory) : factory_(factory) {}
Attachment(const Attachment&) = delete;
Attachment(Attachment&&) = delete;
Attachment& operator=(const Attachment&) = delete;
Attachment& operator=(Attachment&&) = delete;
T& Get(GlContext& ctx) const { return ctx.GetCachedAttachment(*this); }
const FactoryT& factory() const { return factory_; }
private:
FactoryT factory_;
};
// TOOD: const result?
template <class T>
T& GetCachedAttachment(const Attachment<T>& attachment) {
DCHECK(IsCurrent());
AttachmentPtr<void>& entry = attachments_[&attachment];
if (entry == nullptr) {
entry = attachment.factory()(*this);
}
return *static_cast<T*>(entry.get());
}
// These are used for testing specific SyncToken implementations. Do not use
// outside of tests.
enum class SyncTokenTypeForTest {
@ -387,6 +430,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
// A binding that can be used to make this GlContext current.
ContextBinding ThisContextBinding();
// Fill in platform-specific fields. Must _not_ set context_obj.
ContextBinding ThisContextBindingPlatform();
// Fills in a ContextBinding with platform-specific information about which
// context is current on this thread.
static void GetCurrentContextBinding(ContextBinding* binding);
@ -409,6 +454,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
// better mechanism?
bool can_linear_filter_float_textures_;
absl::flat_hash_map<const AttachmentBase*, AttachmentPtr<void>> attachments_;
// Number of glFinish calls completed on the GL thread.
// Changes should be guarded by mutex_. However, we use simple atomic
// loads for efficiency on the fast path.
@ -428,6 +475,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
absl::CondVar wait_for_gl_finish_cv_ ABSL_GUARDED_BY(mutex_);
std::unique_ptr<mediapipe::GlProfilingHelper> profiling_helper_ = nullptr;
bool destructing_ = false;
};
// For backward compatibility. TODO: migrate remaining callers.

View File

@ -84,9 +84,8 @@ void GlContext::DestroyContext() {
}
}
GlContext::ContextBinding GlContext::ThisContextBinding() {
GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
GlContext::ContextBinding result;
result.context_object = shared_from_this();
result.context = context_;
return result;
}

View File

@ -269,9 +269,8 @@ void GlContext::DestroyContext() {
#endif // __ANDROID__
}
GlContext::ContextBinding GlContext::ThisContextBinding() {
GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
GlContext::ContextBinding result;
result.context_object = shared_from_this();
result.display = display_;
result.draw_surface = surface_;
result.read_surface = surface_;

View File

@ -134,9 +134,8 @@ void GlContext::DestroyContext() {
}
}
GlContext::ContextBinding GlContext::ThisContextBinding() {
GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
GlContext::ContextBinding result;
result.context_object = shared_from_this();
result.context = context_;
return result;
}

View File

@ -173,9 +173,8 @@ void GlContext::DestroyContext() {
}
}
GlContext::ContextBinding GlContext::ThisContextBinding() {
GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
GlContext::ContextBinding result;
result.context_object = shared_from_this();
result.context = context_;
return result;
}

View File

@ -111,7 +111,7 @@ absl::Status QuadRenderer::GlRender(float frame_width, float frame_height,
FrameScaleMode scale_mode,
FrameRotation rotation,
bool flip_horizontal, bool flip_vertical,
bool flip_texture) {
bool flip_texture) const {
RET_CHECK(program_) << "Must setup the program before rendering.";
glUseProgram(program_);

View File

@ -72,7 +72,7 @@ class QuadRenderer {
absl::Status GlRender(float frame_width, float frame_height, float view_width,
float view_height, FrameScaleMode scale_mode,
FrameRotation rotation, bool flip_horizontal,
bool flip_vertical, bool flip_texture);
bool flip_vertical, bool flip_texture) const;
// Deletes the rendering program. Must be called withn the GL context where
// it was created.
void GlTeardown();

View File

@ -144,7 +144,7 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
}},
{GpuBufferFormat::kRGBAFloat128,
{
{GL_RGBA, GL_RGBA, GL_FLOAT, 1},
{GL_RGBA32F, GL_RGBA, GL_FLOAT, 1},
}},
}};

View File

@ -16,6 +16,7 @@
namespace mediapipe {
const GraphService<GpuResources> kGpuService("kGpuService");
const GraphService<GpuResources> kGpuService(
"kGpuService", GraphServiceBase::kAllowDefaultInitialization);
} // namespace mediapipe

View File

@ -17,9 +17,18 @@
#include "mediapipe/framework/graph_service.h"
#if !MEDIAPIPE_DISABLE_GPU
#include "mediapipe/gpu/gpu_shared_data_internal.h"
#endif // !MEDIAPIPE_DISABLE_GPU
namespace mediapipe {
class GpuResources;
#if MEDIAPIPE_DISABLE_GPU
class GpuResources {
GpuResources() = delete;
};
#endif // !MEDIAPIPE_DISABLE_GPU
extern const GraphService<GpuResources> kGpuService;
} // namespace mediapipe

Some files were not shown because too many files have changed in this diff Show More