Project import generated by Copybara.
GitOrigin-RevId: 19a829ffd755edb43e54d20c0e7b9348512d5108
This commit is contained in:
parent
c6c80c3745
commit
7fb37c80e8
7
.bazelrc
7
.bazelrc
|
@ -32,6 +32,9 @@ build:macos --copt=-w
|
|||
# Sets the default Apple platform to macOS.
|
||||
build --apple_platform_type=macos
|
||||
|
||||
# Compile ObjC++ files with C++17
|
||||
build --per_file_copt=.*\.mm\$@-std=c++17
|
||||
|
||||
# Allow debugging with XCODE
|
||||
build --apple_generate_dsym
|
||||
|
||||
|
@ -88,6 +91,10 @@ build:darwin_x86_64 --apple_platform_type=macos
|
|||
build:darwin_x86_64 --macos_minimum_os=10.12
|
||||
build:darwin_x86_64 --cpu=darwin_x86_64
|
||||
|
||||
build:darwin_arm64 --apple_platform_type=macos
|
||||
build:darwin_arm64 --macos_minimum_os=10.16
|
||||
build:darwin_arm64 --cpu=darwin_arm64
|
||||
|
||||
# This bazelrc file is meant to be written by a setup script.
|
||||
try-import %workspace%/.configure.bazelrc
|
||||
|
||||
|
|
|
@ -202,7 +202,10 @@ new_local_repository(
|
|||
new_local_repository(
|
||||
name = "macos_opencv",
|
||||
build_file = "@//third_party:opencv_macos.BUILD",
|
||||
path = "/usr/local/opt/opencv@3",
|
||||
# For local MacOS builds, the path should point to an opencv@3 installation.
|
||||
# If you edit the path here, you will also need to update the corresponding
|
||||
# prefix in "opencv_macos.BUILD".
|
||||
path = "/usr/local",
|
||||
)
|
||||
|
||||
new_local_repository(
|
||||
|
|
|
@ -53,7 +53,7 @@ the following:
|
|||
|
||||
```bash
|
||||
$ echo "android_sdk_repository(name = \"androidsdk\")" >> WORKSPACE
|
||||
$ echo "android_ndk_repository(name = \"androidndk\")" >> WORKSPACE
|
||||
$ echo "android_ndk_repository(name = \"androidndk\", api_level=21)" >> WORKSPACE
|
||||
```
|
||||
|
||||
In order to use MediaPipe on earlier Android versions, MediaPipe needs to switch
|
||||
|
|
|
@ -59,6 +59,21 @@ OpenGL ES profile shading language version string: OpenGL ES GLSL ES 3.20
|
|||
OpenGL ES profile extensions:
|
||||
```
|
||||
|
||||
If you have connected to your computer through SSH and find when you probe for
|
||||
GPU information you see the output:
|
||||
|
||||
```bash
|
||||
glxinfo | grep -i opengl
|
||||
Error: unable to open display
|
||||
```
|
||||
|
||||
Try re-establishing your SSH connection with the `-X` option and try again. For
|
||||
example:
|
||||
|
||||
```bash
|
||||
ssh -X <user>@<host>
|
||||
```
|
||||
|
||||
*Notice the ES 3.20 text above.*
|
||||
|
||||
You need to see ES 3.1 or greater printed in order to perform TFLite inference
|
||||
|
|
|
@ -131,7 +131,7 @@ Create a `BUILD` file in the `$APPLICATION_PATH` and add the following build
|
|||
rules:
|
||||
|
||||
```
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
load(
|
||||
"@build_bazel_rules_apple//apple:ios.bzl",
|
||||
|
|
|
@ -32,9 +32,14 @@ example apps, start from, start from
|
|||
xcode-select --install
|
||||
```
|
||||
|
||||
3. Install [Bazel](https://bazel.build/).
|
||||
3. Install [Bazelisk](https://github.com/bazelbuild/bazelisk)
|
||||
.
|
||||
|
||||
We recommend using [Homebrew](https://brew.sh/) to get the latest version.
|
||||
We recommend using [Homebrew](https://brew.sh/) to get the latest versions.
|
||||
|
||||
```bash
|
||||
brew install bazelisk
|
||||
```
|
||||
|
||||
4. Set Python 3.7 as the default Python version and install the Python "six"
|
||||
library. This is needed for TensorFlow.
|
||||
|
@ -187,6 +192,9 @@ Note: When you ask Xcode to run an app, by default it will use the Debug
|
|||
configuration. Some of our demos are computationally heavy; you may want to use
|
||||
the Release configuration for better performance.
|
||||
|
||||
Note: Due to an imcoptibility caused by one of our dependencies, MediaPipe
|
||||
cannot be used for apps running on the iPhone Simulator on Apple Silicon (M1).
|
||||
|
||||
Tip: To switch build configuration in Xcode, click on the target menu, choose
|
||||
"Edit Scheme...", select the Run action, and switch the Build Configuration from
|
||||
Debug to Release. Note that this is set independently for each target.
|
||||
|
|
|
@ -258,13 +258,14 @@ Many of the following settings are advanced and not recommended for general
|
|||
usage. Consult [Enabling tracing and profiling](#enabling-tracing-and-profiling)
|
||||
for a friendlier introduction.
|
||||
|
||||
histogram_interval_size_usec :Specifies the size of the runtimes histogram
|
||||
intervals (in microseconds) to generate the histogram of the Process() time. The
|
||||
last interval extends to +inf. If not specified, the interval is 1000000 usec =
|
||||
1 sec.
|
||||
histogram_interval_size_usec
|
||||
: Specifies the size of the runtimes histogram intervals (in microseconds) to
|
||||
generate the histogram of the `Process()` time. The last interval extends to
|
||||
+inf. If not specified, the interval is 1000000 usec = 1 sec.
|
||||
|
||||
num_histogram_intervals :Specifies the number of intervals to generate the
|
||||
histogram of the `Process()` runtime. If not specified, one interval is used.
|
||||
num_histogram_intervals
|
||||
: Specifies the number of intervals to generate the histogram of the
|
||||
`Process()` runtime. If not specified, one interval is used.
|
||||
|
||||
enable_profiler
|
||||
: If true, the profiler starts profiling when graph is initialized.
|
||||
|
@ -288,7 +289,7 @@ trace_event_types_disabled
|
|||
|
||||
trace_log_path
|
||||
: The output directory and base-name prefix for trace log files. Log files are
|
||||
written to: StrCat(trace_log_path, index, "`.binarypb`")
|
||||
written to: `StrCat(trace_log_path, index, ".binarypb")`
|
||||
|
||||
trace_log_count
|
||||
: The number of trace log files retained. The trace log files are named
|
||||
|
@ -310,8 +311,8 @@ trace_log_instant_events
|
|||
|
||||
trace_log_interval_count
|
||||
: The number of trace log intervals per file. The total log duration is:
|
||||
`trace_log_interval_usec * trace_log_file_count * trace_log_interval_count`.
|
||||
The default value specifies 10 intervals per file.
|
||||
`trace_log_interval_usec * trace_log_count * trace_log_interval_count`. The
|
||||
default value specifies 10 intervals per file.
|
||||
|
||||
trace_log_disabled
|
||||
: An option to turn ON/OFF writing trace files to disk. Saving trace files to
|
||||
|
|
|
@ -75,6 +75,7 @@ alias(
|
|||
actual = select({
|
||||
":macos_i386": ":macos_i386",
|
||||
":macos_x86_64": ":macos_x86_64",
|
||||
":macos_arm64": ":macos_arm64",
|
||||
"//conditions:default": ":macos_i386", # Arbitrarily chosen from above.
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
|
@ -119,6 +120,15 @@ config_setting(
|
|||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "macos_arm64",
|
||||
values = {
|
||||
"apple_platform_type": "macos",
|
||||
"cpu": "darwin_arm64",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
[
|
||||
config_setting(
|
||||
name = arch,
|
||||
|
|
|
@ -214,6 +214,7 @@ cc_library(
|
|||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/formats:rect_cc_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
|
@ -1257,3 +1258,36 @@ cc_test(
|
|||
"@com_google_absl//absl/time",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "get_vector_item_calculator",
|
||||
srcs = ["get_vector_item_calculator.cc"],
|
||||
hdrs = ["get_vector_item_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "vector_size_calculator",
|
||||
srcs = ["vector_size_calculator.cc"],
|
||||
hdrs = ["vector_size_calculator.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:classification_cc_proto",
|
||||
"//mediapipe/framework/formats:landmark_cc_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -28,6 +28,10 @@ typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedLandmarkList>>
|
|||
BeginLoopNormalizedLandmarkListVectorCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopNormalizedLandmarkListVectorCalculator);
|
||||
|
||||
// A calculator to process std::vector<int>.
|
||||
typedef BeginLoopCalculator<std::vector<int>> BeginLoopIntCalculator;
|
||||
REGISTER_CALCULATOR(BeginLoopIntCalculator);
|
||||
|
||||
// A calculator to process std::vector<NormalizedRect>.
|
||||
typedef BeginLoopCalculator<std::vector<::mediapipe::NormalizedRect>>
|
||||
BeginLoopNormalizedRectCalculator;
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
#include "mediapipe/framework/formats/rect.pb.h"
|
||||
#include "mediapipe/util/render_data.pb.h"
|
||||
|
@ -50,4 +51,8 @@ REGISTER_CALCULATOR(EndLoopClassificationListCalculator);
|
|||
typedef EndLoopCalculator<std::vector<TfLiteTensor>> EndLoopTensorCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopTensorCalculator);
|
||||
|
||||
typedef EndLoopCalculator<std::vector<::mediapipe::Detection>>
|
||||
EndLoopDetectionCalculator;
|
||||
REGISTER_CALCULATOR(EndLoopDetectionCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
32
mediapipe/calculators/core/get_vector_item_calculator.cc
Normal file
32
mediapipe/calculators/core/get_vector_item_calculator.cc
Normal file
|
@ -0,0 +1,32 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/get_vector_item_calculator.h"
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using GetLandmarkListVectorItemCalculator =
|
||||
GetVectorItemCalculator<mediapipe::LandmarkList>;
|
||||
REGISTER_CALCULATOR(GetLandmarkListVectorItemCalculator);
|
||||
|
||||
using GetClassificationListVectorItemCalculator =
|
||||
GetVectorItemCalculator<mediapipe::ClassificationList>;
|
||||
REGISTER_CALCULATOR(GetClassificationListVectorItemCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
77
mediapipe/calculators/core/get_vector_item_calculator.h
Normal file
77
mediapipe/calculators/core/get_vector_item_calculator.h
Normal file
|
@ -0,0 +1,77 @@
|
|||
// Copyright 2019 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// A calcutlator to return an item from the vector by its index.
|
||||
//
|
||||
// Inputs:
|
||||
// VECTOR - std::vector<T>
|
||||
// Vector to take an item from.
|
||||
// INDEX - int
|
||||
// Index of the item to return.
|
||||
//
|
||||
// Outputs:
|
||||
// ITEM - T
|
||||
// Item from the vector at given index.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "Get{SpecificType}VectorItemCalculator"
|
||||
// input_stream: "VECTOR:vector"
|
||||
// input_stream: "INDEX:index"
|
||||
// input_stream: "ITEM:item"
|
||||
// }
|
||||
//
|
||||
template <typename T>
|
||||
class GetVectorItemCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::vector<T>> kIn{"VECTOR"};
|
||||
static constexpr Input<int> kIdx{"INDEX"};
|
||||
static constexpr Output<T> kOut{"ITEM"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kIdx, kOut);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (kIn(cc).IsEmpty() || kIdx(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const std::vector<T>& items = kIn(cc).Get();
|
||||
const int idx = kIdx(cc).Get();
|
||||
|
||||
RET_CHECK_LT(idx, items.size());
|
||||
kOut(cc).Send(items[idx]);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_GET_VECTOR_ITEM_CALCULATOR_H_
|
|
@ -83,4 +83,7 @@ REGISTER_CALCULATOR(SplitClassificationListVectorCalculator);
|
|||
typedef SplitVectorCalculator<uint64_t, false> SplitUint64tVectorCalculator;
|
||||
REGISTER_CALCULATOR(SplitUint64tVectorCalculator);
|
||||
|
||||
typedef SplitVectorCalculator<float, false> SplitFloatVectorCalculator;
|
||||
REGISTER_CALCULATOR(SplitFloatVectorCalculator);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
32
mediapipe/calculators/core/vector_size_calculator.cc
Normal file
32
mediapipe/calculators/core/vector_size_calculator.cc
Normal file
|
@ -0,0 +1,32 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "mediapipe/calculators/core/vector_size_calculator.h"
|
||||
|
||||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/landmark.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
using LandmarkListVectorSizeCalculator =
|
||||
VectorSizeCalculator<mediapipe::LandmarkList>;
|
||||
REGISTER_CALCULATOR(LandmarkListVectorSizeCalculator);
|
||||
|
||||
using ClassificationListVectorSizeCalculator =
|
||||
VectorSizeCalculator<mediapipe::ClassificationList>;
|
||||
REGISTER_CALCULATOR(ClassificationListVectorSizeCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
64
mediapipe/calculators/core/vector_size_calculator.h
Normal file
64
mediapipe/calculators/core/vector_size_calculator.h
Normal file
|
@ -0,0 +1,64 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_
|
||||
#define MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// A calcutlator to return vector size.
|
||||
//
|
||||
// Inputs:
|
||||
// VECTOR - std::vector<T>
|
||||
// Vector which size to return.
|
||||
//
|
||||
// Outputs:
|
||||
// SIZE - int
|
||||
// Size of the input vector.
|
||||
//
|
||||
// Example config:
|
||||
// node {
|
||||
// calculator: "{SpecificType}VectorSizeCalculator"
|
||||
// input_stream: "VECTOR:vector"
|
||||
// output_stream: "SIZE:vector_size"
|
||||
// }
|
||||
//
|
||||
template <typename T>
|
||||
class VectorSizeCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<std::vector<T>> kIn{"VECTOR"};
|
||||
static constexpr Output<int> kOut{"SIZE"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kIn, kOut);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
if (kIn(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
kOut(cc).Send(kIn(cc).Get().size());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_CALCULATORS_CORE_VECTOR_SIZE_CALCULATOR_H_
|
|
@ -421,6 +421,10 @@ absl::Status ScaleImageCalculator::InitializeFromOptions() {
|
|||
alignment_boundary_ = options_.alignment_boundary();
|
||||
}
|
||||
|
||||
if (options_.has_output_format()) {
|
||||
output_format_ = options_.output_format();
|
||||
}
|
||||
|
||||
downscaler_.reset(new ImageResizer(options_.post_sharpening_coefficient()));
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -433,13 +437,17 @@ absl::Status ScaleImageCalculator::ValidateImageFormats() const {
|
|||
<< "The output image format was set to UNKNOWN.";
|
||||
// TODO Remove these conditions.
|
||||
RET_CHECK(output_format_ == ImageFormat::SRGB ||
|
||||
output_format_ == ImageFormat::SRGBA ||
|
||||
(input_format_ == output_format_ &&
|
||||
output_format_ == ImageFormat::YCBCR420P))
|
||||
<< "Outputting YCbCr420P images from SRGB input is not yet supported";
|
||||
RET_CHECK(input_format_ == output_format_ ||
|
||||
input_format_ == ImageFormat::YCBCR420P)
|
||||
(input_format_ == ImageFormat::YCBCR420P &&
|
||||
output_format_ == ImageFormat::SRGB) ||
|
||||
(input_format_ == ImageFormat::SRGB &&
|
||||
output_format_ == ImageFormat::SRGBA))
|
||||
<< "Conversion of the color space (except from "
|
||||
"YCbCr420P to SRGB) is not yet supported.";
|
||||
"YCbCr420P to SRGB or SRGB to SRBGA) is not yet supported.";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -604,6 +612,15 @@ absl::Status ScaleImageCalculator::Process(CalculatorContext* cc) {
|
|||
.Add(output_image.release(), cc->InputTimestamp());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
} else if (input_format_ == ImageFormat::SRGB &&
|
||||
output_format_ == ImageFormat::SRGBA) {
|
||||
image_frame = &cc->Inputs().Get(input_data_id_).Get<ImageFrame>();
|
||||
cv::Mat input_mat = ::mediapipe::formats::MatView(image_frame);
|
||||
converted_image_frame.Reset(ImageFormat::SRGBA, image_frame->Width(),
|
||||
image_frame->Height(), alignment_boundary_);
|
||||
cv::Mat output_mat = ::mediapipe::formats::MatView(&converted_image_frame);
|
||||
cv::cvtColor(input_mat, output_mat, cv::COLOR_RGB2RGBA, 4);
|
||||
image_frame = &converted_image_frame;
|
||||
} else {
|
||||
image_frame = &cc->Inputs().Get(input_data_id_).Get<ImageFrame>();
|
||||
MP_RETURN_IF_ERROR(ValidateImageFrame(cc, *image_frame));
|
||||
|
|
|
@ -28,7 +28,9 @@ package(default_visibility = ["//visibility:private"])
|
|||
|
||||
exports_files(
|
||||
glob(["testdata/image_to_tensor/*"]),
|
||||
visibility = ["//mediapipe/calculators/image:__subpackages__"],
|
||||
visibility = [
|
||||
"//mediapipe/calculators/image:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
|
@ -64,15 +66,16 @@ cc_library(
|
|||
":inference_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||
"//mediapipe/framework/tool:subgraph_expansion",
|
||||
"//mediapipe/util/tflite:config",
|
||||
"//mediapipe/util/tflite:tflite_model_loader",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@org_tensorflow//tensorflow/lite:framework",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -91,6 +94,7 @@ cc_library(
|
|||
"//mediapipe/util/tflite:tflite_gpu_runner",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape",
|
||||
],
|
||||
|
@ -142,6 +146,8 @@ cc_library(
|
|||
":inference_calculator_interface",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/xnnpack:xnnpack_delegate",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/c:c_api_types",
|
||||
] + select({
|
||||
"//conditions:default": [
|
||||
"//mediapipe/util:cpu_util",
|
||||
|
|
|
@ -142,22 +142,35 @@ class ImageToTensorCalculator : public Node {
|
|||
cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
|
||||
|
||||
RET_CHECK(options.has_output_tensor_float_range() ||
|
||||
options.has_output_tensor_int_range())
|
||||
options.has_output_tensor_int_range() ||
|
||||
options.has_output_tensor_uint_range())
|
||||
<< "Output tensor range is required.";
|
||||
if (options.has_output_tensor_float_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_float_range().min(),
|
||||
options.output_tensor_float_range().max())
|
||||
<< "Valid output float tensor range is required.";
|
||||
}
|
||||
if (options.has_output_tensor_uint_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_uint_range().min(),
|
||||
options.output_tensor_uint_range().max())
|
||||
<< "Valid output uint tensor range is required.";
|
||||
RET_CHECK_GE(options.output_tensor_uint_range().min(), 0)
|
||||
<< "The minimum of the output uint tensor range must be "
|
||||
"non-negative.";
|
||||
RET_CHECK_LE(options.output_tensor_uint_range().max(), 255)
|
||||
<< "The maximum of the output uint tensor range must be less than or "
|
||||
"equal to 255.";
|
||||
}
|
||||
if (options.has_output_tensor_int_range()) {
|
||||
RET_CHECK_LT(options.output_tensor_int_range().min(),
|
||||
options.output_tensor_int_range().max())
|
||||
<< "Valid output int tensor range is required.";
|
||||
RET_CHECK_GE(options.output_tensor_int_range().min(), 0)
|
||||
<< "The minimum of the output int tensor range must be non-negative.";
|
||||
RET_CHECK_LE(options.output_tensor_int_range().max(), 255)
|
||||
RET_CHECK_GE(options.output_tensor_int_range().min(), -128)
|
||||
<< "The minimum of the output int tensor range must be greater than "
|
||||
"or equal to -128.";
|
||||
RET_CHECK_LE(options.output_tensor_int_range().max(), 127)
|
||||
<< "The maximum of the output int tensor range must be less than or "
|
||||
"equal to 255.";
|
||||
"equal to 127.";
|
||||
}
|
||||
RET_CHECK_GT(options.output_tensor_width(), 0)
|
||||
<< "Valid output tensor width is required.";
|
||||
|
@ -187,15 +200,19 @@ class ImageToTensorCalculator : public Node {
|
|||
options_ = cc->Options<mediapipe::ImageToTensorCalculatorOptions>();
|
||||
output_width_ = options_.output_tensor_width();
|
||||
output_height_ = options_.output_tensor_height();
|
||||
is_int_output_ = options_.has_output_tensor_int_range();
|
||||
is_float_output_ = options_.has_output_tensor_float_range();
|
||||
if (options_.has_output_tensor_uint_range()) {
|
||||
range_min_ =
|
||||
is_int_output_
|
||||
? static_cast<float>(options_.output_tensor_int_range().min())
|
||||
: options_.output_tensor_float_range().min();
|
||||
static_cast<float>(options_.output_tensor_uint_range().min());
|
||||
range_max_ =
|
||||
is_int_output_
|
||||
? static_cast<float>(options_.output_tensor_int_range().max())
|
||||
: options_.output_tensor_float_range().max();
|
||||
static_cast<float>(options_.output_tensor_uint_range().max());
|
||||
} else if (options_.has_output_tensor_int_range()) {
|
||||
range_min_ = static_cast<float>(options_.output_tensor_int_range().min());
|
||||
range_max_ = static_cast<float>(options_.output_tensor_int_range().max());
|
||||
} else {
|
||||
range_min_ = options_.output_tensor_float_range().min();
|
||||
range_max_ = options_.output_tensor_float_range().max();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -275,6 +292,17 @@ class ImageToTensorCalculator : public Node {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor::ElementType GetOutputTensorType() {
|
||||
if (is_float_output_) {
|
||||
return Tensor::ElementType::kFloat32;
|
||||
}
|
||||
if (range_min_ < 0) {
|
||||
return Tensor::ElementType::kInt8;
|
||||
} else {
|
||||
return Tensor::ElementType::kUInt8;
|
||||
}
|
||||
}
|
||||
|
||||
absl::StatusOr<std::shared_ptr<const mediapipe::Image>> GetInputImage(
|
||||
CalculatorContext* cc) {
|
||||
if (kIn(cc).IsConnected()) {
|
||||
|
@ -305,7 +333,7 @@ class ImageToTensorCalculator : public Node {
|
|||
const Image& image) {
|
||||
// Lazy initialization of the GPU or CPU converter.
|
||||
if (image.UsesGpu()) {
|
||||
if (is_int_output_) {
|
||||
if (!is_float_output_) {
|
||||
return absl::UnimplementedError(
|
||||
"ImageToTensorConverter for the input GPU image currently doesn't "
|
||||
"support quantization.");
|
||||
|
@ -337,11 +365,9 @@ class ImageToTensorCalculator : public Node {
|
|||
} else {
|
||||
if (!cpu_converter_) {
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
ASSIGN_OR_RETURN(cpu_converter_,
|
||||
CreateOpenCvConverter(
|
||||
cc, GetBorderMode(),
|
||||
is_int_output_ ? Tensor::ElementType::kUInt8
|
||||
: Tensor::ElementType::kFloat32));
|
||||
ASSIGN_OR_RETURN(
|
||||
cpu_converter_,
|
||||
CreateOpenCvConverter(cc, GetBorderMode(), GetOutputTensorType()));
|
||||
#else
|
||||
LOG(FATAL) << "Cannot create image to tensor opencv converter since "
|
||||
"MEDIAPIPE_DISABLE_OPENCV is defined.";
|
||||
|
@ -356,7 +382,7 @@ class ImageToTensorCalculator : public Node {
|
|||
mediapipe::ImageToTensorCalculatorOptions options_;
|
||||
int output_width_ = 0;
|
||||
int output_height_ = 0;
|
||||
bool is_int_output_ = false;
|
||||
bool is_float_output_ = false;
|
||||
float range_min_ = 0.0f;
|
||||
float range_max_ = 1.0f;
|
||||
};
|
||||
|
|
|
@ -39,6 +39,14 @@ message ImageToTensorCalculatorOptions {
|
|||
optional int64 max = 2;
|
||||
}
|
||||
|
||||
// Range of uint values [min, max].
|
||||
// min, must be strictly less than max.
|
||||
// Please note that UIntRange is supported for CPU tensors only.
|
||||
message UIntRange {
|
||||
optional uint64 min = 1;
|
||||
optional uint64 max = 2;
|
||||
}
|
||||
|
||||
// Pixel extrapolation methods. See @border_mode.
|
||||
enum BorderMode {
|
||||
BORDER_UNSPECIFIED = 0;
|
||||
|
@ -58,6 +66,7 @@ message ImageToTensorCalculatorOptions {
|
|||
oneof range {
|
||||
FloatRange output_tensor_float_range = 4;
|
||||
IntRange output_tensor_int_range = 7;
|
||||
UIntRange output_tensor_uint_range = 8;
|
||||
}
|
||||
|
||||
// For CONVENTIONAL mode for OpenGL, input image starts at bottom and needs
|
||||
|
|
|
@ -76,12 +76,21 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet,
|
|||
}
|
||||
std::string output_tensor_range;
|
||||
if (output_int_tensor) {
|
||||
if (range_min < 0) {
|
||||
output_tensor_range = absl::Substitute(R"(output_tensor_int_range {
|
||||
min: $0
|
||||
max: $1
|
||||
})",
|
||||
static_cast<int>(range_min),
|
||||
static_cast<int>(range_max));
|
||||
} else {
|
||||
output_tensor_range = absl::Substitute(R"(output_tensor_uint_range {
|
||||
min: $0
|
||||
max: $1
|
||||
})",
|
||||
static_cast<uint>(range_min),
|
||||
static_cast<uint>(range_max));
|
||||
}
|
||||
} else {
|
||||
output_tensor_range = absl::Substitute(R"(output_tensor_float_range {
|
||||
min: $0
|
||||
|
@ -141,9 +150,15 @@ void RunTestWithInputImagePacket(const Packet& input_image_packet,
|
|||
auto view = tensor.GetCpuReadView();
|
||||
cv::Mat tensor_mat;
|
||||
if (output_int_tensor) {
|
||||
if (range_min < 0) {
|
||||
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kInt8);
|
||||
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8SC3,
|
||||
const_cast<int8*>(view.buffer<int8>()));
|
||||
} else {
|
||||
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kUInt8);
|
||||
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_8UC3,
|
||||
const_cast<uint8*>(view.buffer<uint8>()));
|
||||
}
|
||||
} else {
|
||||
EXPECT_EQ(tensor.element_type(), Tensor::ElementType::kFloat32);
|
||||
tensor_mat = cv::Mat(tensor_height, tensor_width, CV_32FC3,
|
||||
|
@ -190,27 +205,30 @@ const std::vector<InputType> kInputTypesToTest = {InputType::kImageFrame,
|
|||
InputType::kImage};
|
||||
|
||||
void RunTest(cv::Mat input, cv::Mat expected_result,
|
||||
std::vector<float> float_range, std::vector<int> int_range,
|
||||
int tensor_width, int tensor_height, bool keep_aspect,
|
||||
std::vector<std::pair<float, float>> float_ranges,
|
||||
std::vector<std::pair<int, int>> int_ranges, int tensor_width,
|
||||
int tensor_height, bool keep_aspect,
|
||||
absl::optional<BorderMode> border_mode,
|
||||
const mediapipe::NormalizedRect& roi) {
|
||||
ASSERT_EQ(2, float_range.size());
|
||||
ASSERT_EQ(2, int_range.size());
|
||||
for (auto input_type : kInputTypesToTest) {
|
||||
for (auto float_range : float_ranges) {
|
||||
RunTestWithInputImagePacket(
|
||||
input_type == InputType::kImageFrame ? MakeImageFramePacket(input)
|
||||
: MakeImagePacket(input),
|
||||
expected_result, float_range[0], float_range[1], tensor_width,
|
||||
expected_result, float_range.first, float_range.second, tensor_width,
|
||||
tensor_height, keep_aspect, border_mode, roi,
|
||||
/*output_int_tensor=*/false);
|
||||
}
|
||||
for (auto int_range : int_ranges) {
|
||||
RunTestWithInputImagePacket(
|
||||
input_type == InputType::kImageFrame ? MakeImageFramePacket(input)
|
||||
: MakeImagePacket(input),
|
||||
expected_result, int_range[0], int_range[1], tensor_width,
|
||||
expected_result, int_range.first, int_range.second, tensor_width,
|
||||
tensor_height, keep_aspect, border_mode, roi,
|
||||
/*output_int_tensor=*/true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) {
|
||||
mediapipe::NormalizedRect roi;
|
||||
|
@ -224,8 +242,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspect) {
|
|||
"tensor/testdata/image_to_tensor/input.jpg"),
|
||||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/medium_sub_rect_keep_aspect.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
|
||||
/*border mode*/ {}, roi);
|
||||
}
|
||||
|
@ -242,8 +260,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectBorderZero) {
|
|||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/"
|
||||
"medium_sub_rect_keep_aspect_border_zero.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
|
||||
BorderMode::kZero, roi);
|
||||
}
|
||||
|
@ -260,8 +278,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectKeepAspectWithRotation) {
|
|||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/"
|
||||
"medium_sub_rect_keep_aspect_with_rotation.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}},
|
||||
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
|
||||
BorderMode::kReplicate, roi);
|
||||
}
|
||||
|
@ -279,8 +297,8 @@ TEST(ImageToTensorCalculatorTest,
|
|||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/"
|
||||
"medium_sub_rect_keep_aspect_with_rotation_border_zero.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/true,
|
||||
BorderMode::kZero, roi);
|
||||
}
|
||||
|
@ -298,8 +316,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotation) {
|
|||
GetRgb(
|
||||
"/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/medium_sub_rect_with_rotation.png"),
|
||||
/*float_range=*/{-1.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{-1.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false,
|
||||
BorderMode::kReplicate, roi);
|
||||
}
|
||||
|
@ -316,8 +334,8 @@ TEST(ImageToTensorCalculatorTest, MediumSubRectWithRotationBorderZero) {
|
|||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/"
|
||||
"medium_sub_rect_with_rotation_border_zero.png"),
|
||||
/*float_range=*/{-1.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{-1.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/256, /*tensor_height=*/256, /*keep_aspect=*/false,
|
||||
BorderMode::kZero, roi);
|
||||
}
|
||||
|
@ -333,8 +351,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRect) {
|
|||
"tensor/testdata/image_to_tensor/input.jpg"),
|
||||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/large_sub_rect.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}},
|
||||
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false,
|
||||
BorderMode::kReplicate, roi);
|
||||
}
|
||||
|
@ -351,8 +369,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectBorderZero) {
|
|||
"tensor/testdata/image_to_tensor/input.jpg"),
|
||||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/large_sub_rect_border_zero.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/false,
|
||||
BorderMode::kZero, roi);
|
||||
}
|
||||
|
@ -369,8 +387,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspect) {
|
|||
"tensor/testdata/image_to_tensor/input.jpg"),
|
||||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/large_sub_rect_keep_aspect.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
|
||||
BorderMode::kReplicate, roi);
|
||||
}
|
||||
|
@ -387,8 +405,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectBorderZero) {
|
|||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/"
|
||||
"large_sub_rect_keep_aspect_border_zero.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
|
||||
BorderMode::kZero, roi);
|
||||
}
|
||||
|
@ -405,8 +423,8 @@ TEST(ImageToTensorCalculatorTest, LargeSubRectKeepAspectWithRotation) {
|
|||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/"
|
||||
"large_sub_rect_keep_aspect_with_rotation.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
|
||||
/*border_mode=*/{}, roi);
|
||||
}
|
||||
|
@ -424,8 +442,8 @@ TEST(ImageToTensorCalculatorTest,
|
|||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/"
|
||||
"large_sub_rect_keep_aspect_with_rotation_border_zero.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}},
|
||||
/*tensor_width=*/128, /*tensor_height=*/128, /*keep_aspect=*/true,
|
||||
/*border_mode=*/BorderMode::kZero, roi);
|
||||
}
|
||||
|
@ -441,8 +459,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRange) {
|
|||
"tensor/testdata/image_to_tensor/input.jpg"),
|
||||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/noop_except_range.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true,
|
||||
BorderMode::kReplicate, roi);
|
||||
}
|
||||
|
@ -458,8 +476,8 @@ TEST(ImageToTensorCalculatorTest, NoOpExceptRangeBorderZero) {
|
|||
"tensor/testdata/image_to_tensor/input.jpg"),
|
||||
GetRgb("/mediapipe/calculators/"
|
||||
"tensor/testdata/image_to_tensor/noop_except_range.png"),
|
||||
/*float_range=*/{0.0f, 1.0f},
|
||||
/*int_range=*/{0, 255},
|
||||
/*float_ranges=*/{{0.0f, 1.0f}},
|
||||
/*int_ranges=*/{{0, 255}, {-128, 127}},
|
||||
/*tensor_width=*/64, /*tensor_height=*/128, /*keep_aspect=*/true,
|
||||
BorderMode::kZero, roi);
|
||||
}
|
||||
|
|
|
@ -268,9 +268,11 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
const RotatedRect& roi,
|
||||
const Size& output_dims, float range_min,
|
||||
float range_max) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) {
|
||||
return InvalidArgumentError(
|
||||
absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ",
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
return InvalidArgumentError(absl::StrCat(
|
||||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
|
||||
|
|
|
@ -172,9 +172,11 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
const RotatedRect& roi,
|
||||
const Size& output_dims, float range_min,
|
||||
float range_max) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) {
|
||||
return InvalidArgumentError(
|
||||
absl::StrCat("Only BGRA/RGBA textures are supported, passed format: ",
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
return InvalidArgumentError(absl::StrCat(
|
||||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
|
||||
|
|
|
@ -352,10 +352,11 @@ class MetalProcessor : public ImageToTensorConverter {
|
|||
const RotatedRect& roi,
|
||||
const Size& output_dims, float range_min,
|
||||
float range_max) override {
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32) {
|
||||
return InvalidArgumentError(
|
||||
absl::StrCat("Only BGRA/RGBA textures are supported, passed "
|
||||
"format: ",
|
||||
if (input.format() != mediapipe::GpuBufferFormat::kBGRA32 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAHalf64 &&
|
||||
input.format() != mediapipe::GpuBufferFormat::kRGBAFloat128) {
|
||||
return InvalidArgumentError(absl::StrCat(
|
||||
"Only 4-channel texture input formats are supported, passed format: ",
|
||||
static_cast<uint32_t>(input.format())));
|
||||
}
|
||||
|
||||
|
|
|
@ -45,7 +45,19 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
border_mode_ = cv::BORDER_CONSTANT;
|
||||
break;
|
||||
}
|
||||
mat_type_ = tensor_type == Tensor::ElementType::kUInt8 ? CV_8UC3 : CV_32FC3;
|
||||
switch (tensor_type_) {
|
||||
case Tensor::ElementType::kInt8:
|
||||
mat_type_ = CV_8SC3;
|
||||
break;
|
||||
case Tensor::ElementType::kFloat32:
|
||||
mat_type_ = CV_32FC3;
|
||||
break;
|
||||
case Tensor::ElementType::kUInt8:
|
||||
mat_type_ = CV_8UC3;
|
||||
break;
|
||||
default:
|
||||
mat_type_ = -1;
|
||||
}
|
||||
}
|
||||
|
||||
absl::StatusOr<Tensor> Convert(const mediapipe::Image& input,
|
||||
|
@ -65,12 +77,22 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
output_dims.width, kNumChannels});
|
||||
auto buffer_view = tensor.GetCpuWriteView();
|
||||
cv::Mat dst;
|
||||
if (tensor_type_ == Tensor::ElementType::kUInt8) {
|
||||
switch (tensor_type_) {
|
||||
case Tensor::ElementType::kInt8:
|
||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
||||
buffer_view.buffer<uint8>());
|
||||
} else {
|
||||
buffer_view.buffer<int8>());
|
||||
break;
|
||||
case Tensor::ElementType::kFloat32:
|
||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
||||
buffer_view.buffer<float>());
|
||||
break;
|
||||
case Tensor::ElementType::kUInt8:
|
||||
dst = cv::Mat(output_dims.height, output_dims.width, mat_type_,
|
||||
buffer_view.buffer<uint8>());
|
||||
break;
|
||||
default:
|
||||
return InvalidArgumentError(
|
||||
absl::StrCat("Unsupported tensor type: ", tensor_type_));
|
||||
}
|
||||
|
||||
const cv::RotatedRect rotated_rect(cv::Point2f(roi.center_x, roi.center_y),
|
||||
|
@ -124,6 +146,13 @@ class OpenCvProcessor : public ImageToTensorConverter {
|
|||
absl::StatusOr<std::unique_ptr<ImageToTensorConverter>> CreateOpenCvConverter(
|
||||
CalculatorContext* cc, BorderMode border_mode,
|
||||
Tensor::ElementType tensor_type) {
|
||||
if (tensor_type != Tensor::ElementType::kInt8 &&
|
||||
tensor_type != Tensor::ElementType::kFloat32 &&
|
||||
tensor_type != Tensor::ElementType::kUInt8) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Tensor type is currently not supported by OpenCvProcessor, type: ",
|
||||
tensor_type));
|
||||
}
|
||||
return absl::make_unique<OpenCvProcessor>(border_mode, tensor_type);
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,9 @@
|
|||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/tool/subgraph_expansion.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
@ -67,5 +69,17 @@ absl::StatusOr<Packet<TfLiteModelPtr>> InferenceCalculator::GetModelAsPacket(
|
|||
"Must specify TFLite model as path or loaded model.");
|
||||
}
|
||||
|
||||
absl::StatusOr<Packet<tflite::OpResolver>>
|
||||
InferenceCalculator::GetOpResolverAsPacket(CalculatorContext* cc) {
|
||||
if (kSideInOpResolver(cc).IsConnected()) {
|
||||
return kSideInOpResolver(cc).As<tflite::OpResolver>();
|
||||
} else if (kSideInCustomOpResolver(cc).IsConnected()) {
|
||||
return kSideInCustomOpResolver(cc).As<tflite::OpResolver>();
|
||||
}
|
||||
return PacketAdopting<tflite::OpResolver>(
|
||||
std::make_unique<
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>());
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/tflite/tflite_model_loader.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/error_reporter.h"
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
|
@ -55,8 +56,11 @@ namespace api2 {
|
|||
// TENSORS - Vector of Tensors
|
||||
//
|
||||
// Input side packet:
|
||||
// DEPRECATED: Prefer to use the "OP_RESOLVER" input side packet instead.
|
||||
// CUSTOM_OP_RESOLVER (optional) - Use a custom op resolver,
|
||||
// instead of the builtin one.
|
||||
// OP_RESOLVER (optional) - Use to provide tflite op resolver
|
||||
// (tflite::OpResolver)
|
||||
// MODEL (optional) - Use to specify TfLite model
|
||||
// (std::unique_ptr<tflite::FlatBufferModel,
|
||||
// std::function<void(tflite::FlatBufferModel*)>>)
|
||||
|
@ -95,15 +99,21 @@ namespace api2 {
|
|||
class InferenceCalculator : public NodeIntf {
|
||||
public:
|
||||
static constexpr Input<std::vector<Tensor>> kInTensors{"TENSORS"};
|
||||
// Deprecated. Prefers to use "OP_RESOLVER" input side packet instead.
|
||||
// TODO: Removes the "CUSTOM_OP_RESOLVER" side input after the
|
||||
// migration.
|
||||
static constexpr SideInput<tflite::ops::builtin::BuiltinOpResolver>::Optional
|
||||
kSideInCustomOpResolver{"CUSTOM_OP_RESOLVER"};
|
||||
static constexpr SideInput<tflite::OpResolver>::Optional kSideInOpResolver{
|
||||
"OP_RESOLVER"};
|
||||
static constexpr SideInput<TfLiteModelPtr>::Optional kSideInModel{"MODEL"};
|
||||
static constexpr Output<std::vector<Tensor>> kOutTensors{"TENSORS"};
|
||||
static constexpr SideInput<
|
||||
mediapipe::InferenceCalculatorOptions::Delegate>::Optional kDelegate{
|
||||
"DELEGATE"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver, kSideInModel,
|
||||
kOutTensors, kDelegate);
|
||||
MEDIAPIPE_NODE_CONTRACT(kInTensors, kSideInCustomOpResolver,
|
||||
kSideInOpResolver, kSideInModel, kOutTensors,
|
||||
kDelegate);
|
||||
|
||||
protected:
|
||||
using TfLiteDelegatePtr =
|
||||
|
@ -111,6 +121,9 @@ class InferenceCalculator : public NodeIntf {
|
|||
|
||||
absl::StatusOr<Packet<TfLiteModelPtr>> GetModelAsPacket(
|
||||
CalculatorContext* cc);
|
||||
|
||||
absl::StatusOr<Packet<tflite::OpResolver>> GetOpResolverAsPacket(
|
||||
CalculatorContext* cc);
|
||||
};
|
||||
|
||||
struct InferenceCalculatorSelector : public InferenceCalculator {
|
||||
|
|
|
@ -116,6 +116,9 @@ message InferenceCalculatorOptions {
|
|||
// to ensure there is no clash of the tokens. If unspecified, NNAPI will
|
||||
// not try caching the compilation.
|
||||
optional string model_token = 2;
|
||||
// The name of an accelerator to be used for NNAPI delegate, e.g.
|
||||
// "google-edgetpu". When not specified, it will be selected by NNAPI.
|
||||
optional string accelerator_name = 3;
|
||||
}
|
||||
message Xnnpack {
|
||||
// Number of threads for XNNPACK delegate. (By default, calculator tries
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
#if defined(MEDIAPIPE_ANDROID)
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
|
||||
#endif // ANDROID
|
||||
|
@ -28,6 +28,7 @@
|
|||
#include "mediapipe/util/cpu_util.h"
|
||||
#endif // !__EMSCRIPTEN__ || __EMSCRIPTEN_PTHREADS__
|
||||
|
||||
#include "tensorflow/lite/c/c_api_types.h"
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
@ -61,6 +62,17 @@ int GetXnnpackNumThreads(
|
|||
return GetXnnpackDefaultNumThreads();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyTensorBuffer(const Tensor& input_tensor,
|
||||
tflite::Interpreter* interpreter,
|
||||
int input_tensor_index) {
|
||||
auto input_tensor_view = input_tensor.GetCpuReadView();
|
||||
auto input_tensor_buffer = input_tensor_view.buffer<T>();
|
||||
T* local_tensor_buffer =
|
||||
interpreter->typed_input_tensor<T>(input_tensor_index);
|
||||
std::memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor.bytes());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class InferenceCalculatorCpuImpl
|
||||
|
@ -73,15 +85,16 @@ class InferenceCalculatorCpuImpl
|
|||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
absl::Status LoadModel(CalculatorContext* cc);
|
||||
absl::Status LoadDelegate(CalculatorContext* cc);
|
||||
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
|
||||
absl::Status InitInterpreter(CalculatorContext* cc);
|
||||
absl::Status LoadDelegate(CalculatorContext* cc,
|
||||
tflite::InterpreterBuilder* interpreter_builder);
|
||||
absl::Status AllocateTensors();
|
||||
|
||||
// TfLite requires us to keep the model alive as long as the interpreter is.
|
||||
Packet<TfLiteModelPtr> model_packet_;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
TfLiteDelegatePtr delegate_;
|
||||
bool has_quantized_input_;
|
||||
TfLiteType input_tensor_type_ = TfLiteType::kTfLiteNoType;
|
||||
};
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::UpdateContract(
|
||||
|
@ -94,8 +107,7 @@ absl::Status InferenceCalculatorCpuImpl::UpdateContract(
|
|||
}
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::Open(CalculatorContext* cc) {
|
||||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||
return LoadDelegateAndAllocateTensors(cc);
|
||||
return InitInterpreter(cc);
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
|
||||
|
@ -108,19 +120,23 @@ absl::Status InferenceCalculatorCpuImpl::Process(CalculatorContext* cc) {
|
|||
|
||||
// Read CPU input into tensors.
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
const Tensor* input_tensor = &input_tensors[i];
|
||||
auto input_tensor_view = input_tensor->GetCpuReadView();
|
||||
if (has_quantized_input_) {
|
||||
// TODO: Support more quantized tensor types.
|
||||
auto input_tensor_buffer = input_tensor_view.buffer<uint8>();
|
||||
uint8* local_tensor_buffer = interpreter_->typed_input_tensor<uint8>(i);
|
||||
std::memcpy(local_tensor_buffer, input_tensor_buffer,
|
||||
input_tensor->bytes());
|
||||
} else {
|
||||
auto input_tensor_buffer = input_tensor_view.buffer<float>();
|
||||
float* local_tensor_buffer = interpreter_->typed_input_tensor<float>(i);
|
||||
std::memcpy(local_tensor_buffer, input_tensor_buffer,
|
||||
input_tensor->bytes());
|
||||
switch (input_tensor_type_) {
|
||||
case TfLiteType::kTfLiteFloat16:
|
||||
case TfLiteType::kTfLiteFloat32: {
|
||||
CopyTensorBuffer<float>(input_tensors[i], interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteUInt8: {
|
||||
CopyTensorBuffer<uint8>(input_tensors[i], interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
case TfLiteType::kTfLiteInt8: {
|
||||
CopyTensorBuffer<int8>(input_tensors[i], interpreter_.get(), i);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported input tensor type:", input_tensor_type_));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -150,39 +166,34 @@ absl::Status InferenceCalculatorCpuImpl::Close(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::LoadModel(CalculatorContext* cc) {
|
||||
absl::Status InferenceCalculatorCpuImpl::InitInterpreter(
|
||||
CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||
const auto& model = *model_packet_.Get();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
|
||||
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||
RET_CHECK(interpreter_);
|
||||
|
||||
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
||||
const auto& op_resolver = op_resolver_packet.Get();
|
||||
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
|
||||
#if defined(__EMSCRIPTEN__)
|
||||
interpreter_->SetNumThreads(1);
|
||||
interpreter_builder.SetNumThreads(1);
|
||||
#else
|
||||
interpreter_->SetNumThreads(
|
||||
interpreter_builder.SetNumThreads(
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
return absl::OkStatus();
|
||||
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
|
||||
RET_CHECK(interpreter_);
|
||||
return AllocateTensors();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::LoadDelegateAndAllocateTensors(
|
||||
CalculatorContext* cc) {
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
||||
|
||||
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
|
||||
absl::Status InferenceCalculatorCpuImpl::AllocateTensors() {
|
||||
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
has_quantized_input_ =
|
||||
interpreter_->tensor(interpreter_->inputs()[0])->quantization.type ==
|
||||
kTfLiteAffineQuantization;
|
||||
input_tensor_type_ = interpreter_->tensor(interpreter_->inputs()[0])->type;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
|
||||
absl::Status InferenceCalculatorCpuImpl::LoadDelegate(
|
||||
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
|
||||
const auto& calculator_opts =
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>();
|
||||
auto opts_delegate = calculator_opts.delegate();
|
||||
|
@ -211,18 +222,20 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
|
|||
if (nnapi_requested) {
|
||||
// Attempt to use NNAPI.
|
||||
// If not supported, the default CPU delegate will be created and used.
|
||||
interpreter_->SetAllowFp16PrecisionForFp32(1);
|
||||
tflite::StatefulNnApiDelegate::Options options;
|
||||
const auto& nnapi = opts_delegate.nnapi();
|
||||
options.allow_fp16 = true;
|
||||
// Set up cache_dir and model_token for NNAPI compilation cache.
|
||||
options.cache_dir =
|
||||
nnapi.has_cache_dir() ? nnapi.cache_dir().c_str() : nullptr;
|
||||
options.model_token =
|
||||
nnapi.has_model_token() ? nnapi.model_token().c_str() : nullptr;
|
||||
options.accelerator_name = nnapi.has_accelerator_name()
|
||||
? nnapi.accelerator_name().c_str()
|
||||
: nullptr;
|
||||
delegate_ = TfLiteDelegatePtr(new tflite::StatefulNnApiDelegate(options),
|
||||
[](TfLiteDelegate*) {});
|
||||
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
||||
kTfLiteOk);
|
||||
interpreter_builder->AddDelegate(delegate_.get());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
|
@ -239,8 +252,7 @@ absl::Status InferenceCalculatorCpuImpl::LoadDelegate(CalculatorContext* cc) {
|
|||
GetXnnpackNumThreads(opts_has_delegate, opts_delegate);
|
||||
delegate_ = TfLiteDelegatePtr(TfLiteXNNPackDelegateCreate(&xnnpack_opts),
|
||||
&TfLiteXNNPackDelegateDelete);
|
||||
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
||||
kTfLiteOk);
|
||||
interpreter_builder->AddDelegate(delegate_.get());
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/util/tflite/config.h"
|
||||
#include "tensorflow/lite/interpreter_builder.h"
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
|
@ -52,9 +53,11 @@ class InferenceCalculatorGlImpl
|
|||
private:
|
||||
absl::Status ReadGpuCaches();
|
||||
absl::Status SaveGpuCaches();
|
||||
absl::Status LoadModel(CalculatorContext* cc);
|
||||
absl::Status LoadDelegate(CalculatorContext* cc);
|
||||
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
|
||||
absl::Status InitInterpreter(CalculatorContext* cc);
|
||||
absl::Status LoadDelegate(CalculatorContext* cc,
|
||||
tflite::InterpreterBuilder* interpreter_builder);
|
||||
absl::Status BindBuffersToTensors();
|
||||
absl::Status AllocateTensors();
|
||||
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
|
||||
|
||||
// TfLite requires us to keep the model alive as long as the interpreter is.
|
||||
|
@ -137,17 +140,11 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
|
|||
#endif // MEDIAPIPE_ANDROID
|
||||
}
|
||||
|
||||
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
|
||||
// for everything.
|
||||
if (!use_advanced_gpu_api_) {
|
||||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
MP_RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
|
||||
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
|
||||
: LoadDelegateAndAllocateTensors(cc);
|
||||
: InitInterpreter(cc);
|
||||
}));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -292,12 +289,6 @@ absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
|
|||
|
||||
absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
||||
CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||
const auto& model = *model_packet_.Get();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
|
||||
// Create runner
|
||||
tflite::gpu::InferenceOptions options;
|
||||
options.priority1 = allow_precision_loss_
|
||||
|
@ -335,6 +326,10 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
|||
break;
|
||||
}
|
||||
}
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||
const auto& model = *model_packet_.Get();
|
||||
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
||||
const auto& op_resolver = op_resolver_packet.Get();
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
|
||||
model, op_resolver, /*allow_quant_ops=*/true));
|
||||
|
||||
|
@ -355,31 +350,27 @@ absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
|
||||
absl::Status InferenceCalculatorGlImpl::InitInterpreter(CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||
const auto& model = *model_packet_.Get();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
|
||||
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||
RET_CHECK(interpreter_);
|
||||
|
||||
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
||||
const auto& op_resolver = op_resolver_packet.Get();
|
||||
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc, &interpreter_builder));
|
||||
#if defined(__EMSCRIPTEN__)
|
||||
interpreter_->SetNumThreads(1);
|
||||
interpreter_builder.SetNumThreads(1);
|
||||
#else
|
||||
interpreter_->SetNumThreads(
|
||||
interpreter_builder.SetNumThreads(
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
|
||||
RET_CHECK(interpreter_);
|
||||
MP_RETURN_IF_ERROR(BindBuffersToTensors());
|
||||
MP_RETURN_IF_ERROR(AllocateTensors());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors(
|
||||
CalculatorContext* cc) {
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
||||
|
||||
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
|
||||
absl::Status InferenceCalculatorGlImpl::AllocateTensors() {
|
||||
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
// TODO: Support quantized tensors.
|
||||
RET_CHECK_NE(
|
||||
|
@ -388,7 +379,8 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegateAndAllocateTensors(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
|
||||
absl::Status InferenceCalculatorGlImpl::LoadDelegate(
|
||||
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
|
||||
// Configure and create the delegate.
|
||||
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
|
||||
options.compile_options.precision_loss_allowed =
|
||||
|
@ -399,7 +391,11 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
|
|||
options.compile_options.inline_parameters = 1;
|
||||
delegate_ = TfLiteDelegatePtr(TfLiteGpuDelegateCreate(&options),
|
||||
&TfLiteGpuDelegateDelete);
|
||||
interpreter_builder->AddDelegate(delegate_.get());
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::BindBuffersToTensors() {
|
||||
// Get input image sizes.
|
||||
const auto& input_indices = interpreter_->inputs();
|
||||
for (int i = 0; i < input_indices.size(); ++i) {
|
||||
|
@ -431,11 +427,6 @@ absl::Status InferenceCalculatorGlImpl::LoadDelegate(CalculatorContext* cc) {
|
|||
output_indices[i]),
|
||||
kTfLiteOk);
|
||||
}
|
||||
|
||||
// Must call this last.
|
||||
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
||||
kTfLiteOk);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -90,9 +90,10 @@ class InferenceCalculatorMetalImpl
|
|||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
absl::Status LoadModel(CalculatorContext* cc);
|
||||
absl::Status LoadDelegate(CalculatorContext* cc);
|
||||
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
|
||||
absl::Status InitInterpreter(CalculatorContext* cc);
|
||||
void AddDelegate(CalculatorContext* cc,
|
||||
tflite::InterpreterBuilder* interpreter_builder);
|
||||
absl::Status CreateConverters(CalculatorContext* cc);
|
||||
|
||||
// TfLite requires us to keep the model alive as long as the interpreter is.
|
||||
Packet<TfLiteModelPtr> model_packet_;
|
||||
|
@ -127,11 +128,9 @@ absl::Status InferenceCalculatorMetalImpl::Open(CalculatorContext* cc) {
|
|||
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
|
||||
allow_precision_loss_ = options.delegate().gpu().allow_precision_loss();
|
||||
|
||||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||
|
||||
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
|
||||
RET_CHECK(gpu_helper_);
|
||||
return LoadDelegateAndAllocateTensors(cc);
|
||||
return InitInterpreter(cc);
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorMetalImpl::Process(CalculatorContext* cc) {
|
||||
|
@ -199,27 +198,20 @@ absl::Status InferenceCalculatorMetalImpl::Close(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorMetalImpl::LoadModel(CalculatorContext* cc) {
|
||||
absl::Status InferenceCalculatorMetalImpl::InitInterpreter(
|
||||
CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||
const auto& model = *model_packet_.Get();
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
|
||||
tflite::InterpreterBuilder(model, op_resolver)(&interpreter_);
|
||||
ASSIGN_OR_RETURN(auto op_resolver_packet, GetOpResolverAsPacket(cc));
|
||||
const auto& op_resolver = op_resolver_packet.Get();
|
||||
tflite::InterpreterBuilder interpreter_builder(model, op_resolver);
|
||||
AddDelegate(cc, &interpreter_builder);
|
||||
interpreter_builder.SetNumThreads(
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
|
||||
RET_CHECK_EQ(interpreter_builder(&interpreter_), kTfLiteOk);
|
||||
RET_CHECK(interpreter_);
|
||||
|
||||
interpreter_->SetNumThreads(
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors(
|
||||
CalculatorContext* cc) {
|
||||
MP_RETURN_IF_ERROR(LoadDelegate(cc));
|
||||
|
||||
// AllocateTensors() can be called only after ModifyGraphWithDelegate.
|
||||
MP_RETURN_IF_ERROR(CreateConverters(cc));
|
||||
RET_CHECK_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
|
||||
// TODO: Support quantized tensors.
|
||||
RET_CHECK_NE(
|
||||
|
@ -228,7 +220,8 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegateAndAllocateTensors(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
|
||||
void InferenceCalculatorMetalImpl::AddDelegate(
|
||||
CalculatorContext* cc, tflite::InterpreterBuilder* interpreter_builder) {
|
||||
const auto& calculator_opts =
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>();
|
||||
|
||||
|
@ -242,9 +235,11 @@ absl::Status InferenceCalculatorMetalImpl::LoadDelegate(CalculatorContext* cc) {
|
|||
options.wait_type = TFLGpuDelegateWaitType::TFLGpuDelegateWaitTypeDoNotWait;
|
||||
delegate_ =
|
||||
TfLiteDelegatePtr(TFLGpuDelegateCreate(&options), &TFLGpuDelegateDelete);
|
||||
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_.get()),
|
||||
kTfLiteOk);
|
||||
interpreter_builder->AddDelegate(delegate_.get());
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorMetalImpl::CreateConverters(
|
||||
CalculatorContext* cc) {
|
||||
id<MTLDevice> device = gpu_helper_.mtlDevice;
|
||||
|
||||
// Get input image sizes.
|
||||
|
|
|
@ -91,6 +91,40 @@ void ConvertAnchorsToRawValues(const std::vector<Anchor>& anchors,
|
|||
}
|
||||
}
|
||||
|
||||
absl::Status CheckCustomTensorMapping(
|
||||
const TensorsToDetectionsCalculatorOptions::TensorMapping& tensor_mapping) {
|
||||
RET_CHECK(tensor_mapping.has_detections_tensor_index() &&
|
||||
tensor_mapping.has_scores_tensor_index());
|
||||
int bitmap = 0;
|
||||
bitmap |= 1 << tensor_mapping.detections_tensor_index();
|
||||
bitmap |= 1 << tensor_mapping.scores_tensor_index();
|
||||
if (!tensor_mapping.has_num_detections_tensor_index() &&
|
||||
!tensor_mapping.has_classes_tensor_index() &&
|
||||
!tensor_mapping.has_anchors_tensor_index()) {
|
||||
// Only allows the output tensor index 0 and 1 to be occupied.
|
||||
RET_CHECK_EQ(3, bitmap) << "The custom output tensor indices should only "
|
||||
"cover index 0 and 1.";
|
||||
} else if (tensor_mapping.has_anchors_tensor_index()) {
|
||||
RET_CHECK(!tensor_mapping.has_classes_tensor_index() &&
|
||||
!tensor_mapping.has_num_detections_tensor_index());
|
||||
bitmap |= 1 << tensor_mapping.anchors_tensor_index();
|
||||
// If the"anchors" tensor will be available, only allows the output tensor
|
||||
// index 0, 1, 2 to be occupied.
|
||||
RET_CHECK_EQ(7, bitmap) << "The custom output tensor indices should only "
|
||||
"cover index 0, 1 and 2.";
|
||||
} else {
|
||||
RET_CHECK(tensor_mapping.has_classes_tensor_index() &&
|
||||
tensor_mapping.has_num_detections_tensor_index());
|
||||
// If the "classes" and the "number of detections" tensors will be
|
||||
// available, only allows the output tensor index 0, 1, 2, 3 to be occupied.
|
||||
bitmap |= 1 << tensor_mapping.classes_tensor_index();
|
||||
bitmap |= 1 << tensor_mapping.num_detections_tensor_index();
|
||||
RET_CHECK_EQ(15, bitmap) << "The custom output tensor indices should only "
|
||||
"cover index 0, 1, 2 and 3.";
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Convert result Tensors from object detection models into MediaPipe
|
||||
|
@ -170,13 +204,27 @@ class TensorsToDetectionsCalculator : public Node {
|
|||
Detection ConvertToDetection(float box_ymin, float box_xmin, float box_ymax,
|
||||
float box_xmax, float score, int class_id,
|
||||
bool flip_vertically);
|
||||
bool IsClassIndexAllowed(int class_index);
|
||||
|
||||
int num_classes_ = 0;
|
||||
int num_boxes_ = 0;
|
||||
int num_coords_ = 0;
|
||||
std::set<int> ignore_classes_;
|
||||
int max_results_ = -1;
|
||||
|
||||
::mediapipe::TensorsToDetectionsCalculatorOptions options_;
|
||||
// Set of allowed or ignored class indices.
|
||||
struct ClassIndexSet {
|
||||
absl::flat_hash_set<int> values;
|
||||
bool is_allowlist;
|
||||
};
|
||||
// Allowed or ignored class indices based on provided options or side packet.
|
||||
// These are used to filter out the output detection results.
|
||||
ClassIndexSet class_index_set_;
|
||||
|
||||
TensorsToDetectionsCalculatorOptions options_;
|
||||
bool scores_tensor_index_is_set_ = false;
|
||||
TensorsToDetectionsCalculatorOptions::TensorMapping tensor_mapping_;
|
||||
std::vector<int> box_indices_ = {0, 1, 2, 3};
|
||||
bool has_custom_box_indices_ = false;
|
||||
std::vector<Anchor> anchors_;
|
||||
|
||||
#ifndef MEDIAPIPE_DISABLE_GL_COMPUTE
|
||||
|
@ -239,6 +287,21 @@ absl::Status TensorsToDetectionsCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
}
|
||||
}
|
||||
const int num_input_tensors = kInTensors(cc)->size();
|
||||
if (!scores_tensor_index_is_set_) {
|
||||
if (num_input_tensors == 2 ||
|
||||
num_input_tensors == kNumInputTensorsWithAnchors) {
|
||||
tensor_mapping_.set_scores_tensor_index(1);
|
||||
} else {
|
||||
tensor_mapping_.set_scores_tensor_index(2);
|
||||
}
|
||||
scores_tensor_index_is_set_ = true;
|
||||
}
|
||||
if (gpu_processing || num_input_tensors != 4) {
|
||||
// Allows custom bounding box indices when receiving 4 cpu tensors.
|
||||
// Uses the default bbox indices in other cases.
|
||||
RET_CHECK(!has_custom_box_indices_);
|
||||
}
|
||||
|
||||
if (gpu_processing) {
|
||||
if (!gpu_inited_) {
|
||||
|
@ -263,13 +326,15 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
|||
// Postprocessing on CPU for model without postprocessing op. E.g. output
|
||||
// raw score tensor and box tensor. Anchor decoding will be handled below.
|
||||
// TODO: Add flexible input tensor size handling.
|
||||
auto raw_box_tensor = &input_tensors[0];
|
||||
auto raw_box_tensor =
|
||||
&input_tensors[tensor_mapping_.detections_tensor_index()];
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims.size(), 3);
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[0], 1);
|
||||
RET_CHECK_GT(num_boxes_, 0) << "Please set num_boxes in calculator options";
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[1], num_boxes_);
|
||||
RET_CHECK_EQ(raw_box_tensor->shape().dims[2], num_coords_);
|
||||
auto raw_score_tensor = &input_tensors[1];
|
||||
auto raw_score_tensor =
|
||||
&input_tensors[tensor_mapping_.scores_tensor_index()];
|
||||
RET_CHECK_EQ(raw_score_tensor->shape().dims.size(), 3);
|
||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[0], 1);
|
||||
RET_CHECK_EQ(raw_score_tensor->shape().dims[1], num_boxes_);
|
||||
|
@ -282,7 +347,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
|||
// TODO: Support other options to load anchors.
|
||||
if (!anchors_init_) {
|
||||
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
|
||||
auto anchor_tensor = &input_tensors[2];
|
||||
auto anchor_tensor =
|
||||
&input_tensors[tensor_mapping_.anchors_tensor_index()];
|
||||
RET_CHECK_EQ(anchor_tensor->shape().dims.size(), 2);
|
||||
RET_CHECK_EQ(anchor_tensor->shape().dims[0], num_boxes_);
|
||||
RET_CHECK_EQ(anchor_tensor->shape().dims[1], kNumCoordsPerBox);
|
||||
|
@ -308,7 +374,7 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
|||
float max_score = -std::numeric_limits<float>::max();
|
||||
// Find the top score for box i.
|
||||
for (int score_idx = 0; score_idx < num_classes_; ++score_idx) {
|
||||
if (ignore_classes_.find(score_idx) == ignore_classes_.end()) {
|
||||
if (IsClassIndexAllowed(score_idx)) {
|
||||
auto score = raw_scores[i * num_classes_ + score_idx];
|
||||
if (options_.sigmoid_score()) {
|
||||
if (options_.has_score_clipping_thresh()) {
|
||||
|
@ -338,23 +404,26 @@ absl::Status TensorsToDetectionsCalculator::ProcessCPU(
|
|||
// Postprocessing on CPU with postprocessing op (e.g. anchor decoding and
|
||||
// non-maximum suppression) within the model.
|
||||
RET_CHECK_EQ(input_tensors.size(), 4);
|
||||
|
||||
auto num_boxes_tensor = &input_tensors[3];
|
||||
auto num_boxes_tensor =
|
||||
&input_tensors[tensor_mapping_.num_detections_tensor_index()];
|
||||
RET_CHECK_EQ(num_boxes_tensor->shape().dims.size(), 1);
|
||||
RET_CHECK_EQ(num_boxes_tensor->shape().dims[0], 1);
|
||||
|
||||
auto detection_boxes_tensor = &input_tensors[0];
|
||||
auto detection_boxes_tensor =
|
||||
&input_tensors[tensor_mapping_.detections_tensor_index()];
|
||||
RET_CHECK_EQ(detection_boxes_tensor->shape().dims.size(), 3);
|
||||
RET_CHECK_EQ(detection_boxes_tensor->shape().dims[0], 1);
|
||||
const int max_detections = detection_boxes_tensor->shape().dims[1];
|
||||
RET_CHECK_EQ(detection_boxes_tensor->shape().dims[2], num_coords_);
|
||||
|
||||
auto detection_classes_tensor = &input_tensors[1];
|
||||
auto detection_classes_tensor =
|
||||
&input_tensors[tensor_mapping_.classes_tensor_index()];
|
||||
RET_CHECK_EQ(detection_classes_tensor->shape().dims.size(), 2);
|
||||
RET_CHECK_EQ(detection_classes_tensor->shape().dims[0], 1);
|
||||
RET_CHECK_EQ(detection_classes_tensor->shape().dims[1], max_detections);
|
||||
|
||||
auto detection_scores_tensor = &input_tensors[2];
|
||||
auto detection_scores_tensor =
|
||||
&input_tensors[tensor_mapping_.scores_tensor_index()];
|
||||
RET_CHECK_EQ(detection_scores_tensor->shape().dims.size(), 2);
|
||||
RET_CHECK_EQ(detection_scores_tensor->shape().dims[0], 1);
|
||||
RET_CHECK_EQ(detection_scores_tensor->shape().dims[1], max_detections);
|
||||
|
@ -394,12 +463,14 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
-> absl::Status {
|
||||
if (!anchors_init_) {
|
||||
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
|
||||
auto read_view = input_tensors[2].GetOpenGlBufferReadView();
|
||||
auto read_view = input_tensors[tensor_mapping_.anchors_tensor_index()]
|
||||
.GetOpenGlBufferReadView();
|
||||
glBindBuffer(GL_COPY_READ_BUFFER, read_view.name());
|
||||
auto write_view = raw_anchors_buffer_->GetOpenGlBufferWriteView();
|
||||
glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name());
|
||||
glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0,
|
||||
input_tensors[2].bytes());
|
||||
glCopyBufferSubData(
|
||||
GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0,
|
||||
input_tensors[tensor_mapping_.anchors_tensor_index()].bytes());
|
||||
} else if (!kInAnchors(cc).IsEmpty()) {
|
||||
const auto& anchors = *kInAnchors(cc);
|
||||
auto anchors_view = raw_anchors_buffer_->GetCpuWriteView();
|
||||
|
@ -418,7 +489,9 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
auto decoded_boxes_view =
|
||||
decoded_boxes_buffer_->GetOpenGlBufferWriteView();
|
||||
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, decoded_boxes_view.name());
|
||||
auto input0_view = input_tensors[0].GetOpenGlBufferReadView();
|
||||
auto input0_view =
|
||||
input_tensors[tensor_mapping_.detections_tensor_index()]
|
||||
.GetOpenGlBufferReadView();
|
||||
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input0_view.name());
|
||||
auto raw_anchors_view = raw_anchors_buffer_->GetOpenGlBufferReadView();
|
||||
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 2, raw_anchors_view.name());
|
||||
|
@ -427,7 +500,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
|
||||
// Score boxes.
|
||||
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 0, scored_boxes_view.name());
|
||||
auto input1_view = input_tensors[1].GetOpenGlBufferReadView();
|
||||
auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()]
|
||||
.GetOpenGlBufferReadView();
|
||||
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, input1_view.name());
|
||||
glUseProgram(score_program_);
|
||||
glDispatchCompute(num_boxes_, 1, 1);
|
||||
|
@ -459,7 +533,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
if (input_tensors.size() == kNumInputTensorsWithAnchors) {
|
||||
RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
|
||||
auto command_buffer = [gpu_helper_ commandBuffer];
|
||||
auto src_buffer = input_tensors[2].GetMtlBufferReadView(command_buffer);
|
||||
auto src_buffer = input_tensors[tensor_mapping_.anchors_tensor_index()]
|
||||
.GetMtlBufferReadView(command_buffer);
|
||||
auto dest_buffer =
|
||||
raw_anchors_buffer_->GetMtlBufferWriteView(command_buffer);
|
||||
id<MTLBlitCommandEncoder> blit_command =
|
||||
|
@ -468,7 +543,9 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
sourceOffset:0
|
||||
toBuffer:dest_buffer.buffer()
|
||||
destinationOffset:0
|
||||
size:input_tensors[2].bytes()];
|
||||
size:input_tensors[tensor_mapping_
|
||||
.anchors_tensor_index()]
|
||||
.bytes()];
|
||||
[blit_command endEncoding];
|
||||
[command_buffer commit];
|
||||
} else if (!kInAnchors(cc).IsEmpty()) {
|
||||
|
@ -495,7 +572,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
auto decoded_boxes_view =
|
||||
decoded_boxes_buffer_->GetMtlBufferWriteView(command_buffer);
|
||||
[command_encoder setBuffer:decoded_boxes_view.buffer() offset:0 atIndex:0];
|
||||
auto input0_view = input_tensors[0].GetMtlBufferReadView(command_buffer);
|
||||
auto input0_view = input_tensors[tensor_mapping_.detections_tensor_index()]
|
||||
.GetMtlBufferReadView(command_buffer);
|
||||
[command_encoder setBuffer:input0_view.buffer() offset:0 atIndex:1];
|
||||
auto raw_anchors_view =
|
||||
raw_anchors_buffer_->GetMtlBufferReadView(command_buffer);
|
||||
|
@ -507,7 +585,8 @@ absl::Status TensorsToDetectionsCalculator::ProcessGPU(
|
|||
|
||||
[command_encoder setComputePipelineState:score_program_];
|
||||
[command_encoder setBuffer:scored_boxes_view.buffer() offset:0 atIndex:0];
|
||||
auto input1_view = input_tensors[1].GetMtlBufferReadView(command_buffer);
|
||||
auto input1_view = input_tensors[tensor_mapping_.scores_tensor_index()]
|
||||
.GetMtlBufferReadView(command_buffer);
|
||||
[command_encoder setBuffer:input1_view.buffer() offset:0 atIndex:1];
|
||||
MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1);
|
||||
MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1);
|
||||
|
@ -570,6 +649,10 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) {
|
|||
num_classes_ = options_.num_classes();
|
||||
num_boxes_ = options_.num_boxes();
|
||||
num_coords_ = options_.num_coords();
|
||||
CHECK_NE(options_.max_results(), 0)
|
||||
<< "The maximum number of the top-scored detection results must be "
|
||||
"non-zero.";
|
||||
max_results_ = options_.max_results();
|
||||
|
||||
// Currently only support 2D when num_values_per_keypoint equals to 2.
|
||||
CHECK_EQ(options_.num_values_per_keypoint(), 2);
|
||||
|
@ -581,15 +664,55 @@ absl::Status TensorsToDetectionsCalculator::LoadOptions(CalculatorContext* cc) {
|
|||
|
||||
if (kSideInIgnoreClasses(cc).IsConnected()) {
|
||||
RET_CHECK(!kSideInIgnoreClasses(cc).IsEmpty());
|
||||
RET_CHECK(options_.allow_classes().empty());
|
||||
class_index_set_.is_allowlist = false;
|
||||
for (int ignore_class : *kSideInIgnoreClasses(cc)) {
|
||||
ignore_classes_.insert(ignore_class);
|
||||
class_index_set_.values.insert(ignore_class);
|
||||
}
|
||||
} else if (!options_.allow_classes().empty()) {
|
||||
RET_CHECK(options_.ignore_classes().empty());
|
||||
class_index_set_.is_allowlist = true;
|
||||
for (int i = 0; i < options_.allow_classes_size(); ++i) {
|
||||
class_index_set_.values.insert(options_.allow_classes(i));
|
||||
}
|
||||
} else {
|
||||
class_index_set_.is_allowlist = false;
|
||||
for (int i = 0; i < options_.ignore_classes_size(); ++i) {
|
||||
ignore_classes_.insert(options_.ignore_classes(i));
|
||||
class_index_set_.values.insert(options_.ignore_classes(i));
|
||||
}
|
||||
}
|
||||
|
||||
if (options_.has_tensor_mapping()) {
|
||||
RET_CHECK_OK(CheckCustomTensorMapping(options_.tensor_mapping()));
|
||||
tensor_mapping_ = options_.tensor_mapping();
|
||||
scores_tensor_index_is_set_ = true;
|
||||
} else {
|
||||
// Assigns the default tensor indices.
|
||||
tensor_mapping_.set_detections_tensor_index(0);
|
||||
tensor_mapping_.set_classes_tensor_index(1);
|
||||
tensor_mapping_.set_anchors_tensor_index(2);
|
||||
tensor_mapping_.set_num_detections_tensor_index(3);
|
||||
// The scores tensor index needs to be determined based on the number of
|
||||
// model's output tensors, which will be available in the first invocation
|
||||
// of the Process() method.
|
||||
tensor_mapping_.set_scores_tensor_index(-1);
|
||||
scores_tensor_index_is_set_ = false;
|
||||
}
|
||||
|
||||
if (options_.has_box_boundaries_indices()) {
|
||||
box_indices_ = {options_.box_boundaries_indices().ymin(),
|
||||
options_.box_boundaries_indices().xmin(),
|
||||
options_.box_boundaries_indices().ymax(),
|
||||
options_.box_boundaries_indices().xmax()};
|
||||
int bitmap = 0;
|
||||
for (int i : box_indices_) {
|
||||
bitmap |= 1 << i;
|
||||
}
|
||||
RET_CHECK_EQ(bitmap, 15) << "The custom box boundaries indices should only "
|
||||
"cover index 0, 1, 2, and 3.";
|
||||
has_custom_box_indices_ = true;
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -661,14 +784,22 @@ absl::Status TensorsToDetectionsCalculator::ConvertToDetections(
|
|||
const float* detection_boxes, const float* detection_scores,
|
||||
const int* detection_classes, std::vector<Detection>* output_detections) {
|
||||
for (int i = 0; i < num_boxes_; ++i) {
|
||||
if (max_results_ > 0 && output_detections->size() == max_results_) {
|
||||
break;
|
||||
}
|
||||
if (options_.has_min_score_thresh() &&
|
||||
detection_scores[i] < options_.min_score_thresh()) {
|
||||
continue;
|
||||
}
|
||||
if (!IsClassIndexAllowed(detection_classes[i])) {
|
||||
continue;
|
||||
}
|
||||
const int box_offset = i * num_coords_;
|
||||
Detection detection = ConvertToDetection(
|
||||
detection_boxes[box_offset + 0], detection_boxes[box_offset + 1],
|
||||
detection_boxes[box_offset + 2], detection_boxes[box_offset + 3],
|
||||
/*box_ymin=*/detection_boxes[box_offset + box_indices_[0]],
|
||||
/*box_xmin=*/detection_boxes[box_offset + box_indices_[1]],
|
||||
/*box_ymax=*/detection_boxes[box_offset + box_indices_[2]],
|
||||
/*box_xmax=*/detection_boxes[box_offset + box_indices_[3]],
|
||||
detection_scores[i], detection_classes[i], options_.flip_vertically());
|
||||
const auto& bbox = detection.location_data().relative_bounding_box();
|
||||
if (bbox.width() < 0 || bbox.height() < 0 || std::isnan(bbox.width()) ||
|
||||
|
@ -910,7 +1041,7 @@ void main() {
|
|||
options_.has_score_clipping_thresh() ? 1 : 0,
|
||||
options_.has_score_clipping_thresh() ? options_.score_clipping_thresh()
|
||||
: 0,
|
||||
!ignore_classes_.empty() ? 1 : 0);
|
||||
!IsClassIndexAllowed(0));
|
||||
|
||||
// # filter classes supported is hardware dependent.
|
||||
int max_wg_size; // typically <= 1024
|
||||
|
@ -919,7 +1050,14 @@ void main() {
|
|||
CHECK_LT(num_classes_, max_wg_size)
|
||||
<< "# classes must be < " << max_wg_size;
|
||||
// TODO support better filtering.
|
||||
CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed";
|
||||
if (class_index_set_.is_allowlist) {
|
||||
CHECK_EQ(class_index_set_.values.size(),
|
||||
IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1)
|
||||
<< "Only all classes >= class 0 or >= class 1";
|
||||
} else {
|
||||
CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1)
|
||||
<< "Only ignore class 0 is allowed";
|
||||
}
|
||||
|
||||
// Shader program
|
||||
{
|
||||
|
@ -1126,10 +1264,17 @@ kernel void scoreKernel(
|
|||
options_.has_score_clipping_thresh() ? 1 : 0,
|
||||
options_.has_score_clipping_thresh() ? options_.score_clipping_thresh()
|
||||
: 0,
|
||||
ignore_classes_.size() ? 1 : 0);
|
||||
!IsClassIndexAllowed(0));
|
||||
|
||||
// TODO support better filtering.
|
||||
CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed";
|
||||
if (class_index_set_.is_allowlist) {
|
||||
CHECK_EQ(class_index_set_.values.size(),
|
||||
IsClassIndexAllowed(0) ? num_classes_ : num_classes_ - 1)
|
||||
<< "Only all classes >= class 0 or >= class 1";
|
||||
} else {
|
||||
CHECK_EQ(class_index_set_.values.size(), IsClassIndexAllowed(0) ? 0 : 1)
|
||||
<< "Only ignore class 0 is allowed";
|
||||
}
|
||||
|
||||
{
|
||||
// Shader program
|
||||
|
@ -1161,5 +1306,16 @@ kernel void scoreKernel(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
bool TensorsToDetectionsCalculator::IsClassIndexAllowed(int class_index) {
|
||||
if (class_index_set_.values.empty()) {
|
||||
return true;
|
||||
}
|
||||
if (class_index_set_.is_allowlist) {
|
||||
return class_index_set_.values.contains(class_index);
|
||||
} else {
|
||||
return !class_index_set_.values.contains(class_index);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -57,7 +57,12 @@ message TensorsToDetectionsCalculatorOptions {
|
|||
optional bool reverse_output_order = 14 [default = false];
|
||||
// The ids of classes that should be ignored during decoding the score for
|
||||
// each predicted box. Can be overridden with IGNORE_CLASSES side packet.
|
||||
// `ignore_classes` and `allow_classes` are mutually exclusive.
|
||||
repeated int32 ignore_classes = 8;
|
||||
// The ids of classes that should be allowed during decoding the score for
|
||||
// each predicted box. `ignore_classes` and `allow_classes` are mutually
|
||||
// exclusive.
|
||||
repeated int32 allow_classes = 21 [packed = true];
|
||||
|
||||
optional bool sigmoid_score = 15 [default = false];
|
||||
optional float score_clipping_thresh = 16;
|
||||
|
@ -71,4 +76,40 @@ message TensorsToDetectionsCalculatorOptions {
|
|||
|
||||
// Score threshold for perserving decoded detections.
|
||||
optional float min_score_thresh = 19;
|
||||
|
||||
// The maximum number of the detection results to return. If < 0, all
|
||||
// available results will be returned.
|
||||
// For the detection models that have built-in non max suppression op, the
|
||||
// output detections are the top-scored results. Otherwise, the output
|
||||
// detections are the first N results that have higher scores than
|
||||
// `min_score_thresh`.
|
||||
optional int32 max_results = 20 [default = -1];
|
||||
|
||||
// The custom model output tensor mapping.
|
||||
// The indices of the "detections" tensor and the "scores" tensor are always
|
||||
// required. If the model outputs an "anchors" tensor, `anchors_tensor_index`
|
||||
// must be specified. If the model outputs both "classes" tensor and "number
|
||||
// of detections" tensors, `classes_tensor_index` and
|
||||
// `num_detections_tensor_index` must be set.
|
||||
message TensorMapping {
|
||||
optional int32 detections_tensor_index = 1;
|
||||
optional int32 classes_tensor_index = 2;
|
||||
optional int32 scores_tensor_index = 3;
|
||||
optional int32 num_detections_tensor_index = 4;
|
||||
optional int32 anchors_tensor_index = 5;
|
||||
}
|
||||
optional TensorMapping tensor_mapping = 22;
|
||||
|
||||
// Represents the bounding box by using the combination of boundaries,
|
||||
// {ymin, xmin, ymax, xmax}.
|
||||
// The default order is {ymin, xmin, ymax, xmax}.
|
||||
message BoxBoundariesIndices {
|
||||
optional int32 ymin = 1 [default = 0];
|
||||
optional int32 xmin = 2 [default = 1];
|
||||
optional int32 ymax = 3 [default = 2];
|
||||
optional int32 xmax = 4 [default = 3];
|
||||
}
|
||||
oneof box_indices {
|
||||
BoxBoundariesIndices box_boundaries_indices = 23;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -121,8 +121,12 @@ absl::Status TensorToImageFrameCalculator::Process(CalculatorContext* cc) {
|
|||
if (d > 255) d = 255;
|
||||
buffer[i] = d;
|
||||
}
|
||||
output = ::absl::make_unique<ImageFrame>(format, width, height,
|
||||
width * depth, buffer.release());
|
||||
output = ::absl::make_unique<ImageFrame>(
|
||||
format, width, height, width * depth, buffer.release(),
|
||||
[total_size](uint8* ptr) {
|
||||
::operator delete[](ptr, total_size,
|
||||
std::align_val_t(EIGEN_MAX_ALIGN_BYTES));
|
||||
});
|
||||
} else if (input_tensor.dtype() == tensorflow::DT_UINT8) {
|
||||
if (scale_factor_ != 1.0) {
|
||||
return absl::InvalidArgumentError("scale_factor_ given for uint8 tensor");
|
||||
|
|
|
@ -121,10 +121,11 @@ cc_library(
|
|||
deps = [
|
||||
":tflite_custom_op_resolver_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/util/tflite:cpu_op_resolver",
|
||||
"//mediapipe/util/tflite:op_resolver",
|
||||
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -12,14 +12,22 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "mediapipe/calculators/tflite/tflite_custom_op_resolver_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/tflite/cpu_op_resolver.h"
|
||||
#include "mediapipe/util/tflite/op_resolver.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
constexpr char kOpResolverTag[] = "OP_RESOLVER";
|
||||
} // namespace
|
||||
|
||||
// This calculator creates a custom op resolver as a side packet that can be
|
||||
// used in TfLiteInferenceCalculator. Current custom op resolver supports the
|
||||
// following custom op on CPU and GPU:
|
||||
|
@ -27,7 +35,9 @@ namespace mediapipe {
|
|||
// MaxPoolArgmax
|
||||
// MaxUnpooling
|
||||
//
|
||||
// Usage example:
|
||||
// Usage examples:
|
||||
//
|
||||
// For using with TfliteInferenceCalculator:
|
||||
// node {
|
||||
// calculator: "TfLiteCustomOpResolverCalculator"
|
||||
// output_side_packet: "op_resolver"
|
||||
|
@ -37,12 +47,27 @@ namespace mediapipe {
|
|||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// For using with InferenceCalculator:
|
||||
// node {
|
||||
// calculator: "TfLiteCustomOpResolverCalculator"
|
||||
// output_side_packet: "OP_RESOLVER:op_resolver"
|
||||
// node_options: {
|
||||
// [type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] {
|
||||
// use_gpu: true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class TfLiteCustomOpResolverCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
if (cc->OutputSidePackets().HasTag(kOpResolverTag)) {
|
||||
cc->OutputSidePackets().Tag(kOpResolverTag).Set<tflite::OpResolver>();
|
||||
} else {
|
||||
cc->OutputSidePackets()
|
||||
.Index(0)
|
||||
.Set<tflite::ops::builtin::BuiltinOpResolver>();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -59,7 +84,14 @@ class TfLiteCustomOpResolverCalculator : public CalculatorBase {
|
|||
op_resolver = absl::make_unique<mediapipe::CpuOpResolver>();
|
||||
}
|
||||
|
||||
if (cc->OutputSidePackets().HasTag(kOpResolverTag)) {
|
||||
cc->OutputSidePackets()
|
||||
.Tag(kOpResolverTag)
|
||||
.Set(mediapipe::api2::PacketAdopting<tflite::OpResolver>(
|
||||
std::move(op_resolver)));
|
||||
} else {
|
||||
cc->OutputSidePackets().Index(0).Set(Adopt(op_resolver.release()));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/util:label_map_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -304,6 +305,7 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/util:resource_util",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
] + select({
|
||||
"//mediapipe:android": [
|
||||
"//mediapipe/util/android/file/base",
|
||||
|
@ -350,6 +352,40 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "detection_transformation_calculator",
|
||||
srcs = ["detection_transformation_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "detection_transformation_calculator_test",
|
||||
size = "small",
|
||||
srcs = ["detection_transformation_calculator_test.cc"],
|
||||
deps = [
|
||||
":detection_transformation_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "non_max_suppression_calculator",
|
||||
srcs = ["non_max_suppression_calculator.cc"],
|
||||
|
|
|
@ -12,12 +12,12 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "mediapipe/calculators/util/detection_label_id_to_text_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
|
||||
#if defined(MEDIAPIPE_MOBILE)
|
||||
|
@ -53,8 +53,11 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase {
|
|||
absl::Status Process(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
absl::node_hash_map<int, std::string> label_map_;
|
||||
::mediapipe::DetectionLabelIdToTextCalculatorOptions options_;
|
||||
// Local label map built from the calculator options' `label_map_path` or
|
||||
// `label` field.
|
||||
LabelMap local_label_map_;
|
||||
bool keep_label_id_;
|
||||
const LabelMap& GetLabelMap(CalculatorContext* cc);
|
||||
};
|
||||
REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator);
|
||||
|
||||
|
@ -69,13 +72,16 @@ absl::Status DetectionLabelIdToTextCalculator::GetContract(
|
|||
absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
options_ =
|
||||
const auto& options =
|
||||
cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>();
|
||||
|
||||
if (options_.has_label_map_path()) {
|
||||
if (options.has_label_map_path()) {
|
||||
RET_CHECK(!options.has_label_map() && options.label().empty())
|
||||
<< "Only can set one of the following fields in the CalculatorOptions: "
|
||||
"label_map_path, label, and label_map.";
|
||||
std::string string_path;
|
||||
ASSIGN_OR_RETURN(string_path,
|
||||
PathToResourceAsFile(options_.label_map_path()));
|
||||
PathToResourceAsFile(options.label_map_path()));
|
||||
std::string label_map_string;
|
||||
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
|
||||
|
||||
|
@ -83,13 +89,21 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
|
|||
std::string line;
|
||||
int i = 0;
|
||||
while (std::getline(stream, line)) {
|
||||
label_map_[i++] = line;
|
||||
LabelMapItem item;
|
||||
item.set_name(line);
|
||||
(*local_label_map_.mutable_index_to_item())[i++] = item;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < options_.label_size(); ++i) {
|
||||
label_map_[i] = options_.label(i);
|
||||
} else if (!options.label().empty()) {
|
||||
RET_CHECK(!options.has_label_map())
|
||||
<< "Only can set one of the following fields in the CalculatorOptions: "
|
||||
"label_map_path, label, and label_map.";
|
||||
for (int i = 0; i < options.label_size(); ++i) {
|
||||
LabelMapItem item;
|
||||
item.set_name(options.label(i));
|
||||
(*local_label_map_.mutable_index_to_item())[i] = item;
|
||||
}
|
||||
}
|
||||
keep_label_id_ = options.keep_label_id();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -101,13 +115,18 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
|
|||
Detection& output_detection = output_detections.back();
|
||||
bool has_text_label = false;
|
||||
for (const int32 label_id : output_detection.label_id()) {
|
||||
if (label_map_.find(label_id) != label_map_.end()) {
|
||||
output_detection.add_label(label_map_[label_id]);
|
||||
if (GetLabelMap(cc).index_to_item().find(label_id) !=
|
||||
GetLabelMap(cc).index_to_item().end()) {
|
||||
auto item = GetLabelMap(cc).index_to_item().at(label_id);
|
||||
output_detection.add_label(item.name());
|
||||
if (item.has_display_name()) {
|
||||
output_detection.add_display_name(item.display_name());
|
||||
}
|
||||
has_text_label = true;
|
||||
}
|
||||
}
|
||||
// Remove label_id field if text labels exist.
|
||||
if (has_text_label && !options_.keep_label_id()) {
|
||||
if (has_text_label && !keep_label_id_) {
|
||||
output_detection.clear_label_id();
|
||||
}
|
||||
}
|
||||
|
@ -117,4 +136,13 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const LabelMap& DetectionLabelIdToTextCalculator::GetLabelMap(
|
||||
CalculatorContext* cc) {
|
||||
return !local_label_map_.index_to_item().empty()
|
||||
? local_label_map_
|
||||
: cc->Options<
|
||||
::mediapipe::DetectionLabelIdToTextCalculatorOptions>()
|
||||
.label_map();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -17,6 +17,7 @@ syntax = "proto2";
|
|||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/util/label_map.proto";
|
||||
|
||||
message DetectionLabelIdToTextCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
|
@ -26,7 +27,7 @@ message DetectionLabelIdToTextCalculatorOptions {
|
|||
// Path to a label map file for getting the actual name of detected classes.
|
||||
optional string label_map_path = 1;
|
||||
|
||||
// Alternative way to specify label map
|
||||
// Alternative way to specify label map.
|
||||
// label: "label for id 0"
|
||||
// label: "label for id 1"
|
||||
// ...
|
||||
|
@ -36,4 +37,7 @@ message DetectionLabelIdToTextCalculatorOptions {
|
|||
// could be found. By setting this field to true, it is always copied to the
|
||||
// output detections.
|
||||
optional bool keep_label_id = 3;
|
||||
|
||||
// Label map.
|
||||
optional LabelMap label_map = 4;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,298 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
T BoundedValue(T value, T upper_bound) {
|
||||
T output = std::min(value, upper_bound);
|
||||
if (output < 0) {
|
||||
return 0;
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::Status ConvertRelativeBoundingBoxToBoundingBox(
|
||||
const std::pair<int, int>& image_size, Detection* detection) {
|
||||
const int image_width = image_size.first;
|
||||
const int image_height = image_size.second;
|
||||
const auto& relative_bbox =
|
||||
detection->location_data().relative_bounding_box();
|
||||
auto* bbox = detection->mutable_location_data()->mutable_bounding_box();
|
||||
bbox->set_xmin(
|
||||
BoundedValue<int>(relative_bbox.xmin() * image_width, image_width));
|
||||
bbox->set_ymin(
|
||||
BoundedValue<int>(relative_bbox.ymin() * image_height, image_height));
|
||||
bbox->set_width(
|
||||
BoundedValue<int>(relative_bbox.width() * image_width, image_width));
|
||||
bbox->set_height(
|
||||
BoundedValue<int>(relative_bbox.height() * image_height, image_height));
|
||||
detection->mutable_location_data()->set_format(LocationData::BOUNDING_BOX);
|
||||
detection->mutable_location_data()->clear_relative_bounding_box();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ConvertBoundingBoxToRelativeBoundingBox(
|
||||
const std::pair<int, int>& image_size, Detection* detection) {
|
||||
int image_width = image_size.first;
|
||||
int image_height = image_size.second;
|
||||
const auto& bbox = detection->location_data().bounding_box();
|
||||
auto* relative_bbox =
|
||||
detection->mutable_location_data()->mutable_relative_bounding_box();
|
||||
relative_bbox->set_xmin(
|
||||
BoundedValue<float>((float)bbox.xmin() / image_width, 1.0f));
|
||||
relative_bbox->set_ymin(
|
||||
BoundedValue<float>((float)bbox.ymin() / image_height, 1.0f));
|
||||
relative_bbox->set_width(
|
||||
BoundedValue<float>((float)bbox.width() / image_width, 1.0f));
|
||||
relative_bbox->set_height(
|
||||
BoundedValue<float>((float)bbox.height() / image_height, 1.0f));
|
||||
detection->mutable_location_data()->clear_bounding_box();
|
||||
detection->mutable_location_data()->set_format(
|
||||
LocationData::RELATIVE_BOUNDING_BOX);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<LocationData::Format> GetLocationDataFormat(
|
||||
const Detection& detection) {
|
||||
if (!detection.has_location_data()) {
|
||||
return absl::InvalidArgumentError("Detection must have location data.");
|
||||
}
|
||||
LocationData::Format format = detection.location_data().format();
|
||||
RET_CHECK(format == LocationData::RELATIVE_BOUNDING_BOX ||
|
||||
format == LocationData::BOUNDING_BOX)
|
||||
<< "Detection's location data format must be either "
|
||||
"RELATIVE_BOUNDING_BOX or BOUNDING_BOX";
|
||||
return format;
|
||||
}
|
||||
|
||||
absl::StatusOr<LocationData::Format> GetLocationDataFormat(
|
||||
std::vector<Detection>& detections) {
|
||||
RET_CHECK(!detections.empty());
|
||||
LocationData::Format output_format;
|
||||
ASSIGN_OR_RETURN(output_format, GetLocationDataFormat(detections[0]));
|
||||
for (int i = 1; i < detections.size(); ++i) {
|
||||
ASSIGN_OR_RETURN(LocationData::Format format,
|
||||
GetLocationDataFormat(detections[i]));
|
||||
if (output_format != format) {
|
||||
return absl::InvalidArgumentError(
|
||||
"Input detections have different location data formats.");
|
||||
}
|
||||
}
|
||||
return output_format;
|
||||
}
|
||||
|
||||
absl::Status ConvertBoundingBox(const std::pair<int, int>& image_size,
|
||||
Detection* detection) {
|
||||
if (!detection->has_location_data()) {
|
||||
return absl::InvalidArgumentError("Detection must have location data.");
|
||||
}
|
||||
switch (detection->location_data().format()) {
|
||||
case LocationData::RELATIVE_BOUNDING_BOX:
|
||||
return ConvertRelativeBoundingBoxToBoundingBox(image_size, detection);
|
||||
case LocationData::BOUNDING_BOX:
|
||||
return ConvertBoundingBoxToRelativeBoundingBox(image_size, detection);
|
||||
default:
|
||||
return absl::InvalidArgumentError(
|
||||
"Detection's location data format must be either "
|
||||
"RELATIVE_BOUNDING_BOX or BOUNDING_BOX.");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Transforms relative bounding box(es) to pixel bounding box(es) in a detection
|
||||
// proto/detection list/detection vector, or vice versa.
|
||||
//
|
||||
// Inputs:
|
||||
// One of the following:
|
||||
// DETECTION: A Detection proto.
|
||||
// DETECTIONS: An std::vector<Detection>/ a DetectionList proto.
|
||||
// IMAGE_SIZE: A std::pair<int, int> represention image width and height.
|
||||
//
|
||||
// Outputs:
|
||||
// At least one of the following:
|
||||
// PIXEL_DETECTION: A Detection proto with pixel bounding box.
|
||||
// PIXEL_DETECTIONS: An std::vector<Detection> with pixel bounding boxes.
|
||||
// PIXEL_DETECTION_LIST: A DetectionList proto with pixel bounding boxes.
|
||||
// RELATIVE_DETECTION: A Detection proto with relative bounding box.
|
||||
// RELATIVE_DETECTIONS: An std::vector<Detection> with relative bounding boxes.
|
||||
// RELATIVE_DETECTION_LIST: A DetectionList proto with relative bounding boxes.
|
||||
//
|
||||
// Example config:
|
||||
// For input detection(s) with relative bounding box(es):
|
||||
// node {
|
||||
// calculator: "DetectionTransformationCalculator"
|
||||
// input_stream: "DETECTION:input_detection"
|
||||
// input_stream: "IMAGE_SIZE:image_size"
|
||||
// output_stream: "PIXEL_DETECTION:output_detection"
|
||||
// output_stream: "PIXEL_DETECTIONS:output_detections"
|
||||
// output_stream: "PIXEL_DETECTION_LIST:output_detection_list"
|
||||
// }
|
||||
//
|
||||
// For input detection(s) with pixel bounding box(es):
|
||||
// node {
|
||||
// calculator: "DetectionTransformationCalculator"
|
||||
// input_stream: "DETECTION:input_detection"
|
||||
// input_stream: "IMAGE_SIZE:image_size"
|
||||
// output_stream: "RELATIVE_DETECTION:output_detection"
|
||||
// output_stream: "RELATIVE_DETECTIONS:output_detections"
|
||||
// output_stream: "RELATIVE_DETECTION_LIST:output_detection_list"
|
||||
// }
|
||||
class DetectionTransformationCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<Detection>::Optional kInDetection{"DETECTION"};
|
||||
static constexpr Input<OneOf<DetectionList, std::vector<Detection>>>::Optional
|
||||
kInDetections{"DETECTIONS"};
|
||||
static constexpr Input<std::pair<int, int>> kInImageSize{"IMAGE_SIZE"};
|
||||
static constexpr Output<Detection>::Optional kOutPixelDetection{
|
||||
"PIXEL_DETECTION"};
|
||||
static constexpr Output<std::vector<Detection>>::Optional kOutPixelDetections{
|
||||
"PIXEL_DETECTIONS"};
|
||||
static constexpr Output<DetectionList>::Optional kOutPixelDetectionList{
|
||||
"PIXEL_DETECTION_LIST"};
|
||||
static constexpr Output<Detection>::Optional kOutRelativeDetection{
|
||||
"RELATIVE_DETECTION"};
|
||||
static constexpr Output<std::vector<Detection>>::Optional
|
||||
kOutRelativeDetections{"RELATIVE_DETECTIONS"};
|
||||
static constexpr Output<DetectionList>::Optional kOutRelativeDetectionList{
|
||||
"RELATIVE_DETECTION_LIST"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kInDetection, kInDetections, kInImageSize,
|
||||
kOutPixelDetection, kOutPixelDetections,
|
||||
kOutPixelDetectionList, kOutRelativeDetection,
|
||||
kOutRelativeDetections, kOutRelativeDetectionList);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc) {
|
||||
RET_CHECK(kInImageSize(cc).IsConnected()) << "Image size must be provided.";
|
||||
RET_CHECK(kInDetections(cc).IsConnected() ^ kInDetection(cc).IsConnected());
|
||||
if (kInDetections(cc).IsConnected()) {
|
||||
RET_CHECK(kOutPixelDetections(cc).IsConnected() ||
|
||||
kOutPixelDetectionList(cc).IsConnected() ||
|
||||
kOutRelativeDetections(cc).IsConnected() ||
|
||||
kOutRelativeDetectionList(cc).IsConnected())
|
||||
<< "Output must be a container of detections.";
|
||||
}
|
||||
RET_CHECK(kOutPixelDetections(cc).IsConnected() ||
|
||||
kOutPixelDetectionList(cc).IsConnected() ||
|
||||
kOutPixelDetection(cc).IsConnected() ||
|
||||
kOutRelativeDetections(cc).IsConnected() ||
|
||||
kOutRelativeDetectionList(cc).IsConnected() ||
|
||||
kOutRelativeDetection(cc).IsConnected())
|
||||
<< "Must connect at least one output stream.";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override {
|
||||
output_pixel_bounding_boxes_ = kOutPixelDetections(cc).IsConnected() ||
|
||||
kOutPixelDetectionList(cc).IsConnected() ||
|
||||
kOutPixelDetection(cc).IsConnected();
|
||||
output_relative_bounding_boxes_ =
|
||||
kOutRelativeDetections(cc).IsConnected() ||
|
||||
kOutRelativeDetectionList(cc).IsConnected() ||
|
||||
kOutRelativeDetection(cc).IsConnected();
|
||||
RET_CHECK(output_pixel_bounding_boxes_ ^ output_relative_bounding_boxes_)
|
||||
<< "All output streams must have the same stream tag prefix, either "
|
||||
"\"PIXEL\" or \"RELATIVE_\".";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
std::pair<int, int> image_size = kInImageSize(cc).Get();
|
||||
std::vector<Detection> transformed_detections;
|
||||
LocationData::Format input_location_data_format;
|
||||
if (kInDetections(cc).IsConnected()) {
|
||||
transformed_detections = kInDetections(cc).Visit(
|
||||
[&](const DetectionList& detection_list) {
|
||||
return std::vector<Detection>(detection_list.detection().begin(),
|
||||
detection_list.detection().end());
|
||||
},
|
||||
[&](const std::vector<Detection>& detection_vector) {
|
||||
return detection_vector;
|
||||
});
|
||||
ASSIGN_OR_RETURN(input_location_data_format,
|
||||
GetLocationDataFormat(transformed_detections));
|
||||
for (Detection& detection : transformed_detections) {
|
||||
MP_RETURN_IF_ERROR(ConvertBoundingBox(image_size, &detection));
|
||||
}
|
||||
} else {
|
||||
ASSIGN_OR_RETURN(input_location_data_format,
|
||||
GetLocationDataFormat(kInDetection(cc).Get()));
|
||||
Detection transformed_detection(kInDetection(cc).Get());
|
||||
MP_RETURN_IF_ERROR(
|
||||
ConvertBoundingBox(image_size, &transformed_detection));
|
||||
transformed_detections.push_back(transformed_detection);
|
||||
}
|
||||
if (input_location_data_format == LocationData::RELATIVE_BOUNDING_BOX) {
|
||||
RET_CHECK(!output_relative_bounding_boxes_)
|
||||
<< "Input detections are with relative bounding box(es), and the "
|
||||
"output detections must have pixel bounding box(es).";
|
||||
if (kOutPixelDetection(cc).IsConnected()) {
|
||||
kOutPixelDetection(cc).Send(transformed_detections[0]);
|
||||
}
|
||||
if (kOutPixelDetections(cc).IsConnected()) {
|
||||
kOutPixelDetections(cc).Send(transformed_detections);
|
||||
}
|
||||
if (kOutPixelDetectionList(cc).IsConnected()) {
|
||||
DetectionList detection_list;
|
||||
for (const auto& detection : transformed_detections) {
|
||||
detection_list.add_detection()->CopyFrom(detection);
|
||||
}
|
||||
kOutPixelDetectionList(cc).Send(detection_list);
|
||||
}
|
||||
} else {
|
||||
RET_CHECK(!output_pixel_bounding_boxes_)
|
||||
<< "Input detections are with pixel bounding box(es), and the "
|
||||
"output detections must have relative bounding box(es).";
|
||||
if (kOutRelativeDetection(cc).IsConnected()) {
|
||||
kOutRelativeDetection(cc).Send(transformed_detections[0]);
|
||||
}
|
||||
if (kOutRelativeDetections(cc).IsConnected()) {
|
||||
kOutRelativeDetections(cc).Send(transformed_detections);
|
||||
}
|
||||
if (kOutRelativeDetectionList(cc).IsConnected()) {
|
||||
DetectionList detection_list;
|
||||
for (const auto& detection : transformed_detections) {
|
||||
detection_list.add_detection()->CopyFrom(detection);
|
||||
}
|
||||
kOutRelativeDetectionList(cc).Send(detection_list);
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
private:
|
||||
bool output_relative_bounding_boxes_;
|
||||
bool output_pixel_bounding_boxes_;
|
||||
};
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(DetectionTransformationCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,287 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/calculator_runner.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/formats/location_data.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kDetectionTag[] = "DETECTION";
|
||||
constexpr char kDetectionsTag[] = "DETECTIONS";
|
||||
constexpr char kImageSizeTag[] = "IMAGE_SIZE";
|
||||
constexpr char kPixelDetectionTag[] = "PIXEL_DETECTION";
|
||||
constexpr char kPixelDetectionListTag[] = "PIXEL_DETECTION_LIST";
|
||||
constexpr char kPixelDetectionsTag[] = "PIXEL_DETECTIONS";
|
||||
constexpr char kRelativeDetectionListTag[] = "RELATIVE_DETECTION_LIST";
|
||||
constexpr char kRelativeDetectionsTag[] = "RELATIVE_DETECTIONS";
|
||||
|
||||
Detection DetectionWithBoundingBox(int32 xmin, int32 ymin, int32 width,
|
||||
int32 height) {
|
||||
Detection detection;
|
||||
LocationData* location_data = detection.mutable_location_data();
|
||||
location_data->set_format(LocationData::BOUNDING_BOX);
|
||||
location_data->mutable_bounding_box()->set_xmin(xmin);
|
||||
location_data->mutable_bounding_box()->set_ymin(ymin);
|
||||
location_data->mutable_bounding_box()->set_width(width);
|
||||
location_data->mutable_bounding_box()->set_height(height);
|
||||
return detection;
|
||||
}
|
||||
|
||||
Detection DetectionWithRelativeBoundingBox(float xmin, float ymin, float width,
|
||||
float height) {
|
||||
Detection detection;
|
||||
LocationData* location_data = detection.mutable_location_data();
|
||||
location_data->set_format(LocationData::RELATIVE_BOUNDING_BOX);
|
||||
location_data->mutable_relative_bounding_box()->set_xmin(xmin);
|
||||
location_data->mutable_relative_bounding_box()->set_ymin(ymin);
|
||||
location_data->mutable_relative_bounding_box()->set_width(width);
|
||||
location_data->mutable_relative_bounding_box()->set_height(height);
|
||||
return detection;
|
||||
}
|
||||
|
||||
std::vector<Detection> ConvertToDetectionVector(
|
||||
const DetectionList& detection_list) {
|
||||
std::vector<Detection> output;
|
||||
for (const auto& detection : detection_list.detection()) {
|
||||
output.push_back(detection);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
void CheckBoundingBox(const Detection& output, const Detection& expected) {
|
||||
const auto& output_bbox = output.location_data().bounding_box();
|
||||
const auto& expected_bbox = output.location_data().bounding_box();
|
||||
EXPECT_THAT(output_bbox.xmin(), testing::Eq(expected_bbox.xmin()));
|
||||
EXPECT_THAT(output_bbox.ymin(), testing::Eq(expected_bbox.ymin()));
|
||||
EXPECT_THAT(output_bbox.width(), testing::Eq(expected_bbox.width()));
|
||||
EXPECT_THAT(output_bbox.height(), testing::Eq(expected_bbox.height()));
|
||||
}
|
||||
|
||||
void CheckRelativeBoundingBox(const Detection& output,
|
||||
const Detection& expected) {
|
||||
const auto& output_bbox = output.location_data().relative_bounding_box();
|
||||
const auto& expected_bbox = output.location_data().relative_bounding_box();
|
||||
EXPECT_THAT(output_bbox.xmin(), testing::FloatEq(expected_bbox.xmin()));
|
||||
EXPECT_THAT(output_bbox.ymin(), testing::FloatEq(expected_bbox.ymin()));
|
||||
EXPECT_THAT(output_bbox.width(), testing::FloatEq(expected_bbox.width()));
|
||||
EXPECT_THAT(output_bbox.height(), testing::FloatEq(expected_bbox.height()));
|
||||
}
|
||||
|
||||
void CheckOutputDetections(const std::vector<Detection>& expected,
|
||||
const std::vector<Detection>& output) {
|
||||
ASSERT_EQ(output.size(), expected.size());
|
||||
for (int i = 0; i < output.size(); ++i) {
|
||||
auto output_format = output[i].location_data().format();
|
||||
ASSERT_TRUE(output_format == LocationData::RELATIVE_BOUNDING_BOX ||
|
||||
output_format == LocationData::BOUNDING_BOX);
|
||||
ASSERT_EQ(output_format, expected[i].location_data().format());
|
||||
if (output_format == LocationData::RELATIVE_BOUNDING_BOX) {
|
||||
CheckRelativeBoundingBox(output[i], expected[i]);
|
||||
}
|
||||
if (output_format == LocationData::BOUNDING_BOX) {
|
||||
CheckBoundingBox(output[i], expected[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(DetectionsTransformationCalculatorTest, MissingImageSize) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "DetectionTransformationCalculator"
|
||||
input_stream: "DETECTIONS:detections"
|
||||
output_stream: "PIXEL_DETECTION:detection"
|
||||
)pb"));
|
||||
|
||||
auto status = runner.Run();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.message(),
|
||||
testing::HasSubstr("Image size must be provided"));
|
||||
}
|
||||
|
||||
TEST(DetectionsTransformationCalculatorTest, WrongOutputType) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "DetectionTransformationCalculator"
|
||||
input_stream: "DETECTIONS:detections"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "PIXEL_DETECTION:detection"
|
||||
)pb"));
|
||||
|
||||
auto status = runner.Run();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.message(),
|
||||
testing::HasSubstr("Output must be a container of detections"));
|
||||
}
|
||||
|
||||
TEST(DetectionsTransformationCalculatorTest, WrongLocationDataFormat) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "DetectionTransformationCalculator"
|
||||
input_stream: "DETECTION:input_detection"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "PIXEL_DETECTION:output_detection"
|
||||
)pb"));
|
||||
|
||||
Detection detection;
|
||||
detection.mutable_location_data()->set_format(LocationData::GLOBAL);
|
||||
runner.MutableInputs()
|
||||
->Tag(kDetectionTag)
|
||||
.packets.push_back(MakePacket<Detection>(detection).At(Timestamp(0)));
|
||||
std::pair<int, int> image_size({2000, 1000});
|
||||
runner.MutableInputs()
|
||||
->Tag(kImageSizeTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
|
||||
|
||||
auto status = runner.Run();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.message(),
|
||||
testing::HasSubstr("location data format must be either "
|
||||
"RELATIVE_BOUNDING_BOX or BOUNDING_BOX"));
|
||||
}
|
||||
|
||||
TEST(DetectionsTransformationCalculatorTest,
|
||||
ConvertBoundingBoxToRelativeBoundingBox) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "DetectionTransformationCalculator"
|
||||
input_stream: "DETECTIONS:input_detections"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "RELATIVE_DETECTIONS:output_detections"
|
||||
output_stream: "RELATIVE_DETECTION_LIST:output_detection_list"
|
||||
)pb"));
|
||||
|
||||
auto detections(absl::make_unique<std::vector<Detection>>());
|
||||
detections->push_back(DetectionWithBoundingBox(100, 200, 400, 300));
|
||||
detections->push_back(DetectionWithBoundingBox(0, 0, 2000, 1000));
|
||||
std::pair<int, int> image_size({2000, 1000});
|
||||
runner.MutableInputs()
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(0)));
|
||||
runner.MutableInputs()
|
||||
->Tag(kImageSizeTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
std::vector<Detection> expected(
|
||||
{DetectionWithRelativeBoundingBox(0.05, 0.2, 0.2, 0.3),
|
||||
DetectionWithRelativeBoundingBox(0, 0, 1, 1)});
|
||||
const std::vector<Packet>& detections_output =
|
||||
runner.Outputs().Tag(kRelativeDetectionsTag).packets;
|
||||
ASSERT_EQ(1, detections_output.size());
|
||||
CheckOutputDetections(expected,
|
||||
detections_output[0].Get<std::vector<Detection>>());
|
||||
|
||||
const std::vector<Packet>& detection_list_output =
|
||||
runner.Outputs().Tag(kRelativeDetectionListTag).packets;
|
||||
ASSERT_EQ(1, detection_list_output.size());
|
||||
CheckOutputDetections(
|
||||
expected,
|
||||
ConvertToDetectionVector(detection_list_output[0].Get<DetectionList>()));
|
||||
}
|
||||
|
||||
TEST(DetectionsTransformationCalculatorTest,
|
||||
ConvertRelativeBoundingBoxToBoundingBox) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "DetectionTransformationCalculator"
|
||||
input_stream: "DETECTIONS:input_detections"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "PIXEL_DETECTIONS:output_detections"
|
||||
output_stream: "PIXEL_DETECTION_LIST:output_detection_list"
|
||||
)pb"));
|
||||
|
||||
auto detections(absl::make_unique<std::vector<Detection>>());
|
||||
detections->push_back(DetectionWithRelativeBoundingBox(0.1, 0.2, 0.3, 0.4));
|
||||
detections->push_back(DetectionWithRelativeBoundingBox(0, 0, 1, 1));
|
||||
std::pair<int, int> image_size({2000, 1000});
|
||||
runner.MutableInputs()
|
||||
->Tag(kDetectionsTag)
|
||||
.packets.push_back(Adopt(detections.release()).At(Timestamp(0)));
|
||||
runner.MutableInputs()
|
||||
->Tag(kImageSizeTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
std::vector<Detection> expected({DetectionWithBoundingBox(100, 200, 400, 300),
|
||||
DetectionWithBoundingBox(0, 0, 2000, 1000)});
|
||||
const std::vector<Packet>& detections_output =
|
||||
runner.Outputs().Tag(kPixelDetectionsTag).packets;
|
||||
ASSERT_EQ(1, detections_output.size());
|
||||
CheckOutputDetections(expected,
|
||||
detections_output[0].Get<std::vector<Detection>>());
|
||||
|
||||
const std::vector<Packet>& detection_list_output =
|
||||
runner.Outputs().Tag(kPixelDetectionListTag).packets;
|
||||
ASSERT_EQ(1, detection_list_output.size());
|
||||
CheckOutputDetections(
|
||||
expected,
|
||||
ConvertToDetectionVector(detection_list_output[0].Get<DetectionList>()));
|
||||
}
|
||||
|
||||
TEST(DetectionsTransformationCalculatorTest, ConvertSingleDetection) {
|
||||
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"pb(
|
||||
calculator: "DetectionTransformationCalculator"
|
||||
input_stream: "DETECTION:input_detection"
|
||||
input_stream: "IMAGE_SIZE:image_size"
|
||||
output_stream: "PIXEL_DETECTION:outpu_detection"
|
||||
output_stream: "PIXEL_DETECTIONS:output_detections"
|
||||
output_stream: "PIXEL_DETECTION_LIST:output_detection_list"
|
||||
)pb"));
|
||||
|
||||
runner.MutableInputs()
|
||||
->Tag(kDetectionTag)
|
||||
.packets.push_back(MakePacket<Detection>(DetectionWithRelativeBoundingBox(
|
||||
0.05, 0.2, 0.2, 0.3))
|
||||
.At(Timestamp(0)));
|
||||
std::pair<int, int> image_size({2000, 1000});
|
||||
runner.MutableInputs()
|
||||
->Tag(kImageSizeTag)
|
||||
.packets.push_back(
|
||||
MakePacket<std::pair<int, int>>(image_size).At(Timestamp(0)));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
std::vector<Detection> expected(
|
||||
{DetectionWithBoundingBox(100, 200, 400, 300)});
|
||||
const std::vector<Packet>& detection_output =
|
||||
runner.Outputs().Tag(kPixelDetectionTag).packets;
|
||||
ASSERT_EQ(1, detection_output.size());
|
||||
CheckOutputDetections(expected, {detection_output[0].Get<Detection>()});
|
||||
|
||||
const std::vector<Packet>& detections_output =
|
||||
runner.Outputs().Tag(kPixelDetectionsTag).packets;
|
||||
ASSERT_EQ(1, detections_output.size());
|
||||
CheckOutputDetections(expected,
|
||||
detections_output[0].Get<std::vector<Detection>>());
|
||||
|
||||
const std::vector<Packet>& detection_list_output =
|
||||
runner.Outputs().Tag(kPixelDetectionListTag).packets;
|
||||
ASSERT_EQ(1, detection_list_output.size());
|
||||
CheckOutputDetections(
|
||||
expected,
|
||||
ConvertToDetectionVector(detection_list_output[0].Get<DetectionList>()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -181,7 +181,7 @@ class TrackingGraphTest : public Test {
|
|||
// Each image is shifted to the right and bottom by kTranslationStep
|
||||
// pixels compared with the previous image.
|
||||
static constexpr int kTranslationStep = 10;
|
||||
static constexpr float kEqualityTolerance = 3e-4f;
|
||||
static constexpr float kEqualityTolerance = 1e-3f;
|
||||
};
|
||||
|
||||
void TrackingGraphTest::ExpectBoxAtFrame(const TimedBoxProto& box, float frame,
|
||||
|
|
|
@ -85,7 +85,7 @@ class KinematicPathSolver {
|
|||
double current_position_px_;
|
||||
double prior_position_px_;
|
||||
double current_velocity_deg_per_s_;
|
||||
uint64 current_time_;
|
||||
uint64 current_time_ = 0;
|
||||
// History of observations (second) and their time (first).
|
||||
std::deque<std::pair<uint64, int>> raw_positions_at_time_;
|
||||
// Current target position.
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""Example of reading a MediaSequence dataset.
|
||||
"""
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "facedetectioncpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "facedetectiongpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "faceeffect",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "facemeshgpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "handdetectiongpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "handtrackinggpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "helloworld",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "holistictrackinggpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "iristrackinggpu",
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lint as: python3
|
||||
"""This script is used to set up automatic provisioning for iOS examples.
|
||||
|
||||
It scans the provisioning profiles used by Xcode, looking for one matching the
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "objectdetectioncpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "objectdetectiongpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "objectdetectiontrackinggpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "posetrackinggpu",
|
||||
|
|
|
@ -24,7 +24,7 @@ load(
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
MIN_IOS_VERSION = "10.0"
|
||||
MIN_IOS_VERSION = "11.0"
|
||||
|
||||
alias(
|
||||
name = "selfiesegmentationgpu",
|
||||
|
|
|
@ -234,7 +234,9 @@ cc_library(
|
|||
"//mediapipe/framework/tool:options_map",
|
||||
"//mediapipe/framework/tool:packet_generator_wrapper_calculator_cc_proto",
|
||||
"//mediapipe/framework/tool:tag_map",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -348,6 +350,7 @@ cc_library(
|
|||
"//mediapipe/framework/tool:validate",
|
||||
"//mediapipe/framework/tool:validate_name",
|
||||
"//mediapipe/gpu:graph_support",
|
||||
"//mediapipe/gpu:gpu_service",
|
||||
"//mediapipe/util:cpu_util",
|
||||
] + select({
|
||||
"//conditions:default": ["//mediapipe/gpu:gpu_shared_data_internal"],
|
||||
|
@ -416,7 +419,6 @@ cc_library(
|
|||
"//mediapipe/framework/tool:status_util",
|
||||
"//mediapipe/framework/tool:tag_map",
|
||||
"//mediapipe/framework/tool:validate_name",
|
||||
"//mediapipe/gpu:graph_support",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
|
@ -613,7 +615,11 @@ cc_library(
|
|||
hdrs = ["graph_service.h"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
deps = [
|
||||
":packet",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -167,7 +167,6 @@ struct IsCompatibleType<V, OneOf<U...>>
|
|||
template <typename T>
|
||||
inline Packet<T> PacketBase::As() const {
|
||||
if (!payload_) return Packet<T>().At(timestamp_);
|
||||
packet_internal::Holder<T>* typed_payload = payload_->As<T>();
|
||||
internal::CheckCompatibleType(*payload_, internal::Wrap<T>{});
|
||||
return Packet<T>(payload_).At(timestamp_);
|
||||
}
|
||||
|
@ -217,8 +216,8 @@ class Packet : public Packet<internal::Generic> {
|
|||
const T& operator*() const { return Get(); }
|
||||
const T* operator->() const { return &Get(); }
|
||||
|
||||
template <typename U>
|
||||
T GetOr(U&& v) const {
|
||||
template <typename U, typename TT = T>
|
||||
std::enable_if_t<!std::is_abstract_v<TT>, TT> GetOr(U&& v) const {
|
||||
return IsEmpty() ? static_cast<T>(absl::forward<U>(v)) : **this;
|
||||
}
|
||||
|
||||
|
|
|
@ -4,11 +4,15 @@ namespace api2 {
|
|||
namespace {
|
||||
|
||||
#if defined(TEST_NO_ASSIGN_WRONG_PACKET_TYPE)
|
||||
void AssignWrongPacketType() { Packet<int> p = MakePacket<float>(1.0); }
|
||||
int AssignWrongPacketType() {
|
||||
Packet<int> p = MakePacket<float>(1.0);
|
||||
return *p;
|
||||
}
|
||||
#elif defined(TEST_NO_ASSIGN_GENERIC_TO_SPECIFIC)
|
||||
void AssignWrongPacketType() {
|
||||
int AssignWrongPacketType() {
|
||||
Packet<> p = MakePacket<float>(1.0);
|
||||
Packet<int> p2 = p;
|
||||
return *p2;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -264,6 +264,23 @@ TEST(PacketTest, Polymorphism) {
|
|||
EXPECT_EQ((**mutable_base).name(), "Derived");
|
||||
}
|
||||
|
||||
class AbstractBase {
|
||||
public:
|
||||
virtual ~AbstractBase() = default;
|
||||
virtual absl::string_view name() const = 0;
|
||||
};
|
||||
|
||||
class ConcreteDerived : public AbstractBase {
|
||||
public:
|
||||
absl::string_view name() const override { return "ConcreteDerived"; }
|
||||
};
|
||||
|
||||
TEST(PacketTest, PolymorphismAbstract) {
|
||||
Packet<AbstractBase> base =
|
||||
PacketAdopting<AbstractBase>(absl::make_unique<ConcreteDerived>());
|
||||
EXPECT_EQ(base->name(), "ConcreteDerived");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -40,6 +40,17 @@ TEST(PortTest, DeletedCopyConstructorInput) {
|
|||
EXPECT_EQ(std::string(kSideOutputPort.Tag()), "SIDE_OUTPUT");
|
||||
}
|
||||
|
||||
class AbstractBase {
|
||||
public:
|
||||
virtual ~AbstractBase() = default;
|
||||
virtual absl::string_view name() const = 0;
|
||||
};
|
||||
|
||||
TEST(PortTest, Abstract) {
|
||||
static constexpr Input<AbstractBase> kInputPort{"INPUT"};
|
||||
EXPECT_EQ(std::string(kInputPort.Tag()), "INPUT");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include <typeindex>
|
||||
|
||||
// TODO: Move protos in another CL after the C++ code migration.
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/graph_service.h"
|
||||
#include "mediapipe/framework/mediapipe_options.pb.h"
|
||||
|
@ -147,7 +149,7 @@ class CalculatorContract {
|
|||
bool IsOptional() const { return optional_; }
|
||||
|
||||
private:
|
||||
GraphServiceBase service_;
|
||||
const GraphServiceBase& service_;
|
||||
bool optional_ = false;
|
||||
};
|
||||
|
||||
|
@ -156,9 +158,12 @@ class CalculatorContract {
|
|||
return it->second;
|
||||
}
|
||||
|
||||
const std::map<std::string, GraphServiceRequest>& ServiceRequests() const {
|
||||
return service_requests_;
|
||||
}
|
||||
// A GraphService's key is always a static constant, so we can use string_view
|
||||
// as the key type without lifetime issues.
|
||||
using ServiceReqMap =
|
||||
absl::flat_hash_map<absl::string_view, GraphServiceRequest>;
|
||||
|
||||
const ServiceReqMap& ServiceRequests() const { return service_requests_; }
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
|
@ -180,7 +185,7 @@ class CalculatorContract {
|
|||
std::string input_stream_handler_;
|
||||
MediaPipeOptions input_stream_handler_options_;
|
||||
std::string node_name_;
|
||||
std::map<std::string, GraphServiceRequest> service_requests_;
|
||||
ServiceReqMap service_requests_;
|
||||
bool process_timestamps_ = false;
|
||||
TimestampDiff timestamp_offset_ = TimestampDiff::Unset();
|
||||
|
||||
|
|
|
@ -226,6 +226,16 @@ absl::Status CalculatorGraph::InitializeStreams() {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Hack for backwards compatibility with ancient GPU calculators. Can it
|
||||
// be retired yet?
|
||||
static void MaybeFixupLegacyGpuNodeContract(CalculatorNode& node) {
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
if (node.Contract().InputSidePackets().HasTag(kGpuSharedTagName)) {
|
||||
const_cast<CalculatorContract&>(node.Contract()).UseService(kGpuService);
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
|
||||
absl::Status CalculatorGraph::InitializeCalculatorNodes() {
|
||||
// Check if the user has specified a maximum queue size for an input stream.
|
||||
max_queue_size_ = validated_graph_->Config().max_queue_size();
|
||||
|
@ -246,6 +256,7 @@ absl::Status CalculatorGraph::InitializeCalculatorNodes() {
|
|||
validated_graph_.get(), node_ref, input_stream_managers_.get(),
|
||||
output_stream_managers_.get(), output_side_packets_.get(),
|
||||
&buffer_size_hint, profiler_);
|
||||
MaybeFixupLegacyGpuNodeContract(*nodes_.back());
|
||||
if (buffer_size_hint > 0) {
|
||||
max_queue_size_ = std::max(max_queue_size_, buffer_size_hint);
|
||||
}
|
||||
|
@ -283,6 +294,7 @@ absl::Status CalculatorGraph::InitializePacketGeneratorNodes(
|
|||
validated_graph_.get(), node_ref, input_stream_managers_.get(),
|
||||
output_stream_managers_.get(), output_side_packets_.get(),
|
||||
&buffer_size_hint, profiler_);
|
||||
MaybeFixupLegacyGpuNodeContract(*nodes_.back());
|
||||
if (!result.ok()) {
|
||||
// Collect as many errors as we can before failing.
|
||||
errors.push_back(result);
|
||||
|
@ -495,9 +507,8 @@ absl::StatusOr<Packet> CalculatorGraph::GetOutputSidePacket(
|
|||
<< "\" because it doesn't exist.";
|
||||
}
|
||||
Packet output_packet;
|
||||
if (scheduler_.IsTerminated()) {
|
||||
// Side-packets from calculators can be retrieved only after the graph is
|
||||
// done.
|
||||
if (!output_side_packets_[side_packet_index].GetPacket().IsEmpty() ||
|
||||
scheduler_.IsTerminated()) {
|
||||
output_packet = output_side_packets_[side_packet_index].GetPacket();
|
||||
}
|
||||
if (output_packet.IsEmpty()) {
|
||||
|
@ -546,6 +557,7 @@ absl::Status CalculatorGraph::StartRun(
|
|||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
absl::Status CalculatorGraph::SetGpuResources(
|
||||
std::shared_ptr<::mediapipe::GpuResources> resources) {
|
||||
RET_CHECK_NE(resources, nullptr);
|
||||
auto gpu_service = service_manager_.GetServiceObject(kGpuService);
|
||||
RET_CHECK_EQ(gpu_service, nullptr)
|
||||
<< "The GPU resources have already been configured.";
|
||||
|
@ -557,56 +569,56 @@ std::shared_ptr<::mediapipe::GpuResources> CalculatorGraph::GetGpuResources()
|
|||
return service_manager_.GetServiceObject(kGpuService);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
|
||||
static Packet GetLegacyGpuSharedSidePacket(
|
||||
const std::map<std::string, Packet>& side_packets) {
|
||||
std::map<std::string, Packet> additional_side_packets;
|
||||
bool update_sp = false;
|
||||
bool uses_gpu = false;
|
||||
for (const auto& node : nodes_) {
|
||||
if (node->UsesGpu()) {
|
||||
uses_gpu = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (uses_gpu) {
|
||||
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
|
||||
|
||||
auto legacy_sp_iter = side_packets.find(kGpuSharedSidePacketName);
|
||||
// Workaround for b/116875321: CalculatorRunner provides an empty packet,
|
||||
// instead of just leaving it unset.
|
||||
bool has_legacy_sp = legacy_sp_iter != side_packets.end() &&
|
||||
!legacy_sp_iter->second.IsEmpty();
|
||||
if (legacy_sp_iter == side_packets.end()) return {};
|
||||
// Note that, because of b/116875321, the legacy side packet may be set but
|
||||
// empty. But it's ok, because here we return an empty packet to indicate the
|
||||
// missing case anyway.
|
||||
return legacy_sp_iter->second;
|
||||
}
|
||||
|
||||
absl::Status CalculatorGraph::MaybeSetUpGpuServiceFromLegacySidePacket(
|
||||
Packet legacy_sp) {
|
||||
if (legacy_sp.IsEmpty()) return absl::OkStatus();
|
||||
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
|
||||
if (gpu_resources) {
|
||||
if (has_legacy_sp) {
|
||||
LOG(WARNING)
|
||||
<< "::mediapipe::GpuSharedData provided as a side packet while the "
|
||||
<< "graph already had one; ignoring side packet";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
update_sp = true;
|
||||
} else {
|
||||
if (has_legacy_sp) {
|
||||
gpu_resources =
|
||||
legacy_sp_iter->second.Get<::mediapipe::GpuSharedData*>()
|
||||
->gpu_resources;
|
||||
} else {
|
||||
ASSIGN_OR_RETURN(gpu_resources, ::mediapipe::GpuResources::Create());
|
||||
update_sp = true;
|
||||
}
|
||||
MP_RETURN_IF_ERROR(
|
||||
service_manager_.SetServiceObject(kGpuService, gpu_resources));
|
||||
gpu_resources = legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources;
|
||||
return service_manager_.SetServiceObject(kGpuService, gpu_resources);
|
||||
}
|
||||
|
||||
// Create or replace the legacy side packet if needed.
|
||||
if (update_sp) {
|
||||
legacy_gpu_shared_.reset(new ::mediapipe::GpuSharedData(gpu_resources));
|
||||
std::map<std::string, Packet> CalculatorGraph::MaybeCreateLegacyGpuSidePacket(
|
||||
Packet legacy_sp) {
|
||||
std::map<std::string, Packet> additional_side_packets;
|
||||
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
|
||||
if (gpu_resources &&
|
||||
(legacy_sp.IsEmpty() ||
|
||||
legacy_sp.Get<::mediapipe::GpuSharedData*>()->gpu_resources !=
|
||||
gpu_resources)) {
|
||||
legacy_gpu_shared_ =
|
||||
absl::make_unique<mediapipe::GpuSharedData>(gpu_resources);
|
||||
additional_side_packets[kGpuSharedSidePacketName] =
|
||||
MakePacket<::mediapipe::GpuSharedData*>(legacy_gpu_shared_.get());
|
||||
}
|
||||
return additional_side_packets;
|
||||
}
|
||||
|
||||
static bool UsesGpu(const CalculatorNode& node) {
|
||||
return node.Contract().ServiceRequests().contains(kGpuService.key);
|
||||
}
|
||||
|
||||
absl::Status CalculatorGraph::PrepareGpu() {
|
||||
auto gpu_resources = service_manager_.GetServiceObject(kGpuService);
|
||||
if (!gpu_resources) return absl::OkStatus();
|
||||
// Set up executors.
|
||||
for (auto& node : nodes_) {
|
||||
if (node->UsesGpu()) {
|
||||
if (UsesGpu(*node)) {
|
||||
MP_RETURN_IF_ERROR(gpu_resources->PrepareGpuNode(node.get()));
|
||||
}
|
||||
}
|
||||
|
@ -614,11 +626,32 @@ absl::StatusOr<std::map<std::string, Packet>> CalculatorGraph::PrepareGpu(
|
|||
MP_RETURN_IF_ERROR(
|
||||
SetExecutorInternal(name_executor.first, name_executor.second));
|
||||
}
|
||||
}
|
||||
return additional_side_packets;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
absl::Status CalculatorGraph::PrepareServices() {
|
||||
for (const auto& node : nodes_) {
|
||||
for (const auto& [key, request] : node->Contract().ServiceRequests()) {
|
||||
auto packet = service_manager_.GetServicePacket(request.Service());
|
||||
if (!packet.IsEmpty()) continue;
|
||||
auto packet_or = request.Service().CreateDefaultObject();
|
||||
if (packet_or.ok()) {
|
||||
MP_RETURN_IF_ERROR(service_manager_.SetServicePacket(
|
||||
request.Service(), std::move(packet_or).value()));
|
||||
} else if (request.IsOptional()) {
|
||||
continue;
|
||||
} else {
|
||||
return absl::InternalError(absl::StrCat(
|
||||
"Service \"", request.Service().key, "\", required by node ",
|
||||
node->DebugName(), ", was not provided and cannot be created: ",
|
||||
std::move(packet_or).status().message()));
|
||||
}
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status CalculatorGraph::PrepareForRun(
|
||||
const std::map<std::string, Packet>& extra_side_packets,
|
||||
const std::map<std::string, Packet>& stream_headers) {
|
||||
|
@ -637,7 +670,13 @@ absl::Status CalculatorGraph::PrepareForRun(
|
|||
|
||||
std::map<std::string, Packet> additional_side_packets;
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
ASSIGN_OR_RETURN(additional_side_packets, PrepareGpu(extra_side_packets));
|
||||
auto legacy_sp = GetLegacyGpuSharedSidePacket(extra_side_packets);
|
||||
MP_RETURN_IF_ERROR(MaybeSetUpGpuServiceFromLegacySidePacket(legacy_sp));
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
MP_RETURN_IF_ERROR(PrepareServices());
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
MP_RETURN_IF_ERROR(PrepareGpu());
|
||||
additional_side_packets = MaybeCreateLegacyGpuSidePacket(legacy_sp);
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
const std::map<std::string, Packet>* input_side_packets;
|
||||
|
|
|
@ -165,10 +165,13 @@ class CalculatorGraph {
|
|||
StatusOrPoller AddOutputStreamPoller(const std::string& stream_name,
|
||||
bool observe_timestamp_bounds = false);
|
||||
|
||||
// Gets output side packet by name after the graph is done. However, base
|
||||
// packets (generated by PacketGenerators) can be retrieved before
|
||||
// graph is done. Returns error if the graph is still running (for non-base
|
||||
// packets) or the output side packet is not found or empty.
|
||||
// Gets output side packet by name. The output side packet can be successfully
|
||||
// retrevied in one of the following situations:
|
||||
// - The graph is done.
|
||||
// - The output side packet has been generated by a calculator and the graph
|
||||
// is currently idle.
|
||||
// - The side packet is a base packet generated by a PacketGenerator.
|
||||
// Returns error if the the output side packet is not found or empty.
|
||||
absl::StatusOr<Packet> GetOutputSidePacket(const std::string& packet_name);
|
||||
|
||||
// Runs the graph after adding the given extra input side packets. All
|
||||
|
@ -367,13 +370,8 @@ class CalculatorGraph {
|
|||
std::shared_ptr<GpuResources> GetGpuResources() const;
|
||||
|
||||
absl::Status SetGpuResources(std::shared_ptr<GpuResources> resources);
|
||||
|
||||
// Helper for PrepareForRun. If it returns a non-empty map, those packets
|
||||
// must be added to the existing side packets, replacing existing values
|
||||
// that have the same key.
|
||||
absl::StatusOr<std::map<std::string, Packet>> PrepareGpu(
|
||||
const std::map<std::string, Packet>& side_packets);
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
template <typename T>
|
||||
absl::Status SetServiceObject(const GraphService<T>& service,
|
||||
std::shared_ptr<T> object) {
|
||||
|
@ -495,6 +493,18 @@ class CalculatorGraph {
|
|||
const std::map<std::string, Packet>& extra_side_packets,
|
||||
const std::map<std::string, Packet>& stream_headers);
|
||||
|
||||
absl::Status PrepareServices();
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
absl::Status MaybeSetUpGpuServiceFromLegacySidePacket(Packet legacy_sp);
|
||||
// Helper for PrepareForRun. If it returns a non-empty map, those packets
|
||||
// must be added to the existing side packets, replacing existing values
|
||||
// that have the same key.
|
||||
std::map<std::string, Packet> MaybeCreateLegacyGpuSidePacket(
|
||||
Packet legacy_sp);
|
||||
absl::Status PrepareGpu();
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
// Cleans up any remaining state after the run and returns any errors that may
|
||||
// have occurred during the run. Called after the scheduler has terminated.
|
||||
absl::Status FinishRun();
|
||||
|
|
|
@ -732,11 +732,12 @@ TEST(CalculatorGraph, GetOutputSidePacket) {
|
|||
status_or_packet = graph.GetOutputSidePacket("unknown");
|
||||
EXPECT_FALSE(status_or_packet.ok());
|
||||
EXPECT_EQ(absl::StatusCode::kNotFound, status_or_packet.status().code());
|
||||
// Should return UNAVAILABLE before graph is done for valid non-base
|
||||
// packets.
|
||||
// Should return the packet after the graph becomes idle.
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
status_or_packet = graph.GetOutputSidePacket("num_of_packets");
|
||||
EXPECT_FALSE(status_or_packet.ok());
|
||||
EXPECT_EQ(absl::StatusCode::kUnavailable, status_or_packet.status().code());
|
||||
MP_ASSERT_OK(status_or_packet);
|
||||
EXPECT_EQ(max_count, status_or_packet.value().Get<int>());
|
||||
EXPECT_EQ(Timestamp::Unset(), status_or_packet.value().Timestamp());
|
||||
// Should stil return a base even before graph is done.
|
||||
status_or_packet = graph.GetOutputSidePacket("output_uint64");
|
||||
MP_ASSERT_OK(status_or_packet);
|
||||
|
@ -896,5 +897,23 @@ TEST(CalculatorGraph, GeneratorAfterCalculatorProcess) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST(CalculatorGraph, GetOutputSidePacketAfterCalculatorIsOpened) {
|
||||
CalculatorGraph graph;
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
node {
|
||||
calculator: "IntegerOutputSidePacketCalculator"
|
||||
output_side_packet: "offset"
|
||||
}
|
||||
)pb");
|
||||
MP_ASSERT_OK(graph.Initialize(config));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
// Must be called to ensure that the calculator is opened.
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
absl::StatusOr<Packet> status_or_packet = graph.GetOutputSidePacket("offset");
|
||||
MP_ASSERT_OK(status_or_packet);
|
||||
EXPECT_EQ(1, status_or_packet.value().Get<int>());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -46,7 +46,6 @@
|
|||
#include "mediapipe/framework/tool/status_util.h"
|
||||
#include "mediapipe/framework/tool/tag_map.h"
|
||||
#include "mediapipe/framework/tool/validate_name.h"
|
||||
#include "mediapipe/gpu/graph_support.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -155,11 +154,6 @@ absl::Status CalculatorNode::Initialize(
|
|||
|
||||
const CalculatorContract& contract = node_type_info_->Contract();
|
||||
|
||||
uses_gpu_ =
|
||||
node_type_info_->InputSidePacketTypes().HasTag(kGpuSharedTagName) ||
|
||||
ContainsKey(node_type_info_->Contract().ServiceRequests(),
|
||||
kGpuService.key);
|
||||
|
||||
// TODO Propagate types between calculators when SetAny is used.
|
||||
|
||||
MP_RETURN_IF_ERROR(InitializeOutputSidePackets(
|
||||
|
@ -397,7 +391,7 @@ absl::Status CalculatorNode::PrepareForRun(
|
|||
std::move(schedule_callback), error_callback);
|
||||
output_stream_handler_->PrepareForRun(error_callback);
|
||||
|
||||
const auto& contract = node_type_info_->Contract();
|
||||
const auto& contract = Contract();
|
||||
input_side_packet_types_ = RemoveOmittedPacketTypes(
|
||||
contract.InputSidePackets(), all_side_packets, validated_graph_);
|
||||
MP_RETURN_IF_ERROR(input_side_packet_handler_.PrepareForRun(
|
||||
|
|
|
@ -195,9 +195,6 @@ class CalculatorNode {
|
|||
// Called by SchedulerQueue when a node is opened.
|
||||
void NodeOpened() ABSL_LOCKS_EXCLUDED(status_mutex_);
|
||||
|
||||
// Returns whether this is a GPU calculator node.
|
||||
bool UsesGpu() const { return uses_gpu_; }
|
||||
|
||||
// Returns the scheduler queue the node is assigned to.
|
||||
internal::SchedulerQueue* GetSchedulerQueue() const {
|
||||
return scheduler_queue_;
|
||||
|
@ -234,6 +231,12 @@ class CalculatorNode {
|
|||
return *calculator_state_;
|
||||
}
|
||||
|
||||
// Returns the node's contract.
|
||||
// Must not be called before the CalculatorNode is initialized.
|
||||
const CalculatorContract& Contract() const {
|
||||
return node_type_info_->Contract();
|
||||
}
|
||||
|
||||
private:
|
||||
// Sets up the output side packets from the main flat array.
|
||||
absl::Status InitializeOutputSidePackets(
|
||||
|
@ -363,9 +366,6 @@ class CalculatorNode {
|
|||
|
||||
std::unique_ptr<OutputStreamHandler> output_stream_handler_;
|
||||
|
||||
// Whether this is a GPU calculator.
|
||||
bool uses_gpu_ = false;
|
||||
|
||||
// True if CleanupAfterRun() needs to call CloseNode().
|
||||
bool needs_to_close_ = false;
|
||||
|
||||
|
|
|
@ -187,6 +187,21 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "opencv",
|
||||
define_values = {
|
||||
"use_opencv": "true",
|
||||
},
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "portable_opencv",
|
||||
define_values = {
|
||||
"use_portable_opencv": "true",
|
||||
"use_opencv": "false",
|
||||
},
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "location",
|
||||
srcs = ["location.cc"],
|
||||
|
@ -194,6 +209,8 @@ cc_library(
|
|||
defines = select({
|
||||
"//conditions:default": [],
|
||||
"//mediapipe:android": ["MEDIAPIPE_ANDROID_OPENCV"],
|
||||
":portable_opencv": ["MEDIAPIPE_ANDROID_OPENCV"],
|
||||
":opencv": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
|
|
|
@ -76,7 +76,7 @@ class Tensor {
|
|||
|
||||
public:
|
||||
// No resources are allocated here.
|
||||
enum class ElementType { kNone, kFloat16, kFloat32, kUInt8 };
|
||||
enum class ElementType { kNone, kFloat16, kFloat32, kUInt8, kInt8 };
|
||||
struct Shape {
|
||||
Shape() = default;
|
||||
Shape(std::initializer_list<int> dimensions) : dims(dimensions) {}
|
||||
|
@ -217,6 +217,8 @@ class Tensor {
|
|||
return sizeof(float);
|
||||
case ElementType::kUInt8:
|
||||
return 1;
|
||||
case ElementType::kInt8:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
int bytes() const { return shape_.num_elements() * element_size(); }
|
||||
|
|
|
@ -16,6 +16,12 @@
|
|||
#define MEDIAPIPE_FRAMEWORK_GRAPH_SERVICE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -27,18 +33,74 @@ namespace mediapipe {
|
|||
// IMPORTANT: this is an experimental API. Get in touch with the MediaPipe team
|
||||
// if you want to use it. In most cases, you should use a side packet instead.
|
||||
|
||||
struct GraphServiceBase {
|
||||
class GraphServiceBase {
|
||||
public:
|
||||
// TODO: fix services for which default init is broken, remove
|
||||
// this setting.
|
||||
enum DefaultInitSupport {
|
||||
kAllowDefaultInitialization,
|
||||
kDisallowDefaultInitialization
|
||||
};
|
||||
|
||||
constexpr GraphServiceBase(const char* key) : key(key) {}
|
||||
|
||||
virtual ~GraphServiceBase() = default;
|
||||
inline virtual absl::StatusOr<Packet> CreateDefaultObject() const {
|
||||
return DefaultInitializationUnsupported();
|
||||
}
|
||||
|
||||
const char* key;
|
||||
|
||||
protected:
|
||||
absl::Status DefaultInitializationUnsupported() const {
|
||||
return absl::UnimplementedError(absl::StrCat(
|
||||
"Graph service '", key, "' does not support default initialization"));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GraphService : public GraphServiceBase {
|
||||
class GraphService : public GraphServiceBase {
|
||||
public:
|
||||
using type = T;
|
||||
using packet_type = std::shared_ptr<T>;
|
||||
|
||||
constexpr GraphService(const char* key) : GraphServiceBase(key) {}
|
||||
constexpr GraphService(const char* my_key, DefaultInitSupport default_init =
|
||||
kDisallowDefaultInitialization)
|
||||
: GraphServiceBase(my_key), default_init_(default_init) {}
|
||||
|
||||
absl::StatusOr<Packet> CreateDefaultObject() const override {
|
||||
if (default_init_ != kAllowDefaultInitialization) {
|
||||
return DefaultInitializationUnsupported();
|
||||
}
|
||||
auto packet_or = CreateDefaultObjectInternal();
|
||||
if (packet_or.ok()) {
|
||||
return MakePacket<std::shared_ptr<T>>(std::move(packet_or).value());
|
||||
} else {
|
||||
return packet_or.status();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<std::shared_ptr<T>> CreateDefaultObjectInternal() const {
|
||||
auto call_create = [](auto x) -> decltype(decltype(x)::type::Create()) {
|
||||
return decltype(x)::type::Create();
|
||||
};
|
||||
if constexpr (std::is_invocable_r_v<absl::StatusOr<std::shared_ptr<T>>,
|
||||
decltype(call_create), type_tag<T>>) {
|
||||
return T::Create();
|
||||
}
|
||||
if constexpr (std::is_default_constructible_v<T>) {
|
||||
return std::make_shared<T>();
|
||||
}
|
||||
return DefaultInitializationUnsupported();
|
||||
}
|
||||
|
||||
template <class U>
|
||||
struct type_tag {
|
||||
using type = U;
|
||||
};
|
||||
|
||||
DefaultInitSupport default_init_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -35,6 +35,8 @@ class GraphServiceManager {
|
|||
Packet GetServicePacket(const GraphServiceBase& service) const;
|
||||
|
||||
std::map<std::string, Packet> service_packets_;
|
||||
|
||||
friend class CalculatorGraph;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -6,11 +6,13 @@
|
|||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
const GraphService<int> kIntService("mediapipe::IntService");
|
||||
} // namespace
|
||||
|
||||
TEST(GraphServiceManager, SetGetServiceObject) {
|
||||
GraphServiceManager service_manager;
|
||||
|
||||
constexpr GraphService<int> kIntService("mediapipe::IntService");
|
||||
EXPECT_EQ(service_manager.GetServiceObject(kIntService), nullptr);
|
||||
|
||||
MP_EXPECT_OK(service_manager.SetServiceObject(kIntService,
|
||||
|
@ -22,8 +24,6 @@ TEST(GraphServiceManager, SetGetServiceObject) {
|
|||
TEST(GraphServiceManager, SetServicePacket) {
|
||||
GraphServiceManager service_manager;
|
||||
|
||||
constexpr GraphService<int> kIntService("mediapipe::IntService");
|
||||
|
||||
MP_EXPECT_OK(service_manager.SetServicePacket(
|
||||
kIntService,
|
||||
mediapipe::MakePacket<std::shared_ptr<int>>(std::make_shared<int>(100))));
|
||||
|
@ -36,8 +36,6 @@ TEST(GraphServiceManager, ServicePackets) {
|
|||
|
||||
EXPECT_TRUE(service_manager.ServicePackets().empty());
|
||||
|
||||
constexpr GraphService<int> kIntService("mediapipe::IntService");
|
||||
|
||||
MP_EXPECT_OK(service_manager.SetServiceObject(kIntService,
|
||||
std::make_shared<int>(100)));
|
||||
|
||||
|
|
|
@ -150,5 +150,12 @@ TEST_F(GraphServiceTest, OptionalIsAvailable) {
|
|||
EXPECT_EQ(PacketValues<int>(output_packets_), (std::vector<int>{108}));
|
||||
}
|
||||
|
||||
TEST_F(GraphServiceTest, CreateDefault) {
|
||||
EXPECT_FALSE(kTestService.CreateDefaultObject().ok());
|
||||
MP_EXPECT_OK(kAnotherService.CreateDefaultObject());
|
||||
EXPECT_FALSE(kNoDefaultService.CreateDefaultObject().ok());
|
||||
MP_EXPECT_OK(kNeedsCreateService.CreateDefaultObject());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -50,15 +50,18 @@ absl::Status InputStreamHandler::SetupInputShards(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, int>>
|
||||
std::vector<std::tuple<std::string, int, int, Timestamp>>
|
||||
InputStreamHandler::GetMonitoringInfo() {
|
||||
std::vector<std::pair<std::string, int>> monitoring_info_vector;
|
||||
std::vector<std::tuple<std::string, int, int, Timestamp>>
|
||||
monitoring_info_vector;
|
||||
for (auto& stream : input_stream_managers_) {
|
||||
if (!stream) {
|
||||
continue;
|
||||
}
|
||||
monitoring_info_vector.emplace_back(
|
||||
std::pair<std::string, int>(stream->Name(), stream->QueueSize()));
|
||||
std::tuple<std::string, int, int, Timestamp>(
|
||||
stream->Name(), stream->QueueSize(), stream->NumPacketsAdded(),
|
||||
stream->MinTimestampOrBound(nullptr)));
|
||||
}
|
||||
return monitoring_info_vector;
|
||||
}
|
||||
|
|
|
@ -94,7 +94,7 @@ class InputStreamHandler {
|
|||
|
||||
// Returns a vector of pairs of stream name and queue size for monitoring
|
||||
// purpose.
|
||||
std::vector<std::pair<std::string, int>> GetMonitoringInfo();
|
||||
std::vector<std::tuple<std::string, int, int, Timestamp>> GetMonitoringInfo();
|
||||
|
||||
// Resets the input stream handler and its underlying input streams for
|
||||
// another run of the graph.
|
||||
|
|
|
@ -329,6 +329,11 @@ Packet InputStreamManager::PopQueueHead(bool* stream_is_done) {
|
|||
return packet;
|
||||
}
|
||||
|
||||
int InputStreamManager::NumPacketsAdded() const {
|
||||
absl::MutexLock lock(&stream_mutex_);
|
||||
return num_packets_added_;
|
||||
}
|
||||
|
||||
int InputStreamManager::QueueSize() const {
|
||||
absl::MutexLock lock(&stream_mutex_);
|
||||
return static_cast<int>(queue_.size());
|
||||
|
|
|
@ -87,12 +87,14 @@ class InputStreamManager {
|
|||
// Timestamp::PostStream(), the packet must be the only packet in the
|
||||
// stream.
|
||||
// Violation of any of these conditions causes an error status.
|
||||
absl::Status AddPackets(const std::list<Packet>& container, bool* notify);
|
||||
absl::Status AddPackets(const std::list<Packet>& container, bool* notify)
|
||||
ABSL_LOCKS_EXCLUDED(stream_mutex_);
|
||||
|
||||
// Move a list of timestamped packets. Sets "notify" to true if the queue
|
||||
// becomes non-empty. Does nothing if the input stream is closed. After the
|
||||
// move, all packets in the container must be empty.
|
||||
absl::Status MovePackets(std::list<Packet>* container, bool* notify);
|
||||
absl::Status MovePackets(std::list<Packet>* container, bool* notify)
|
||||
ABSL_LOCKS_EXCLUDED(stream_mutex_);
|
||||
|
||||
// Closes the input stream. This function can be called multiple times.
|
||||
void Close() ABSL_LOCKS_EXCLUDED(stream_mutex_);
|
||||
|
@ -140,6 +142,9 @@ class InputStreamManager {
|
|||
// Timestamp::Done() after the pop.
|
||||
Packet PopQueueHead(bool* stream_is_done) ABSL_LOCKS_EXCLUDED(stream_mutex_);
|
||||
|
||||
// Returns the number of packets in the queue.
|
||||
int NumPacketsAdded() const ABSL_LOCKS_EXCLUDED(stream_mutex_);
|
||||
|
||||
// Returns the number of packets in the queue.
|
||||
int QueueSize() const ABSL_LOCKS_EXCLUDED(stream_mutex_);
|
||||
|
||||
|
|
|
@ -767,6 +767,7 @@ TEST_F(InputStreamManagerTest, QueueSizeTest) {
|
|||
EXPECT_EQ(3, num_packets_dropped_);
|
||||
EXPECT_TRUE(input_stream_manager_->IsEmpty());
|
||||
EXPECT_FALSE(stream_is_done_);
|
||||
EXPECT_EQ(3, input_stream_manager_->NumPacketsAdded());
|
||||
|
||||
packets.clear();
|
||||
packets.push_back(MakePacket<std::string>("packet 4").At(Timestamp(60)));
|
||||
|
@ -776,6 +777,7 @@ TEST_F(InputStreamManagerTest, QueueSizeTest) {
|
|||
input_stream_manager_->AddPackets(packets, ¬ify_)); // Notification
|
||||
EXPECT_FALSE(input_stream_manager_->IsEmpty());
|
||||
EXPECT_TRUE(notify_);
|
||||
EXPECT_EQ(5, input_stream_manager_->NumPacketsAdded());
|
||||
|
||||
expected_queue_becomes_full_count_ = 2;
|
||||
expected_queue_becomes_not_full_count_ = 1;
|
||||
|
|
|
@ -12,6 +12,8 @@ def mediapipe_cc_test(
|
|||
timeout = None,
|
||||
args = [],
|
||||
additional_deps = DEFAULT_ADDITIONAL_TEST_DEPS,
|
||||
platforms = ["linux", "android", "ios", "wasm"],
|
||||
exclude_platforms = None,
|
||||
# ios_unit_test arguments
|
||||
ios_minimum_os_version = "9.0",
|
||||
# android_cc_test arguments
|
||||
|
|
|
@ -412,8 +412,7 @@ cc_library(
|
|||
name = "status_matchers",
|
||||
testonly = 1,
|
||||
hdrs = ["status_matchers.h"],
|
||||
# Use this library through "mediapipe/framework/port:gtest_main".
|
||||
visibility = ["//mediapipe/framework/port:__pkg__"],
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":status",
|
||||
"@com_google_googletest//:gtest",
|
||||
|
|
|
@ -16,8 +16,14 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
const GraphService<TestServiceObject> kTestService("test_service");
|
||||
const GraphService<int> kAnotherService("another_service");
|
||||
const GraphService<TestServiceObject> kTestService(
|
||||
"test_service", GraphServiceBase::kDisallowDefaultInitialization);
|
||||
const GraphService<int> kAnotherService(
|
||||
"another_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||
const GraphService<NoDefaultConstructor> kNoDefaultService(
|
||||
"no_default_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||
const GraphService<NeedsCreateMethod> kNeedsCreateService(
|
||||
"needs_create_service", GraphServiceBase::kAllowDefaultInitialization);
|
||||
|
||||
absl::Status TestServiceCalculator::GetContract(CalculatorContract* cc) {
|
||||
cc->Inputs().Index(0).Set<int>();
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#define MEDIAPIPE_FRAMEWORK_TEST_SERVICE_H_
|
||||
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/graph_service.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -24,6 +25,23 @@ using TestServiceObject = std::map<std::string, int>;
|
|||
extern const GraphService<TestServiceObject> kTestService;
|
||||
extern const GraphService<int> kAnotherService;
|
||||
|
||||
class NoDefaultConstructor {
|
||||
public:
|
||||
NoDefaultConstructor() = delete;
|
||||
};
|
||||
extern const GraphService<NoDefaultConstructor> kNoDefaultService;
|
||||
|
||||
class NeedsCreateMethod {
|
||||
public:
|
||||
static absl::StatusOr<std::shared_ptr<NeedsCreateMethod>> Create() {
|
||||
return std::shared_ptr<NeedsCreateMethod>(new NeedsCreateMethod());
|
||||
}
|
||||
|
||||
private:
|
||||
NeedsCreateMethod() = default;
|
||||
};
|
||||
extern const GraphService<NeedsCreateMethod> kNeedsCreateService;
|
||||
|
||||
// Use a service.
|
||||
class TestServiceCalculator : public CalculatorBase {
|
||||
public:
|
||||
|
|
|
@ -134,7 +134,7 @@ cc_library(
|
|||
name = "name_util",
|
||||
srcs = ["name_util.cc"],
|
||||
hdrs = ["name_util.h"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":validate_name",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
|
|
|
@ -225,7 +225,7 @@ std::string GetTestOutputsDir() {
|
|||
return output_dir;
|
||||
}
|
||||
|
||||
std::string GetTestDataDir(const std::string& package_base_path) {
|
||||
std::string GetTestDataDir(absl::string_view package_base_path) {
|
||||
return file::JoinPath(GetTestRootDir(), package_base_path, "testdata/");
|
||||
}
|
||||
|
||||
|
@ -270,7 +270,7 @@ absl::StatusOr<std::unique_ptr<ImageFrame>> LoadTestImage(
|
|||
format, width, height, width * output_channels, data, stbi_image_free);
|
||||
}
|
||||
|
||||
std::unique_ptr<ImageFrame> LoadTestPng(const std::string& path,
|
||||
std::unique_ptr<ImageFrame> LoadTestPng(absl::string_view path,
|
||||
ImageFormat::Format format) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ std::string GetTestFilePath(absl::string_view relative_path);
|
|||
// directory.
|
||||
// This handles the different paths where test data ends up when using
|
||||
// ion_cc_test on various platforms.
|
||||
std::string GetTestDataDir(const std::string& package_base_path);
|
||||
std::string GetTestDataDir(absl::string_view package_base_path);
|
||||
|
||||
// Loads a binary graph from path. Returns true iff successful.
|
||||
bool LoadTestGraph(CalculatorGraphConfig* proto, const std::string& path);
|
||||
|
@ -75,7 +75,7 @@ absl::StatusOr<std::unique_ptr<ImageFrame>> LoadTestImage(
|
|||
// Loads a PNG image from path using the given ImageFormat. Returns nullptr in
|
||||
// case of failure.
|
||||
std::unique_ptr<ImageFrame> LoadTestPng(
|
||||
const std::string& path, ImageFormat::Format format = ImageFormat::SRGBA);
|
||||
absl::string_view path, ImageFormat::Format format = ImageFormat::SRGBA);
|
||||
|
||||
// Returns the luminance image of |original_image|.
|
||||
// The format of |original_image| must be sRGB or sRGBA.
|
||||
|
|
|
@ -38,14 +38,19 @@ cc_library(
|
|||
srcs = ["gpu_service.cc"],
|
||||
hdrs = ["gpu_service.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = ["//mediapipe/framework:graph_service"],
|
||||
deps = ["//mediapipe/framework:graph_service"] + select({
|
||||
"//conditions:default": [
|
||||
":gpu_shared_data_internal",
|
||||
],
|
||||
"//mediapipe/gpu:disable_gpu": [],
|
||||
}),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "graph_support",
|
||||
hdrs = ["graph_support.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":gpu_service"],
|
||||
deps = ["//mediapipe/framework:graph_service"],
|
||||
)
|
||||
|
||||
GL_BASE_LINK_OPTS = select({
|
||||
|
@ -366,7 +371,6 @@ objc_library(
|
|||
hdrs = ["pixel_buffer_pool_util.h"],
|
||||
copts = [
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-std=c++17",
|
||||
],
|
||||
sdk_frameworks = [
|
||||
"Accelerate",
|
||||
|
@ -389,7 +393,6 @@ objc_library(
|
|||
copts = [
|
||||
"-x objective-c++",
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-std=c++17",
|
||||
],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
|
@ -425,7 +428,6 @@ objc_library(
|
|||
copts = [
|
||||
"-x objective-c++",
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-std=c++17",
|
||||
],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
|
@ -691,7 +693,6 @@ objc_library(
|
|||
name = "gl_calculator_helper_ios",
|
||||
copts = [
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-std=c++17",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
|
@ -707,7 +708,6 @@ objc_library(
|
|||
hdrs = ["MPPMetalHelper.h"],
|
||||
copts = [
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-std=c++17",
|
||||
],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
|
@ -801,7 +801,6 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":gl_calculator_helper",
|
||||
":gpu_buffer_storage_image_frame",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -927,7 +926,6 @@ mediapipe_cc_proto_library(
|
|||
objc_library(
|
||||
name = "metal_copy_calculator",
|
||||
srcs = ["MetalCopyCalculator.mm"],
|
||||
copts = ["-std=c++17"],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
|
@ -946,7 +944,6 @@ objc_library(
|
|||
objc_library(
|
||||
name = "metal_rgb_weight_calculator",
|
||||
srcs = ["MetalRgbWeightCalculator.mm"],
|
||||
copts = ["-std=c++17"],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
|
@ -964,7 +961,6 @@ objc_library(
|
|||
objc_library(
|
||||
name = "metal_sobel_calculator",
|
||||
srcs = ["MetalSobelCalculator.mm"],
|
||||
copts = ["-std=c++17"],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
|
@ -982,7 +978,6 @@ objc_library(
|
|||
objc_library(
|
||||
name = "metal_sobel_compute_calculator",
|
||||
srcs = ["MetalSobelComputeCalculator.mm"],
|
||||
copts = ["-std=c++17"],
|
||||
features = ["-layering_check"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
|
@ -1018,7 +1013,6 @@ objc_library(
|
|||
objc_library(
|
||||
name = "mps_threshold_calculator",
|
||||
srcs = ["MPSThresholdCalculator.mm"],
|
||||
copts = ["-std=c++17"],
|
||||
sdk_frameworks = [
|
||||
"CoreVideo",
|
||||
"Metal",
|
||||
|
@ -1053,7 +1047,6 @@ objc_library(
|
|||
],
|
||||
copts = [
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-std=c++17",
|
||||
],
|
||||
data = [
|
||||
"//mediapipe/objc:testdata/googlelogo_color_272x92dp.png",
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
#include "absl/base/dynamic_annotations.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
|
@ -358,6 +359,7 @@ absl::Status GlContext::FinishInitialization(bool create_thread) {
|
|||
GlContext::GlContext() {}
|
||||
|
||||
GlContext::~GlContext() {
|
||||
destructing_ = true;
|
||||
// Note: on Apple platforms, this object contains Objective-C objects.
|
||||
// The destructor will release them, but ARC must be on.
|
||||
#ifdef __OBJC__
|
||||
|
@ -366,11 +368,16 @@ GlContext::~GlContext() {
|
|||
#endif
|
||||
#endif // __OBJC__
|
||||
|
||||
if (thread_) {
|
||||
auto status = thread_->Run([this] {
|
||||
auto clear_attachments = [this] {
|
||||
attachments_.clear();
|
||||
if (profiling_helper_) {
|
||||
profiling_helper_->LogAllTimestamps();
|
||||
}
|
||||
};
|
||||
|
||||
if (thread_) {
|
||||
auto status = thread_->Run([this, clear_attachments] {
|
||||
clear_attachments();
|
||||
return ExitContext(nullptr);
|
||||
});
|
||||
LOG_IF(ERROR, !status.ok())
|
||||
|
@ -378,6 +385,17 @@ GlContext::~GlContext() {
|
|||
if (thread_->IsCurrentThread()) {
|
||||
thread_.release()->SelfDestruct();
|
||||
}
|
||||
} else {
|
||||
if (IsCurrent()) {
|
||||
clear_attachments();
|
||||
} else {
|
||||
ContextBinding saved_context;
|
||||
auto status = SwitchContextAndRun([&clear_attachments] {
|
||||
clear_attachments();
|
||||
return absl::OkStatus();
|
||||
});
|
||||
LOG_IF(ERROR, !status.ok()) << status;
|
||||
}
|
||||
}
|
||||
DestroyContext();
|
||||
}
|
||||
|
@ -501,6 +519,14 @@ absl::Status GlContext::SwitchContext(ContextBinding* saved_context,
|
|||
}
|
||||
}
|
||||
|
||||
GlContext::ContextBinding GlContext::ThisContextBinding() {
|
||||
GlContext::ContextBinding result = ThisContextBindingPlatform();
|
||||
if (!destructing_) {
|
||||
result.context_object = shared_from_this();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::Status GlContext::EnterContext(ContextBinding* saved_context) {
|
||||
DCHECK(HasContext());
|
||||
return SwitchContext(saved_context, ThisContextBinding());
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/executor.h"
|
||||
#include "mediapipe/framework/mediapipe_profiling.h"
|
||||
|
@ -285,6 +286,48 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
|||
// Sets default texture filtering parameters.
|
||||
void SetStandardTextureParams(GLenum target, GLint internal_format);
|
||||
|
||||
template <class T>
|
||||
using AttachmentPtr = std::unique_ptr<T, std::function<void(void*)>>;
|
||||
|
||||
template <class T, class... Args>
|
||||
static std::enable_if_t<!std::is_array<T>::value, AttachmentPtr<T>>
|
||||
MakeAttachmentPtr(Args&&... args) {
|
||||
return {new T(std::forward<Args>(args)...),
|
||||
[](void* ptr) { delete static_cast<T*>(ptr); }};
|
||||
}
|
||||
|
||||
class AttachmentBase {};
|
||||
|
||||
template <class T>
|
||||
class Attachment : public AttachmentBase {
|
||||
public:
|
||||
using FactoryT = std::function<AttachmentPtr<T>(GlContext&)>;
|
||||
Attachment(FactoryT factory) : factory_(factory) {}
|
||||
|
||||
Attachment(const Attachment&) = delete;
|
||||
Attachment(Attachment&&) = delete;
|
||||
Attachment& operator=(const Attachment&) = delete;
|
||||
Attachment& operator=(Attachment&&) = delete;
|
||||
|
||||
T& Get(GlContext& ctx) const { return ctx.GetCachedAttachment(*this); }
|
||||
|
||||
const FactoryT& factory() const { return factory_; }
|
||||
|
||||
private:
|
||||
FactoryT factory_;
|
||||
};
|
||||
|
||||
// TOOD: const result?
|
||||
template <class T>
|
||||
T& GetCachedAttachment(const Attachment<T>& attachment) {
|
||||
DCHECK(IsCurrent());
|
||||
AttachmentPtr<void>& entry = attachments_[&attachment];
|
||||
if (entry == nullptr) {
|
||||
entry = attachment.factory()(*this);
|
||||
}
|
||||
return *static_cast<T*>(entry.get());
|
||||
}
|
||||
|
||||
// These are used for testing specific SyncToken implementations. Do not use
|
||||
// outside of tests.
|
||||
enum class SyncTokenTypeForTest {
|
||||
|
@ -387,6 +430,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
|||
|
||||
// A binding that can be used to make this GlContext current.
|
||||
ContextBinding ThisContextBinding();
|
||||
// Fill in platform-specific fields. Must _not_ set context_obj.
|
||||
ContextBinding ThisContextBindingPlatform();
|
||||
// Fills in a ContextBinding with platform-specific information about which
|
||||
// context is current on this thread.
|
||||
static void GetCurrentContextBinding(ContextBinding* binding);
|
||||
|
@ -409,6 +454,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
|||
// better mechanism?
|
||||
bool can_linear_filter_float_textures_;
|
||||
|
||||
absl::flat_hash_map<const AttachmentBase*, AttachmentPtr<void>> attachments_;
|
||||
|
||||
// Number of glFinish calls completed on the GL thread.
|
||||
// Changes should be guarded by mutex_. However, we use simple atomic
|
||||
// loads for efficiency on the fast path.
|
||||
|
@ -428,6 +475,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
|||
absl::CondVar wait_for_gl_finish_cv_ ABSL_GUARDED_BY(mutex_);
|
||||
|
||||
std::unique_ptr<mediapipe::GlProfilingHelper> profiling_helper_ = nullptr;
|
||||
|
||||
bool destructing_ = false;
|
||||
};
|
||||
|
||||
// For backward compatibility. TODO: migrate remaining callers.
|
||||
|
|
|
@ -84,9 +84,8 @@ void GlContext::DestroyContext() {
|
|||
}
|
||||
}
|
||||
|
||||
GlContext::ContextBinding GlContext::ThisContextBinding() {
|
||||
GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
|
||||
GlContext::ContextBinding result;
|
||||
result.context_object = shared_from_this();
|
||||
result.context = context_;
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -269,9 +269,8 @@ void GlContext::DestroyContext() {
|
|||
#endif // __ANDROID__
|
||||
}
|
||||
|
||||
GlContext::ContextBinding GlContext::ThisContextBinding() {
|
||||
GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
|
||||
GlContext::ContextBinding result;
|
||||
result.context_object = shared_from_this();
|
||||
result.display = display_;
|
||||
result.draw_surface = surface_;
|
||||
result.read_surface = surface_;
|
||||
|
|
|
@ -134,9 +134,8 @@ void GlContext::DestroyContext() {
|
|||
}
|
||||
}
|
||||
|
||||
GlContext::ContextBinding GlContext::ThisContextBinding() {
|
||||
GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
|
||||
GlContext::ContextBinding result;
|
||||
result.context_object = shared_from_this();
|
||||
result.context = context_;
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -173,9 +173,8 @@ void GlContext::DestroyContext() {
|
|||
}
|
||||
}
|
||||
|
||||
GlContext::ContextBinding GlContext::ThisContextBinding() {
|
||||
GlContext::ContextBinding GlContext::ThisContextBindingPlatform() {
|
||||
GlContext::ContextBinding result;
|
||||
result.context_object = shared_from_this();
|
||||
result.context = context_;
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -111,7 +111,7 @@ absl::Status QuadRenderer::GlRender(float frame_width, float frame_height,
|
|||
FrameScaleMode scale_mode,
|
||||
FrameRotation rotation,
|
||||
bool flip_horizontal, bool flip_vertical,
|
||||
bool flip_texture) {
|
||||
bool flip_texture) const {
|
||||
RET_CHECK(program_) << "Must setup the program before rendering.";
|
||||
|
||||
glUseProgram(program_);
|
||||
|
|
|
@ -72,7 +72,7 @@ class QuadRenderer {
|
|||
absl::Status GlRender(float frame_width, float frame_height, float view_width,
|
||||
float view_height, FrameScaleMode scale_mode,
|
||||
FrameRotation rotation, bool flip_horizontal,
|
||||
bool flip_vertical, bool flip_texture);
|
||||
bool flip_vertical, bool flip_texture) const;
|
||||
// Deletes the rendering program. Must be called withn the GL context where
|
||||
// it was created.
|
||||
void GlTeardown();
|
||||
|
|
|
@ -144,7 +144,7 @@ const GlTextureInfo& GlTextureInfoForGpuBufferFormat(GpuBufferFormat format,
|
|||
}},
|
||||
{GpuBufferFormat::kRGBAFloat128,
|
||||
{
|
||||
{GL_RGBA, GL_RGBA, GL_FLOAT, 1},
|
||||
{GL_RGBA32F, GL_RGBA, GL_FLOAT, 1},
|
||||
}},
|
||||
}};
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
const GraphService<GpuResources> kGpuService("kGpuService");
|
||||
const GraphService<GpuResources> kGpuService(
|
||||
"kGpuService", GraphServiceBase::kAllowDefaultInitialization);
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -17,9 +17,18 @@
|
|||
|
||||
#include "mediapipe/framework/graph_service.h"
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
#include "mediapipe/gpu/gpu_shared_data_internal.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
class GpuResources;
|
||||
#if MEDIAPIPE_DISABLE_GPU
|
||||
class GpuResources {
|
||||
GpuResources() = delete;
|
||||
};
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
extern const GraphService<GpuResources> kGpuService;
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user