Merge branch 'ios-task-files' into ios-task
This commit is contained in:
commit
7f7776ef80
2
.github/ISSUE_TEMPLATE/13-solution-issue.md
vendored
2
.github/ISSUE_TEMPLATE/13-solution-issue.md
vendored
|
@ -1,6 +1,6 @@
|
|||
---
|
||||
name: "Solution (legacy) Issue"
|
||||
about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions), such as "Pose" or "Iris", including inference model usage/training, solution-specific calculators, etc.
|
||||
about: Use this template for assistance with a specific Mediapipe solution (google.github.io/mediapipe/solutions) such as "Pose", including inference model usage/training, solution-specific calculators etc.
|
||||
labels: type:support
|
||||
|
||||
---
|
||||
|
|
|
@ -259,6 +259,7 @@ mp_holistic = mp.solutions.holistic
|
|||
|
||||
# For static images:
|
||||
IMAGE_FILES = []
|
||||
BG_COLOR = (192, 192, 192) # gray
|
||||
with mp_holistic.Holistic(
|
||||
static_image_mode=True,
|
||||
model_complexity=2,
|
||||
|
|
|
@ -12,12 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
proto_library(
|
||||
name = "mfcc_mel_calculators_proto",
|
||||
srcs = ["mfcc_mel_calculators.proto"],
|
||||
|
|
|
@ -567,7 +567,7 @@ cc_library(
|
|||
name = "packet_thinner_calculator",
|
||||
srcs = ["packet_thinner_calculator.cc"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:packet_thinner_calculator_cc_proto",
|
||||
":packet_thinner_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:video_stream_header",
|
||||
|
@ -584,7 +584,7 @@ cc_test(
|
|||
srcs = ["packet_thinner_calculator_test.cc"],
|
||||
deps = [
|
||||
":packet_thinner_calculator",
|
||||
"//mediapipe/calculators/core:packet_thinner_calculator_cc_proto",
|
||||
":packet_thinner_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:video_stream_header",
|
||||
|
@ -762,7 +762,7 @@ cc_library(
|
|||
srcs = ["packet_resampler_calculator.cc"],
|
||||
hdrs = ["packet_resampler_calculator.h"],
|
||||
deps = [
|
||||
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
|
||||
":packet_resampler_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework/deps:mathutil",
|
||||
|
@ -786,7 +786,7 @@ cc_test(
|
|||
],
|
||||
deps = [
|
||||
":packet_resampler_calculator",
|
||||
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
|
||||
":packet_resampler_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework/formats:video_stream_header",
|
||||
|
@ -852,10 +852,10 @@ cc_test(
|
|||
name = "flow_limiter_calculator_test",
|
||||
srcs = ["flow_limiter_calculator_test.cc"],
|
||||
deps = [
|
||||
":counting_source_calculator",
|
||||
":flow_limiter_calculator",
|
||||
":flow_limiter_calculator_cc_proto",
|
||||
"//mediapipe/calculators/core:counting_source_calculator",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
":pass_through_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_runner",
|
||||
"//mediapipe/framework:test_calculators",
|
||||
|
@ -1302,7 +1302,7 @@ cc_test(
|
|||
srcs = ["packet_sequencer_calculator_test.cc"],
|
||||
deps = [
|
||||
":packet_sequencer_calculator",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
":pass_through_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:subgraph",
|
||||
|
|
|
@ -47,7 +47,7 @@ namespace api2 {
|
|||
// calculator: "Get{SpecificType}VectorItemCalculator"
|
||||
// input_stream: "VECTOR:vector"
|
||||
// input_stream: "INDEX:index"
|
||||
// input_stream: "ITEM:item"
|
||||
// output_stream: "ITEM:item"
|
||||
// options {
|
||||
// [mediapipe.GetVectorItemCalculatorOptions.ext] {
|
||||
// item_index: 5
|
||||
|
|
|
@ -76,7 +76,11 @@ constexpr char kMaxInFlightTag[] = "MAX_IN_FLIGHT";
|
|||
// }
|
||||
// output_stream: "gated_frames"
|
||||
// }
|
||||
class RealTimeFlowLimiterCalculator : public CalculatorBase {
|
||||
//
|
||||
// Please use FlowLimiterCalculator, which replaces this calculator and
|
||||
// defines a few additional configuration options.
|
||||
class ABSL_DEPRECATED("Use FlowLimiterCalculator instead.")
|
||||
RealTimeFlowLimiterCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
int num_data_streams = cc->Inputs().NumEntries("");
|
||||
|
|
|
@ -66,12 +66,16 @@ class SequenceShiftCalculator : public Node {
|
|||
// The number of packets or timestamps we need to store to output packet[i] at
|
||||
// the timestamp of packet[i + packet_offset]; equal to abs(packet_offset).
|
||||
int cache_size_;
|
||||
bool emit_empty_packets_before_first_packet_ = false;
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(SequenceShiftCalculator);
|
||||
|
||||
absl::Status SequenceShiftCalculator::Open(CalculatorContext* cc) {
|
||||
packet_offset_ = kOffset(cc).GetOr(
|
||||
cc->Options<mediapipe::SequenceShiftCalculatorOptions>().packet_offset());
|
||||
emit_empty_packets_before_first_packet_ =
|
||||
cc->Options<mediapipe::SequenceShiftCalculatorOptions>()
|
||||
.emit_empty_packets_before_first_packet();
|
||||
cache_size_ = abs(packet_offset_);
|
||||
// An offset of zero is a no-op, but someone might still request it.
|
||||
if (packet_offset_ == 0) {
|
||||
|
@ -96,6 +100,8 @@ void SequenceShiftCalculator::ProcessPositiveOffset(CalculatorContext* cc) {
|
|||
// Ready to output oldest packet with current timestamp.
|
||||
kOut(cc).Send(packet_cache_.front().At(cc->InputTimestamp()));
|
||||
packet_cache_.pop_front();
|
||||
} else if (emit_empty_packets_before_first_packet_) {
|
||||
LOG(FATAL) << "Not supported yet";
|
||||
}
|
||||
// Store current packet for later output.
|
||||
packet_cache_.push_back(kIn(cc).packet());
|
||||
|
|
|
@ -23,4 +23,8 @@ message SequenceShiftCalculatorOptions {
|
|||
optional SequenceShiftCalculatorOptions ext = 107633927;
|
||||
}
|
||||
optional int32 packet_offset = 1 [default = -1];
|
||||
|
||||
// Emits empty packets before the first delayed packet is emitted. Takes
|
||||
// effect only when packet offset is set to positive.
|
||||
optional bool emit_empty_packets_before_first_packet = 2 [default = false];
|
||||
}
|
||||
|
|
|
@ -378,8 +378,8 @@ cc_library(
|
|||
name = "scale_image_calculator",
|
||||
srcs = ["scale_image_calculator.cc"],
|
||||
deps = [
|
||||
":scale_image_calculator_cc_proto",
|
||||
":scale_image_utils",
|
||||
"//mediapipe/calculators/image:scale_image_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
|
@ -747,8 +747,8 @@ cc_test(
|
|||
tags = ["desktop_only_test"],
|
||||
deps = [
|
||||
":affine_transformation",
|
||||
":image_transformation_calculator",
|
||||
":warp_affine_calculator",
|
||||
"//mediapipe/calculators/image:image_transformation_calculator",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_converter",
|
||||
"//mediapipe/calculators/tensor:image_to_tensor_utils",
|
||||
"//mediapipe/calculators/util:from_image_calculator",
|
||||
|
|
|
@ -92,8 +92,8 @@ class GlTextureWarpAffineRunner
|
|||
|
||||
constexpr GLchar kVertShader[] = R"(
|
||||
in vec4 position;
|
||||
in mediump vec4 texture_coordinate;
|
||||
out mediump vec2 sample_coordinate;
|
||||
in highp vec4 texture_coordinate;
|
||||
out highp vec2 sample_coordinate;
|
||||
uniform mat4 transform_matrix;
|
||||
|
||||
void main() {
|
||||
|
@ -104,7 +104,7 @@ class GlTextureWarpAffineRunner
|
|||
)";
|
||||
|
||||
constexpr GLchar kFragShader[] = R"(
|
||||
DEFAULT_PRECISION(mediump, float)
|
||||
DEFAULT_PRECISION(highp, float)
|
||||
in vec2 sample_coordinate;
|
||||
uniform sampler2D input_texture;
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@ void SetColorChannel(int channel, uint8 value, cv::Mat* mat) {
|
|||
|
||||
constexpr char kRgbaInTag[] = "RGBA_IN";
|
||||
constexpr char kRgbInTag[] = "RGB_IN";
|
||||
constexpr char kBgrInTag[] = "BGR_IN";
|
||||
constexpr char kBgraInTag[] = "BGRA_IN";
|
||||
constexpr char kGrayInTag[] = "GRAY_IN";
|
||||
constexpr char kRgbaOutTag[] = "RGBA_OUT";
|
||||
|
@ -57,6 +58,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT";
|
|||
// RGB -> RGBA
|
||||
// RGBA -> BGRA
|
||||
// BGRA -> RGBA
|
||||
// BGR -> RGB
|
||||
//
|
||||
// This calculator only supports a single input stream and output stream at a
|
||||
// time. If more than one input stream or output stream is present, the
|
||||
|
@ -69,6 +71,7 @@ constexpr char kGrayOutTag[] = "GRAY_OUT";
|
|||
// RGB_IN: The input video stream (ImageFrame, SRGB).
|
||||
// BGRA_IN: The input video stream (ImageFrame, SBGRA).
|
||||
// GRAY_IN: The input video stream (ImageFrame, GRAY8).
|
||||
// BGR_IN: The input video stream (ImageFrame, SBGR).
|
||||
//
|
||||
// Output streams:
|
||||
// RGBA_OUT: The output video stream (ImageFrame, SRGBA).
|
||||
|
@ -122,6 +125,10 @@ absl::Status ColorConvertCalculator::GetContract(CalculatorContract* cc) {
|
|||
cc->Inputs().Tag(kBgraInTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
if (cc->Inputs().HasTag(kBgrInTag)) {
|
||||
cc->Inputs().Tag(kBgrInTag).Set<ImageFrame>();
|
||||
}
|
||||
|
||||
if (cc->Outputs().HasTag(kRgbOutTag)) {
|
||||
cc->Outputs().Tag(kRgbOutTag).Set<ImageFrame>();
|
||||
}
|
||||
|
@ -194,6 +201,11 @@ absl::Status ColorConvertCalculator::Process(CalculatorContext* cc) {
|
|||
return ConvertAndOutput(kRgbaInTag, kBgraOutTag, ImageFormat::SBGRA,
|
||||
cv::COLOR_RGBA2BGRA, cc);
|
||||
}
|
||||
// BGR -> RGB
|
||||
if (cc->Inputs().HasTag(kBgrInTag) && cc->Outputs().HasTag(kRgbOutTag)) {
|
||||
return ConvertAndOutput(kBgrInTag, kRgbOutTag, ImageFormat::SRGB,
|
||||
cv::COLOR_BGR2RGB, cc);
|
||||
}
|
||||
|
||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Unsupported image format conversion.";
|
||||
|
|
|
@ -12,10 +12,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = ["//visibility:private"])
|
||||
|
||||
proto_library(
|
||||
|
|
|
@ -68,8 +68,8 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
|
||||
constexpr GLchar kExtractSubRectVertexShader[] = R"(
|
||||
in vec4 position;
|
||||
in mediump vec4 texture_coordinate;
|
||||
out mediump vec2 sample_coordinate;
|
||||
in highp vec4 texture_coordinate;
|
||||
out highp vec2 sample_coordinate;
|
||||
uniform mat4 transform_matrix;
|
||||
|
||||
void main() {
|
||||
|
@ -86,7 +86,7 @@ class GlProcessor : public ImageToTensorConverter {
|
|||
)";
|
||||
|
||||
constexpr GLchar kExtractSubRectFragBody[] = R"(
|
||||
DEFAULT_PRECISION(mediump, float)
|
||||
DEFAULT_PRECISION(highp, float)
|
||||
|
||||
// Provided by kExtractSubRectVertexShader.
|
||||
in vec2 sample_coordinate;
|
||||
|
|
|
@ -22,8 +22,8 @@ cc_library(
|
|||
name = "alignment_points_to_rects_calculator",
|
||||
srcs = ["alignment_points_to_rects_calculator.cc"],
|
||||
deps = [
|
||||
":detections_to_rects_calculator",
|
||||
":detections_to_rects_calculator_cc_proto",
|
||||
"//mediapipe/calculators/util:detections_to_rects_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
#
|
||||
# Copyright 2019 The MediaPipe Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -227,13 +226,13 @@ cc_library(
|
|||
":mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":calculator_cc_proto",
|
||||
":graph_service",
|
||||
":mediapipe_options_cc_proto",
|
||||
":packet_generator_cc_proto",
|
||||
":packet_type",
|
||||
":port",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework:status_handler_cc_proto",
|
||||
":status_handler_cc_proto",
|
||||
"//mediapipe/framework/port:any_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
|
@ -329,10 +328,10 @@ cc_library(
|
|||
":thread_pool_executor",
|
||||
":timestamp",
|
||||
":validated_graph_config",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework:status_handler_cc_proto",
|
||||
"//mediapipe/framework:thread_pool_executor_cc_proto",
|
||||
":calculator_cc_proto",
|
||||
":packet_generator_cc_proto",
|
||||
":status_handler_cc_proto",
|
||||
":thread_pool_executor_cc_proto",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
@ -370,7 +369,7 @@ cc_library(
|
|||
visibility = [":mediapipe_internal"],
|
||||
deps = [
|
||||
":graph_service",
|
||||
"//mediapipe/framework:packet",
|
||||
":packet",
|
||||
"@com_google_absl//absl/status",
|
||||
],
|
||||
)
|
||||
|
@ -380,7 +379,7 @@ cc_test(
|
|||
srcs = ["graph_service_manager_test.cc"],
|
||||
deps = [
|
||||
":graph_service_manager",
|
||||
"//mediapipe/framework:packet",
|
||||
":packet",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
],
|
||||
)
|
||||
|
@ -392,6 +391,7 @@ cc_library(
|
|||
visibility = [":mediapipe_internal"],
|
||||
deps = [
|
||||
":calculator_base",
|
||||
":calculator_cc_proto",
|
||||
":calculator_context",
|
||||
":calculator_context_manager",
|
||||
":calculator_state",
|
||||
|
@ -408,10 +408,9 @@ cc_library(
|
|||
":packet_set",
|
||||
":packet_type",
|
||||
":port",
|
||||
":stream_handler_cc_proto",
|
||||
":timestamp",
|
||||
":validated_graph_config",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:stream_handler_cc_proto",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
|
@ -467,6 +466,7 @@ cc_library(
|
|||
hdrs = ["calculator_state.h"],
|
||||
visibility = [":mediapipe_internal"],
|
||||
deps = [
|
||||
":calculator_cc_proto",
|
||||
":counter",
|
||||
":counter_factory",
|
||||
":graph_service",
|
||||
|
@ -476,7 +476,6 @@ cc_library(
|
|||
":packet",
|
||||
":packet_set",
|
||||
":port",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/port:any_proto",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/tool:options_map",
|
||||
|
@ -584,7 +583,7 @@ cc_library(
|
|||
hdrs = ["executor.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||
":mediapipe_options_cc_proto",
|
||||
"//mediapipe/framework/deps:registration",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
|
@ -671,11 +670,11 @@ cc_library(
|
|||
":collection_item_id",
|
||||
":input_stream_manager",
|
||||
":input_stream_shard",
|
||||
":mediapipe_options_cc_proto",
|
||||
":mediapipe_profiling",
|
||||
":packet",
|
||||
":packet_set",
|
||||
":packet_type",
|
||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||
"//mediapipe/framework/deps:registration",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -785,12 +784,12 @@ cc_library(
|
|||
":calculator_context_manager",
|
||||
":collection",
|
||||
":collection_item_id",
|
||||
":mediapipe_options_cc_proto",
|
||||
":output_stream_manager",
|
||||
":output_stream_shard",
|
||||
":packet_set",
|
||||
":packet_type",
|
||||
":timestamp",
|
||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||
"//mediapipe/framework/deps:registration",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -876,10 +875,10 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":packet",
|
||||
":packet_generator_cc_proto",
|
||||
":packet_set",
|
||||
":packet_type",
|
||||
":port",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework/deps:registration",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -897,13 +896,13 @@ cc_library(
|
|||
":delegating_executor",
|
||||
":executor",
|
||||
":packet",
|
||||
":packet_factory_cc_proto",
|
||||
":packet_generator",
|
||||
":packet_generator_cc_proto",
|
||||
":packet_type",
|
||||
":port",
|
||||
":thread_pool_executor",
|
||||
":validated_graph_config",
|
||||
"//mediapipe/framework:packet_factory_cc_proto",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
@ -1020,10 +1019,10 @@ cc_library(
|
|||
hdrs = ["status_handler.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":mediapipe_options_cc_proto",
|
||||
":packet_set",
|
||||
":packet_type",
|
||||
":port",
|
||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||
"//mediapipe/framework/deps:registration",
|
||||
"//mediapipe/framework/port:status",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -1036,10 +1035,10 @@ cc_library(
|
|||
hdrs = ["subgraph.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":calculator_cc_proto",
|
||||
":graph_service",
|
||||
":graph_service_manager",
|
||||
":port",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/deps:registration",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -1061,7 +1060,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":calculator_framework",
|
||||
"//mediapipe/framework:test_calculators_cc_proto",
|
||||
":test_calculators_cc_proto",
|
||||
"//mediapipe/framework/deps:mathutil",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
|
@ -1098,7 +1097,7 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":executor",
|
||||
"//mediapipe/framework:thread_pool_executor_cc_proto",
|
||||
":thread_pool_executor_cc_proto",
|
||||
"//mediapipe/framework/deps:thread_options",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -1163,22 +1162,22 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":calculator_base",
|
||||
":calculator_cc_proto",
|
||||
":calculator_contract",
|
||||
":graph_service_manager",
|
||||
":legacy_calculator_support",
|
||||
":packet",
|
||||
":packet_generator",
|
||||
":packet_generator_cc_proto",
|
||||
":packet_set",
|
||||
":packet_type",
|
||||
":port",
|
||||
":status_handler",
|
||||
":status_handler_cc_proto",
|
||||
":stream_handler_cc_proto",
|
||||
":subgraph",
|
||||
":thread_pool_executor_cc_proto",
|
||||
":timestamp",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework:status_handler_cc_proto",
|
||||
"//mediapipe/framework:stream_handler_cc_proto",
|
||||
"//mediapipe/framework:thread_pool_executor_cc_proto",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
|
@ -1203,11 +1202,11 @@ cc_test(
|
|||
name = "validated_graph_config_test",
|
||||
srcs = ["validated_graph_config_test.cc"],
|
||||
deps = [
|
||||
":calculator_cc_proto",
|
||||
":calculator_framework",
|
||||
":graph_service",
|
||||
":graph_service_manager",
|
||||
":validated_graph_config",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
|
@ -1234,6 +1233,7 @@ cc_test(
|
|||
linkstatic = 1,
|
||||
deps = [
|
||||
":calculator_base",
|
||||
":calculator_cc_proto",
|
||||
":calculator_context",
|
||||
":calculator_context_manager",
|
||||
":calculator_registry",
|
||||
|
@ -1243,7 +1243,6 @@ cc_test(
|
|||
":output_stream_shard",
|
||||
":packet_set",
|
||||
":packet_type",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:status_util",
|
||||
|
@ -1257,11 +1256,11 @@ cc_test(
|
|||
srcs = ["calculator_contract_test.cc"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":calculator_cc_proto",
|
||||
":calculator_contract",
|
||||
":calculator_contract_test_cc_proto",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework:status_handler_cc_proto",
|
||||
":packet_generator_cc_proto",
|
||||
":status_handler_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
],
|
||||
|
@ -1369,6 +1368,7 @@ cc_test(
|
|||
srcs = ["calculator_context_test.cc"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":calculator_cc_proto",
|
||||
":calculator_context",
|
||||
":calculator_context_manager",
|
||||
":calculator_state",
|
||||
|
@ -1377,7 +1377,6 @@ cc_test(
|
|||
":output_stream_shard",
|
||||
":packet_set",
|
||||
":packet_type",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
|
@ -1404,6 +1403,7 @@ cc_test(
|
|||
":executor",
|
||||
":input_stream_handler",
|
||||
":lifetime_tracker",
|
||||
":mediapipe_options_cc_proto",
|
||||
":output_stream_poller",
|
||||
":packet_set",
|
||||
":packet_type",
|
||||
|
@ -1411,13 +1411,12 @@ cc_test(
|
|||
":subgraph",
|
||||
":test_calculators",
|
||||
":thread_pool_executor",
|
||||
":thread_pool_executor_cc_proto",
|
||||
":timestamp",
|
||||
":type_map",
|
||||
"//mediapipe/calculators/core:counting_source_calculator",
|
||||
"//mediapipe/calculators/core:mux_calculator",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||
"//mediapipe/framework:thread_pool_executor_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
|
@ -1482,12 +1481,12 @@ cc_test(
|
|||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":calculator_cc_proto",
|
||||
":calculator_framework",
|
||||
":test_calculators",
|
||||
"//mediapipe/calculators/core:counting_source_calculator",
|
||||
"//mediapipe/calculators/core:mux_calculator",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
|
@ -1631,8 +1630,8 @@ cc_test(
|
|||
srcs = ["packet_generator_test.cc"],
|
||||
deps = [
|
||||
":packet_generator",
|
||||
":packet_generator_cc_proto",
|
||||
":packet_type",
|
||||
"//mediapipe/framework:packet_generator_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/tool:validate_type",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -15,12 +15,17 @@
|
|||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace test {
|
||||
namespace mediapipe::api2::builder {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::api2::test::Bar;
|
||||
using ::mediapipe::api2::test::FloatAdder;
|
||||
using ::mediapipe::api2::test::Foo;
|
||||
using ::mediapipe::api2::test::Foo2;
|
||||
using ::mediapipe::api2::test::FooBar1;
|
||||
|
||||
TEST(BuilderTest, BuildGraph) {
|
||||
builder::Graph graph;
|
||||
Graph graph;
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
auto& bar = graph.AddNode("Bar");
|
||||
graph.In("IN").SetName("base") >> foo.In("BASE");
|
||||
|
@ -49,22 +54,19 @@ TEST(BuilderTest, BuildGraph) {
|
|||
}
|
||||
|
||||
TEST(BuilderTest, CopyableSource) {
|
||||
builder::Graph graph;
|
||||
builder::Source<int> a = graph[Input<int>("A")];
|
||||
a.SetName("a");
|
||||
builder::Source<int> b = graph[Input<int>("B")];
|
||||
b.SetName("b");
|
||||
builder::SideSource<float> side_a = graph[SideInput<float>("SIDE_A")];
|
||||
side_a.SetName("side_a");
|
||||
builder::SideSource<float> side_b = graph[SideInput<float>("SIDE_B")];
|
||||
side_b.SetName("side_b");
|
||||
builder::Destination<int> out = graph[Output<int>("OUT")];
|
||||
builder::SideDestination<float> side_out =
|
||||
graph[SideOutput<float>("SIDE_OUT")];
|
||||
Graph graph;
|
||||
Source<int> a = graph.In("A").SetName("a").Cast<int>();
|
||||
Source<int> b = graph.In("B").SetName("b").Cast<int>();
|
||||
SideSource<float> side_a =
|
||||
graph.SideIn("SIDE_A").SetName("side_a").Cast<float>();
|
||||
SideSource<float> side_b =
|
||||
graph.SideIn("SIDE_B").SetName("side_b").Cast<float>();
|
||||
Destination<int> out = graph.Out("OUT").Cast<int>();
|
||||
SideDestination<float> side_out = graph.SideOut("SIDE_OUT").Cast<float>();
|
||||
|
||||
builder::Source<int> input = a;
|
||||
Source<int> input = a;
|
||||
input = b;
|
||||
builder::SideSource<float> side_input = side_b;
|
||||
SideSource<float> side_input = side_b;
|
||||
side_input = side_a;
|
||||
|
||||
input >> out;
|
||||
|
@ -83,31 +85,27 @@ TEST(BuilderTest, CopyableSource) {
|
|||
}
|
||||
|
||||
TEST(BuilderTest, BuildGraphWithFunctions) {
|
||||
builder::Graph graph;
|
||||
Graph graph;
|
||||
|
||||
builder::Source<int> base = graph[Input<int>("IN")];
|
||||
base.SetName("base");
|
||||
builder::SideSource<float> side = graph[SideInput<float>("SIDE")];
|
||||
side.SetName("side");
|
||||
Source<int> base = graph.In("IN").SetName("base").Cast<int>();
|
||||
SideSource<float> side = graph.SideIn("SIDE").SetName("side").Cast<float>();
|
||||
|
||||
auto foo_fn = [](builder::Source<int> base, builder::SideSource<float> side,
|
||||
builder::Graph& graph) {
|
||||
auto foo_fn = [](Source<int> base, SideSource<float> side, Graph& graph) {
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
base >> foo[Input<int>("BASE")];
|
||||
side >> foo[SideInput<float>("SIDE")];
|
||||
return foo[Output<double>("OUT")];
|
||||
base >> foo.In("BASE");
|
||||
side >> foo.SideIn("SIDE");
|
||||
return foo.Out("OUT")[0].Cast<double>();
|
||||
};
|
||||
builder::Source<double> foo_out = foo_fn(base, side, graph);
|
||||
Source<double> foo_out = foo_fn(base, side, graph);
|
||||
|
||||
auto bar_fn = [](builder::Source<double> in, builder::Graph& graph) {
|
||||
auto bar_fn = [](Source<double> in, Graph& graph) {
|
||||
auto& bar = graph.AddNode("Bar");
|
||||
in >> bar[Input<double>("IN")];
|
||||
return bar[Output<double>("OUT")];
|
||||
in >> bar.In("IN");
|
||||
return bar.Out("OUT")[0].Cast<double>();
|
||||
};
|
||||
builder::Source<double> bar_out = bar_fn(foo_out, graph);
|
||||
bar_out.SetName("out");
|
||||
Source<double> bar_out = bar_fn(foo_out, graph);
|
||||
|
||||
bar_out >> graph[Output<double>("OUT")];
|
||||
bar_out.SetName("out") >> graph.Out("OUT");
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
@ -131,7 +129,7 @@ TEST(BuilderTest, BuildGraphWithFunctions) {
|
|||
|
||||
template <class FooT>
|
||||
void BuildGraphTypedTest() {
|
||||
builder::Graph graph;
|
||||
Graph graph;
|
||||
auto& foo = graph.AddNode<FooT>();
|
||||
auto& bar = graph.AddNode<Bar>();
|
||||
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
|
||||
|
@ -161,12 +159,12 @@ void BuildGraphTypedTest() {
|
|||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest<Foo>(); }
|
||||
TEST(BuilderTest, BuildGraphTyped) { BuildGraphTypedTest<test::Foo>(); }
|
||||
|
||||
TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest<Foo2>(); }
|
||||
TEST(BuilderTest, BuildGraphTyped2) { BuildGraphTypedTest<test::Foo2>(); }
|
||||
|
||||
TEST(BuilderTest, FanOut) {
|
||||
builder::Graph graph;
|
||||
Graph graph;
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
auto& adder = graph.AddNode("FloatAdder");
|
||||
graph.In("IN").SetName("base") >> foo.In("BASE");
|
||||
|
@ -194,9 +192,9 @@ TEST(BuilderTest, FanOut) {
|
|||
}
|
||||
|
||||
TEST(BuilderTest, TypedMultiple) {
|
||||
builder::Graph graph;
|
||||
auto& foo = graph.AddNode<Foo>();
|
||||
auto& adder = graph.AddNode<FloatAdder>();
|
||||
Graph graph;
|
||||
auto& foo = graph.AddNode<test::Foo>();
|
||||
auto& adder = graph.AddNode<test::FloatAdder>();
|
||||
graph.In("IN").SetName("base") >> foo.In(MPP_TAG("BASE"));
|
||||
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[0];
|
||||
foo.Out(MPP_TAG("OUT")) >> adder.In(MPP_TAG("IN"))[1];
|
||||
|
@ -222,14 +220,14 @@ TEST(BuilderTest, TypedMultiple) {
|
|||
}
|
||||
|
||||
TEST(BuilderTest, TypedByPorts) {
|
||||
builder::Graph graph;
|
||||
auto& foo = graph.AddNode<Foo>();
|
||||
Graph graph;
|
||||
auto& foo = graph.AddNode<test::Foo>();
|
||||
auto& adder = graph.AddNode<FloatAdder>();
|
||||
|
||||
graph[FooBar1::kIn].SetName("base") >> foo[Foo::kBase];
|
||||
graph.In(FooBar1::kIn).SetName("base") >> foo[Foo::kBase];
|
||||
foo[Foo::kOut] >> adder[FloatAdder::kIn][0];
|
||||
foo[Foo::kOut] >> adder[FloatAdder::kIn][1];
|
||||
adder[FloatAdder::kOut].SetName("out") >> graph[FooBar1::kOut];
|
||||
adder[FloatAdder::kOut].SetName("out") >> graph.Out(FooBar1::kOut);
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
|
@ -251,7 +249,7 @@ TEST(BuilderTest, TypedByPorts) {
|
|||
}
|
||||
|
||||
TEST(BuilderTest, PacketGenerator) {
|
||||
builder::Graph graph;
|
||||
Graph graph;
|
||||
auto& generator = graph.AddPacketGenerator("FloatGenerator");
|
||||
graph.SideIn("IN") >> generator.SideIn("IN");
|
||||
generator.SideOut("OUT") >> graph.SideOut("OUT");
|
||||
|
@ -270,7 +268,7 @@ TEST(BuilderTest, PacketGenerator) {
|
|||
}
|
||||
|
||||
TEST(BuilderTest, EmptyTag) {
|
||||
builder::Graph graph;
|
||||
Graph graph;
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
graph.In("A").SetName("a") >> foo.In("")[0];
|
||||
graph.In("C").SetName("c") >> foo.In("")[2];
|
||||
|
@ -302,7 +300,7 @@ TEST(BuilderTest, StringLikeTags) {
|
|||
const std::string kB = "B";
|
||||
constexpr absl::string_view kC = "C";
|
||||
|
||||
builder::Graph graph;
|
||||
Graph graph;
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
graph.In(kA).SetName("a") >> foo.In(kA);
|
||||
graph.In(kB).SetName("b") >> foo.In(kB);
|
||||
|
@ -324,7 +322,7 @@ TEST(BuilderTest, StringLikeTags) {
|
|||
}
|
||||
|
||||
TEST(BuilderTest, GraphIndexes) {
|
||||
builder::Graph graph;
|
||||
Graph graph;
|
||||
auto& foo = graph.AddNode("Foo");
|
||||
graph.In(0).SetName("a") >> foo.In("")[0];
|
||||
graph.In(1).SetName("c") >> foo.In("")[2];
|
||||
|
@ -376,28 +374,27 @@ class AnyAndSameTypeCalculator : public NodeIntf {
|
|||
};
|
||||
|
||||
TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
|
||||
builder::Graph graph;
|
||||
builder::Source<AnyType> any_input = graph[Input<AnyType>{"GRAPH_ANY_INPUT"}];
|
||||
builder::Source<int> int_input = graph[Input<int>{"GRAPH_INT_INPUT"}];
|
||||
Graph graph;
|
||||
Source<AnyType> any_input = graph.In("GRAPH_ANY_INPUT");
|
||||
Source<int> int_input = graph.In("GRAPH_INT_INPUT").Cast<int>();
|
||||
|
||||
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
|
||||
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
|
||||
int_input >> node[AnyAndSameTypeCalculator::kIntInput];
|
||||
|
||||
builder::Source<AnyType> any_type_output =
|
||||
Source<AnyType> any_type_output =
|
||||
node[AnyAndSameTypeCalculator::kAnyTypeOutput];
|
||||
any_type_output.SetName("any_type_output");
|
||||
|
||||
builder::Source<AnyType> same_type_output =
|
||||
Source<AnyType> same_type_output =
|
||||
node[AnyAndSameTypeCalculator::kSameTypeOutput];
|
||||
same_type_output.SetName("same_type_output");
|
||||
builder::Source<AnyType> recursive_same_type_output =
|
||||
Source<AnyType> recursive_same_type_output =
|
||||
node[AnyAndSameTypeCalculator::kRecursiveSameTypeOutput];
|
||||
recursive_same_type_output.SetName("recursive_same_type_output");
|
||||
builder::Source<int> same_int_output =
|
||||
node[AnyAndSameTypeCalculator::kSameIntOutput];
|
||||
Source<int> same_int_output = node[AnyAndSameTypeCalculator::kSameIntOutput];
|
||||
same_int_output.SetName("same_int_output");
|
||||
builder::Source<int> recursive_same_int_type_output =
|
||||
Source<int> recursive_same_int_type_output =
|
||||
node[AnyAndSameTypeCalculator::kRecursiveSameIntOutput];
|
||||
recursive_same_int_type_output.SetName("recursive_same_int_type_output");
|
||||
|
||||
|
@ -420,15 +417,16 @@ TEST(BuilderTest, AnyAndSameTypeHandledProperly) {
|
|||
}
|
||||
|
||||
TEST(BuilderTest, AnyTypeCanBeCast) {
|
||||
builder::Graph graph;
|
||||
builder::Source<std::string> any_input =
|
||||
Graph graph;
|
||||
Source<std::string> any_input =
|
||||
graph.In("GRAPH_ANY_INPUT").Cast<std::string>();
|
||||
|
||||
auto& node = graph.AddNode("AnyAndSameTypeCalculator");
|
||||
any_input >> node[AnyAndSameTypeCalculator::kAnyTypeInput];
|
||||
builder::Source<double> any_type_output =
|
||||
node[AnyAndSameTypeCalculator::kAnyTypeOutput].Cast<double>();
|
||||
any_type_output.SetName("any_type_output");
|
||||
Source<double> any_type_output =
|
||||
node[AnyAndSameTypeCalculator::kAnyTypeOutput]
|
||||
.SetName("any_type_output")
|
||||
.Cast<double>();
|
||||
|
||||
any_type_output >> graph.Out("GRAPH_ANY_OUTPUT").Cast<double>();
|
||||
|
||||
|
@ -446,11 +444,11 @@ TEST(BuilderTest, AnyTypeCanBeCast) {
|
|||
}
|
||||
|
||||
TEST(BuilderTest, MultiPortIsCastToMultiPort) {
|
||||
builder::Graph graph;
|
||||
builder::MultiSource<AnyType> any_input = graph.In("ANY_INPUT");
|
||||
builder::MultiSource<int> int_input = any_input.Cast<int>();
|
||||
builder::MultiDestination<AnyType> any_output = graph.Out("ANY_OUTPUT");
|
||||
builder::MultiDestination<int> int_output = any_output.Cast<int>();
|
||||
Graph graph;
|
||||
MultiSource<AnyType> any_input = graph.In("ANY_INPUT");
|
||||
MultiSource<int> int_input = any_input.Cast<int>();
|
||||
MultiDestination<AnyType> any_output = graph.Out("ANY_OUTPUT");
|
||||
MultiDestination<int> int_output = any_output.Cast<int>();
|
||||
int_input >> int_output;
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
|
@ -462,11 +460,11 @@ TEST(BuilderTest, MultiPortIsCastToMultiPort) {
|
|||
}
|
||||
|
||||
TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) {
|
||||
builder::Graph graph;
|
||||
builder::MultiSource<AnyType> any_multi_input = graph.In("ANY_INPUT");
|
||||
builder::Source<AnyType> any_input = any_multi_input;
|
||||
builder::MultiDestination<AnyType> any_multi_output = graph.Out("ANY_OUTPUT");
|
||||
builder::Destination<AnyType> any_output = any_multi_output;
|
||||
Graph graph;
|
||||
MultiSource<AnyType> any_multi_input = graph.In("ANY_INPUT");
|
||||
Source<AnyType> any_input = any_multi_input;
|
||||
MultiDestination<AnyType> any_multi_output = graph.Out("ANY_OUTPUT");
|
||||
Destination<AnyType> any_output = any_multi_output;
|
||||
any_input >> any_output;
|
||||
|
||||
CalculatorGraphConfig expected =
|
||||
|
@ -478,11 +476,11 @@ TEST(BuilderTest, MultiPortCanBeSlicedToSinglePort) {
|
|||
}
|
||||
|
||||
TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
|
||||
builder::Graph graph;
|
||||
builder::Source<int> int_input = graph.In("INT_INPUT").Cast<int>();
|
||||
builder::Source<AnyType> any_input = graph.In("ANY_OUTPUT");
|
||||
builder::Destination<int> int_output = graph.Out("INT_OUTPUT").Cast<int>();
|
||||
builder::Destination<AnyType> any_output = graph.Out("ANY_OUTPUT");
|
||||
Graph graph;
|
||||
Source<int> int_input = graph.In("INT_INPUT").Cast<int>();
|
||||
Source<AnyType> any_input = graph.In("ANY_OUTPUT");
|
||||
Destination<int> int_output = graph.Out("INT_OUTPUT").Cast<int>();
|
||||
Destination<AnyType> any_output = graph.Out("ANY_OUTPUT");
|
||||
int_input >> int_output;
|
||||
any_input >> any_output;
|
||||
|
||||
|
@ -496,6 +494,5 @@ TEST(BuilderTest, SinglePortAccessWorksThroughSlicing) {
|
|||
EXPECT_THAT(graph.GetConfig(), EqualsProto(expected));
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
} // namespace
|
||||
} // namespace mediapipe::api2::builder
|
||||
|
|
|
@ -557,8 +557,8 @@ class OutputSidePacketAccess {
|
|||
if (output_) output_->Set(ToOldPacket(std::move(packet)));
|
||||
}
|
||||
|
||||
void Set(const T& payload) { Set(MakePacket<T>(payload)); }
|
||||
void Set(T&& payload) { Set(MakePacket<T>(std::move(payload))); }
|
||||
void Set(const T& payload) { Set(api2::MakePacket<T>(payload)); }
|
||||
void Set(T&& payload) { Set(api2::MakePacket<T>(std::move(payload))); }
|
||||
|
||||
private:
|
||||
OutputSidePacketAccess(OutputSidePacket* output) : output_(output) {}
|
||||
|
|
|
@ -20,9 +20,14 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
package(default_visibility = [
|
||||
"//mediapipe:__subpackages__",
|
||||
])
|
||||
package_group(
|
||||
name = "mediapipe_internal",
|
||||
packages = [
|
||||
"//mediapipe/...",
|
||||
],
|
||||
)
|
||||
|
||||
package(default_visibility = ["mediapipe_internal"])
|
||||
|
||||
bzl_library(
|
||||
name = "expand_template_bzl",
|
||||
|
@ -214,6 +219,9 @@ cc_library(
|
|||
name = "registration",
|
||||
srcs = ["registration.cc"],
|
||||
hdrs = ["registration.h"],
|
||||
visibility = [
|
||||
"mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":registration_token",
|
||||
"//mediapipe/framework/port:logging",
|
||||
|
|
|
@ -26,7 +26,7 @@ licenses(["notice"])
|
|||
mediapipe_proto_library(
|
||||
name = "detection_proto",
|
||||
srcs = ["detection.proto"],
|
||||
deps = ["//mediapipe/framework/formats:location_data_proto"],
|
||||
deps = [":location_data_proto"],
|
||||
)
|
||||
|
||||
mediapipe_register_type(
|
||||
|
@ -38,7 +38,7 @@ mediapipe_register_type(
|
|||
"::std::vector<::mediapipe::Detection>",
|
||||
"::std::vector<::mediapipe::DetectionList>",
|
||||
],
|
||||
deps = ["//mediapipe/framework/formats:detection_cc_proto"],
|
||||
deps = [":detection_cc_proto"],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
|
@ -105,8 +105,8 @@ cc_library(
|
|||
srcs = ["matrix.cc"],
|
||||
hdrs = ["matrix.h"],
|
||||
deps = [
|
||||
":matrix_data_cc_proto",
|
||||
"//mediapipe/framework:port",
|
||||
"//mediapipe/framework/formats:matrix_data_cc_proto",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
|
@ -142,7 +142,7 @@ cc_library(
|
|||
srcs = ["image_frame.cc"],
|
||||
hdrs = ["image_frame.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
":image_format_cc_proto",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -166,8 +166,8 @@ cc_library(
|
|||
srcs = ["image_frame_opencv.cc"],
|
||||
hdrs = ["image_frame_opencv.h"],
|
||||
deps = [
|
||||
":image_format_cc_proto",
|
||||
":image_frame",
|
||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
],
|
||||
)
|
||||
|
@ -194,7 +194,7 @@ cc_library(
|
|||
deps = [
|
||||
"@com_google_protobuf//:protobuf",
|
||||
"//mediapipe/framework/formats/annotation:locus_cc_proto",
|
||||
"//mediapipe/framework/formats:location_data_cc_proto",
|
||||
":location_data_cc_proto",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -245,7 +245,7 @@ cc_library(
|
|||
name = "video_stream_header",
|
||||
hdrs = ["video_stream_header.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
":image_format_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -263,9 +263,9 @@ cc_test(
|
|||
size = "small",
|
||||
srcs = ["image_frame_opencv_test.cc"],
|
||||
deps = [
|
||||
":image_format_cc_proto",
|
||||
":image_frame",
|
||||
":image_frame_opencv",
|
||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
|
@ -324,8 +324,8 @@ cc_library(
|
|||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
":image_format_cc_proto",
|
||||
":image_frame",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"//mediapipe/framework:port",
|
||||
"//mediapipe/framework:type_map",
|
||||
|
@ -354,7 +354,7 @@ cc_library(
|
|||
hdrs = ["image_multi_pool.h"],
|
||||
deps = [
|
||||
":image",
|
||||
"//mediapipe/framework/formats:image_frame_pool",
|
||||
":image_frame_pool",
|
||||
"//mediapipe/framework:port",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -390,7 +390,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":image",
|
||||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
":image_format_cc_proto",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:statusor",
|
||||
|
@ -428,7 +428,10 @@ cc_library(
|
|||
"tensor.cc",
|
||||
"tensor_ahwb.cc",
|
||||
],
|
||||
hdrs = ["tensor.h"],
|
||||
hdrs = [
|
||||
"tensor.h",
|
||||
"tensor_internal.h",
|
||||
],
|
||||
copts = select({
|
||||
"//mediapipe:apple": [
|
||||
"-x objective-c++",
|
||||
|
@ -452,6 +455,7 @@ cc_library(
|
|||
],
|
||||
}),
|
||||
deps = [
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"//mediapipe/framework:port",
|
||||
|
|
|
@ -38,11 +38,11 @@ cc_library(
|
|||
srcs = ["optical_flow_field.cc"],
|
||||
hdrs = ["optical_flow_field.h"],
|
||||
deps = [
|
||||
":optical_flow_field_data_cc_proto",
|
||||
"//mediapipe/framework:type_map",
|
||||
"//mediapipe/framework/deps:mathutil",
|
||||
"//mediapipe/framework/formats:location",
|
||||
"//mediapipe/framework/formats:location_opencv",
|
||||
"//mediapipe/framework/formats/motion:optical_flow_field_data_cc_proto",
|
||||
"//mediapipe/framework/port:file_helpers",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:logging",
|
||||
|
|
|
@ -246,10 +246,10 @@ Tensor::OpenGlTexture2dView::GetLayoutDimensions(const Tensor::Shape& shape,
|
|||
return Tensor::OpenGlTexture2dView::Layout::kAligned;
|
||||
}
|
||||
}
|
||||
// The best performance of a compute shader can be achived with textures'
|
||||
// The best performance of a compute shader can be achieved with textures'
|
||||
// width multiple of 256. Making minimum fixed width of 256 waste memory for
|
||||
// small tensors. The optimal balance memory-vs-performance is power of 2.
|
||||
// The texture width and height are choosen to be closer to square.
|
||||
// The texture width and height are chosen to be closer to square.
|
||||
float power = std::log2(std::sqrt(static_cast<float>(num_pixels)));
|
||||
w = 1 << static_cast<int>(power);
|
||||
int h = (num_pixels + w - 1) / w;
|
||||
|
@ -326,7 +326,7 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const {
|
|||
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
|
||||
AllocateOpenGlBuffer();
|
||||
if (!(valid_ & kValidOpenGlBuffer)) {
|
||||
// If the call succeds then AHWB -> SSBO are synchronized so any usage of
|
||||
// If the call succeeds then AHWB -> SSBO are synchronized so any usage of
|
||||
// the SSBO is correct after this call.
|
||||
if (!InsertAhwbToSsboFence()) {
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
|
||||
|
@ -348,8 +348,10 @@ Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const {
|
|||
};
|
||||
}
|
||||
|
||||
Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView() const {
|
||||
Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView(
|
||||
uint64_t source_location_hash) const {
|
||||
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
|
||||
TrackAhwbUsage(source_location_hash);
|
||||
AllocateOpenGlBuffer();
|
||||
valid_ = kValidOpenGlBuffer;
|
||||
return {opengl_buffer_, std::move(lock), nullptr};
|
||||
|
@ -385,6 +387,7 @@ void Tensor::Move(Tensor* src) {
|
|||
src->element_type_ = ElementType::kNone; // Mark as invalidated.
|
||||
cpu_buffer_ = src->cpu_buffer_;
|
||||
src->cpu_buffer_ = nullptr;
|
||||
ahwb_tracking_key_ = src->ahwb_tracking_key_;
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
device_ = src->device_;
|
||||
src->device_ = nil;
|
||||
|
@ -589,8 +592,10 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
|
|||
return {cpu_buffer_, std::move(lock)};
|
||||
}
|
||||
|
||||
Tensor::CpuWriteView Tensor::GetCpuWriteView() const {
|
||||
Tensor::CpuWriteView Tensor::GetCpuWriteView(
|
||||
uint64_t source_location_hash) const {
|
||||
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
|
||||
TrackAhwbUsage(source_location_hash);
|
||||
AllocateCpuBuffer();
|
||||
valid_ = kValidCpu;
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
@ -620,24 +625,4 @@ void Tensor::AllocateCpuBuffer() const {
|
|||
}
|
||||
}
|
||||
|
||||
void Tensor::SetPreferredStorageType(StorageType type) {
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
if (__builtin_available(android 26, *)) {
|
||||
use_ahwb_ = type == StorageType::kAhwb;
|
||||
VLOG(4) << "Tensor: use of AHardwareBuffer is "
|
||||
<< (use_ahwb_ ? "allowed" : "not allowed");
|
||||
}
|
||||
#else
|
||||
VLOG(4) << "Tensor: use of AHardwareBuffer is not allowed";
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
}
|
||||
|
||||
Tensor::StorageType Tensor::GetPreferredStorageType() {
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
return use_ahwb_ ? StorageType::kAhwb : StorageType::kDefault;
|
||||
#else
|
||||
return StorageType::kDefault;
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -24,8 +24,9 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/formats/tensor_internal.h"
|
||||
#include "mediapipe/framework/port.h"
|
||||
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
|
@ -48,6 +49,22 @@
|
|||
#include "mediapipe/gpu/gl_context.h"
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
|
||||
#if defined __has_builtin
|
||||
#if __has_builtin(__builtin_LINE)
|
||||
#define builtin_LINE __builtin_LINE
|
||||
#endif
|
||||
#if __has_builtin(__builtin_FILE)
|
||||
#define builtin_FILE __builtin_FILE
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef builtin_LINE
|
||||
#define builtin_LINE() 0
|
||||
#endif
|
||||
#ifndef builtin_FILE
|
||||
#define builtin_FILE() ""
|
||||
#endif
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Tensor is a container of multi-dimensional data that supports sharing the
|
||||
|
@ -65,7 +82,7 @@ namespace mediapipe {
|
|||
// GLuint buffer = view.buffer();
|
||||
// Then the buffer can be bound to the GPU command buffer.
|
||||
// ...binding the buffer to the command buffer...
|
||||
// ...commiting command buffer and releasing the view...
|
||||
// ...committing command buffer and releasing the view...
|
||||
//
|
||||
// The following request for the CPU view will be blocked until the GPU view is
|
||||
// released and the GPU task is finished.
|
||||
|
@ -161,7 +178,9 @@ class Tensor {
|
|||
using CpuReadView = CpuView<const void>;
|
||||
CpuReadView GetCpuReadView() const;
|
||||
using CpuWriteView = CpuView<void>;
|
||||
CpuWriteView GetCpuWriteView() const;
|
||||
CpuWriteView GetCpuWriteView(
|
||||
uint64_t source_location_hash =
|
||||
tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const;
|
||||
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
// TODO: id<MTLBuffer> vs. MtlBufferView.
|
||||
|
@ -305,7 +324,9 @@ class Tensor {
|
|||
// A valid OpenGL context must be bound to the calling thread due to possible
|
||||
// GPU resource allocation.
|
||||
OpenGlBufferView GetOpenGlBufferReadView() const;
|
||||
OpenGlBufferView GetOpenGlBufferWriteView() const;
|
||||
OpenGlBufferView GetOpenGlBufferWriteView(
|
||||
uint64_t source_location_hash =
|
||||
tensor_internal::FnvHash64(builtin_FILE(), builtin_LINE())) const;
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
|
||||
const Shape& shape() const { return shape_; }
|
||||
|
@ -408,9 +429,13 @@ class Tensor {
|
|||
mutable std::function<void()> release_callback_;
|
||||
bool AllocateAHardwareBuffer(int size_alignment = 0) const;
|
||||
void CreateEglSyncAndFd() const;
|
||||
// Use Ahwb for other views: OpenGL / CPU buffer.
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
static inline bool use_ahwb_ = false;
|
||||
// Use Ahwb for other views: OpenGL / CPU buffer.
|
||||
mutable bool use_ahwb_ = false;
|
||||
mutable uint64_t ahwb_tracking_key_ = 0;
|
||||
// TODO: Tracks all unique tensors. Can grow to a large number. LRU
|
||||
// can be more predicted.
|
||||
static inline absl::flat_hash_set<uint64_t> ahwb_usage_track_;
|
||||
// Expects the target SSBO to be already bound.
|
||||
bool AllocateAhwbMapToSsbo() const;
|
||||
bool InsertAhwbToSsboFence() const;
|
||||
|
@ -419,6 +444,8 @@ class Tensor {
|
|||
void* MapAhwbToCpuRead() const;
|
||||
void* MapAhwbToCpuWrite() const;
|
||||
void MoveCpuOrSsboToAhwb() const;
|
||||
// Set current tracking key, set "use ahwb" if the key is already marked.
|
||||
void TrackAhwbUsage(uint64_t key) const;
|
||||
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
mutable std::shared_ptr<mediapipe::GlContext> gl_context_;
|
||||
|
|
|
@ -212,9 +212,6 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const {
|
|||
CHECK(!(valid_ & kValidOpenGlTexture2d))
|
||||
<< "Tensor conversion between OpenGL texture and AHardwareBuffer is not "
|
||||
"supported.";
|
||||
CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer))
|
||||
<< "Interoperability bettween OpenGL buffer and AHardwareBuffer is not "
|
||||
"supported on target system.";
|
||||
bool transfer = !ahwb_;
|
||||
CHECK(AllocateAHardwareBuffer())
|
||||
<< "AHardwareBuffer is not supported on the target system.";
|
||||
|
@ -268,6 +265,10 @@ Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
|
|||
}
|
||||
|
||||
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
|
||||
// Mark current tracking key as Ahwb-use.
|
||||
ahwb_usage_track_.insert(ahwb_tracking_key_);
|
||||
use_ahwb_ = true;
|
||||
|
||||
if (__builtin_available(android 26, *)) {
|
||||
if (ahwb_ == nullptr) {
|
||||
AHardwareBuffer_Desc desc = {};
|
||||
|
@ -315,7 +316,13 @@ void Tensor::MoveCpuOrSsboToAhwb() const {
|
|||
ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, -1, nullptr, &dest);
|
||||
CHECK(error == 0) << "AHardwareBuffer_lock " << error;
|
||||
}
|
||||
if (valid_ & kValidOpenGlBuffer) {
|
||||
if (valid_ & kValidCpu) {
|
||||
std::memcpy(dest, cpu_buffer_, bytes());
|
||||
// Free CPU memory because next time AHWB is mapped instead.
|
||||
free(cpu_buffer_);
|
||||
cpu_buffer_ = nullptr;
|
||||
valid_ &= ~kValidCpu;
|
||||
} else if (valid_ & kValidOpenGlBuffer) {
|
||||
gl_context_->Run([this, dest]() {
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
|
||||
const void* src = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(),
|
||||
|
@ -326,11 +333,9 @@ void Tensor::MoveCpuOrSsboToAhwb() const {
|
|||
});
|
||||
opengl_buffer_ = GL_INVALID_INDEX;
|
||||
gl_context_ = nullptr;
|
||||
} else if (valid_ & kValidCpu) {
|
||||
std::memcpy(dest, cpu_buffer_, bytes());
|
||||
// Free CPU memory because next time AHWB is mapped instead.
|
||||
free(cpu_buffer_);
|
||||
cpu_buffer_ = nullptr;
|
||||
// Reset OpenGL Buffer validness. The OpenGL buffer will be allocated on top
|
||||
// of the Ahwb at the next request to the OpenGlBufferView.
|
||||
valid_ &= ~kValidOpenGlBuffer;
|
||||
} else {
|
||||
LOG(FATAL) << "Can't convert tensor with mask " << valid_ << " into AHWB.";
|
||||
}
|
||||
|
@ -446,6 +451,16 @@ void* Tensor::MapAhwbToCpuWrite() const {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void Tensor::TrackAhwbUsage(uint64_t source_location_hash) const {
|
||||
if (ahwb_tracking_key_ == 0) {
|
||||
ahwb_tracking_key_ = source_location_hash;
|
||||
for (int dim : shape_.dims) {
|
||||
ahwb_tracking_key_ = tensor_internal::FnvHash64(ahwb_tracking_key_, dim);
|
||||
}
|
||||
}
|
||||
use_ahwb_ = ahwb_usage_track_.contains(ahwb_tracking_key_);
|
||||
}
|
||||
|
||||
#else // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
bool Tensor::AllocateAhwbMapToSsbo() const { return false; }
|
||||
|
@ -454,6 +469,7 @@ void Tensor::MoveAhwbStuff(Tensor* src) {}
|
|||
void Tensor::ReleaseAhwbStuff() {}
|
||||
void* Tensor::MapAhwbToCpuRead() const { return nullptr; }
|
||||
void* Tensor::MapAhwbToCpuWrite() const { return nullptr; }
|
||||
void Tensor::TrackAhwbUsage(uint64_t key) const {}
|
||||
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
|
|
|
@ -152,6 +152,36 @@ TEST_F(TensorAhwbGpuTest, TestReplacingCpuByAhwb) {
|
|||
{
|
||||
auto view = tensor.GetAHardwareBufferReadView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||
}
|
||||
auto ptr = tensor.GetCpuReadView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
std::vector<float> reference;
|
||||
reference.resize(num_elements);
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
reference[i] = static_cast<float>(i) / 10.0f;
|
||||
}
|
||||
EXPECT_THAT(absl::Span<const float>(ptr, num_elements),
|
||||
testing::Pointwise(testing::FloatEq(), reference));
|
||||
}
|
||||
|
||||
TEST_F(TensorAhwbGpuTest, TestReplacingGpuByAhwb) {
|
||||
// Request the GPU view to get the ssbo allocated internally.
|
||||
// Request Ahwb view then to transform the storage into Ahwb.
|
||||
Tensor::SetPreferredStorageType(Tensor::StorageType::kDefault);
|
||||
constexpr size_t num_elements = 20;
|
||||
Tensor tensor{Tensor::ElementType::kFloat32, Tensor::Shape({num_elements})};
|
||||
RunInGlContext([&tensor] {
|
||||
auto ssbo_view = tensor.GetOpenGlBufferWriteView();
|
||||
auto ssbo_name = ssbo_view.name();
|
||||
EXPECT_GT(ssbo_name, 0);
|
||||
FillGpuBuffer(ssbo_name, tensor.shape().num_elements(),
|
||||
tensor.element_type());
|
||||
});
|
||||
{
|
||||
auto view = tensor.GetAHardwareBufferReadView();
|
||||
EXPECT_NE(view.handle(), nullptr);
|
||||
view.SetReadingFinishedFunc([](bool) { return true; });
|
||||
}
|
||||
auto ptr = tensor.GetCpuReadView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
|
|
|
@ -1,71 +0,0 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
|
||||
|
||||
#if !defined(MEDIAPIPE_NO_JNI) && \
|
||||
(__ANDROID_API__ >= 26 || \
|
||||
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
||||
|
||||
#include <android/hardware_buffer.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mediapipe/framework/formats/tensor_buffer.h"
|
||||
#include "mediapipe/framework/formats/tensor_internal.h"
|
||||
#include "mediapipe/framework/formats/tensor_v2.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Supports:
|
||||
// - float 16 and 32 bits
|
||||
// - signed / unsigned integers 8,16,32 bits
|
||||
class TensorHardwareBufferView;
|
||||
struct TensorHardwareBufferViewDescriptor : public Tensor::ViewDescriptor {
|
||||
using ViewT = TensorHardwareBufferView;
|
||||
TensorBufferDescriptor buffer;
|
||||
};
|
||||
|
||||
class TensorHardwareBufferView : public Tensor::View {
|
||||
public:
|
||||
TENSOR_UNIQUE_VIEW_TYPE_ID();
|
||||
~TensorHardwareBufferView() = default;
|
||||
|
||||
const TensorHardwareBufferViewDescriptor& descriptor() const override {
|
||||
return descriptor_;
|
||||
}
|
||||
AHardwareBuffer* handle() const { return ahwb_handle_; }
|
||||
|
||||
protected:
|
||||
TensorHardwareBufferView(int access_capability, Tensor::View::Access access,
|
||||
Tensor::View::State state,
|
||||
const TensorHardwareBufferViewDescriptor& desc,
|
||||
AHardwareBuffer* ahwb_handle)
|
||||
: Tensor::View(kId, access_capability, access, state),
|
||||
descriptor_(desc),
|
||||
ahwb_handle_(ahwb_handle) {}
|
||||
|
||||
private:
|
||||
bool MatchDescriptor(
|
||||
uint64_t view_type_id,
|
||||
const Tensor::ViewDescriptor& base_descriptor) const override {
|
||||
if (!Tensor::View::MatchDescriptor(view_type_id, base_descriptor))
|
||||
return false;
|
||||
auto descriptor =
|
||||
static_cast<const TensorHardwareBufferViewDescriptor&>(base_descriptor);
|
||||
return descriptor.buffer.format == descriptor_.buffer.format &&
|
||||
descriptor.buffer.size_alignment <=
|
||||
descriptor_.buffer.size_alignment &&
|
||||
descriptor_.buffer.size_alignment %
|
||||
descriptor.buffer.size_alignment ==
|
||||
0;
|
||||
}
|
||||
const TensorHardwareBufferViewDescriptor& descriptor_;
|
||||
AHardwareBuffer* ahwb_handle_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // !defined(MEDIAPIPE_NO_JNI) && \
|
||||
(__ANDROID_API__ >= 26 || \
|
||||
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
||||
|
||||
#endif // MEDIAPIPE_FRAMEWORK_FORMATS_TENSOR_HARDWARE_BUFFER_H_
|
|
@ -1,216 +0,0 @@
|
|||
#if !defined(MEDIAPIPE_NO_JNI) && \
|
||||
(__ANDROID_API__ >= 26 || \
|
||||
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/framework/formats/tensor_backend.h"
|
||||
#include "mediapipe/framework/formats/tensor_cpu_buffer.h"
|
||||
#include "mediapipe/framework/formats/tensor_hardware_buffer.h"
|
||||
#include "mediapipe/framework/formats/tensor_v2.h"
|
||||
#include "util/task/status_macros.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
class TensorCpuViewImpl : public TensorCpuView {
|
||||
public:
|
||||
TensorCpuViewImpl(int access_capabilities, Tensor::View::Access access,
|
||||
Tensor::View::State state,
|
||||
const TensorCpuViewDescriptor& descriptor, void* pointer,
|
||||
AHardwareBuffer* ahwb_handle)
|
||||
: TensorCpuView(access_capabilities, access, state, descriptor, pointer),
|
||||
ahwb_handle_(ahwb_handle) {}
|
||||
~TensorCpuViewImpl() {
|
||||
// If handle_ is null then this view is constructed in GetViews with no
|
||||
// access.
|
||||
if (ahwb_handle_) {
|
||||
if (__builtin_available(android 26, *)) {
|
||||
AHardwareBuffer_unlock(ahwb_handle_, nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
AHardwareBuffer* ahwb_handle_;
|
||||
};
|
||||
|
||||
class TensorHardwareBufferViewImpl : public TensorHardwareBufferView {
|
||||
public:
|
||||
TensorHardwareBufferViewImpl(
|
||||
int access_capability, Tensor::View::Access access,
|
||||
Tensor::View::State state,
|
||||
const TensorHardwareBufferViewDescriptor& descriptor,
|
||||
AHardwareBuffer* handle)
|
||||
: TensorHardwareBufferView(access_capability, access, state, descriptor,
|
||||
handle) {}
|
||||
~TensorHardwareBufferViewImpl() = default;
|
||||
};
|
||||
|
||||
class HardwareBufferCpuStorage : public TensorStorage {
|
||||
public:
|
||||
~HardwareBufferCpuStorage() {
|
||||
if (!ahwb_handle_) return;
|
||||
if (__builtin_available(android 26, *)) {
|
||||
AHardwareBuffer_release(ahwb_handle_);
|
||||
}
|
||||
}
|
||||
|
||||
static absl::Status CanProvide(
|
||||
int access_capability, const Tensor::Shape& shape, uint64_t view_type_id,
|
||||
const Tensor::ViewDescriptor& base_descriptor) {
|
||||
// TODO: use AHardwareBuffer_isSupported for API >= 29.
|
||||
static const bool is_ahwb_supported = [] {
|
||||
if (__builtin_available(android 26, *)) {
|
||||
AHardwareBuffer_Desc desc = {};
|
||||
// Aligned to the largest possible virtual memory page size.
|
||||
constexpr uint32_t kPageSize = 16384;
|
||||
desc.width = kPageSize;
|
||||
desc.height = 1;
|
||||
desc.layers = 1;
|
||||
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
|
||||
desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
|
||||
AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
|
||||
AHardwareBuffer* handle;
|
||||
if (AHardwareBuffer_allocate(&desc, &handle) != 0) return false;
|
||||
AHardwareBuffer_release(handle);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}();
|
||||
if (!is_ahwb_supported) {
|
||||
return absl::UnavailableError(
|
||||
"AHardwareBuffer is not supported on the platform.");
|
||||
}
|
||||
|
||||
if (view_type_id != TensorCpuView::kId &&
|
||||
view_type_id != TensorHardwareBufferView::kId) {
|
||||
return absl::InvalidArgumentError(
|
||||
"A view type is not supported by this storage.");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<Tensor::View>> GetViews(uint64_t latest_version) {
|
||||
std::vector<std::unique_ptr<Tensor::View>> result;
|
||||
auto update_state = latest_version == version_
|
||||
? Tensor::View::State::kUpToDate
|
||||
: Tensor::View::State::kOutdated;
|
||||
if (ahwb_handle_) {
|
||||
result.push_back(
|
||||
std::unique_ptr<Tensor::View>(new TensorHardwareBufferViewImpl(
|
||||
kAccessCapability, Tensor::View::Access::kNoAccess, update_state,
|
||||
hw_descriptor_, ahwb_handle_)));
|
||||
|
||||
result.push_back(std::unique_ptr<Tensor::View>(new TensorCpuViewImpl(
|
||||
kAccessCapability, Tensor::View::Access::kNoAccess, update_state,
|
||||
cpu_descriptor_, nullptr, nullptr)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<Tensor::View>> GetView(
|
||||
Tensor::View::Access access, const Tensor::Shape& shape,
|
||||
uint64_t latest_version, uint64_t view_type_id,
|
||||
const Tensor::ViewDescriptor& base_descriptor, int access_capability) {
|
||||
MP_RETURN_IF_ERROR(
|
||||
CanProvide(access_capability, shape, view_type_id, base_descriptor));
|
||||
const auto& buffer_descriptor =
|
||||
view_type_id == TensorHardwareBufferView::kId
|
||||
? static_cast<const TensorHardwareBufferViewDescriptor&>(
|
||||
base_descriptor)
|
||||
.buffer
|
||||
: static_cast<const TensorCpuViewDescriptor&>(base_descriptor)
|
||||
.buffer;
|
||||
if (!ahwb_handle_) {
|
||||
if (__builtin_available(android 26, *)) {
|
||||
AHardwareBuffer_Desc desc = {};
|
||||
desc.width = TensorBufferSize(buffer_descriptor, shape);
|
||||
desc.height = 1;
|
||||
desc.layers = 1;
|
||||
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
|
||||
// TODO: Use access capabilities to set hints.
|
||||
desc.usage = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
|
||||
AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
|
||||
auto error = AHardwareBuffer_allocate(&desc, &ahwb_handle_);
|
||||
if (error != 0) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat("Error allocating hardware buffer: ", error));
|
||||
}
|
||||
// Fill all possible views to provide it as proto views.
|
||||
hw_descriptor_.buffer = buffer_descriptor;
|
||||
cpu_descriptor_.buffer = buffer_descriptor;
|
||||
}
|
||||
}
|
||||
if (buffer_descriptor.format != hw_descriptor_.buffer.format ||
|
||||
buffer_descriptor.size_alignment >
|
||||
hw_descriptor_.buffer.size_alignment ||
|
||||
hw_descriptor_.buffer.size_alignment %
|
||||
buffer_descriptor.size_alignment >
|
||||
0) {
|
||||
return absl::AlreadyExistsError(
|
||||
"A view with different params is already allocated with this "
|
||||
"storage");
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<Tensor::View>> result;
|
||||
if (view_type_id == TensorHardwareBufferView::kId) {
|
||||
result = GetAhwbView(access, shape, base_descriptor);
|
||||
} else {
|
||||
result = GetCpuView(access, shape, base_descriptor);
|
||||
}
|
||||
if (result.ok()) version_ = latest_version;
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
absl::StatusOr<std::unique_ptr<Tensor::View>> GetAhwbView(
|
||||
Tensor::View::Access access, const Tensor::Shape& shape,
|
||||
const Tensor::ViewDescriptor& base_descriptor) {
|
||||
return std::unique_ptr<Tensor::View>(new TensorHardwareBufferViewImpl(
|
||||
kAccessCapability, access, Tensor::View::State::kUpToDate,
|
||||
hw_descriptor_, ahwb_handle_));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<Tensor::View>> GetCpuView(
|
||||
Tensor::View::Access access, const Tensor::Shape& shape,
|
||||
const Tensor::ViewDescriptor& base_descriptor) {
|
||||
void* pointer = nullptr;
|
||||
if (__builtin_available(android 26, *)) {
|
||||
int error =
|
||||
AHardwareBuffer_lock(ahwb_handle_,
|
||||
access == Tensor::View::Access::kWriteOnly
|
||||
? AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN
|
||||
: AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN,
|
||||
-1, nullptr, &pointer);
|
||||
if (error != 0) {
|
||||
return absl::UnknownError(
|
||||
absl::StrCat("Error locking hardware buffer: ", error));
|
||||
}
|
||||
}
|
||||
return std::unique_ptr<Tensor::View>(
|
||||
new TensorCpuViewImpl(access == Tensor::View::Access::kWriteOnly
|
||||
? Tensor::View::AccessCapability::kWrite
|
||||
: Tensor::View::AccessCapability::kRead,
|
||||
access, Tensor::View::State::kUpToDate,
|
||||
cpu_descriptor_, pointer, ahwb_handle_));
|
||||
}
|
||||
|
||||
static constexpr int kAccessCapability =
|
||||
Tensor::View::AccessCapability::kRead |
|
||||
Tensor::View::AccessCapability::kWrite;
|
||||
TensorHardwareBufferViewDescriptor hw_descriptor_;
|
||||
AHardwareBuffer* ahwb_handle_ = nullptr;
|
||||
|
||||
TensorCpuViewDescriptor cpu_descriptor_;
|
||||
uint64_t version_ = 0;
|
||||
};
|
||||
TENSOR_REGISTER_STORAGE(HardwareBufferCpuStorage);
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
|
||||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
|
@ -1,76 +0,0 @@
|
|||
|
||||
#if !defined(MEDIAPIPE_NO_JNI) && \
|
||||
(__ANDROID_API__ >= 26 || \
|
||||
defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
||||
#include <android/hardware_buffer.h>
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mediapipe/framework/formats/tensor_cpu_buffer.h"
|
||||
#include "mediapipe/framework/formats/tensor_hardware_buffer.h"
|
||||
#include "mediapipe/framework/formats/tensor_v2.h"
|
||||
#include "testing/base/public/gmock.h"
|
||||
#include "testing/base/public/gunit.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
class TensorHardwareBufferTest : public ::testing::Test {
|
||||
public:
|
||||
TensorHardwareBufferTest() {}
|
||||
~TensorHardwareBufferTest() override {}
|
||||
};
|
||||
|
||||
TEST_F(TensorHardwareBufferTest, TestFloat32) {
|
||||
Tensor tensor{Tensor::Shape({1})};
|
||||
{
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto view,
|
||||
tensor.GetView<Tensor::View::Access::kWriteOnly>(
|
||||
TensorHardwareBufferViewDescriptor{
|
||||
.buffer = {.format =
|
||||
TensorBufferDescriptor::Format::kFloat32}}));
|
||||
EXPECT_NE(view->handle(), nullptr);
|
||||
}
|
||||
{
|
||||
const auto& const_tensor = tensor;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto view,
|
||||
const_tensor.GetView<Tensor::View::Access::kReadOnly>(
|
||||
TensorCpuViewDescriptor{
|
||||
.buffer = {.format =
|
||||
TensorBufferDescriptor::Format::kFloat32}}));
|
||||
EXPECT_NE(view->data<void>(), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorHardwareBufferTest, TestInt8Padding) {
|
||||
Tensor tensor{Tensor::Shape({1})};
|
||||
|
||||
{
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto view,
|
||||
tensor.GetView<Tensor::View::Access::kWriteOnly>(
|
||||
TensorHardwareBufferViewDescriptor{
|
||||
.buffer = {.format = TensorBufferDescriptor::Format::kInt8,
|
||||
.size_alignment = 4}}));
|
||||
EXPECT_NE(view->handle(), nullptr);
|
||||
}
|
||||
{
|
||||
const auto& const_tensor = tensor;
|
||||
MP_ASSERT_OK_AND_ASSIGN(
|
||||
auto view,
|
||||
const_tensor.GetView<Tensor::View::Access::kReadOnly>(
|
||||
TensorCpuViewDescriptor{
|
||||
.buffer = {.format = TensorBufferDescriptor::Format::kInt8}}));
|
||||
EXPECT_NE(view->data<void>(), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // !defined(MEDIAPIPE_NO_JNI) && (__ANDROID_API__ >= 26 ||
|
||||
// defined(__ANDROID_UNAVAILABLE_SYMBOLS_ARE_WEAK__))
|
|
@ -18,8 +18,6 @@
|
|||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
#include "mediapipe/framework/tool/type_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// Generates unique view id at compile-time using FILE and LINE.
|
||||
|
@ -41,10 +39,12 @@ namespace tensor_internal {
|
|||
// https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
|
||||
constexpr uint64_t kFnvPrime = 0x00000100000001B3;
|
||||
constexpr uint64_t kFnvOffsetBias = 0xcbf29ce484222325;
|
||||
constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) {
|
||||
return (str[0] == 0) ? hash : FnvHash64(str + 1, (hash ^ str[0]) * kFnvPrime);
|
||||
constexpr uint64_t FnvHash64(uint64_t value1, uint64_t value2) {
|
||||
return (value2 ^ value1) * kFnvPrime;
|
||||
}
|
||||
constexpr uint64_t FnvHash64(const char* str, uint64_t hash = kFnvOffsetBias) {
|
||||
return (str[0] == 0) ? hash : FnvHash64(str + 1, FnvHash64(hash, str[0]));
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
struct TypeList {
|
||||
static constexpr std::size_t size{sizeof...(Ts)};
|
||||
|
|
|
@ -88,8 +88,8 @@ cc_library(
|
|||
srcs = ["default_input_stream_handler.cc"],
|
||||
hdrs = ["default_input_stream_handler.h"],
|
||||
deps = [
|
||||
":default_input_stream_handler_cc_proto",
|
||||
"//mediapipe/framework:input_stream_handler",
|
||||
"//mediapipe/framework/stream_handler:default_input_stream_handler_cc_proto",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -110,8 +110,8 @@ cc_library(
|
|||
srcs = ["fixed_size_input_stream_handler.cc"],
|
||||
deps = [
|
||||
":default_input_stream_handler",
|
||||
":fixed_size_input_stream_handler_cc_proto",
|
||||
"//mediapipe/framework:input_stream_handler",
|
||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -159,13 +159,13 @@ cc_library(
|
|||
name = "sync_set_input_stream_handler",
|
||||
srcs = ["sync_set_input_stream_handler.cc"],
|
||||
deps = [
|
||||
":sync_set_input_stream_handler_cc_proto",
|
||||
"//mediapipe/framework:collection",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework:input_stream_handler",
|
||||
"//mediapipe/framework:mediapipe_options_cc_proto",
|
||||
"//mediapipe/framework:packet_set",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto",
|
||||
"//mediapipe/framework/tool:tag_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
|
@ -177,10 +177,10 @@ cc_library(
|
|||
name = "timestamp_align_input_stream_handler",
|
||||
srcs = ["timestamp_align_input_stream_handler.cc"],
|
||||
deps = [
|
||||
":timestamp_align_input_stream_handler_cc_proto",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework:input_stream_handler",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler_cc_proto",
|
||||
"//mediapipe/framework/tool:validate_name",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
|
@ -243,6 +243,7 @@ cc_test(
|
|||
srcs = ["set_input_stream_handler_test.cc"],
|
||||
deps = [
|
||||
":fixed_size_input_stream_handler",
|
||||
":fixed_size_input_stream_handler_cc_proto",
|
||||
":mux_input_stream_handler",
|
||||
"//mediapipe/calculators/core:mux_calculator",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
|
@ -251,7 +252,6 @@ cc_test(
|
|||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -272,13 +272,13 @@ cc_test(
|
|||
srcs = ["fixed_size_input_stream_handler_test.cc"],
|
||||
deps = [
|
||||
":fixed_size_input_stream_handler",
|
||||
":fixed_size_input_stream_handler_cc_proto",
|
||||
"//mediapipe/calculators/core:counting_source_calculator",
|
||||
"//mediapipe/calculators/core:pass_through_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler_cc_proto",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
|
@ -289,11 +289,11 @@ cc_test(
|
|||
srcs = ["sync_set_input_stream_handler_test.cc"],
|
||||
deps = [
|
||||
":sync_set_input_stream_handler",
|
||||
":sync_set_input_stream_handler_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:test_calculators",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler_cc_proto",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -299,6 +299,7 @@ mediapipe_cc_test(
|
|||
requires_full_emulation = False,
|
||||
deps = [
|
||||
":node_chain_subgraph_cc_proto",
|
||||
":node_chain_subgraph_options_lib",
|
||||
":options_field_util",
|
||||
":options_registry",
|
||||
":options_syntax_util",
|
||||
|
@ -313,7 +314,6 @@ mediapipe_cc_test(
|
|||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/testdata:night_light_calculator_cc_proto",
|
||||
"//mediapipe/framework/testdata:night_light_calculator_options_lib",
|
||||
"//mediapipe/framework/tool:node_chain_subgraph_options_lib",
|
||||
"//mediapipe/util:header_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -422,9 +422,9 @@ cc_library(
|
|||
srcs = ["source.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":source_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:source_cc_proto",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -485,13 +485,13 @@ cc_library(
|
|||
hdrs = ["template_expander.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":calculator_graph_template_cc_proto",
|
||||
":proto_util_lite",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:numbers",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:calculator_graph_template_cc_proto",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -506,6 +506,7 @@ cc_library(
|
|||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":calculator_graph_template_cc_proto",
|
||||
":proto_util_lite",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework/deps:proto_descriptor_cc_proto",
|
||||
|
@ -515,7 +516,6 @@ cc_library(
|
|||
"//mediapipe/framework/port:map_util",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:calculator_graph_template_cc_proto",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -661,8 +661,8 @@ cc_library(
|
|||
hdrs = ["simulation_clock_executor.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":simulation_clock",
|
||||
"//mediapipe/framework:thread_pool_executor",
|
||||
"//mediapipe/framework/tool:simulation_clock",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -789,10 +789,10 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":name_util",
|
||||
":switch_container_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:switch_container_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -805,6 +805,7 @@ cc_library(
|
|||
deps = [
|
||||
":container_util",
|
||||
":options_util",
|
||||
":switch_container_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework/deps:mathutil",
|
||||
|
@ -814,7 +815,6 @@ cc_library(
|
|||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
||||
"//mediapipe/framework/tool:switch_container_cc_proto",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -841,6 +841,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":container_util",
|
||||
":switch_container_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:collection_item_id",
|
||||
"//mediapipe/framework:input_stream_shard",
|
||||
|
@ -850,7 +851,6 @@ cc_library(
|
|||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
||||
"//mediapipe/framework/tool:switch_container_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -893,6 +893,7 @@ cc_library(
|
|||
":container_util",
|
||||
":name_util",
|
||||
":subgraph_expansion",
|
||||
":switch_container_cc_proto",
|
||||
":switch_demux_calculator",
|
||||
":switch_mux_calculator",
|
||||
"//mediapipe/calculators/core:packet_sequencer_calculator",
|
||||
|
@ -904,7 +905,6 @@ cc_library(
|
|||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:switch_container_cc_proto",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
|
|
@ -564,6 +564,7 @@ cc_library(
|
|||
name = "gpu_shared_data_internal_stub",
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":gl_context_options_cc_proto",
|
||||
":graph_support",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_node",
|
||||
|
@ -571,7 +572,6 @@ cc_library(
|
|||
"//mediapipe/framework:port",
|
||||
"//mediapipe/framework/deps:no_destructor",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/gpu:gl_context_options_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -592,7 +592,7 @@ cc_library(
|
|||
}),
|
||||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
"//mediapipe/gpu:gl_context_options_cc_proto",
|
||||
":gl_context_options_cc_proto",
|
||||
":graph_support",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:executor",
|
||||
|
@ -833,10 +833,10 @@ cc_library(
|
|||
deps = [
|
||||
":gl_base",
|
||||
":gl_simple_shaders",
|
||||
":scale_mode_cc_proto",
|
||||
":shader_util",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/gpu:scale_mode_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -907,8 +907,8 @@ proto_library(
|
|||
srcs = ["gl_scaler_calculator.proto"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":scale_mode_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/gpu:scale_mode_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -930,6 +930,7 @@ cc_library(
|
|||
deps = [
|
||||
":gl_calculator_helper",
|
||||
":gl_quad_renderer",
|
||||
":gl_scaler_calculator_cc_proto",
|
||||
":gl_simple_shaders",
|
||||
":shader_util",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -937,7 +938,6 @@ cc_library(
|
|||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/tool:options_util",
|
||||
"//mediapipe/gpu:gl_scaler_calculator_cc_proto",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -950,13 +950,13 @@ cc_library(
|
|||
":egl_surface_holder",
|
||||
":gl_calculator_helper",
|
||||
":gl_quad_renderer",
|
||||
":gl_surface_sink_calculator_cc_proto",
|
||||
":gpu_buffer",
|
||||
":shader_util",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/gpu:gl_surface_sink_calculator_cc_proto",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -966,8 +966,8 @@ proto_library(
|
|||
name = "gl_surface_sink_calculator_proto",
|
||||
srcs = ["gl_surface_sink_calculator.proto"],
|
||||
deps = [
|
||||
":scale_mode_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/gpu:scale_mode_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -15,10 +15,13 @@
|
|||
package com.google.mediapipe.framework;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.PixelFormat;
|
||||
import android.media.Image;
|
||||
import com.google.mediapipe.framework.image.BitmapExtractor;
|
||||
import com.google.mediapipe.framework.image.ByteBufferExtractor;
|
||||
import com.google.mediapipe.framework.image.MPImage;
|
||||
import com.google.mediapipe.framework.image.MPImageProperties;
|
||||
import com.google.mediapipe.framework.image.MediaImageExtractor;
|
||||
import java.nio.ByteBuffer;
|
||||
|
||||
// TODO: use Preconditions in this file.
|
||||
|
@ -97,7 +100,17 @@ public class AndroidPacketCreator extends PacketCreator {
|
|||
}
|
||||
return Packet.create(nativeCreateRgbaImage(mediapipeGraph.getNativeHandle(), bitmap));
|
||||
}
|
||||
|
||||
if (properties.getStorageType() == MPImage.STORAGE_TYPE_MEDIA_IMAGE) {
|
||||
Image mediaImage = MediaImageExtractor.extract(image);
|
||||
if (mediaImage.getFormat() != PixelFormat.RGBA_8888) {
|
||||
throw new UnsupportedOperationException("Android media image must use RGBA_8888 config.");
|
||||
}
|
||||
return createImage(
|
||||
mediaImage.getPlanes()[0].getBuffer(),
|
||||
mediaImage.getWidth(),
|
||||
mediaImage.getHeight(),
|
||||
/* numChannels= */ 4);
|
||||
}
|
||||
// Unsupported type.
|
||||
throw new UnsupportedOperationException(
|
||||
"Unsupported Image container type: " + properties.getStorageType());
|
||||
|
|
|
@ -14,6 +14,10 @@
|
|||
|
||||
package com.google.mediapipe.framework;
|
||||
|
||||
import com.google.common.flogger.FluentLogger;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* A {@link TextureFrame} that represents a texture produced by MediaPipe.
|
||||
*
|
||||
|
@ -21,6 +25,7 @@ package com.google.mediapipe.framework;
|
|||
* method.
|
||||
*/
|
||||
public class GraphTextureFrame implements TextureFrame {
|
||||
private static final FluentLogger logger = FluentLogger.forEnclosingClass();
|
||||
private long nativeBufferHandle;
|
||||
// We cache these to be able to get them without a JNI call.
|
||||
private int textureName;
|
||||
|
@ -30,6 +35,8 @@ public class GraphTextureFrame implements TextureFrame {
|
|||
// True when created with PacketGetter.getTextureFrameDeferredSync(). This will result in gpuWait
|
||||
// when calling getTextureName().
|
||||
private final boolean deferredSync;
|
||||
private final Set<Long> activeConsumerContextHandleSet = new HashSet<>();
|
||||
private int refCount = 1;
|
||||
|
||||
GraphTextureFrame(long nativeHandle, long timestamp) {
|
||||
this(nativeHandle, timestamp, false);
|
||||
|
@ -54,17 +61,19 @@ public class GraphTextureFrame implements TextureFrame {
|
|||
* condition if release() is called after the if-check for nativeBufferHandle is already passed.
|
||||
*/
|
||||
@Override
|
||||
public int getTextureName() {
|
||||
public synchronized int getTextureName() {
|
||||
// Return special texture id 0 if handle is 0 i.e. frame is already released.
|
||||
if (nativeBufferHandle == 0) {
|
||||
return 0;
|
||||
}
|
||||
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
|
||||
// PacketGetter.getTextureFrameDeferredSync().
|
||||
if (deferredSync) {
|
||||
// Note that, if a CPU wait has already been done, the sync point will have been
|
||||
// cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait.
|
||||
nativeGpuWait(nativeBufferHandle);
|
||||
if (activeConsumerContextHandleSet.add(nativeGetCurrentExternalContextHandle())) {
|
||||
// Gpu wait only if deferredSync is true, such as when this GraphTextureFrame is created using
|
||||
// PacketGetter.getTextureFrameDeferredSync().
|
||||
if (deferredSync) {
|
||||
// Note that, if a CPU wait has already been done, the sync point will have been
|
||||
// cleared and this will turn into a no-op. See GlFenceSyncPoint::Wait.
|
||||
nativeGpuWait(nativeBufferHandle);
|
||||
}
|
||||
}
|
||||
return textureName;
|
||||
}
|
||||
|
@ -86,15 +95,31 @@ public class GraphTextureFrame implements TextureFrame {
|
|||
return timestamp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsRetain() {
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void retain() {
|
||||
// TODO: check that refCount is > 0 and handle is not 0.
|
||||
refCount++;
|
||||
}
|
||||
|
||||
/**
|
||||
* Releases a reference to the underlying buffer.
|
||||
*
|
||||
* <p>The consumer calls this when it is done using the texture.
|
||||
*/
|
||||
@Override
|
||||
public void release() {
|
||||
GlSyncToken consumerToken =
|
||||
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
|
||||
public synchronized void release() {
|
||||
GlSyncToken consumerToken = null;
|
||||
// Note that this remove should be moved to the other overload of release when b/68808951 is
|
||||
// addressed.
|
||||
if (activeConsumerContextHandleSet.remove(nativeGetCurrentExternalContextHandle())) {
|
||||
consumerToken =
|
||||
new GraphGlSyncToken(nativeCreateSyncTokenForCurrentExternalContext(nativeBufferHandle));
|
||||
}
|
||||
release(consumerToken);
|
||||
}
|
||||
|
||||
|
@ -108,18 +133,40 @@ public class GraphTextureFrame implements TextureFrame {
|
|||
* currently cannot create a GlSyncToken, so they cannot call this method.
|
||||
*/
|
||||
@Override
|
||||
public void release(GlSyncToken consumerSyncToken) {
|
||||
if (nativeBufferHandle != 0) {
|
||||
long token = consumerSyncToken == null ? 0 : consumerSyncToken.nativeToken();
|
||||
nativeReleaseBuffer(nativeBufferHandle, token);
|
||||
nativeBufferHandle = 0;
|
||||
public synchronized void release(GlSyncToken consumerSyncToken) {
|
||||
if (nativeBufferHandle == 0) {
|
||||
if (consumerSyncToken != null) {
|
||||
logger.atWarning().log("release with sync token, but handle is 0");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (consumerSyncToken != null) {
|
||||
long token = consumerSyncToken.nativeToken();
|
||||
nativeDidRead(nativeBufferHandle, token);
|
||||
// We should remove the token's context from activeConsumerContextHandleSet here, but for now
|
||||
// we do it in the release(void) overload.
|
||||
consumerSyncToken.release();
|
||||
}
|
||||
|
||||
refCount--;
|
||||
if (refCount <= 0) {
|
||||
nativeReleaseBuffer(nativeBufferHandle);
|
||||
nativeBufferHandle = 0;
|
||||
}
|
||||
}
|
||||
|
||||
private native void nativeReleaseBuffer(long nativeHandle, long consumerSyncToken);
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
if (refCount >= 0 || nativeBufferHandle != 0) {
|
||||
logger.atWarning().log("release was not called before finalize");
|
||||
}
|
||||
if (!activeConsumerContextHandleSet.isEmpty()) {
|
||||
logger.atWarning().log("active consumers did not release with sync before finalize");
|
||||
}
|
||||
}
|
||||
|
||||
private native void nativeReleaseBuffer(long nativeHandle);
|
||||
|
||||
private native int nativeGetTextureName(long nativeHandle);
|
||||
private native int nativeGetWidth(long nativeHandle);
|
||||
|
@ -128,4 +175,8 @@ public class GraphTextureFrame implements TextureFrame {
|
|||
private native void nativeGpuWait(long nativeHandle);
|
||||
|
||||
private native long nativeCreateSyncTokenForCurrentExternalContext(long nativeHandle);
|
||||
|
||||
private native long nativeGetCurrentExternalContextHandle();
|
||||
|
||||
private native void nativeDidRead(long nativeHandle, long consumerSyncToken);
|
||||
}
|
||||
|
|
|
@ -59,4 +59,18 @@ public interface TextureFrame extends TextureReleaseCallback {
|
|||
*/
|
||||
@Override
|
||||
void release(GlSyncToken syncToken);
|
||||
|
||||
/**
|
||||
* If this method returns true, this object supports the retain method, and can be used with
|
||||
* multiple consumers. Call retain for each additional consumer beyond the first; each consumer
|
||||
* should call release.
|
||||
*/
|
||||
default boolean supportsRetain() {
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Increments the reference count. Only available with some implementations of TextureFrame. */
|
||||
default void retain() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,20 +15,16 @@
|
|||
#include "mediapipe/java/com/google/mediapipe/framework/jni/graph_texture_frame_jni.h"
|
||||
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gl_context.h"
|
||||
#include "mediapipe/gpu/gl_texture_buffer.h"
|
||||
#include "mediapipe/java/com/google/mediapipe/framework/jni/jni_util.h"
|
||||
|
||||
using mediapipe::GlTextureBufferSharedPtr;
|
||||
|
||||
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)(
|
||||
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) {
|
||||
JNIEnv* env, jobject thiz, jlong nativeHandle) {
|
||||
GlTextureBufferSharedPtr* buffer =
|
||||
reinterpret_cast<GlTextureBufferSharedPtr*>(nativeHandle);
|
||||
if (consumerSyncToken) {
|
||||
mediapipe::GlSyncToken& token =
|
||||
*reinterpret_cast<mediapipe::GlSyncToken*>(consumerSyncToken);
|
||||
(*buffer)->DidRead(token);
|
||||
}
|
||||
delete buffer;
|
||||
}
|
||||
|
||||
|
@ -84,3 +80,18 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
|
|||
}
|
||||
return reinterpret_cast<jlong>(token);
|
||||
}
|
||||
|
||||
JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
|
||||
nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz) {
|
||||
return reinterpret_cast<jlong>(
|
||||
mediapipe::GlContext::GetCurrentNativeContext());
|
||||
}
|
||||
|
||||
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)(
|
||||
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken) {
|
||||
GlTextureBufferSharedPtr* buffer =
|
||||
reinterpret_cast<GlTextureBufferSharedPtr*>(nativeHandle);
|
||||
mediapipe::GlSyncToken& token =
|
||||
*reinterpret_cast<mediapipe::GlSyncToken*>(consumerSyncToken);
|
||||
(*buffer)->DidRead(token);
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ extern "C" {
|
|||
|
||||
// Releases a native mediapipe::GpuBuffer.
|
||||
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeReleaseBuffer)(
|
||||
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken);
|
||||
JNIEnv* env, jobject thiz, jlong nativeHandle);
|
||||
|
||||
JNIEXPORT jint JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeGetTextureName)(
|
||||
JNIEnv* env, jobject thiz, jlong nativeHandle);
|
||||
|
@ -44,6 +44,12 @@ JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
|
|||
nativeCreateSyncTokenForCurrentExternalContext)(JNIEnv* env, jobject thiz,
|
||||
jlong nativeHandle);
|
||||
|
||||
JNIEXPORT void JNICALL GRAPH_TEXTURE_FRAME_METHOD(nativeDidRead)(
|
||||
JNIEnv* env, jobject thiz, jlong nativeHandle, jlong consumerSyncToken);
|
||||
|
||||
JNIEXPORT jlong JNICALL GRAPH_TEXTURE_FRAME_METHOD(
|
||||
nativeGetCurrentExternalContextHandle)(JNIEnv* env, jobject thiz);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif // __cplusplus
|
||||
|
|
|
@ -17,3 +17,6 @@ from mediapipe.model_maker.python.core.utils import quantization
|
|||
from mediapipe.model_maker.python.vision import image_classifier
|
||||
from mediapipe.model_maker.python.vision import gesture_recognizer
|
||||
from mediapipe.model_maker.python.text import text_classifier
|
||||
|
||||
# Remove duplicated and non-public API
|
||||
del python
|
||||
|
|
|
@ -29,3 +29,12 @@ BertModelOptions = model_options.BertModelOptions
|
|||
SupportedModels = model_spec.SupportedModels
|
||||
TextClassifier = text_classifier.TextClassifier
|
||||
TextClassifierOptions = text_classifier_options.TextClassifierOptions
|
||||
|
||||
# Remove duplicated and non-public API
|
||||
del hyperparameters
|
||||
del dataset
|
||||
del model_options
|
||||
del model_spec
|
||||
del preprocessor # pylint: disable=undefined-variable
|
||||
del text_classifier
|
||||
del text_classifier_options
|
||||
|
|
|
@ -33,7 +33,6 @@ from mediapipe.model_maker.python.text.text_classifier import preprocessor
|
|||
from mediapipe.model_maker.python.text.text_classifier import text_classifier_options
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import metadata_writer
|
||||
from mediapipe.tasks.python.metadata.metadata_writers import text_classifier as text_classifier_writer
|
||||
from official.nlp import optimization
|
||||
|
||||
|
||||
def _validate(options: text_classifier_options.TextClassifierOptions):
|
||||
|
@ -417,8 +416,22 @@ class _BertClassifier(TextClassifier):
|
|||
total_steps = self._hparams.steps_per_epoch * self._hparams.epochs
|
||||
warmup_steps = int(total_steps * 0.1)
|
||||
initial_lr = self._hparams.learning_rate
|
||||
self._optimizer = optimization.create_optimizer(initial_lr, total_steps,
|
||||
warmup_steps)
|
||||
# Implements linear decay of the learning rate.
|
||||
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||
initial_learning_rate=initial_lr,
|
||||
decay_steps=total_steps,
|
||||
end_learning_rate=0.0,
|
||||
power=1.0)
|
||||
if warmup_steps:
|
||||
lr_schedule = model_util.WarmUp(
|
||||
initial_learning_rate=initial_lr,
|
||||
decay_schedule_fn=lr_schedule,
|
||||
warmup_steps=warmup_steps)
|
||||
|
||||
self._optimizer = tf.keras.optimizers.experimental.AdamW(
|
||||
lr_schedule, weight_decay=0.01, epsilon=1e-6, global_clipnorm=1.0)
|
||||
self._optimizer.exclude_from_weight_decay(
|
||||
var_names=["LayerNorm", "layer_norm", "bias"])
|
||||
|
||||
def _save_vocab(self, vocab_filepath: str):
|
||||
tf.io.gfile.copy(
|
||||
|
|
|
@ -146,6 +146,8 @@ py_test(
|
|||
tags = ["notsan"],
|
||||
deps = [
|
||||
":gesture_recognizer_import",
|
||||
":hyperparameters",
|
||||
":model_options",
|
||||
"//mediapipe/model_maker/python/core/utils:test_util",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
|
|
|
@ -25,3 +25,12 @@ HParams = hyperparameters.HParams
|
|||
Dataset = dataset.Dataset
|
||||
HandDataPreprocessingParams = dataset.HandDataPreprocessingParams
|
||||
GestureRecognizerOptions = gesture_recognizer_options.GestureRecognizerOptions
|
||||
|
||||
# Remove duplicated and non-public API
|
||||
del constants # pylint: disable=undefined-variable
|
||||
del dataset
|
||||
del gesture_recognizer
|
||||
del gesture_recognizer_options
|
||||
del hyperparameters
|
||||
del metadata_writer # pylint: disable=undefined-variable
|
||||
del model_options
|
||||
|
|
|
@ -173,15 +173,20 @@ class GestureRecognizer(classifier.Classifier):
|
|||
batch_size=None,
|
||||
dtype=tf.float32,
|
||||
name='hand_embedding')
|
||||
|
||||
x = tf.keras.layers.BatchNormalization()(inputs)
|
||||
x = tf.keras.layers.ReLU()(x)
|
||||
x = inputs
|
||||
dropout_rate = self._model_options.dropout_rate
|
||||
x = tf.keras.layers.Dropout(rate=dropout_rate, name='dropout')(x)
|
||||
for i, width in enumerate(self._model_options.layer_widths):
|
||||
x = tf.keras.layers.BatchNormalization()(x)
|
||||
x = tf.keras.layers.ReLU()(x)
|
||||
x = tf.keras.layers.Dropout(rate=dropout_rate)(x)
|
||||
x = tf.keras.layers.Dense(width, name=f'custom_gesture_recognizer_{i}')(x)
|
||||
x = tf.keras.layers.BatchNormalization()(x)
|
||||
x = tf.keras.layers.ReLU()(x)
|
||||
x = tf.keras.layers.Dropout(rate=dropout_rate)(x)
|
||||
outputs = tf.keras.layers.Dense(
|
||||
self._num_classes,
|
||||
activation='softmax',
|
||||
name='custom_gesture_recognizer')(
|
||||
name='custom_gesture_recognizer_out')(
|
||||
x)
|
||||
|
||||
self._model = tf.keras.Model(inputs=inputs, outputs=outputs)
|
||||
|
|
|
@ -23,6 +23,8 @@ import tensorflow as tf
|
|||
|
||||
from mediapipe.model_maker.python.core.utils import test_util
|
||||
from mediapipe.model_maker.python.vision import gesture_recognizer
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.gesture_recognizer import model_options
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
_TEST_DATA_DIR = 'mediapipe/model_maker/python/vision/gesture_recognizer/testdata'
|
||||
|
@ -48,11 +50,11 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
self._train_data, self._validation_data = all_data.split(0.9)
|
||||
|
||||
def test_gesture_recognizer_model(self):
|
||||
model_options = gesture_recognizer.ModelOptions()
|
||||
mo = gesture_recognizer.ModelOptions()
|
||||
hparams = gesture_recognizer.HParams(
|
||||
export_dir=tempfile.mkdtemp(), epochs=2)
|
||||
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
|
||||
model_options=model_options, hparams=hparams)
|
||||
model_options=mo, hparams=hparams)
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._validation_data,
|
||||
|
@ -60,12 +62,38 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
|
||||
self._test_accuracy(model)
|
||||
|
||||
def test_export_gesture_recognizer_model(self):
|
||||
model_options = gesture_recognizer.ModelOptions()
|
||||
@unittest_mock.patch.object(
|
||||
tf.keras.layers, 'Dense', wraps=tf.keras.layers.Dense)
|
||||
def test_gesture_recognizer_model_layer_widths(self, mock_dense):
|
||||
layer_widths = [64, 32]
|
||||
mo = gesture_recognizer.ModelOptions(layer_widths=layer_widths)
|
||||
hparams = gesture_recognizer.HParams(
|
||||
export_dir=tempfile.mkdtemp(), epochs=2)
|
||||
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
|
||||
model_options=model_options, hparams=hparams)
|
||||
model_options=mo, hparams=hparams)
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._validation_data,
|
||||
options=gesture_recognizer_options)
|
||||
expected_calls = [
|
||||
unittest_mock.call(w, name=f'custom_gesture_recognizer_{i}')
|
||||
for i, w in enumerate(layer_widths)
|
||||
]
|
||||
expected_calls.append(
|
||||
unittest_mock.call(
|
||||
len(self._train_data.label_names),
|
||||
activation='softmax',
|
||||
name='custom_gesture_recognizer_out'))
|
||||
self.assertLen(mock_dense.call_args_list, len(expected_calls))
|
||||
mock_dense.assert_has_calls(expected_calls)
|
||||
self._test_accuracy(model)
|
||||
|
||||
def test_export_gesture_recognizer_model(self):
|
||||
mo = gesture_recognizer.ModelOptions()
|
||||
hparams = gesture_recognizer.HParams(
|
||||
export_dir=tempfile.mkdtemp(), epochs=2)
|
||||
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
|
||||
model_options=mo, hparams=hparams)
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
train_data=self._train_data,
|
||||
validation_data=self._validation_data,
|
||||
|
@ -102,12 +130,12 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
self.assertGreater(accuracy, threshold)
|
||||
|
||||
@unittest_mock.patch.object(
|
||||
gesture_recognizer.hyperparameters,
|
||||
hyperparameters,
|
||||
'HParams',
|
||||
autospec=True,
|
||||
return_value=gesture_recognizer.HParams(epochs=1))
|
||||
@unittest_mock.patch.object(
|
||||
gesture_recognizer.model_options,
|
||||
model_options,
|
||||
'GestureRecognizerModelOptions',
|
||||
autospec=True,
|
||||
return_value=gesture_recognizer.ModelOptions())
|
||||
|
@ -122,11 +150,11 @@ class GestureRecognizerTest(tf.test.TestCase):
|
|||
mock_model_options.assert_called_once()
|
||||
|
||||
def test_continual_training_by_loading_checkpoint(self):
|
||||
model_options = gesture_recognizer.ModelOptions()
|
||||
mo = gesture_recognizer.ModelOptions()
|
||||
hparams = gesture_recognizer.HParams(
|
||||
export_dir=tempfile.mkdtemp(), epochs=2)
|
||||
gesture_recognizer_options = gesture_recognizer.GestureRecognizerOptions(
|
||||
model_options=model_options, hparams=hparams)
|
||||
model_options=mo, hparams=hparams)
|
||||
mock_stdout = io.StringIO()
|
||||
with mock.patch('sys.stdout', mock_stdout):
|
||||
model = gesture_recognizer.GestureRecognizer.create(
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"""Configurable model options for gesture recognizer models."""
|
||||
|
||||
import dataclasses
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -23,5 +24,10 @@ class GestureRecognizerModelOptions:
|
|||
Attributes:
|
||||
dropout_rate: The fraction of the input units to drop, used in dropout
|
||||
layer.
|
||||
layer_widths: A list of hidden layer widths for the gesture model. Each
|
||||
element in the list will create a new hidden layer with the specified
|
||||
width. The hidden layers are separated with BatchNorm, Dropout, and ReLU.
|
||||
Defaults to an empty list(no hidden layers).
|
||||
"""
|
||||
dropout_rate: float = 0.05
|
||||
layer_widths: List[int] = dataclasses.field(default_factory=list)
|
||||
|
|
|
@ -121,7 +121,9 @@ py_library(
|
|||
srcs = ["image_classifier_test.py"],
|
||||
data = ["//mediapipe/model_maker/python/vision/image_classifier/testdata"],
|
||||
deps = [
|
||||
":hyperparameters",
|
||||
":image_classifier_import",
|
||||
":model_options",
|
||||
"//mediapipe/tasks/python/test:test_utils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -27,3 +27,12 @@ ModelOptions = model_options.ImageClassifierModelOptions
|
|||
ModelSpec = model_spec.ModelSpec
|
||||
SupportedModels = model_spec.SupportedModels
|
||||
ImageClassifierOptions = image_classifier_options.ImageClassifierOptions
|
||||
|
||||
# Remove duplicated and non-public API
|
||||
del dataset
|
||||
del hyperparameters
|
||||
del image_classifier
|
||||
del image_classifier_options
|
||||
del model_options
|
||||
del model_spec
|
||||
del train_image_classifier_lib # pylint: disable=undefined-variable
|
||||
|
|
|
@ -24,6 +24,8 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
|
||||
from mediapipe.model_maker.python.vision import image_classifier
|
||||
from mediapipe.model_maker.python.vision.image_classifier import hyperparameters
|
||||
from mediapipe.model_maker.python.vision.image_classifier import model_options
|
||||
from mediapipe.tasks.python.test import test_utils
|
||||
|
||||
|
||||
|
@ -159,15 +161,15 @@ class ImageClassifierTest(tf.test.TestCase, parameterized.TestCase):
|
|||
self.assertGreaterEqual(accuracy, threshold)
|
||||
|
||||
@unittest_mock.patch.object(
|
||||
image_classifier.hyperparameters,
|
||||
hyperparameters,
|
||||
'HParams',
|
||||
autospec=True,
|
||||
return_value=image_classifier.HParams(epochs=1))
|
||||
return_value=hyperparameters.HParams(epochs=1))
|
||||
@unittest_mock.patch.object(
|
||||
image_classifier.model_options,
|
||||
model_options,
|
||||
'ImageClassifierModelOptions',
|
||||
autospec=True,
|
||||
return_value=image_classifier.ModelOptions())
|
||||
return_value=model_options.ImageClassifierModelOptions())
|
||||
def test_create_hparams_and_model_options_if_none_in_image_classifier_options(
|
||||
self, mock_hparams, mock_model_options):
|
||||
options = image_classifier.ImageClassifierOptions(
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
absl-py
|
||||
mediapipe==0.9.1
|
||||
mediapipe==0.9.0.1
|
||||
numpy
|
||||
opencv-python
|
||||
tensorflow>=2.10
|
||||
|
|
|
@ -28,6 +28,8 @@ import PIL.Image
|
|||
from mediapipe.python._framework_bindings import image
|
||||
from mediapipe.python._framework_bindings import image_frame
|
||||
|
||||
TEST_IMAGE_PATH = 'mediapipe/python/solutions/testdata'
|
||||
|
||||
Image = image.Image
|
||||
ImageFormat = image_frame.ImageFormat
|
||||
|
||||
|
@ -187,5 +189,26 @@ class ImageTest(absltest.TestCase):
|
|||
gc.collect()
|
||||
self.assertEqual(sys.getrefcount(rgb_image), initial_ref_count)
|
||||
|
||||
def test_image_create_from_cvmat(self):
|
||||
image_path = os.path.join(os.path.dirname(__file__),
|
||||
'solutions/testdata/hands.jpg')
|
||||
mat = cv2.imread(image_path).astype(np.uint8)
|
||||
mat = cv2.cvtColor(mat, cv2.COLOR_BGR2RGB)
|
||||
rgb_image = Image(image_format=ImageFormat.SRGB, data=mat)
|
||||
self.assertEqual(rgb_image.width, 720)
|
||||
self.assertEqual(rgb_image.height, 382)
|
||||
self.assertEqual(rgb_image.channels, 3)
|
||||
self.assertEqual(rgb_image.image_format, ImageFormat.SRGB)
|
||||
self.assertTrue(np.array_equal(mat, rgb_image.numpy_view()))
|
||||
|
||||
def test_image_create_from_file(self):
|
||||
image_path = os.path.join(os.path.dirname(__file__),
|
||||
'solutions/testdata/hands.jpg')
|
||||
loaded_image = Image.create_from_file(image_path)
|
||||
self.assertEqual(loaded_image.width, 720)
|
||||
self.assertEqual(loaded_image.height, 382)
|
||||
self.assertEqual(loaded_image.channels, 3)
|
||||
self.assertEqual(loaded_image.image_format, ImageFormat.SRGB)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
|
@ -157,7 +157,7 @@ class PacketTest(absltest.TestCase):
|
|||
p.timestamp = 0
|
||||
self.assertAlmostEqual(packet_getter.get_float(p), 0.42)
|
||||
self.assertEqual(p.timestamp, 0)
|
||||
p2 = packet_creator.create_float(np.float(0.42))
|
||||
p2 = packet_creator.create_float(float(0.42))
|
||||
p2.timestamp = 0
|
||||
self.assertAlmostEqual(packet_getter.get_float(p2), 0.42)
|
||||
self.assertEqual(p2.timestamp, 0)
|
||||
|
|
|
@ -48,16 +48,20 @@ void ImageSubmodule(pybind11::module* module) {
|
|||
become immutable after creation.
|
||||
|
||||
Creation examples:
|
||||
import cv2
|
||||
cv_mat = cv2.imread(input_file)[:, :, ::-1]
|
||||
rgb_frame = mp.Image(format=ImageFormat.SRGB, data=cv_mat)
|
||||
gray_frame = mp.Image(
|
||||
format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
|
||||
|
||||
from PIL import Image
|
||||
pil_img = Image.new('RGB', (60, 30), color = 'red')
|
||||
image = mp.Image(
|
||||
format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
|
||||
```python
|
||||
import cv2
|
||||
cv_mat = cv2.imread(input_file)[:, :, ::-1]
|
||||
rgb_frame = mp.Image(image_format=ImageFormat.SRGB, data=cv_mat)
|
||||
gray_frame = mp.Image(
|
||||
image_format=ImageFormat.GRAY,
|
||||
data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
|
||||
|
||||
from PIL import Image
|
||||
pil_img = Image.new('RGB', (60, 30), color = 'red')
|
||||
image = mp.Image(
|
||||
image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
|
||||
```
|
||||
|
||||
The pixel data in an Image can be retrieved as a numpy ndarray by calling
|
||||
`Image.numpy_view()`. The returned numpy ndarray is a reference to the
|
||||
|
@ -65,15 +69,18 @@ void ImageSubmodule(pybind11::module* module) {
|
|||
numpy ndarray, it's required to obtain a copy of it.
|
||||
|
||||
Pixel data retrieval examples:
|
||||
for channel in range(num_channel):
|
||||
for col in range(width):
|
||||
for row in range(height):
|
||||
print(image[row, col, channel])
|
||||
|
||||
output_ndarray = image.numpy_view()
|
||||
print(output_ndarray[0, 0, 0])
|
||||
copied_ndarray = np.copy(output_ndarray)
|
||||
copied_ndarray[0,0,0] = 0
|
||||
```python
|
||||
for channel in range(num_channel):
|
||||
for col in range(width):
|
||||
for row in range(height):
|
||||
print(image[row, col, channel])
|
||||
|
||||
output_ndarray = image.numpy_view()
|
||||
print(output_ndarray[0, 0, 0])
|
||||
copied_ndarray = np.copy(output_ndarray)
|
||||
copied_ndarray[0,0,0] = 0
|
||||
```
|
||||
)doc",
|
||||
py::dynamic_attr());
|
||||
|
||||
|
@ -156,9 +163,11 @@ void ImageSubmodule(pybind11::module* module) {
|
|||
An unwritable numpy ndarray.
|
||||
|
||||
Examples:
|
||||
```
|
||||
output_ndarray = image.numpy_view()
|
||||
copied_ndarray = np.copy(output_ndarray)
|
||||
copied_ndarray[0,0,0] = 0
|
||||
```
|
||||
)doc");
|
||||
|
||||
image.def(
|
||||
|
@ -191,10 +200,12 @@ void ImageSubmodule(pybind11::module* module) {
|
|||
IndexError: If the index is invalid or out of bounds.
|
||||
|
||||
Examples:
|
||||
```
|
||||
for channel in range(num_channel):
|
||||
for col in range(width):
|
||||
for row in range(height):
|
||||
print(image[row, col, channel])
|
||||
```
|
||||
)doc");
|
||||
|
||||
image
|
||||
|
@ -224,7 +235,9 @@ void ImageSubmodule(pybind11::module* module) {
|
|||
A boolean.
|
||||
|
||||
Examples:
|
||||
```
|
||||
image.is_aligned(16)
|
||||
```
|
||||
)doc");
|
||||
|
||||
image.def_static(
|
||||
|
|
|
@ -83,14 +83,15 @@ void ImageFrameSubmodule(pybind11::module* module) {
|
|||
Creation examples:
|
||||
import cv2
|
||||
cv_mat = cv2.imread(input_file)[:, :, ::-1]
|
||||
rgb_frame = mp.ImageFrame(format=ImageFormat.SRGB, data=cv_mat)
|
||||
rgb_frame = mp.ImageFrame(image_format=ImageFormat.SRGB, data=cv_mat)
|
||||
gray_frame = mp.ImageFrame(
|
||||
format=ImageFormat.GRAY, data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
|
||||
image_format=ImageFormat.GRAY,
|
||||
data=cv2.cvtColor(cv_mat, cv2.COLOR_RGB2GRAY))
|
||||
|
||||
from PIL import Image
|
||||
pil_img = Image.new('RGB', (60, 30), color = 'red')
|
||||
image_frame = mp.ImageFrame(
|
||||
format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
|
||||
image_format=mp.ImageFormat.SRGB, data=np.asarray(pil_img))
|
||||
|
||||
The pixel data in an ImageFrame can be retrieved as a numpy ndarray by calling
|
||||
`ImageFrame.numpy_view()`. The returned numpy ndarray is a reference to the
|
||||
|
|
|
@ -23,4 +23,3 @@ objc_library(
|
|||
],
|
||||
module_name = "MPPCommon",
|
||||
)
|
||||
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* @enum TFLSupportErrorCode
|
||||
* This enum specifies error codes for TensorFlow Lite Task Library.
|
||||
* It maintains a 1:1 mapping to TfLiteSupportErrorCode of C libray.
|
||||
* @enum MPPTasksErrorCode
|
||||
* This enum specifies error codes for Mediapipe Task Library.
|
||||
* It maintains a 1:1 mapping to MediaPipeTasksStatus of the C ++libray.
|
||||
*/
|
||||
typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
|
||||
|
||||
|
@ -48,16 +48,16 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
|
|||
MPPTasksErrorCodeFileReadError,
|
||||
// I/O error when mmap-ing file.
|
||||
MPPTasksErrorCodeFileMmapError,
|
||||
// ZIP I/O error when unpacMPPTasksErrorCodeing the zip file.
|
||||
// ZIP I/O error when unpacking the zip file.
|
||||
MPPTasksErrorCodeFileZipError,
|
||||
|
||||
// TensorFlow Lite metadata error codes.
|
||||
|
||||
// Unexpected schema version (aMPPTasksErrorCodea file_identifier) in the Metadata FlatBuffer.
|
||||
// Unexpected schema version (aka file_identifier) in the Metadata FlatBuffer.
|
||||
MPPTasksErrorCodeMetadataInvalidSchemaVersionError = 200,
|
||||
// No such associated file within metadata, or file has not been pacMPPTasksErrorCodeed.
|
||||
// No such associated file within metadata, or file has not been packed.
|
||||
MPPTasksErrorCodeMetadataAssociatedFileNotFoundError,
|
||||
// ZIP I/O error when unpacMPPTasksErrorCodeing an associated file.
|
||||
// ZIP I/O error when unpacking an associated file.
|
||||
MPPTasksErrorCodeMetadataAssociatedFileZipError,
|
||||
// Inconsistency error between the metadata and actual TF Lite model.
|
||||
// E.g.: number of labels and output tensor values differ.
|
||||
|
@ -167,11 +167,10 @@ typedef NS_ENUM(NSUInteger, MPPTasksErrorCode) {
|
|||
// Task graph config is invalid.
|
||||
MPPTasksErrorCodeInvalidTaskGraphConfigError,
|
||||
|
||||
// The first error code in MPPTasksErrorCode (for internal use only).
|
||||
MPPTasksErrorCodeFirst = MPPTasksErrorCodeError,
|
||||
|
||||
/**
|
||||
* The last error code in TFLSupportErrorCode (for internal use only).
|
||||
*/
|
||||
// The last error code in MPPTasksErrorCode (for internal use only).
|
||||
MPPTasksErrorCodeLast = MPPTasksErrorCodeInvalidTaskGraphConfigError,
|
||||
|
||||
} NS_SWIFT_NAME(TasksErrorCode);
|
||||
|
|
|
@ -24,7 +24,7 @@ extern NSString *const MPPTasksErrorDomain;
|
|||
@interface MPPCommonUtils : NSObject
|
||||
|
||||
/**
|
||||
* Creates and saves an NSError in the Mediapipe task library domain, with the given code and
|
||||
* Creates and saves an NSError in the MediPipe task library domain, with the given code and
|
||||
* description.
|
||||
*
|
||||
* @param code Error code.
|
||||
|
@ -51,9 +51,9 @@ extern NSString *const MPPTasksErrorDomain;
|
|||
description:(NSString *)description;
|
||||
|
||||
/**
|
||||
* Converts an absl status to an NSError.
|
||||
* Converts an absl::Status to an NSError.
|
||||
*
|
||||
* @param status absl status.
|
||||
* @param status absl::Status.
|
||||
* @param error Pointer to the memory location where the created error should be saved. If `nil`,
|
||||
* no error will be saved.
|
||||
*/
|
||||
|
@ -61,15 +61,15 @@ extern NSString *const MPPTasksErrorDomain;
|
|||
|
||||
/**
|
||||
* Allocates a block of memory with the specified size and returns a pointer to it. If memory
|
||||
* cannot be allocated because of an invalid memSize, it saves an error. In other cases, it
|
||||
* cannot be allocated because of an invalid `memSize`, it saves an error. In other cases, it
|
||||
* terminates program execution.
|
||||
*
|
||||
* @param memSize size of memory to be allocated
|
||||
* @param error Pointer to the memory location where errors if any should be saved. If `nil`, no
|
||||
* error will be saved.
|
||||
*
|
||||
* @return Pointer to the allocated block of memory on successfull allocation. nil in case as
|
||||
* error is encountered because of invalid memSize. If failure is due to any other reason, method
|
||||
* @return Pointer to the allocated block of memory on successfull allocation. `nil` in case as
|
||||
* error is encountered because of invalid `memSize`. If failure is due to any other reason, method
|
||||
* terminates program execution.
|
||||
*/
|
||||
+ (void *)mallocWithSize:(size_t)memSize error:(NSError **)error;
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include "mediapipe/tasks/cc/common.h"
|
||||
|
||||
/** Error domain of MediaPipe task library errors. */
|
||||
NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks";
|
||||
NSString *const MPPTasksErrorDomain = @"com.google.mediapipe.tasks";
|
||||
|
||||
@implementation MPPCommonUtils
|
||||
|
||||
|
@ -68,7 +68,7 @@ NSString *const MPPTasksErrorDomain = @"org.mediapipe.tasks";
|
|||
if (status.ok()) {
|
||||
return YES;
|
||||
}
|
||||
// Payload of absl::Status created by the Media Pipe task library stores an appropriate value of
|
||||
// Payload of absl::Status created by the MediaPipe task library stores an appropriate value of
|
||||
// the enum MediaPipeTasksStatus. The integer value corresponding to the MediaPipeTasksStatus enum
|
||||
// stored in the payload is extracted here to later map to the appropriate error code to be
|
||||
// returned. In cases where the enum is not stored in (payload is NULL or the payload string
|
||||
|
|
|
@ -17,25 +17,38 @@
|
|||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* Holds settings for any single iOS Mediapipe classification task.
|
||||
* Holds settings for any single iOS MediaPipe classification task.
|
||||
*/
|
||||
NS_SWIFT_NAME(ClassifierOptions)
|
||||
@interface MPPClassifierOptions : NSObject <NSCopying>
|
||||
|
||||
/** If set, all classes in this list will be filtered out from the results . */
|
||||
@property(nonatomic, copy) NSArray<NSString *> *labelDenyList;
|
||||
|
||||
/** If set, all classes not in this list will be filtered out from the results . */
|
||||
@property(nonatomic, copy) NSArray<NSString *> *labelAllowList;
|
||||
|
||||
/** Display names local for display names*/
|
||||
/** The locale to use for display names specified through the TFLite Model
|
||||
* Metadata, if any. Defaults to English.
|
||||
*/
|
||||
@property(nonatomic, copy) NSString *displayNamesLocale;
|
||||
|
||||
/** Results with score threshold greater than this value are returned . */
|
||||
/** The maximum number of top-scored classification results to return. If < 0,
|
||||
* all available results will be returned. If 0, an invalid argument error is
|
||||
* returned.
|
||||
*/
|
||||
@property(nonatomic) NSInteger maxResults;
|
||||
|
||||
/** Score threshold to override the one provided in the model metadata (if any).
|
||||
* Results below this value are rejected.
|
||||
*/
|
||||
@property(nonatomic) float scoreThreshold;
|
||||
|
||||
/** Limit to the number of classes that can be returned in results. */
|
||||
@property(nonatomic) NSInteger maxResults;
|
||||
/** The allowlist of category names. If non-empty, detection results whose
|
||||
* category name is not in this set will be filtered out. Duplicate or unknown
|
||||
* category names are ignored. Mutually exclusive with categoryDenylist.
|
||||
*/
|
||||
@property(nonatomic, copy) NSArray<NSString *> *categoryAllowlist;
|
||||
|
||||
/** The denylist of category names. If non-empty, detection results whose
|
||||
* category name is in this set will be filtered out. Duplicate or unknown
|
||||
* category names are ignored. Mutually exclusive with categoryAllowlist.
|
||||
*/
|
||||
@property(nonatomic, copy) NSArray<NSString *> *categoryDenylist;
|
||||
|
||||
@end
|
||||
|
||||
|
|
|
@ -30,8 +30,8 @@
|
|||
|
||||
classifierOptions.scoreThreshold = self.scoreThreshold;
|
||||
classifierOptions.maxResults = self.maxResults;
|
||||
classifierOptions.labelDenyList = self.labelDenyList;
|
||||
classifierOptions.labelAllowList = self.labelAllowList;
|
||||
classifierOptions.categoryDenylist = self.categoryDenylist;
|
||||
classifierOptions.categoryAllowlist = self.categoryAllowlist;
|
||||
classifierOptions.displayNamesLocale = self.displayNamesLocale;
|
||||
|
||||
return classifierOptions;
|
||||
|
|
|
@ -20,17 +20,23 @@ using ClassifierOptionsProto = ::mediapipe::tasks::components::processors::proto
|
|||
}
|
||||
|
||||
@implementation MPPClassifierOptions (Helpers)
|
||||
|
||||
- (void)copyToProto:(ClassifierOptionsProto *)classifierOptionsProto {
|
||||
classifierOptionsProto->Clear();
|
||||
|
||||
if (self.displayNamesLocale) {
|
||||
classifierOptionsProto->set_display_names_locale(self.displayNamesLocale.cppString);
|
||||
}
|
||||
|
||||
classifierOptionsProto->set_max_results((int)self.maxResults);
|
||||
|
||||
classifierOptionsProto->set_score_threshold(self.scoreThreshold);
|
||||
for (NSString *category in self.labelAllowList) {
|
||||
|
||||
for (NSString *category in self.categoryAllowlist) {
|
||||
classifierOptionsProto->add_category_allowlist(category.cppString);
|
||||
}
|
||||
|
||||
for (NSString *category in self.labelDenyList) {
|
||||
for (NSString *category in self.categoryDenylist) {
|
||||
classifierOptionsProto->add_category_denylist(category.cppString);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,19 +16,35 @@ package(default_visibility = ["//mediapipe/tasks:internal"])
|
|||
|
||||
licenses(["notice"])
|
||||
|
||||
objc_library(
|
||||
name = "MPPBaseOptions",
|
||||
srcs = ["sources/MPPBaseOptions.m"],
|
||||
hdrs = ["sources/MPPBaseOptions.h"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPTaskOptions",
|
||||
srcs = ["sources/MPPTaskOptions.m"],
|
||||
hdrs = ["sources/MPPTaskOptions.h"],
|
||||
copts = [
|
||||
"-ObjC++",
|
||||
"-std=c++17",
|
||||
],
|
||||
deps = [
|
||||
":MPPBaseOptions",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPTaskResult",
|
||||
srcs = ["sources/MPPTaskResult.m"],
|
||||
hdrs = ["sources/MPPTaskResult.h"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPTaskOptionsProtocol",
|
||||
hdrs = ["sources/MPPTaskOptionsProtocol.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPTaskInfo",
|
||||
srcs = ["sources/MPPTaskInfo.mm"],
|
||||
|
@ -64,32 +80,12 @@ objc_library(
|
|||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPTaskResult",
|
||||
srcs = ["sources/MPPTaskResult.m"],
|
||||
hdrs = ["sources/MPPTaskResult.h"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPBaseOptions",
|
||||
srcs = ["sources/MPPBaseOptions.m"],
|
||||
hdrs = ["sources/MPPBaseOptions.h"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPTaskOptionsProtocol",
|
||||
hdrs = ["sources/MPPTaskOptionsProtocol.h"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_cc_proto",
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "MPPTaskManager",
|
||||
srcs = ["sources/MPPTaskManager.mm"],
|
||||
hdrs = ["sources/MPPTaskManager.h"],
|
||||
name = "MPPTaskRunner",
|
||||
srcs = ["sources/MPPTaskRunner.mm"],
|
||||
hdrs = ["sources/MPPTaskRunner.h"],
|
||||
deps = [
|
||||
"//mediapipe/tasks/cc/core:task_runner",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/tasks/ios/common/utils:MPPCommonUtils",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptionsProtocol.h"
|
||||
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
|
@ -55,7 +54,7 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
outputStreams:(NSArray<NSString *> *)outputStreams
|
||||
taskOptions:(id<MPPTaskOptionsProtocol>)taskOptions
|
||||
enableFlowLimiting:(BOOL)enableFlowLimiting
|
||||
error:(NSError **)error;
|
||||
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
/**
|
||||
* Creates a MediaPipe Task protobuf message from the MPPTaskInfo instance.
|
||||
|
|
|
@ -24,9 +24,9 @@
|
|||
namespace {
|
||||
using CalculatorGraphConfig = ::mediapipe::CalculatorGraphConfig;
|
||||
using Node = ::mediapipe::CalculatorGraphConfig::Node;
|
||||
using ::mediapipe::InputStreamInfo;
|
||||
using ::mediapipe::CalculatorOptions;
|
||||
using ::mediapipe::FlowLimiterCalculatorOptions;
|
||||
using ::mediapipe::InputStreamInfo;
|
||||
} // namespace
|
||||
|
||||
@implementation MPPTaskInfo
|
||||
|
@ -82,45 +82,46 @@ using ::mediapipe::FlowLimiterCalculatorOptions;
|
|||
graph_config.add_output_stream(cpp_output_stream);
|
||||
}
|
||||
|
||||
if (self.enableFlowLimiting) {
|
||||
Node *flow_limit_calculator_node = graph_config.add_node();
|
||||
|
||||
flow_limit_calculator_node->set_calculator("FlowLimiterCalculator");
|
||||
|
||||
InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info();
|
||||
input_stream_info->set_tag_index("FINISHED");
|
||||
input_stream_info->set_back_edge(true);
|
||||
|
||||
FlowLimiterCalculatorOptions *flow_limit_calculator_options =
|
||||
flow_limit_calculator_node->mutable_options()->MutableExtension(
|
||||
FlowLimiterCalculatorOptions::ext);
|
||||
flow_limit_calculator_options->set_max_in_flight(1);
|
||||
flow_limit_calculator_options->set_max_in_queue(1);
|
||||
|
||||
for (NSString *inputStream in self.inputStreams) {
|
||||
graph_config.add_input_stream(inputStream.cppString);
|
||||
|
||||
NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream];
|
||||
flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString);
|
||||
|
||||
NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream];
|
||||
task_subgraph_node->add_input_stream(taskInputStream.cppString);
|
||||
|
||||
NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream];
|
||||
flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString);
|
||||
}
|
||||
|
||||
NSString *firstOutputStream = self.outputStreams[0];
|
||||
auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString;
|
||||
flow_limit_calculator_node->add_input_stream(finished_output_stream);
|
||||
} else {
|
||||
if (!self.enableFlowLimiting) {
|
||||
for (NSString *inputStream in self.inputStreams) {
|
||||
auto cpp_input_stream = inputStream.cppString;
|
||||
task_subgraph_node->add_input_stream(cpp_input_stream);
|
||||
graph_config.add_input_stream(cpp_input_stream);
|
||||
}
|
||||
return graph_config;
|
||||
}
|
||||
|
||||
Node *flow_limit_calculator_node = graph_config.add_node();
|
||||
|
||||
flow_limit_calculator_node->set_calculator("FlowLimiterCalculator");
|
||||
|
||||
InputStreamInfo *input_stream_info = flow_limit_calculator_node->add_input_stream_info();
|
||||
input_stream_info->set_tag_index("FINISHED");
|
||||
input_stream_info->set_back_edge(true);
|
||||
|
||||
FlowLimiterCalculatorOptions *flow_limit_calculator_options =
|
||||
flow_limit_calculator_node->mutable_options()->MutableExtension(
|
||||
FlowLimiterCalculatorOptions::ext);
|
||||
flow_limit_calculator_options->set_max_in_flight(1);
|
||||
flow_limit_calculator_options->set_max_in_queue(1);
|
||||
|
||||
for (NSString *inputStream in self.inputStreams) {
|
||||
graph_config.add_input_stream(inputStream.cppString);
|
||||
|
||||
NSString *strippedInputStream = [MPPTaskInfo stripTagIndex:inputStream];
|
||||
flow_limit_calculator_node->add_input_stream(strippedInputStream.cppString);
|
||||
|
||||
NSString *taskInputStream = [MPPTaskInfo addStreamNamePrefix:inputStream];
|
||||
task_subgraph_node->add_input_stream(taskInputStream.cppString);
|
||||
|
||||
NSString *strippedTaskInputStream = [MPPTaskInfo stripTagIndex:taskInputStream];
|
||||
flow_limit_calculator_node->add_output_stream(strippedTaskInputStream.cppString);
|
||||
}
|
||||
|
||||
NSString *firstOutputStream = self.outputStreams[0];
|
||||
auto finished_output_stream = "FINISHED:" + firstOutputStream.cppString;
|
||||
flow_limit_calculator_node->add_input_stream(finished_output_stream);
|
||||
|
||||
return graph_config;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h"
|
||||
|
||||
|
@ -19,27 +22,13 @@ NS_ASSUME_NONNULL_BEGIN
|
|||
* this class.
|
||||
*/
|
||||
NS_SWIFT_NAME(TaskOptions)
|
||||
|
||||
@interface MPPTaskOptions : NSObject <NSCopying>
|
||||
/**
|
||||
* Base options for configuring the Mediapipe task.
|
||||
*/
|
||||
@property(nonatomic, copy) MPPBaseOptions *baseOptions;
|
||||
|
||||
/**
|
||||
* Initializes a new `MPPTaskOptions` with the absolute path to the model file
|
||||
* stored locally on the device, set to the given the model path.
|
||||
*
|
||||
* @discussion The external model file must be a single standalone TFLite file. It could be packed
|
||||
* with TFLite Model Metadata[1] and associated files if they exist. Failure to provide the
|
||||
* necessary metadata and associated files might result in errors. Check the [documentation]
|
||||
* (https://www.tensorflow.org/lite/convert/metadata) for each task about the specific requirement.
|
||||
*
|
||||
* @param modelPath An absolute path to a TensorFlow Lite model file stored locally on the device.
|
||||
*
|
||||
* @return An instance of `MPPTaskOptions` initialized to the given model path.
|
||||
*/
|
||||
- (instancetype)initWithModelPath:(NSString *)modelPath;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskOptions.h"
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPBaseOptions.h"
|
||||
|
||||
|
@ -25,12 +25,12 @@
|
|||
return self;
|
||||
}
|
||||
|
||||
- (instancetype)initWithModelPath:(NSString *)modelPath {
|
||||
self = [self init];
|
||||
if (self) {
|
||||
_baseOptions.modelAssetPath = modelPath;
|
||||
}
|
||||
return self;
|
||||
- (id)copyWithZone:(NSZone *)zone {
|
||||
MPPTaskOptions *taskOptions = [[MPPTaskOptions alloc] init];
|
||||
|
||||
taskOptions.baseOptions = self.baseOptions;
|
||||
|
||||
return taskOptions;
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
|
@ -1,26 +1,29 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
#include "mediapipe/framework/calculator_options.pb.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* Any mediapipe task options should confirm to this protocol.
|
||||
* Any MediaPipe task options should confirm to this protocol.
|
||||
*/
|
||||
@protocol MPPTaskOptionsProtocol
|
||||
|
||||
/**
|
||||
* Copies the iOS Mediapipe task options to an object of mediapipe::CalculatorOptions proto.
|
||||
* Copies the iOS MediaPipe task options to an object of mediapipe::CalculatorOptions proto.
|
||||
*/
|
||||
- (void)copyToProto:(mediapipe::CalculatorOptions *)optionsProto;
|
||||
|
||||
|
|
|
@ -1,30 +1,36 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* MediaPipe Tasks options base class. Any MediaPipe task-specific options class should extend
|
||||
* MediaPipe Tasks result base class. Any MediaPipe task result class should extend
|
||||
* this class.
|
||||
*/
|
||||
NS_SWIFT_NAME(TaskResult)
|
||||
|
||||
@interface MPPTaskResult : NSObject <NSCopying>
|
||||
/**
|
||||
* Base options for configuring the Mediapipe task.
|
||||
* Timestamp that is associated with the task result object.
|
||||
*/
|
||||
@property(nonatomic, assign, readonly) long timeStamp;
|
||||
@property(nonatomic, assign, readonly) long timestamp;
|
||||
|
||||
- (instancetype)initWithTimeStamp:(long)timeStamp;
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
- (instancetype)initWithTimestamp:(long)timestamp NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
@end
|
||||
|
||||
|
|
|
@ -1,27 +1,31 @@
|
|||
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
|
||||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskResult.h"
|
||||
|
||||
@implementation MPPTaskResult
|
||||
|
||||
- (instancetype)initWithTimeStamp:(long)timeStamp {
|
||||
self = [self init];
|
||||
- (instancetype)initWithTimestamp:(long)timestamp {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_timeStamp = timeStamp;
|
||||
_timestamp = timestamp;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (id)copyWithZone:(NSZone *)zone {
|
||||
return [[MPPTaskResult alloc] initWithTimestamp:self.timestamp];
|
||||
}
|
||||
|
||||
@end
|
||||
|
|
50
mediapipe/tasks/ios/core/sources/MPPTaskRunner.h
Normal file
50
mediapipe/tasks/ios/core/sources/MPPTaskRunner.h
Normal file
|
@ -0,0 +1,50 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/tasks/cc/core/task_runner.h"
|
||||
|
||||
NS_ASSUME_NONNULL_BEGIN
|
||||
|
||||
/**
|
||||
* This class is used to create and call appropriate methods on the C++ Task Runner.
|
||||
*/
|
||||
|
||||
@interface MPPTaskRunner : NSObject
|
||||
|
||||
/**
|
||||
* Initializes a new `MPPTaskRunner` with the mediapipe task graph config proto.
|
||||
*
|
||||
* @param graphConfig A mediapipe task graph config proto.
|
||||
*
|
||||
* @return An instance of `MPPTaskRunner` initialized to the given graph config proto.
|
||||
*/
|
||||
- (instancetype)initWithCalculatorGraphConfig:(mediapipe::CalculatorGraphConfig)graphConfig
|
||||
error:(NSError **)error NS_DESIGNATED_INITIALIZER;
|
||||
|
||||
- (absl::StatusOr<mediapipe::tasks::core::PacketMap>)
|
||||
process:(const mediapipe::tasks::core::PacketMap &)packetMap
|
||||
error:(NSError **)error;
|
||||
|
||||
- (void)close;
|
||||
|
||||
- (instancetype)init NS_UNAVAILABLE;
|
||||
|
||||
+ (instancetype)new NS_UNAVAILABLE;
|
||||
|
||||
@end
|
||||
|
||||
NS_ASSUME_NONNULL_END
|
56
mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm
Normal file
56
mediapipe/tasks/ios/core/sources/MPPTaskRunner.mm
Normal file
|
@ -0,0 +1,56 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#import "mediapipe/tasks/ios/core/sources/MPPTaskRunner.h"
|
||||
#import "mediapipe/tasks/ios/common/utils/sources/MPPCommonUtils.h"
|
||||
|
||||
namespace {
|
||||
using ::mediapipe::CalculatorGraphConfig;
|
||||
using ::mediapipe::Packet;
|
||||
using ::mediapipe::tasks::core::PacketMap;
|
||||
using TaskRunnerCpp = ::mediapipe::tasks::core::TaskRunner;
|
||||
} // namespace
|
||||
|
||||
@interface MPPTaskRunner () {
|
||||
// Cpp Task Runner
|
||||
std::unique_ptr<TaskRunnerCpp> _cppTaskRunner;
|
||||
}
|
||||
@end
|
||||
|
||||
@implementation MPPTaskRunner
|
||||
|
||||
- (instancetype)initWithCalculatorGraphConfig:(CalculatorGraphConfig)graphConfig
|
||||
error:(NSError **)error {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
auto taskRunnerResult = TaskRunnerCpp::Create(std::move(graphConfig));
|
||||
|
||||
if (![MPPCommonUtils checkCppError:taskRunnerResult.status() toError:error]) {
|
||||
return nil;
|
||||
}
|
||||
|
||||
_cppTaskRunner = std::move(taskRunnerResult.value());
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (absl::StatusOr<PacketMap>)process:(const PacketMap &)packetMap {
|
||||
return _cppTaskRunner->Process(packetMap);
|
||||
}
|
||||
|
||||
- (void)close {
|
||||
_cppTaskRunner->Close();
|
||||
}
|
||||
|
||||
@end
|
|
@ -203,6 +203,8 @@ public final class AudioClassifier extends BaseAudioTaskApi {
|
|||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<AudioClassifierOptions>builder()
|
||||
.setTaskName(AudioClassifier.class.getSimpleName())
|
||||
.setTaskRunningModeName(options.runningMode().name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
|
|
|
@ -200,6 +200,8 @@ public final class AudioEmbedder extends BaseAudioTaskApi {
|
|||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<AudioEmbedderOptions>builder()
|
||||
.setTaskName(AudioEmbedder.class.getSimpleName())
|
||||
.setTaskRunningModeName(options.runningMode().name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
|
|
|
@ -22,6 +22,7 @@ android_library(
|
|||
],
|
||||
manifest = "AndroidManifest.xml",
|
||||
deps = [
|
||||
":logging",
|
||||
"//mediapipe/calculators/core:flow_limiter_calculator_java_proto_lite",
|
||||
"//mediapipe/calculators/tensor:inference_calculator_java_proto_lite",
|
||||
"//mediapipe/framework:calculator_java_proto_lite",
|
||||
|
@ -37,11 +38,22 @@ android_library(
|
|||
],
|
||||
)
|
||||
|
||||
android_library(
|
||||
name = "logging",
|
||||
srcs = glob(
|
||||
["logging/*.java"],
|
||||
),
|
||||
deps = [
|
||||
"//third_party:autovalue",
|
||||
"@maven//:com_google_guava_guava",
|
||||
],
|
||||
)
|
||||
|
||||
load("//mediapipe/tasks/java/com/google/mediapipe/tasks:mediapipe_tasks_aar.bzl", "mediapipe_tasks_core_aar")
|
||||
|
||||
mediapipe_tasks_core_aar(
|
||||
name = "tasks_core",
|
||||
srcs = glob(["*.java"]) + [
|
||||
srcs = glob(["**/*.java"]) + [
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/containers:java_src",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/processors:java_src",
|
||||
"//mediapipe/tasks/java/com/google/mediapipe/tasks/components/utils:java_src",
|
||||
|
|
|
@ -32,6 +32,12 @@ public abstract class TaskInfo<T extends TaskOptions> {
|
|||
/** Builder for {@link TaskInfo}. */
|
||||
@AutoValue.Builder
|
||||
public abstract static class Builder<T extends TaskOptions> {
|
||||
/** Sets the MediaPipe task name. */
|
||||
public abstract Builder<T> setTaskName(String value);
|
||||
|
||||
/** Sets the MediaPipe task running mode name. */
|
||||
public abstract Builder<T> setTaskRunningModeName(String value);
|
||||
|
||||
/** Sets the MediaPipe task graph name. */
|
||||
public abstract Builder<T> setTaskGraphName(String value);
|
||||
|
||||
|
@ -71,6 +77,10 @@ public abstract class TaskInfo<T extends TaskOptions> {
|
|||
}
|
||||
}
|
||||
|
||||
abstract String taskName();
|
||||
|
||||
abstract String taskRunningModeName();
|
||||
|
||||
abstract String taskGraphName();
|
||||
|
||||
abstract T taskOptions();
|
||||
|
@ -82,7 +92,7 @@ public abstract class TaskInfo<T extends TaskOptions> {
|
|||
abstract Boolean enableFlowLimiting();
|
||||
|
||||
public static <T extends TaskOptions> Builder<T> builder() {
|
||||
return new AutoValue_TaskInfo.Builder<T>();
|
||||
return new AutoValue_TaskInfo.Builder<T>().setTaskName("").setTaskRunningModeName("");
|
||||
}
|
||||
|
||||
/* Returns a list of the output stream names without the stream tags. */
|
||||
|
|
|
@ -21,6 +21,8 @@ import com.google.mediapipe.framework.AndroidPacketCreator;
|
|||
import com.google.mediapipe.framework.Graph;
|
||||
import com.google.mediapipe.framework.MediaPipeException;
|
||||
import com.google.mediapipe.framework.Packet;
|
||||
import com.google.mediapipe.tasks.core.logging.TasksStatsLogger;
|
||||
import com.google.mediapipe.tasks.core.logging.TasksStatsDummyLogger;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
|
@ -34,6 +36,7 @@ public class TaskRunner implements AutoCloseable {
|
|||
private final Graph graph;
|
||||
private final ModelResourcesCache modelResourcesCache;
|
||||
private final AndroidPacketCreator packetCreator;
|
||||
private final TasksStatsLogger statsLogger;
|
||||
private long lastSeenTimestamp = Long.MIN_VALUE;
|
||||
private ErrorListener errorListener;
|
||||
|
||||
|
@ -51,6 +54,8 @@ public class TaskRunner implements AutoCloseable {
|
|||
Context context,
|
||||
TaskInfo<? extends TaskOptions> taskInfo,
|
||||
OutputHandler<? extends TaskResult, ?> outputHandler) {
|
||||
TasksStatsLogger statsLogger =
|
||||
TasksStatsDummyLogger.create(context, taskInfo.taskName(), taskInfo.taskRunningModeName());
|
||||
AndroidAssetUtil.initializeNativeAssetManager(context);
|
||||
Graph mediapipeGraph = new Graph();
|
||||
mediapipeGraph.loadBinaryGraph(taskInfo.generateGraphConfig());
|
||||
|
@ -58,12 +63,15 @@ public class TaskRunner implements AutoCloseable {
|
|||
mediapipeGraph.setServiceObject(new ModelResourcesCacheService(), graphModelResourcesCache);
|
||||
mediapipeGraph.addMultiStreamCallback(
|
||||
taskInfo.outputStreamNames(),
|
||||
outputHandler::run,
|
||||
/*observeTimestampBounds=*/ outputHandler.handleTimestampBoundChanges());
|
||||
packets -> {
|
||||
outputHandler.run(packets);
|
||||
statsLogger.recordInvocationEnd(packets.get(0).getTimestamp());
|
||||
},
|
||||
/* observeTimestampBounds= */ outputHandler.handleTimestampBoundChanges());
|
||||
mediapipeGraph.startRunningGraph();
|
||||
// Waits until all calculators are opened and the graph is fully started.
|
||||
mediapipeGraph.waitUntilGraphIdle();
|
||||
return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler);
|
||||
return new TaskRunner(mediapipeGraph, graphModelResourcesCache, outputHandler, statsLogger);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -91,7 +99,10 @@ public class TaskRunner implements AutoCloseable {
|
|||
* @param inputs a map contains (input stream {@link String}, data {@link Packet}) pairs.
|
||||
*/
|
||||
public synchronized TaskResult process(Map<String, Packet> inputs) {
|
||||
addPackets(inputs, generateSyntheticTimestamp());
|
||||
long syntheticInputTimestamp = generateSyntheticTimestamp();
|
||||
// TODO: Support recording GPU input arrival.
|
||||
statsLogger.recordCpuInputArrival(syntheticInputTimestamp);
|
||||
addPackets(inputs, syntheticInputTimestamp);
|
||||
graph.waitUntilGraphIdle();
|
||||
lastSeenTimestamp = outputHandler.getLatestOutputTimestamp();
|
||||
return outputHandler.retrieveCachedTaskResult();
|
||||
|
@ -112,6 +123,7 @@ public class TaskRunner implements AutoCloseable {
|
|||
*/
|
||||
public synchronized TaskResult process(Map<String, Packet> inputs, long inputTimestamp) {
|
||||
validateInputTimstamp(inputTimestamp);
|
||||
statsLogger.recordCpuInputArrival(inputTimestamp);
|
||||
addPackets(inputs, inputTimestamp);
|
||||
graph.waitUntilGraphIdle();
|
||||
return outputHandler.retrieveCachedTaskResult();
|
||||
|
@ -132,6 +144,7 @@ public class TaskRunner implements AutoCloseable {
|
|||
*/
|
||||
public synchronized void send(Map<String, Packet> inputs, long inputTimestamp) {
|
||||
validateInputTimstamp(inputTimestamp);
|
||||
statsLogger.recordCpuInputArrival(inputTimestamp);
|
||||
addPackets(inputs, inputTimestamp);
|
||||
}
|
||||
|
||||
|
@ -145,6 +158,7 @@ public class TaskRunner implements AutoCloseable {
|
|||
graphStarted.set(false);
|
||||
graph.closeAllPacketSources();
|
||||
graph.waitUntilGraphDone();
|
||||
statsLogger.logSessionEnd();
|
||||
} catch (MediaPipeException e) {
|
||||
reportError(e);
|
||||
}
|
||||
|
@ -154,6 +168,7 @@ public class TaskRunner implements AutoCloseable {
|
|||
// Waits until all calculators are opened and the graph is fully restarted.
|
||||
graph.waitUntilGraphIdle();
|
||||
graphStarted.set(true);
|
||||
statsLogger.logSessionStart();
|
||||
} catch (MediaPipeException e) {
|
||||
reportError(e);
|
||||
}
|
||||
|
@ -169,6 +184,7 @@ public class TaskRunner implements AutoCloseable {
|
|||
graphStarted.set(false);
|
||||
graph.closeAllPacketSources();
|
||||
graph.waitUntilGraphDone();
|
||||
statsLogger.logSessionEnd();
|
||||
if (modelResourcesCache != null) {
|
||||
modelResourcesCache.release();
|
||||
}
|
||||
|
@ -247,12 +263,15 @@ public class TaskRunner implements AutoCloseable {
|
|||
private TaskRunner(
|
||||
Graph graph,
|
||||
ModelResourcesCache modelResourcesCache,
|
||||
OutputHandler<? extends TaskResult, ?> outputHandler) {
|
||||
OutputHandler<? extends TaskResult, ?> outputHandler,
|
||||
TasksStatsLogger statsLogger) {
|
||||
this.outputHandler = outputHandler;
|
||||
this.graph = graph;
|
||||
this.modelResourcesCache = modelResourcesCache;
|
||||
this.packetCreator = new AndroidPacketCreator(graph);
|
||||
this.statsLogger = statsLogger;
|
||||
graphStarted.set(true);
|
||||
this.statsLogger.logSessionStart();
|
||||
}
|
||||
|
||||
/** Reports error. */
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core.logging;
|
||||
|
||||
import android.content.Context;
|
||||
|
||||
/** A dummy MediaPipe Tasks stats logger that has all methods as no-ops. */
|
||||
public class TasksStatsDummyLogger implements TasksStatsLogger {
|
||||
|
||||
/**
|
||||
* Creates the MediaPipe Tasks stats dummy logger.
|
||||
*
|
||||
* @param context a {@link Context}.
|
||||
* @param taskNameStr the task api name.
|
||||
* @param taskRunningModeStr the task running mode string representation.
|
||||
*/
|
||||
public static TasksStatsDummyLogger create(
|
||||
Context context, String taskNameStr, String taskRunningModeStr) {
|
||||
return new TasksStatsDummyLogger();
|
||||
}
|
||||
|
||||
private TasksStatsDummyLogger() {}
|
||||
|
||||
/** Logs the start of a MediaPipe Tasks API session. */
|
||||
@Override
|
||||
public void logSessionStart() {}
|
||||
|
||||
/**
|
||||
* Records MediaPipe Tasks API receiving CPU input data.
|
||||
*
|
||||
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
|
||||
* invocation.
|
||||
*/
|
||||
@Override
|
||||
public void recordCpuInputArrival(long packetTimestamp) {}
|
||||
|
||||
/**
|
||||
* Records MediaPipe Tasks API receiving GPU input data.
|
||||
*
|
||||
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
|
||||
* invocation.
|
||||
*/
|
||||
@Override
|
||||
public void recordGpuInputArrival(long packetTimestamp) {}
|
||||
|
||||
/**
|
||||
* Records the end of a Mediapipe Tasks API invocation.
|
||||
*
|
||||
* @param packetTimestamp the output packet timestamp that acts as the identifier of the api
|
||||
* invocation.
|
||||
*/
|
||||
@Override
|
||||
public void recordInvocationEnd(long packetTimestamp) {}
|
||||
|
||||
/** Logs the MediaPipe Tasks API periodic invocation report. */
|
||||
@Override
|
||||
public void logInvocationReport(StatsSnapshot stats) {}
|
||||
|
||||
/** Logs the Tasks API session end event. */
|
||||
@Override
|
||||
public void logSessionEnd() {}
|
||||
|
||||
/** Logs the MediaPipe Tasks API initialization error. */
|
||||
@Override
|
||||
public void logInitError() {}
|
||||
}
|
|
@ -0,0 +1,98 @@
|
|||
// Copyright 2022 The MediaPipe Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.google.mediapipe.tasks.core.logging;
|
||||
|
||||
import com.google.auto.value.AutoValue;
|
||||
|
||||
/** The stats logger interface that defines what MediaPipe Tasks events to log. */
|
||||
public interface TasksStatsLogger {
|
||||
/** Task stats snapshot. */
|
||||
@AutoValue
|
||||
abstract static class StatsSnapshot {
|
||||
static StatsSnapshot create(
|
||||
int cpuInputCount,
|
||||
int gpuInputCount,
|
||||
int finishedCount,
|
||||
int droppedCount,
|
||||
long totalLatencyMs,
|
||||
long peakLatencyMs,
|
||||
long elapsedTimeMs) {
|
||||
return new AutoValue_TasksStatsLogger_StatsSnapshot(
|
||||
cpuInputCount,
|
||||
gpuInputCount,
|
||||
finishedCount,
|
||||
droppedCount,
|
||||
totalLatencyMs,
|
||||
peakLatencyMs,
|
||||
elapsedTimeMs);
|
||||
}
|
||||
|
||||
static StatsSnapshot createDefault() {
|
||||
return new AutoValue_TasksStatsLogger_StatsSnapshot(0, 0, 0, 0, 0, 0, 0);
|
||||
}
|
||||
|
||||
abstract int cpuInputCount();
|
||||
|
||||
abstract int gpuInputCount();
|
||||
|
||||
abstract int finishedCount();
|
||||
|
||||
abstract int droppedCount();
|
||||
|
||||
abstract long totalLatencyMs();
|
||||
|
||||
abstract long peakLatencyMs();
|
||||
|
||||
abstract long elapsedTimeMs();
|
||||
}
|
||||
|
||||
/** Logs the start of a MediaPipe Tasks API session. */
|
||||
public void logSessionStart();
|
||||
|
||||
/**
|
||||
* Records MediaPipe Tasks API receiving CPU input data.
|
||||
*
|
||||
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
|
||||
* invocation.
|
||||
*/
|
||||
public void recordCpuInputArrival(long packetTimestamp);
|
||||
|
||||
/**
|
||||
* Records MediaPipe Tasks API receiving GPU input data.
|
||||
*
|
||||
* @param packetTimestamp the input packet timestamp that acts as the identifier of the api
|
||||
* invocation.
|
||||
*/
|
||||
public void recordGpuInputArrival(long packetTimestamp);
|
||||
|
||||
/**
|
||||
* Records the end of a Mediapipe Tasks API invocation.
|
||||
*
|
||||
* @param packetTimestamp the output packet timestamp that acts as the identifier of the api
|
||||
* invocation.
|
||||
*/
|
||||
public void recordInvocationEnd(long packetTimestamp);
|
||||
|
||||
/** Logs the MediaPipe Tasks API periodic invocation report. */
|
||||
public void logInvocationReport(StatsSnapshot stats);
|
||||
|
||||
/** Logs the Tasks API session end event. */
|
||||
public void logSessionEnd();
|
||||
|
||||
/** Logs the MediaPipe Tasks API initialization error. */
|
||||
public void logInitError();
|
||||
|
||||
// TODO: Logs more error types.
|
||||
}
|
|
@ -169,6 +169,7 @@ public final class TextClassifier implements AutoCloseable {
|
|||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<TextClassifierOptions>builder()
|
||||
.setTaskName(TextClassifier.class.getSimpleName())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
|
|
|
@ -159,6 +159,7 @@ public final class TextEmbedder implements AutoCloseable {
|
|||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<TextEmbedderOptions>builder()
|
||||
.setTaskName(TextEmbedder.class.getSimpleName())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
|
|
|
@ -194,6 +194,8 @@ public final class GestureRecognizer extends BaseVisionTaskApi {
|
|||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<GestureRecognizerOptions>builder()
|
||||
.setTaskName(GestureRecognizer.class.getSimpleName())
|
||||
.setTaskRunningModeName(recognizerOptions.runningMode().name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
|
|
|
@ -183,6 +183,8 @@ public final class HandLandmarker extends BaseVisionTaskApi {
|
|||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<HandLandmarkerOptions>builder()
|
||||
.setTaskName(HandLandmarker.class.getSimpleName())
|
||||
.setTaskRunningModeName(landmarkerOptions.runningMode().name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
|
|
|
@ -197,6 +197,8 @@ public final class ImageClassifier extends BaseVisionTaskApi {
|
|||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<ImageClassifierOptions>builder()
|
||||
.setTaskName(ImageClassifier.class.getSimpleName())
|
||||
.setTaskRunningModeName(options.runningMode().name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
|
|
|
@ -180,6 +180,8 @@ public final class ImageEmbedder extends BaseVisionTaskApi {
|
|||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<ImageEmbedderOptions>builder()
|
||||
.setTaskName(ImageEmbedder.class.getSimpleName())
|
||||
.setTaskRunningModeName(options.runningMode().name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
|
|
|
@ -190,6 +190,8 @@ public final class ObjectDetector extends BaseVisionTaskApi {
|
|||
TaskRunner.create(
|
||||
context,
|
||||
TaskInfo.<ObjectDetectorOptions>builder()
|
||||
.setTaskName(ObjectDetector.class.getSimpleName())
|
||||
.setTaskRunningModeName(detectorOptions.runningMode().name())
|
||||
.setTaskGraphName(TASK_GRAPH_NAME)
|
||||
.setInputStreams(INPUT_STREAMS)
|
||||
.setOutputStreams(OUTPUT_STREAMS)
|
||||
|
|
|
@ -86,7 +86,7 @@ describe('convertBaseOptionsToProto()', () => {
|
|||
it('can enable CPU delegate', async () => {
|
||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'cpu',
|
||||
delegate: 'CPU',
|
||||
});
|
||||
expect(baseOptionsProto.toObject()).toEqual(mockBytesResult);
|
||||
});
|
||||
|
@ -94,7 +94,7 @@ describe('convertBaseOptionsToProto()', () => {
|
|||
it('can enable GPU delegate', async () => {
|
||||
const baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'gpu',
|
||||
delegate: 'GPU',
|
||||
});
|
||||
expect(baseOptionsProto.toObject()).toEqual({
|
||||
...mockBytesResult,
|
||||
|
@ -117,7 +117,7 @@ describe('convertBaseOptionsToProto()', () => {
|
|||
it('can reset delegate', async () => {
|
||||
let baseOptionsProto = await convertBaseOptionsToProto({
|
||||
modelAssetBuffer: new Uint8Array(mockBytes),
|
||||
delegate: 'gpu',
|
||||
delegate: 'GPU',
|
||||
});
|
||||
// Clear backend
|
||||
baseOptionsProto =
|
||||
|
|
|
@ -71,7 +71,7 @@ async function configureExternalFile(
|
|||
/** Configues the `acceleration` option. */
|
||||
function configureAcceleration(options: BaseOptions, proto: BaseOptionsProto) {
|
||||
const acceleration = proto.getAcceleration() ?? new Acceleration();
|
||||
if (options.delegate === 'gpu') {
|
||||
if (options.delegate === 'GPU') {
|
||||
acceleration.setGpu(new InferenceCalculatorOptions.Delegate.Gpu());
|
||||
} else {
|
||||
acceleration.setTflite(new InferenceCalculatorOptions.Delegate.TfLite());
|
||||
|
|
|
@ -44,22 +44,14 @@ async function isSimdSupported(): Promise<boolean> {
|
|||
}
|
||||
|
||||
async function createFileset(
|
||||
taskName: string, basePath: string = '.'): Promise<WasmFileset> {
|
||||
if (await isSimdSupported()) {
|
||||
return {
|
||||
wasmLoaderPath:
|
||||
`${basePath}/${taskName}_wasm_internal.js`,
|
||||
wasmBinaryPath:
|
||||
`${basePath}/${taskName}_wasm_internal.wasm`,
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
wasmLoaderPath:
|
||||
`${basePath}/${taskName}_wasm_nosimd_internal.js`,
|
||||
wasmBinaryPath:
|
||||
`${basePath}/${taskName}_wasm_nosimd_internal.wasm`,
|
||||
};
|
||||
}
|
||||
taskName: string, basePath: string = ''): Promise<WasmFileset> {
|
||||
const suffix =
|
||||
await isSimdSupported() ? 'wasm_internal' : 'wasm_nosimd_internal';
|
||||
|
||||
return {
|
||||
wasmLoaderPath: `${basePath}/${taskName}_${suffix}.js`,
|
||||
wasmBinaryPath: `${basePath}/${taskName}_${suffix}.wasm`,
|
||||
};
|
||||
}
|
||||
|
||||
// tslint:disable:class-as-namespace
|
||||
|
|
|
@ -31,7 +31,7 @@ export declare interface BaseOptions {
|
|||
modelAssetBuffer?: Uint8Array|undefined;
|
||||
|
||||
/** Overrides the default backend to use for the provided model. */
|
||||
delegate?: 'cpu'|'gpu'|undefined;
|
||||
delegate?: 'CPU'|'GPU'|undefined;
|
||||
}
|
||||
|
||||
/** Options to configure MediaPipe Tasks in general. */
|
||||
|
|
|
@ -1028,7 +1028,9 @@ export class GraphRunner {
|
|||
// Set up our TS listener to receive any packets for this stream, and
|
||||
// additionally reformat our Uint8Array into a Float32Array for the user.
|
||||
this.setListener(outputStreamName, (data: Uint8Array) => {
|
||||
const floatArray = new Float32Array(data.buffer); // Should be very fast
|
||||
// Should be very fast
|
||||
const floatArray =
|
||||
new Float32Array(data.buffer, data.byteOffset, data.length / 4);
|
||||
callbackFcn(floatArray);
|
||||
});
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -490,10 +490,10 @@ setuptools.setup(
|
|||
'Operating System :: MacOS :: MacOS X',
|
||||
'Operating System :: Microsoft :: Windows',
|
||||
'Operating System :: POSIX :: Linux',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
'Programming Language :: Python :: 3.10',
|
||||
'Programming Language :: Python :: 3.11',
|
||||
'Programming Language :: Python :: 3 :: Only',
|
||||
'Topic :: Scientific/Engineering',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
|
|
Loading…
Reference in New Issue
Block a user