Project import generated by Copybara.

GitOrigin-RevId: 6e5aa035cd1f6a9333962df5d3ab97a05bd5744e
This commit is contained in:
MediaPipe Team 2022-06-23 12:35:07 -07:00 committed by Sebastian Schmidt
parent 4a20e9909d
commit c688862570
144 changed files with 5772 additions and 2118 deletions

View File

@ -1 +1 @@
5.0.0 5.2.0

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
FROM ubuntu:18.04 FROM ubuntu:20.04
MAINTAINER <mediapipe@google.com> MAINTAINER <mediapipe@google.com>
@ -42,6 +42,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
software-properties-common && \ software-properties-common && \
add-apt-repository -y ppa:openjdk-r/ppa && \ add-apt-repository -y ppa:openjdk-r/ppa && \
apt-get update && apt-get install -y openjdk-8-jdk && \ apt-get update && apt-get install -y openjdk-8-jdk && \
apt-get install -y mesa-common-dev libegl1-mesa-dev libgles2-mesa-dev && \
apt-get install -y mesa-utils && \
apt-get clean && \ apt-get clean && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
@ -50,13 +52,13 @@ RUN pip3 install --upgrade setuptools
RUN pip3 install wheel RUN pip3 install wheel
RUN pip3 install future RUN pip3 install future
RUN pip3 install six==1.14.0 RUN pip3 install six==1.14.0
RUN pip3 install tensorflow==1.14.0 RUN pip3 install tensorflow==2.2.0
RUN pip3 install tf_slim RUN pip3 install tf_slim
RUN ln -s /usr/bin/python3 /usr/bin/python RUN ln -s /usr/bin/python3 /usr/bin/python
# Install bazel # Install bazel
ARG BAZEL_VERSION=5.0.0 ARG BAZEL_VERSION=5.2.0
RUN mkdir /bazel && \ RUN mkdir /bazel && \
wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\ wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\
azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \ azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \

View File

@ -35,8 +35,9 @@ http_archive(
http_archive( http_archive(
name = "rules_cc", name = "rules_cc",
strip_prefix = "rules_cc-main", strip_prefix = "rules_cc-2f8c04c04462ab83c545ab14c0da68c3b4c96191",
urls = ["https://github.com/bazelbuild/rules_cc/archive/main.zip"], # The commit can be updated if the build passes. Last updated 6/23/22.
urls = ["https://github.com/bazelbuild/rules_cc/archive/2f8c04c04462ab83c545ab14c0da68c3b4c96191.zip"],
) )
http_archive( http_archive(

View File

@ -244,6 +244,7 @@ cc_test(
"//mediapipe/framework/formats:time_series_header_cc_proto", "//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:test_util",
"@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:flag",
], ],
) )

View File

@ -20,8 +20,12 @@
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/test_util.h"
namespace mediapipe { namespace mediapipe {
namespace {
constexpr char kTestPackageRoot[] = "mediapipe/calculators/audio";
TEST(AudioDecoderCalculatorTest, TestWAV) { TEST(AudioDecoderCalculatorTest, TestWAV) {
CalculatorGraphConfig::Node node_config = CalculatorGraphConfig::Node node_config =
@ -37,9 +41,8 @@ TEST(AudioDecoderCalculatorTest, TestWAV) {
})pb"); })pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>( runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./", file::JoinPath(GetTestDataDir(kTestPackageRoot),
"/mediapipe/calculators/audio/" "sine_wave_1k_44100_mono_2_sec_wav.audio"));
"testdata/sine_wave_1k_44100_mono_2_sec_wav.audio"));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
MP_EXPECT_OK(runner.Outputs() MP_EXPECT_OK(runner.Outputs()
.Tag("AUDIO_HEADER") .Tag("AUDIO_HEADER")
@ -68,9 +71,8 @@ TEST(AudioDecoderCalculatorTest, Test48KWAV) {
})pb"); })pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>( runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./", file::JoinPath(GetTestDataDir(kTestPackageRoot),
"/mediapipe/calculators/audio/" "sine_wave_1k_48000_stereo_2_sec_wav.audio"));
"testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio"));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
MP_EXPECT_OK(runner.Outputs() MP_EXPECT_OK(runner.Outputs()
.Tag("AUDIO_HEADER") .Tag("AUDIO_HEADER")
@ -99,9 +101,8 @@ TEST(AudioDecoderCalculatorTest, TestMP3) {
})pb"); })pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>( runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./", file::JoinPath(GetTestDataDir(kTestPackageRoot),
"/mediapipe/calculators/audio/" "sine_wave_1k_44100_stereo_2_sec_mp3.audio"));
"testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio"));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
MP_EXPECT_OK(runner.Outputs() MP_EXPECT_OK(runner.Outputs()
.Tag("AUDIO_HEADER") .Tag("AUDIO_HEADER")
@ -130,9 +131,8 @@ TEST(AudioDecoderCalculatorTest, TestAAC) {
})pb"); })pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>( runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
file::JoinPath("./", file::JoinPath(GetTestDataDir(kTestPackageRoot),
"/mediapipe/calculators/audio/" "sine_wave_1k_44100_stereo_2_sec_aac.audio"));
"testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio"));
MP_ASSERT_OK(runner.Run()); MP_ASSERT_OK(runner.Run());
MP_EXPECT_OK(runner.Outputs() MP_EXPECT_OK(runner.Outputs()
.Tag("AUDIO_HEADER") .Tag("AUDIO_HEADER")
@ -147,4 +147,5 @@ TEST(AudioDecoderCalculatorTest, TestAAC) {
std::ceil(44100.0 * 2 / 1024)); std::ceil(44100.0 * 2 / 1024));
} }
} // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -20,24 +20,22 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "Eigen/Core"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "audio/dsp/spectrogram/spectrogram.h" #include "audio/dsp/spectrogram/spectrogram.h"
#include "audio/dsp/window_functions.h" #include "audio/dsp/window_functions.h"
#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h" #include "mediapipe/calculators/audio/spectrogram_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/framework/port/core_proto_inc.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/source_location.h"
#include "mediapipe/framework/port/status_builder.h" #include "mediapipe/framework/port/status_builder.h"
#include "mediapipe/util/time_series_util.h" #include "mediapipe/util/time_series_util.h"
namespace mediapipe { namespace mediapipe {
namespace {
constexpr char kFrameDurationTag[] = "FRAME_DURATION";
constexpr char kFrameOverlapTag[] = "FRAME_OVERLAP";
} // namespace
// MediaPipe Calculator for computing the "spectrogram" (short-time Fourier // MediaPipe Calculator for computing the "spectrogram" (short-time Fourier
// transform squared-magnitude, by default) of a multichannel input // transform squared-magnitude, by default) of a multichannel input
// time series, including optionally overlapping frames. Options are // time series, including optionally overlapping frames. Options are
@ -46,11 +44,14 @@ namespace mediapipe {
// //
// Result is a MatrixData record (for single channel input and when the // Result is a MatrixData record (for single channel input and when the
// allow_multichannel_input flag is false), or a vector of MatrixData records, // allow_multichannel_input flag is false), or a vector of MatrixData records,
// one for each channel (when the allow_multichannel_input flag is set). The // one for each channel (when the allow_multichannel_input flag is set). Each
// rows of each spectrogram matrix correspond to the n_fft/2+1 unique complex // waveform frame is converted to frequency by a fast Fourier transform whose
// values, or squared/linear/dB magnitudes, depending on the output_type option. // size, n_fft, is the smallest power of two large enough to enclose the frame
// Each input packet will result in zero or one output packets, each containing // length of round(frame_duration_seconds * sample_rate).The rows of each
// one Matrix for each channel of the input, where each Matrix has one or more // spectrogram matrix(result) correspond to the n_fft/2+1 unique complex values,
// or squared/linear/dB magnitudes, depending on the output_type option. Each
// input packet will result in zero or one output packets, each containing one
// Matrix for each channel of the input, where each Matrix has one or more
// columns of spectral values, one for each complete frame of input samples. If // columns of spectral values, one for each complete frame of input samples. If
// the input packet contains too few samples to trigger a new output frame, no // the input packet contains too few samples to trigger a new output frame, no
// output packet is generated (since zero-length packets are not legal since // output packet is generated (since zero-length packets are not legal since
@ -71,6 +72,22 @@ class SpectrogramCalculator : public CalculatorBase {
// Input stream with TimeSeriesHeader. // Input stream with TimeSeriesHeader.
); );
if (cc->InputSidePackets().HasTag(kFrameDurationTag)) {
cc->InputSidePackets()
.Tag(kFrameDurationTag)
.Set<double>(
// Optional side packet for frame_duration_seconds if provided.
);
}
if (cc->InputSidePackets().HasTag(kFrameOverlapTag)) {
cc->InputSidePackets()
.Tag(kFrameOverlapTag)
.Set<double>(
// Optional side packet for frame_overlap_seconds if provided.
);
}
SpectrogramCalculatorOptions spectrogram_options = SpectrogramCalculatorOptions spectrogram_options =
cc->Options<SpectrogramCalculatorOptions>(); cc->Options<SpectrogramCalculatorOptions>();
if (!spectrogram_options.allow_multichannel_input()) { if (!spectrogram_options.allow_multichannel_input()) {
@ -184,27 +201,47 @@ class SpectrogramCalculator : public CalculatorBase {
// Fixed scale factor applied to output values (regardless of type). // Fixed scale factor applied to output values (regardless of type).
double output_scale_; double output_scale_;
static const float kLnPowerToDb; static const float kLnSquaredMagnitudeToDb;
}; };
REGISTER_CALCULATOR(SpectrogramCalculator); REGISTER_CALCULATOR(SpectrogramCalculator);
// Factor to convert ln(magnitude_squared) to deciBels = 10.0/ln(10.0). // DECIBELS = 20*log10(LINEAR_MAGNITUDE) = 10*Log10(SQUARED_MAGNITUDE)
const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518; // =10/ln(10)*ln(SQUARED_MAGNITUDE).
// Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0).
const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518;
absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) { absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
SpectrogramCalculatorOptions spectrogram_options = SpectrogramCalculatorOptions spectrogram_options =
cc->Options<SpectrogramCalculatorOptions>(); cc->Options<SpectrogramCalculatorOptions>();
// Provide frame_duration_seconds and frame_overlap_seconds either from static
// options, or dynamically from a side packet, the side packet one will
// override the options one if provided.
double frame_duration_seconds = 0;
double frame_overlap_seconds = 0;
if (cc->InputSidePackets().HasTag(kFrameDurationTag)) {
frame_duration_seconds =
cc->InputSidePackets().Tag(kFrameDurationTag).Get<double>();
} else {
frame_duration_seconds = spectrogram_options.frame_duration_seconds();
}
if (cc->InputSidePackets().HasTag(kFrameOverlapTag)) {
frame_overlap_seconds =
cc->InputSidePackets().Tag(kFrameOverlapTag).Get<double>();
} else {
frame_overlap_seconds = spectrogram_options.frame_overlap_seconds();
}
use_local_timestamp_ = spectrogram_options.use_local_timestamp(); use_local_timestamp_ = spectrogram_options.use_local_timestamp();
if (spectrogram_options.frame_duration_seconds() <= 0.0) { if (frame_duration_seconds <= 0.0) {
// TODO: return an error. // TODO: return an error.
} }
if (spectrogram_options.frame_overlap_seconds() >= if (frame_overlap_seconds >= frame_duration_seconds) {
spectrogram_options.frame_duration_seconds()) {
// TODO: return an error. // TODO: return an error.
} }
if (spectrogram_options.frame_overlap_seconds() < 0.0) { if (frame_overlap_seconds < 0.0) {
// TODO: return an error. // TODO: return an error.
} }
@ -220,10 +257,8 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
// TODO: return an error. // TODO: return an error.
} }
frame_duration_samples_ = frame_duration_samples_ = round(frame_duration_seconds * input_sample_rate_);
round(spectrogram_options.frame_duration_seconds() * input_sample_rate_); frame_overlap_samples_ = round(frame_overlap_seconds * input_sample_rate_);
frame_overlap_samples_ =
round(spectrogram_options.frame_overlap_seconds() * input_sample_rate_);
pad_final_packet_ = spectrogram_options.pad_final_packet(); pad_final_packet_ = spectrogram_options.pad_final_packet();
output_type_ = spectrogram_options.output_type(); output_type_ = spectrogram_options.output_type();
@ -419,7 +454,7 @@ absl::Status SpectrogramCalculator::ProcessVector(const Matrix& input_stream,
return ProcessVectorToOutput( return ProcessVectorToOutput(
input_stream, input_stream,
+[](const Matrix& col) -> const Matrix { +[](const Matrix& col) -> const Matrix {
return kLnPowerToDb * col.array().log().matrix(); return kLnSquaredMagnitudeToDb * col.array().log().matrix();
}, cc); }, cc);
} }
// clang-format on // clang-format on

View File

@ -32,7 +32,11 @@ message SpectrogramCalculatorOptions {
// Duration of overlap between adjacent windows. // Duration of overlap between adjacent windows.
// Hence, frame_rate = 1/(frame_duration_seconds - frame_overlap_seconds). // Hence, frame_rate = 1/(frame_duration_seconds - frame_overlap_seconds).
// Required that 0 <= frame_overlap_seconds < frame_duration_seconds. // Note the frame_rate here is not the MediaPipe packet rate, the frame here
// means each Fourier transform analysis waveform frame, the output MediaPipe
// packet rate will the the same as input, if frame rate is lower than input
// packet rate, will result in intermittent empty output packets. Required
// that 0 <= frame_overlap_seconds < frame_duration_seconds.
optional double frame_overlap_seconds = 2 [default = 0.0]; optional double frame_overlap_seconds = 2 [default = 0.0];
// Whether to pad the final packet with zeros. If true, guarantees that // Whether to pad the final packet with zeros. If true, guarantees that
@ -42,6 +46,11 @@ message SpectrogramCalculatorOptions {
// Output value type can be squared-magnitude, linear-magnitude, // Output value type can be squared-magnitude, linear-magnitude,
// deciBels (dB, = 20*log10(linear_magnitude)), or std::complex. // deciBels (dB, = 20*log10(linear_magnitude)), or std::complex.
// Their relationship:
// COMPLEX c = Re + Im*i;
// SQUARED_MAGNITUDE = Re^2 + Im^2;
// LINEAR_MAGNITUDE = sqrt(SQUARED_MAGNITUDE);
// DECIBELS = 20*log10(LINEAR_MAGNITUDE) = 10*log10(SQUARED_MAGNITUDE);
enum OutputType { enum OutputType {
SQUARED_MAGNITUDE = 0; SQUARED_MAGNITUDE = 0;
LINEAR_MAGNITUDE = 1; LINEAR_MAGNITUDE = 1;

View File

@ -557,6 +557,22 @@ cc_library(
alwayslink = 1, alwayslink = 1,
) )
cc_test(
name = "packet_cloner_calculator_test",
srcs = ["packet_cloner_calculator_test.cc"],
deps = [
":packet_cloner_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/tool:sink",
"@com_google_absl//absl/strings",
],
)
cc_library( cc_library(
name = "packet_inner_join_calculator", name = "packet_inner_join_calculator",
srcs = ["packet_inner_join_calculator.cc"], srcs = ["packet_inner_join_calculator.cc"],

View File

@ -73,8 +73,17 @@ typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmark>
ConcatenateLandmarkVectorCalculator; ConcatenateLandmarkVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkVectorCalculator); MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkVectorCalculator);
typedef ConcatenateVectorCalculator<::mediapipe::LandmarkList>
ConcatenateLandmarkListVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListVectorCalculator);
typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmarkList> typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmarkList>
ConcatenateLandmarListVectorCalculator; ConcatenateNormalizedLandmarkListVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateNormalizedLandmarkListVectorCalculator);
// For backwards compatibility, keep the version with the typo.
using ConcatenateLandmarListVectorCalculator =
ConcatenateNormalizedLandmarkListVectorCalculator;
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarListVectorCalculator); MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarListVectorCalculator);
typedef ConcatenateVectorCalculator<mediapipe::ClassificationList> typedef ConcatenateVectorCalculator<mediapipe::ClassificationList>

View File

@ -32,8 +32,8 @@ constexpr char kOptionsTag[] = "OPTIONS";
// FlowLimiterCalculator is used to limit the number of frames in flight // FlowLimiterCalculator is used to limit the number of frames in flight
// by dropping input frames when necessary. // by dropping input frames when necessary.
// //
// The input stream "FINISH" is used to signal the FlowLimiterCalculator // The input stream "FINISHED" is used to signal the FlowLimiterCalculator
// when a frame is finished processing. Either a non-empty "FINISH" packet // when a frame is finished processing. Either a non-empty "FINISHED" packet
// or a timestamp bound should be received for each processed frame. // or a timestamp bound should be received for each processed frame.
// //
// The combination of `max_in_flight: 1` and `max_in_queue: 1` generally gives // The combination of `max_in_flight: 1` and `max_in_queue: 1` generally gives

View File

@ -16,9 +16,10 @@
// For every packet that appears in B, outputs the most recent packet from each // For every packet that appears in B, outputs the most recent packet from each
// of the A_i on a separate stream. // of the A_i on a separate stream.
#include <string_view>
#include <vector> #include <vector>
#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h"
#include "mediapipe/calculators/core/packet_cloner_calculator.pb.h" #include "mediapipe/calculators/core/packet_cloner_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -34,7 +35,18 @@ namespace mediapipe {
// calculator: "PacketClonerCalculator" // calculator: "PacketClonerCalculator"
// input_stream: "first_base_signal" // input_stream: "first_base_signal"
// input_stream: "second_base_signal" // input_stream: "second_base_signal"
// input_stream: "tick_signal" // input_stream: "tick_signal" # or input_stream: "TICK:tick_signal"
// output_stream: "cloned_first_base_signal"
// output_stream: "cloned_second_base_signal"
// }
//
// Or you can use "TICK" tag and put corresponding input stream at any location,
// for example at the very beginning:
// node {
// calculator: "PacketClonerCalculator"
// input_stream: "TICK:tick_signal"
// input_stream: "first_base_signal"
// input_stream: "second_base_signal"
// output_stream: "cloned_first_base_signal" // output_stream: "cloned_first_base_signal"
// output_stream: "cloned_second_base_signal" // output_stream: "cloned_second_base_signal"
// } // }
@ -46,12 +58,13 @@ namespace mediapipe {
class PacketClonerCalculator : public CalculatorBase { class PacketClonerCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
const int tick_signal_index = cc->Inputs().NumEntries() - 1; const Ids ids = GetIds(*cc);
for (int i = 0; i < tick_signal_index; ++i) { for (const auto& in_out : ids.inputs_outputs) {
cc->Inputs().Index(i).SetAny(); auto& input = cc->Inputs().Get(in_out.in);
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i)); input.SetAny();
cc->Outputs().Get(in_out.out).SetSameAs(&input);
} }
cc->Inputs().Index(tick_signal_index).SetAny(); cc->Inputs().Get(ids.tick_id).SetAny();
return absl::OkStatus(); return absl::OkStatus();
} }
@ -65,13 +78,15 @@ class PacketClonerCalculator : public CalculatorBase {
output_empty_packets_before_all_inputs_received_ = output_empty_packets_before_all_inputs_received_ =
calculator_options.output_packets_only_when_all_inputs_received(); calculator_options.output_packets_only_when_all_inputs_received();
// Parse input streams. // Prepare input and output ids.
tick_signal_index_ = cc->Inputs().NumEntries() - 1; ids_ = GetIds(*cc);
current_.resize(tick_signal_index_); current_.resize(ids_.inputs_outputs.size());
// Pass along the header for each stream if present. // Pass along the header for each stream if present.
for (int i = 0; i < tick_signal_index_; ++i) { for (const auto& in_out : ids_.inputs_outputs) {
if (!cc->Inputs().Index(i).Header().IsEmpty()) { auto& input = cc->Inputs().Get(in_out.in);
cc->Outputs().Index(i).SetHeader(cc->Inputs().Index(i).Header()); if (!input.Header().IsEmpty()) {
cc->Outputs().Get(in_out.out).SetHeader(input.Header());
} }
} }
return absl::OkStatus(); return absl::OkStatus();
@ -79,17 +94,18 @@ class PacketClonerCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final {
// Store input signals. // Store input signals.
for (int i = 0; i < tick_signal_index_; ++i) { for (int i = 0; i < ids_.inputs_outputs.size(); ++i) {
if (!cc->Inputs().Index(i).Value().IsEmpty()) { const auto& input = cc->Inputs().Get(ids_.inputs_outputs[i].in);
current_[i] = cc->Inputs().Index(i).Value(); if (!input.IsEmpty()) {
current_[i] = input.Value();
} }
} }
// Output according to the TICK signal. // Output according to the TICK signal.
if (!cc->Inputs().Index(tick_signal_index_).Value().IsEmpty()) { if (!cc->Inputs().Get(ids_.tick_id).IsEmpty()) {
if (output_only_when_all_inputs_received_) { if (output_only_when_all_inputs_received_) {
// Return if one of the input is null. // Return if one of the input is null.
for (int i = 0; i < tick_signal_index_; ++i) { for (int i = 0; i < ids_.inputs_outputs.size(); ++i) {
if (current_[i].IsEmpty()) { if (current_[i].IsEmpty()) {
if (output_empty_packets_before_all_inputs_received_) { if (output_empty_packets_before_all_inputs_received_) {
SetAllNextTimestampBounds(cc); SetAllNextTimestampBounds(cc);
@ -99,12 +115,12 @@ class PacketClonerCalculator : public CalculatorBase {
} }
} }
// Output each stream. // Output each stream.
for (int i = 0; i < tick_signal_index_; ++i) { for (int i = 0; i < ids_.inputs_outputs.size(); ++i) {
auto& output = cc->Outputs().Get(ids_.inputs_outputs[i].out);
if (!current_[i].IsEmpty()) { if (!current_[i].IsEmpty()) {
cc->Outputs().Index(i).AddPacket( output.AddPacket(current_[i].At(cc->InputTimestamp()));
current_[i].At(cc->InputTimestamp()));
} else { } else {
cc->Outputs().Index(i).SetNextTimestampBound( output.SetNextTimestampBound(
cc->InputTimestamp().NextAllowedInStream()); cc->InputTimestamp().NextAllowedInStream());
} }
} }
@ -113,15 +129,44 @@ class PacketClonerCalculator : public CalculatorBase {
} }
private: private:
struct Ids {
struct InputOutput {
CollectionItemId in;
CollectionItemId out;
};
CollectionItemId tick_id;
std::vector<InputOutput> inputs_outputs;
};
template <typename CC>
static Ids GetIds(CC& cc) {
Ids ids;
static constexpr absl::string_view kEmptyTag = "";
int num_inputs_to_clone = cc.Inputs().NumEntries(kEmptyTag);
static constexpr absl::string_view kTickTag = "TICK";
if (cc.Inputs().HasTag(kTickTag)) {
ids.tick_id = cc.Inputs().GetId(kTickTag, 0);
} else {
--num_inputs_to_clone;
ids.tick_id = cc.Inputs().GetId(kEmptyTag, num_inputs_to_clone);
}
for (int i = 0; i < num_inputs_to_clone; ++i) {
ids.inputs_outputs.push_back({.in = cc.Inputs().GetId(kEmptyTag, i),
.out = cc.Outputs().GetId(kEmptyTag, i)});
}
return ids;
}
void SetAllNextTimestampBounds(CalculatorContext* cc) { void SetAllNextTimestampBounds(CalculatorContext* cc) {
for (int j = 0; j < tick_signal_index_; ++j) { for (const auto& in_out : ids_.inputs_outputs) {
cc->Outputs().Index(j).SetNextTimestampBound( cc->Outputs()
cc->InputTimestamp().NextAllowedInStream()); .Get(in_out.out)
.SetNextTimestampBound(cc->InputTimestamp().NextAllowedInStream());
} }
} }
std::vector<Packet> current_; std::vector<Packet> current_;
int tick_signal_index_; Ids ids_;
bool output_only_when_all_inputs_received_; bool output_only_when_all_inputs_received_;
bool output_empty_packets_before_all_inputs_received_; bool output_empty_packets_before_all_inputs_received_;
}; };

View File

@ -0,0 +1,349 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/str_cat.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
#include "mediapipe/framework/tool/sink.h"
namespace mediapipe {
namespace {
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::Value;
MATCHER_P2(IntPacket, value, ts, "") {
return Value(arg.template Get<int>(), Eq(value)) &&
Value(arg.Timestamp(), Eq(Timestamp(ts)));
}
MATCHER_P2(FloatPacket, value, ts, "") {
return Value(arg.template Get<float>(), Eq(value)) &&
Value(arg.Timestamp(), Eq(Timestamp(ts)));
}
template <typename T>
absl::Status SendPacket(const std::string& input_name, T value, int ts,
CalculatorGraph& graph) {
return graph.AddPacketToInputStream(input_name,
MakePacket<T>(value).At(Timestamp(ts)));
}
struct Params {
bool use_tick_tag = false;
};
class PacketClonerCalculatorTest : public testing::TestWithParam<Params> {};
TEST_P(PacketClonerCalculatorTest, ClonesSingleInputSameTimestamps) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>([&]() {
if (GetParam().use_tick_tag) {
return R"pb(
input_stream: 'in1'
input_stream: 'tick'
node {
calculator: 'PacketClonerCalculator'
input_stream: 'in1'
input_stream: 'TICK:tick'
output_stream: 'out1'
})pb";
}
return R"pb(
input_stream: 'in1'
input_stream: 'tick'
node {
calculator: 'PacketClonerCalculator'
input_stream: 'in1'
input_stream: 'tick'
output_stream: 'out1'
})pb";
}());
std::vector<Packet> out1;
tool::AddVectorSink("out1", &graph_config, &out1);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendPacket("in1", 1, /*ts=*/10000, graph));
MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(out1, ElementsAre(IntPacket(1, 10000)));
}
TEST_P(PacketClonerCalculatorTest, ClonesSingleInputEarlierTimestamps) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>([&]() {
if (GetParam().use_tick_tag) {
return R"pb(
input_stream: 'in1'
input_stream: 'tick'
node {
calculator: 'PacketClonerCalculator'
input_stream: 'in1'
input_stream: 'TICK:tick'
output_stream: 'out1'
})pb";
}
return R"pb(
input_stream: 'in1'
input_stream: 'tick'
node {
calculator: 'PacketClonerCalculator'
input_stream: 'in1'
input_stream: 'tick'
output_stream: 'out1'
})pb";
}());
std::vector<Packet> out1;
tool::AddVectorSink("out1", &graph_config, &out1);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
// PacketClonerCalculator is non-ImmediateInputStreamHandler
// PacketClonerCalculator waits for "in1" to arrive for ts=5000
MP_ASSERT_OK(SendPacket("in1", 1, /*ts=*/5000, graph));
// Newer tick at ts=10000, should NOT trigger output for ts=5000
// PacketClonerCalculator waits for "in1" to arrive for ts=10000
MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph));
MP_ASSERT_OK(SendPacket("tick", 1001, /*ts=*/10001, graph));
MP_ASSERT_OK(SendPacket("tick", 1002, /*ts=*/10002, graph));
// Newer "in1" at ts=15000, should trigger output for ts=10000
MP_ASSERT_OK(SendPacket("in1", 2, /*ts=*/15000, graph));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(out1, ElementsAre(IntPacket(1, 10000), IntPacket(1, 10001),
IntPacket(1, 10002)));
}
TEST_P(PacketClonerCalculatorTest, ClonesFiveInputs) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>([&]() {
if (GetParam().use_tick_tag) {
return R"pb(
input_stream: 'in1'
input_stream: 'in2'
input_stream: 'in3'
input_stream: 'in4'
input_stream: 'in5'
input_stream: 'tick'
node {
calculator: 'PacketClonerCalculator'
input_stream: 'in1'
input_stream: 'in2'
input_stream: 'in3'
input_stream: 'in4'
input_stream: 'in5'
output_stream: 'out1'
output_stream: 'out2'
output_stream: 'out3'
input_stream: 'TICK:tick' # arbitrary location
output_stream: 'out4'
output_stream: 'out5'
}
)pb";
}
return R"pb(
input_stream: 'in1'
input_stream: 'in2'
input_stream: 'in3'
input_stream: 'in4'
input_stream: 'in5'
input_stream: 'tick'
node {
calculator: 'PacketClonerCalculator'
input_stream: 'in1'
input_stream: 'in2'
input_stream: 'in3'
input_stream: 'in4'
input_stream: 'in5'
input_stream: 'tick'
output_stream: 'out1'
output_stream: 'out2'
output_stream: 'out3'
output_stream: 'out4'
output_stream: 'out5'
}
)pb";
}());
constexpr int kNumToClone = 5;
std::array<std::vector<Packet>, kNumToClone> outs;
for (int i = 0; i < kNumToClone; ++i) {
tool::AddVectorSink(absl::StrCat("out", i + 1), &graph_config, &outs[i]);
}
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(SendPacket("in1", 10, /*ts=*/10000, graph));
MP_ASSERT_OK(SendPacket("in2", 20.0f, /*ts=*/10000, graph));
MP_ASSERT_OK(SendPacket("in3", 30, /*ts=*/10000, graph));
MP_ASSERT_OK(SendPacket("in4", 40.0f, /*ts=*/10000, graph));
MP_ASSERT_OK(SendPacket("in5", 50, /*ts=*/10000, graph));
MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph));
// Below "tick" packets won't trigger output, until newer inputs are sent,
// because inputs are missing and ImmediateInputStreamHandler is not
// configured.
MP_ASSERT_OK(SendPacket("tick", 1001, /*ts=*/10001, graph));
MP_ASSERT_OK(SendPacket("tick", 1002, /*ts=*/10002, graph));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(outs, ElementsAre(ElementsAre(IntPacket(10, 10000)),
ElementsAre(FloatPacket(20.0f, 10000)),
ElementsAre(IntPacket(30, 10000)),
ElementsAre(FloatPacket(40.0f, 10000)),
ElementsAre(IntPacket(50, 10000))));
MP_ASSERT_OK(SendPacket("in1", 100, /*ts=*/20000, graph));
MP_ASSERT_OK(SendPacket("in2", 200.0f, /*ts=*/20000, graph));
MP_ASSERT_OK(SendPacket("in3", 300, /*ts=*/20000, graph));
MP_ASSERT_OK(SendPacket("in4", 400.0f, /*ts=*/20000, graph));
MP_ASSERT_OK(SendPacket("in5", 500, /*ts=*/20000, graph));
MP_ASSERT_OK(SendPacket("tick", 2000, /*ts=*/20000, graph));
// Below "tick" packets won't trigger output, because inputs are missing and
// ImmediateInputStreamHandler is not configured.
MP_ASSERT_OK(SendPacket("tick", 2001, /*ts=*/20001, graph));
MP_ASSERT_OK(SendPacket("tick", 2002, /*ts=*/20002, graph));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
outs,
ElementsAre(
ElementsAre(IntPacket(10, 10000), IntPacket(10, 10001),
IntPacket(10, 10002), IntPacket(100, 20000)),
ElementsAre(FloatPacket(20.0f, 10000), FloatPacket(20.0f, 10001),
FloatPacket(20.0f, 10002), FloatPacket(200.0f, 20000)),
ElementsAre(IntPacket(30, 10000), IntPacket(30, 10001),
IntPacket(30, 10002), IntPacket(300, 20000)),
ElementsAre(FloatPacket(40.0f, 10000), FloatPacket(40.0f, 10001),
FloatPacket(40.0f, 10002), FloatPacket(400.0f, 20000)),
ElementsAre(IntPacket(50, 10000), IntPacket(50, 10001),
IntPacket(50, 10002), IntPacket(500, 20000))));
}
TEST_P(PacketClonerCalculatorTest,
ClonesTwoInputsWithImmediateInputStreamHandler) {
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>([&]() {
if (GetParam().use_tick_tag) {
return R"pb(
input_stream: 'in1'
input_stream: 'in2'
input_stream: 'tick'
node {
calculator: 'PacketClonerCalculator'
input_stream: 'TICK:tick'
input_stream: 'in1'
input_stream: 'in2'
output_stream: 'out1'
output_stream: 'out2'
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
})pb";
}
return R"pb(
input_stream: 'in1'
input_stream: 'in2'
input_stream: 'tick'
node {
calculator: 'PacketClonerCalculator'
input_stream: 'in1'
input_stream: 'in2'
input_stream: 'tick'
output_stream: 'out1'
output_stream: 'out2'
input_stream_handler {
input_stream_handler: "ImmediateInputStreamHandler"
}
})pb";
}());
constexpr int kNumToClone = 2;
std::array<std::vector<Packet>, kNumToClone> outs;
for (int i = 0; i < kNumToClone; ++i) {
tool::AddVectorSink(absl::StrCat("out", i + 1), &graph_config, &outs[i]);
}
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
MP_ASSERT_OK(graph.StartRun({}));
// No packets to clone.
MP_ASSERT_OK(SendPacket("tick", 0, /*ts=*/0, graph));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Cloning current packets.
MP_ASSERT_OK(SendPacket("in1", 1, /*ts=*/10000, graph));
MP_ASSERT_OK(SendPacket("in2", 10.0f, /*ts=*/10000, graph));
MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Cloning past packets.
MP_ASSERT_OK(SendPacket("tick", 1500, /*ts=*/15000, graph));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Cloning past packets.
MP_ASSERT_OK(SendPacket("in1", 2, /*ts=*/10001, graph));
MP_ASSERT_OK(SendPacket("in2", 20.0f, /*ts=*/10001, graph));
MP_ASSERT_OK(SendPacket("tick", 2000, /*ts=*/20000, graph));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Cloning future packets.
MP_ASSERT_OK(SendPacket("in1", 3, /*ts=*/30000, graph));
MP_ASSERT_OK(SendPacket("in2", 30.0f, /*ts=*/30000, graph));
// Waiting to ensure newer packets (ts=30000) to clone would get into the
// cloner before tick (ts=25000) does.
MP_ASSERT_OK(graph.WaitUntilIdle());
MP_ASSERT_OK(SendPacket("tick", 3000, /*ts=*/25000, graph));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Cloning packets having different timestamps.
MP_ASSERT_OK(SendPacket("in1", 4, /*ts=*/38000, graph));
MP_ASSERT_OK(SendPacket("in2", 40.0f, /*ts=*/39000, graph));
MP_ASSERT_OK(SendPacket("tick", 4000, /*ts=*/40000, graph));
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_THAT(
outs,
ElementsAre(
ElementsAre(IntPacket(1, 10000), IntPacket(1, 15000),
IntPacket(2, 20000), IntPacket(3, 25000),
IntPacket(4, 40000)),
ElementsAre(FloatPacket(10.0f, 10000), FloatPacket(10.0f, 15000),
FloatPacket(20.0f, 20000), FloatPacket(30.0f, 25000),
FloatPacket(40.0f, 40000))));
}
INSTANTIATE_TEST_SUITE_P(PacketClonerCalculator, PacketClonerCalculatorTest,
testing::ValuesIn({Params{.use_tick_tag = false},
Params{.use_tick_tag = true}}));
} // anonymous namespace
} // namespace mediapipe

View File

@ -157,9 +157,7 @@ absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
} }
} }
if (absl::Status status = strategy_->Process(cc); !status.ok()) { MP_RETURN_IF_ERROR(strategy_->Process(cc));
return status; // Avoid MP_RETURN_IF_ERROR macro for external release.
}
last_packet_ = cc->Inputs().Get(input_data_id_).Value(); last_packet_ = cc->Inputs().Get(input_data_id_).Value();

View File

@ -626,11 +626,8 @@ cc_library(
"//mediapipe/framework/formats:image_format_cc_proto", "//mediapipe/framework/formats:image_format_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_opencv",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",
] + select({ ] + select({
@ -641,6 +638,13 @@ cc_library(
"//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:shader_util", "//mediapipe/gpu:shader_util",
], ],
}) + select({
"//mediapipe/framework/port:disable_opencv": [],
"//conditions:default": [
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/formats:image_opencv",
"//mediapipe/framework/port:opencv_core",
],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -727,7 +731,6 @@ cc_library(
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":affine_transformation", ":affine_transformation",
":affine_transformation_runner_opencv",
":warp_affine_calculator_cc_proto", ":warp_affine_calculator_cc_proto",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
@ -745,6 +748,9 @@ cc_library(
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
":affine_transformation_runner_gl", ":affine_transformation_runner_gl",
], ],
}) + select({
"//mediapipe/framework/port:disable_opencv": [],
"//conditions:default": [":affine_transformation_runner_opencv"],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -799,3 +805,21 @@ cc_test(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
cc_library(
name = "yuv_to_image_calculator",
srcs = ["yuv_to_image_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:yuv_image",
"//third_party/libyuv",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
],
alwayslink = 1,
)

View File

@ -21,10 +21,7 @@
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_format.pb.h" #include "mediapipe/framework/formats/image_format.pb.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/image_opencv.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h" #include "mediapipe/framework/port/vector.h"
@ -34,6 +31,12 @@
#include "mediapipe/gpu/shader_util.h" #include "mediapipe/gpu/shader_util.h"
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
#if !MEDIAPIPE_DISABLE_OPENCV
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/formats/image_opencv.h"
#include "mediapipe/framework/port/opencv_core_inc.h"
#endif // !MEDIAPIPE_DISABLE_OPENCV
namespace mediapipe { namespace mediapipe {
namespace { namespace {
@ -163,7 +166,11 @@ absl::Status SegmentationSmoothingCalculator::Process(CalculatorContext* cc) {
return absl::InternalError("GPU processing is disabled."); return absl::InternalError("GPU processing is disabled.");
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else {
#if !MEDIAPIPE_DISABLE_OPENCV
MP_RETURN_IF_ERROR(RenderCpu(cc)); MP_RETURN_IF_ERROR(RenderCpu(cc));
#else
return absl::InternalError("OpenCV processing is disabled.");
#endif // !MEDIAPIPE_DISABLE_OPENCV
} }
return absl::OkStatus(); return absl::OkStatus();
@ -181,6 +188,7 @@ absl::Status SegmentationSmoothingCalculator::Close(CalculatorContext* cc) {
} }
absl::Status SegmentationSmoothingCalculator::RenderCpu(CalculatorContext* cc) { absl::Status SegmentationSmoothingCalculator::RenderCpu(CalculatorContext* cc) {
#if !MEDIAPIPE_DISABLE_OPENCV
// Setup source images. // Setup source images.
const auto& current_frame = cc->Inputs().Tag(kCurrentMaskTag).Get<Image>(); const auto& current_frame = cc->Inputs().Tag(kCurrentMaskTag).Get<Image>();
auto current_mat = mediapipe::formats::MatView(&current_frame); auto current_mat = mediapipe::formats::MatView(&current_frame);
@ -245,6 +253,7 @@ absl::Status SegmentationSmoothingCalculator::RenderCpu(CalculatorContext* cc) {
cc->Outputs() cc->Outputs()
.Tag(kOutputMaskTag) .Tag(kOutputMaskTag)
.AddPacket(MakePacket<Image>(output_frame).At(cc->InputTimestamp())); .AddPacket(MakePacket<Image>(output_frame).At(cc->InputTimestamp()));
#endif // !MEDIAPIPE_DISABLE_OPENCV
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -24,7 +24,9 @@
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#if !MEDIAPIPE_DISABLE_OPENCV
#include "mediapipe/calculators/image/affine_transformation_runner_opencv.h" #include "mediapipe/calculators/image/affine_transformation_runner_opencv.h"
#endif // !MEDIAPIPE_DISABLE_OPENCV
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h" #include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
#include "mediapipe/framework/api2/node.h" #include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
@ -54,6 +56,7 @@ AffineTransformation::BorderMode GetBorderMode(
template <typename ImageT> template <typename ImageT>
class WarpAffineRunnerHolder {}; class WarpAffineRunnerHolder {};
#if !MEDIAPIPE_DISABLE_OPENCV
template <> template <>
class WarpAffineRunnerHolder<ImageFrame> { class WarpAffineRunnerHolder<ImageFrame> {
public: public:
@ -69,6 +72,7 @@ class WarpAffineRunnerHolder<ImageFrame> {
private: private:
std::unique_ptr<RunnerType> runner_; std::unique_ptr<RunnerType> runner_;
}; };
#endif // !MEDIAPIPE_DISABLE_OPENCV
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
template <> template <>
@ -113,7 +117,9 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
mediapipe::Image> { mediapipe::Image> {
public: public:
absl::Status Open(CalculatorContext* cc) { absl::Status Open(CalculatorContext* cc) {
#if !MEDIAPIPE_DISABLE_OPENCV
MP_RETURN_IF_ERROR(cpu_holder_.Open(cc)); MP_RETURN_IF_ERROR(cpu_holder_.Open(cc));
#endif // !MEDIAPIPE_DISABLE_OPENCV
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(gpu_holder_.Open(cc)); MP_RETURN_IF_ERROR(gpu_holder_.Open(cc));
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
@ -133,6 +139,7 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
return absl::UnavailableError("GPU support is disabled"); return absl::UnavailableError("GPU support is disabled");
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
} }
#if !MEDIAPIPE_DISABLE_OPENCV
ASSIGN_OR_RETURN(auto* runner, cpu_holder_.GetRunner()); ASSIGN_OR_RETURN(auto* runner, cpu_holder_.GetRunner());
const auto& frame_ptr = input.GetImageFrameSharedPtr(); const auto& frame_ptr = input.GetImageFrameSharedPtr();
// Wrap image into image frame. // Wrap image into image frame.
@ -143,10 +150,15 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
ASSIGN_OR_RETURN(auto result, ASSIGN_OR_RETURN(auto result,
runner->Run(image_frame, matrix, size, border_mode)); runner->Run(image_frame, matrix, size, border_mode));
return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result))); return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result)));
#else
return absl::UnavailableError("OpenCV support is disabled");
#endif // !MEDIAPIPE_DISABLE_OPENCV
} }
private: private:
#if !MEDIAPIPE_DISABLE_OPENCV
WarpAffineRunnerHolder<ImageFrame> cpu_holder_; WarpAffineRunnerHolder<ImageFrame> cpu_holder_;
#endif // !MEDIAPIPE_DISABLE_OPENCV
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
WarpAffineRunnerHolder<mediapipe::GpuBuffer> gpu_holder_; WarpAffineRunnerHolder<mediapipe::GpuBuffer> gpu_holder_;
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
@ -200,8 +212,10 @@ class WarpAffineCalculatorImpl : public mediapipe::api2::NodeImpl<InterfaceT> {
} // namespace } // namespace
#if !MEDIAPIPE_DISABLE_OPENCV
MEDIAPIPE_NODE_IMPLEMENTATION( MEDIAPIPE_NODE_IMPLEMENTATION(
WarpAffineCalculatorImpl<WarpAffineCalculatorCpu>); WarpAffineCalculatorImpl<WarpAffineCalculatorCpu>);
#endif // !MEDIAPIPE_DISABLE_OPENCV
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
MEDIAPIPE_NODE_IMPLEMENTATION( MEDIAPIPE_NODE_IMPLEMENTATION(
WarpAffineCalculatorImpl<WarpAffineCalculatorGpu>); WarpAffineCalculatorImpl<WarpAffineCalculatorGpu>);

View File

@ -70,11 +70,13 @@ class WarpAffineCalculatorIntf : public mediapipe::api2::NodeIntf {
static constexpr mediapipe::api2::Output<ImageT> kOutImage{"IMAGE"}; static constexpr mediapipe::api2::Output<ImageT> kOutImage{"IMAGE"};
}; };
#if !MEDIAPIPE_DISABLE_OPENCV
class WarpAffineCalculatorCpu : public WarpAffineCalculatorIntf<ImageFrame> { class WarpAffineCalculatorCpu : public WarpAffineCalculatorIntf<ImageFrame> {
public: public:
MEDIAPIPE_NODE_INTERFACE(WarpAffineCalculatorCpu, kInImage, kMatrix, MEDIAPIPE_NODE_INTERFACE(WarpAffineCalculatorCpu, kInImage, kMatrix,
kOutputSize, kOutImage); kOutputSize, kOutImage);
}; };
#endif // !MEDIAPIPE_DISABLE_OPENCV
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
class WarpAffineCalculatorGpu class WarpAffineCalculatorGpu
: public WarpAffineCalculatorIntf<mediapipe::GpuBuffer> { : public WarpAffineCalculatorIntf<mediapipe::GpuBuffer> {

View File

@ -0,0 +1,123 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <utility>
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "libyuv/convert_argb.h"
#include "libyuv/video_common.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/yuv_image.h"
namespace mediapipe {
namespace api2 {
namespace {
// Utility function to convert FourCC enum to string, for error messages.
std::string FourCCToString(libyuv::FourCC fourcc) {
char buf[5];
buf[0] = (fourcc >> 24) & 0xff;
buf[1] = (fourcc >> 16) & 0xff;
buf[2] = (fourcc >> 8) & 0xff;
buf[3] = (fourcc)&0xff;
buf[4] = 0;
return std::string(buf);
}
} // namespace
// Converts a `YUVImage` into an RGB `Image` using libyuv.
//
// The input `YUVImage` is expected to be in the NV12, NV21, YV12 or I420 (aka
// YV21) format (as per the `fourcc()` property). This covers the most commonly
// used YUV image formats used on mobile devices. Other formats are not
// supported and wil result in an `InvalidArgumentError`.
class YUVToImageCalculator : public Node {
public:
static constexpr Input<YUVImage> kInput{"YUV_IMAGE"};
static constexpr Output<Image> kOutput{"IMAGE"};
MEDIAPIPE_NODE_CONTRACT(kInput, kOutput);
absl::Status Process(CalculatorContext* cc) override {
const auto& yuv_image = *kInput(cc);
// Check that the format is supported.
auto format = yuv_image.fourcc();
if (format != libyuv::FOURCC_NV12 && format != libyuv::FOURCC_NV21 &&
format != libyuv::FOURCC_YV12 && format != libyuv::FOURCC_I420) {
return absl::InvalidArgumentError(
absl::StrFormat("Unsupported YUVImage format: %s. Only NV12, NV21, "
"YV12 and I420 (aka YV21) are supported.",
FourCCToString(format)));
}
// Build a transient ImageFrameSharedPtr with default alignment to host
// conversion results.
ImageFrameSharedPtr image_frame = std::make_shared<ImageFrame>(
ImageFormat::SRGB, yuv_image.width(), yuv_image.height());
// Perform actual conversion.
switch (format) {
case libyuv::FOURCC_NV12:
// 8-bit Y plane followed by an interleaved 8-bit U/V plane with 2×2
// subsampling.
libyuv::NV12ToRAW(
yuv_image.data(0), yuv_image.stride(0), yuv_image.data(1),
yuv_image.stride(1), image_frame->MutablePixelData(),
image_frame->WidthStep(), yuv_image.width(), yuv_image.height());
break;
case libyuv::FOURCC_NV21:
// 8-bit Y plane followed by an interleaved 8-bit V/U plane with 2×2
// subsampling.
libyuv::NV21ToRAW(
yuv_image.data(0), yuv_image.stride(0), yuv_image.data(1),
yuv_image.stride(1), image_frame->MutablePixelData(),
image_frame->WidthStep(), yuv_image.width(), yuv_image.height());
break;
case libyuv::FOURCC_I420:
// Also known as YV21.
// 8-bit Y plane followed by 8-bit 2×2 subsampled U and V planes.
libyuv::I420ToRAW(
yuv_image.data(0), yuv_image.stride(0), yuv_image.data(1),
yuv_image.stride(1), yuv_image.data(2), yuv_image.stride(2),
image_frame->MutablePixelData(), image_frame->WidthStep(),
yuv_image.width(), yuv_image.height());
break;
case libyuv::FOURCC_YV12:
// 8-bit Y plane followed by 8-bit 2×2 subsampled V and U planes.
libyuv::I420ToRAW(
yuv_image.data(0), yuv_image.stride(0), yuv_image.data(2),
yuv_image.stride(2), yuv_image.data(1), yuv_image.stride(1),
image_frame->MutablePixelData(), image_frame->WidthStep(),
yuv_image.width(), yuv_image.height());
break;
default:
// This should never happen (caught by checks above).
return absl::InternalError("Unsupported YUVImage format.");
}
// Finally, build and send an Image object that takes ownership of the
// transient ImageFrameSharedPtr object.
kOutput(cc).Send(std::make_unique<Image>(std::move(image_frame)));
return absl::OkStatus();
}
};
MEDIAPIPE_REGISTER_NODE(YUVToImageCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -40,6 +40,63 @@ selects.config_setting_group(
], ],
) )
mediapipe_proto_library(
name = "audio_to_tensor_calculator_proto",
srcs = ["audio_to_tensor_calculator.proto"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
"//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto",
],
)
cc_library(
name = "audio_to_tensor_calculator",
srcs = ["audio_to_tensor_calculator.cc"],
visibility = [
"//mediapipe/framework:mediapipe_internal",
],
deps = [
":audio_to_tensor_calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/api2:port",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/formats:time_series_header_cc_proto",
"//mediapipe/util:time_series_util",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_audio_tools//audio/dsp:resampler_q",
"@org_tensorflow//tensorflow/lite/c:common",
],
alwayslink = 1,
)
cc_test(
name = "audio_to_tensor_calculator_test",
srcs = ["audio_to_tensor_calculator_test.cc"],
deps = [
":audio_to_tensor_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:timestamp",
"//mediapipe/framework/api2:packet",
"//mediapipe/framework/formats:matrix",
"//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/strings",
"@com_google_audio_tools//audio/dsp:resampler_q",
"@org_tensorflow//tensorflow/lite/c:common",
],
)
mediapipe_proto_library( mediapipe_proto_library(
name = "inference_calculator_proto", name = "inference_calculator_proto",
srcs = ["inference_calculator.proto"], srcs = ["inference_calculator.proto"],
@ -50,6 +107,14 @@ mediapipe_proto_library(
], ],
) )
# This target defines the "InferenceCalculator" component, which looks for the available concrete
# implementations linked into the current binary and picks the one to use.
# You can depend on :inference_calculator instead if you want to automatically include a default
# set of implementations tailored for the current build configuration.
# If you want to have precise control of which implementations to include (e.g. for strict binary
# size concerns), depend on those implementations directly, and do not depend on
# :inference_calculator.
# In all cases, use "InferenceCalulator" in your graphs.
cc_library( cc_library(
name = "inference_calculator_interface", name = "inference_calculator_interface",
srcs = ["inference_calculator.cc"], srcs = ["inference_calculator.cc"],
@ -62,8 +127,9 @@ cc_library(
], ],
"//conditions:default": [], "//conditions:default": [],
}), }),
visibility = ["//visibility:public"],
deps = [ deps = [
":inference_calculator_cc_proto", ":inference_calculator_options_lib",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/api2:node", "//mediapipe/framework/api2:node",
"//mediapipe/framework/api2:packet", "//mediapipe/framework/api2:packet",
@ -85,18 +151,31 @@ cc_library(
name = "inference_calculator_gl", name = "inference_calculator_gl",
srcs = ["inference_calculator_gl.cc"], srcs = ["inference_calculator_gl.cc"],
tags = ["nomac"], # config problem with cpuinfo via TF tags = ["nomac"], # config problem with cpuinfo via TF
visibility = ["//visibility:public"],
deps = [ deps = [
"inference_calculator_interface", ":inference_calculator_interface",
"//mediapipe/framework/deps:file_path",
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/util/tflite:config",
"//mediapipe/util/tflite:tflite_gpu_runner",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/status", "@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite:framework_stable", "@org_tensorflow//tensorflow/lite:framework_stable",
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape", ],
alwayslink = 1,
)
cc_library(
name = "inference_calculator_gl_advanced",
srcs = ["inference_calculator_gl_advanced.cc"],
tags = ["nomac"],
visibility = ["//visibility:public"],
deps = [
":inference_calculator_interface",
"//mediapipe/framework/deps:file_path",
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/util/tflite:tflite_gpu_runner",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@org_tensorflow//tensorflow/lite:framework_stable",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -113,6 +192,7 @@ cc_library(
"-framework MetalKit", "-framework MetalKit",
], ],
tags = ["ios"], tags = ["ios"],
visibility = ["//visibility:public"],
deps = [ deps = [
"inference_calculator_interface", "inference_calculator_interface",
"//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:MPPMetalHelper",
@ -142,6 +222,7 @@ cc_library(
], ],
"//conditions:default": [], "//conditions:default": [],
}), }),
visibility = ["//visibility:public"],
deps = [ deps = [
":inference_calculator_interface", ":inference_calculator_interface",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
@ -161,9 +242,13 @@ cc_library(
cc_library( cc_library(
name = "inference_calculator_gl_if_compute_shader_available", name = "inference_calculator_gl_if_compute_shader_available",
visibility = ["//visibility:public"],
deps = selects.with_or({ deps = selects.with_or({
":compute_shader_unavailable": [], ":compute_shader_unavailable": [],
"//conditions:default": [":inference_calculator_gl"], "//conditions:default": [
":inference_calculator_gl",
":inference_calculator_gl_advanced",
],
}), }),
) )
@ -484,6 +569,7 @@ cc_library(
"//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/util:label_map_cc_proto",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
] + select({ ] + select({
"//mediapipe:android": [ "//mediapipe:android": [
@ -506,6 +592,7 @@ mediapipe_proto_library(
deps = [ deps = [
"//mediapipe/framework:calculator_options_proto", "//mediapipe/framework:calculator_options_proto",
"//mediapipe/framework:calculator_proto", "//mediapipe/framework:calculator_proto",
"//mediapipe/util:label_map_proto",
], ],
) )
@ -672,6 +759,7 @@ cc_library(
], ],
"//conditions:default": [], "//conditions:default": [],
}), }),
visibility = ["//visibility:public"],
deps = [ deps = [
":image_to_tensor_converter", ":image_to_tensor_converter",
":image_to_tensor_utils", ":image_to_tensor_utils",
@ -858,9 +946,7 @@ cc_library(
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"//mediapipe/framework/formats:image", "//mediapipe/framework/formats:image",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_opencv",
"//mediapipe/framework/formats:tensor", "//mediapipe/framework/formats:tensor",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework:calculator_context", "//mediapipe/framework:calculator_context",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
@ -890,6 +976,12 @@ cc_library(
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl/converters:util", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl/converters:util",
], ],
}) + select({
"//mediapipe/framework/port:disable_opencv": [],
"//conditions:default": [
"//mediapipe/framework/formats:image_opencv",
"//mediapipe/framework/port:opencv_imgproc",
],
}), }),
alwayslink = 1, alwayslink = 1,
) )

View File

@ -0,0 +1,401 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <algorithm>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "audio/dsp/resampler_q.h"
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/formats/time_series_header.pb.h"
#include "mediapipe/util/time_series_util.h"
namespace mediapipe {
namespace api2 {
// Converts audio buffers into tensors, possibly with resampling, buffering
// and framing, according to specified inputs and options. All input audio
// buffers will be first resampled from the input sample rate to the target
// sample rate if they are not equal. The resampled audio data (with the
// buffered samples from the previous runs in the streaming mode) will be broken
// into fixed-sized, possibly overlapping frames. Finally, all frames will be
// converted to and outputted as MediaPipe Tensors. The last output tensor will
// be zero-padding if the remaining samples are insufficient.
//
// This calculator assumes that the input timestamps refer to the first
// sample in each Matrix. The output timestamps follow this same convention.
// One Process() call may output multiple tensors packets. The timestamps of
// the output packets are determined by the timestamp of the previous output
// packet, the target sample rate, and the number of samples advanced after the
// previous output.
//
// The calculator has two running modes:
// Streaming mode: when "streaming_mode" is set to true in the calculator
// options, the calculator treats the input audio stream as a continuous
// stream. Thus, any samples that are not consumed in the previous runs will
// be cached in a global sample buffer. The audio data resampled from the
// current raw audio input will be appended to the global sample buffer.
// The calculator will process the global sample buffer and output as many
// tensors as possible.
// Non-streaming mode: when "streaming_mode" is set to false in the calculator
// options, the calculators treats the packets in the input audio stream as
// a batch of unrelated audio buffers. In each Process() call, the input
// buffer will be frist resampled, and framed as fixed-sized, possibly
// overlapping tensors. The last tensor produced by a Process() invocation
// will be zero-padding if the remaining samples are insufficient. As the
// calculator treats the input packets as unrelated, all samples will be
// processed immediately and no samples will be cached in the global sample
// buffer.
//
// Inputs:
// AUDIO - mediapipe::Matrix
// The audio data represented as mediapipe::Matrix.
// SAMPLE_RATE - double @Optional
// The sample rate of the corresponding audio data in the "AUDIO" stream.
// If a sample rate packet is provided at Timestamp::PreStream(), the sample
// rate will be used as the sample rate of every audio packets in the
// "AUDIO" stream. Note that one and only one of the "AUDIO" stream's time
// series header or the "SAMPLE_RATE" stream can exist.
//
// Outputs:
// TENSORS - std::vector<Tensor>
// Vector containing a single Tensor that represents a fix-sized audio
// frame.
// TIMESTAMPS - std::vector<Timestamp> @Optional
// Vector containing the output timestamps emitted by the current Process()
// invocation. In the non-streaming mode, the vector contains all of the
// output timestamps for an input audio buffer.
//
// Example:
// node {
// calculator: "AudioToTensorCalculator"
// input_stream: "AUDIO:audio"
// output_stream: "TENSORS:tensors"
// output_stream: "TIMESTAMPS:timestamps"
// options {
// [mediapipe.AudioToTensorCalculatorOptions.ext] {
// num_channels: 2
// num_samples: 512
// num_overlapping_samples: 64
// target_sample_rate: 16000
// streaming_mode: true # or false
// }
// }
// }
class AudioToTensorCalculator : public Node {
public:
static constexpr Input<Matrix> kAudioIn{"AUDIO"};
// TODO: Removes this optional input stream when the "AUDIO" stream
// uses the new mediapipe audio data containers that carry audio metatdata,
// such as sample rate.
static constexpr Input<double>::Optional kAudioSampleRateIn{"SAMPLE_RATE"};
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
// A vector of the output timestamps emitted by the current Process()
// invocation. The packet timestamp is the last emitted timestamp.
static constexpr Output<std::vector<Timestamp>>::Optional kTimestampsOut{
"TIMESTAMPS"};
MEDIAPIPE_NODE_CONTRACT(kAudioIn, kAudioSampleRateIn, kTensorsOut,
kTimestampsOut);
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc);
absl::Status Process(CalculatorContext* cc);
absl::Status Close(CalculatorContext* cc);
private:
// The target number of channels.
int num_channels_;
// The target number of samples per channel.
int num_samples_;
// The number of samples per channel to advance after the current frame is
// processed.
int frame_step_;
bool streaming_mode_;
bool check_inconsistent_timestamps_;
Timestamp initial_timestamp_ = Timestamp::Unstarted();
int64 cumulative_input_samples_ = 0;
Timestamp next_output_timestamp_ = Timestamp::Unstarted();
double source_sample_rate_ = -1;
double target_sample_rate_ = -1;
// TODO: Configures QResamplerParams through calculator options.
audio_dsp::QResamplerParams params_;
// A QResampler instance to resample an audio stream.
std::unique_ptr<audio_dsp::QResampler<float>> resampler_;
Matrix sample_buffer_;
int processed_buffer_cols_ = 0;
absl::Status ProcessStreamingData(CalculatorContext* cc);
absl::Status ProcessNonStreamingData(CalculatorContext* cc);
absl::Status SetupStreamingResampler(double input_sample_rate_);
void AppendToSampleBuffer(Matrix buffer_to_append);
absl::StatusOr<std::vector<Tensor>> ConvertToTensor(
const Matrix& frame_to_convert);
absl::Status OutputTensors(const Matrix& buffer, bool should_flush,
CalculatorContext* cc);
};
absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) {
const auto& options =
cc->Options<mediapipe::AudioToTensorCalculatorOptions>();
if (!options.has_num_channels() || !options.has_num_samples() ||
!options.has_target_sample_rate()) {
return absl::InvalidArgumentError(
"AudioToTensorCalculatorOptions must specifiy "
"`num_channels`, `num_samples`, and `target_sample_rate`.");
}
if (options.streaming_mode()) {
// Explicitly disables tiemstamp offset to disallow the timestamp bound
// from the input streams to be propagated to the output streams.
// In the streaming mode, the output timestamp bound is based on
// next_output_timestamp_, which can be smaller than the current input
// timestamps.
cc->SetTimestampOffset(TimestampDiff::Unset());
}
return absl::OkStatus();
}
absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
const auto& options =
cc->Options<mediapipe::AudioToTensorCalculatorOptions>();
num_channels_ = options.num_channels();
num_samples_ = options.num_samples();
if (options.has_num_overlapping_samples()) {
RET_CHECK_GE(options.num_overlapping_samples(), 0);
RET_CHECK_LT(options.num_overlapping_samples(), num_samples_);
frame_step_ = num_samples_ - options.num_overlapping_samples();
} else {
frame_step_ = num_samples_;
}
target_sample_rate_ = options.target_sample_rate();
streaming_mode_ = options.streaming_mode();
if (streaming_mode_) {
check_inconsistent_timestamps_ = options.check_inconsistent_timestamps();
sample_buffer_.resize(num_channels_, Eigen::NoChange);
}
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
!kAudioIn(cc).Header().IsEmpty())
<< "Must either specify the time series header of the \"AUDIO\" stream "
"or have the \"SAMPLE_RATE\" stream connected.";
if (!kAudioIn(cc).Header().IsEmpty()) {
mediapipe::TimeSeriesHeader input_header;
MP_RETURN_IF_ERROR(mediapipe::time_series_util::FillTimeSeriesHeaderIfValid(
kAudioIn(cc).Header(), &input_header));
if (streaming_mode_) {
MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate()));
} else {
source_sample_rate_ = input_header.sample_rate();
}
}
return absl::OkStatus();
}
absl::Status AudioToTensorCalculator::Process(CalculatorContext* cc) {
if (cc->InputTimestamp() == Timestamp::PreStream()) {
double current_source_sample_rate = kAudioSampleRateIn(cc).Get();
if (cc->Options<mediapipe::AudioToTensorCalculatorOptions>()
.streaming_mode()) {
return SetupStreamingResampler(current_source_sample_rate);
} else {
source_sample_rate_ = current_source_sample_rate;
return absl::OkStatus();
}
}
// Sanity checks.
const auto& input_frame = kAudioIn(cc).Get();
if (input_frame.rows() != num_channels_) {
return absl::InvalidArgumentError(absl::StrFormat(
"Audio input has %d channel(s) but the model requires %d channel(s).",
input_frame.rows(), num_channels_));
}
if (num_channels_ > 1 && input_frame.IsRowMajor) {
return absl::InvalidArgumentError(
"The audio data should be stored in column-major.");
}
return streaming_mode_ ? ProcessStreamingData(cc)
: ProcessNonStreamingData(cc);
}
absl::Status AudioToTensorCalculator::Close(CalculatorContext* cc) {
if (!streaming_mode_) {
return absl::OkStatus();
}
if (resampler_) {
Matrix resampled_buffer(num_channels_, 0);
resampler_->Flush(&resampled_buffer);
AppendToSampleBuffer(std::move(resampled_buffer));
}
return OutputTensors(sample_buffer_, /*should_flush=*/true, cc);
}
absl::Status AudioToTensorCalculator::ProcessStreamingData(
CalculatorContext* cc) {
const auto& input_buffer = kAudioIn(cc).Get();
if (initial_timestamp_ == Timestamp::Unstarted()) {
initial_timestamp_ = cc->InputTimestamp();
next_output_timestamp_ = initial_timestamp_;
}
if (source_sample_rate_ != -1 && check_inconsistent_timestamps_) {
mediapipe::time_series_util::LogWarningIfTimestampIsInconsistent(
cc->InputTimestamp(), initial_timestamp_, cumulative_input_samples_,
source_sample_rate_);
cumulative_input_samples_ += input_buffer.cols();
}
if (!kAudioSampleRateIn(cc).IsEmpty()) {
double current_source_sample_rate = kAudioSampleRateIn(cc).Get();
if (resampler_) {
RET_CHECK_EQ(current_source_sample_rate, source_sample_rate_);
} else {
MP_RETURN_IF_ERROR(SetupStreamingResampler(current_source_sample_rate));
}
}
if (resampler_) {
Matrix resampled_buffer(num_channels_, 0);
resampler_->ProcessSamples(input_buffer, &resampled_buffer);
AppendToSampleBuffer(std::move(resampled_buffer));
} else {
// Tries to consume the input matrix first to avoid extra data copy.
auto status_or_matrix = kAudioIn(cc).packet().Consume<Matrix>();
if (status_or_matrix.ok()) {
Matrix local_matrix(num_channels_, 0);
local_matrix.swap(*status_or_matrix.value());
AppendToSampleBuffer(std::move(local_matrix));
} else {
AppendToSampleBuffer(input_buffer);
}
}
MP_RETURN_IF_ERROR(OutputTensors(sample_buffer_, /*should_flush=*/false, cc));
// Removes the processed samples from the global sample buffer.
sample_buffer_ = Matrix(sample_buffer_.rightCols(sample_buffer_.cols() -
processed_buffer_cols_ - 1));
return absl::OkStatus();
}
absl::Status AudioToTensorCalculator::ProcessNonStreamingData(
CalculatorContext* cc) {
initial_timestamp_ = cc->InputTimestamp();
next_output_timestamp_ = initial_timestamp_;
const auto& input_frame = kAudioIn(cc).Get();
double source_sample_rate = kAudioSampleRateIn(cc).GetOr(source_sample_rate_);
if (source_sample_rate != -1 && source_sample_rate != target_sample_rate_) {
std::vector<float> resampled = audio_dsp::QResampleSignal<float>(
source_sample_rate, target_sample_rate_, num_channels_, params_,
input_frame);
Eigen::Map<const Matrix> matrix_mapping(resampled.data(), num_channels_,
resampled.size() / num_channels_);
return OutputTensors(matrix_mapping, /*should_flush=*/true, cc);
}
return OutputTensors(input_frame, /*should_flush=*/true, cc);
}
absl::Status AudioToTensorCalculator::SetupStreamingResampler(
double input_sample_rate) {
if (input_sample_rate == source_sample_rate_) {
return absl::OkStatus();
}
source_sample_rate_ = input_sample_rate;
if (source_sample_rate_ != target_sample_rate_) {
resampler_ = absl::make_unique<audio_dsp::QResampler<float>>(
source_sample_rate_, target_sample_rate_, num_channels_, params_);
if (!resampler_) {
return absl::InternalError("Failed to initialize resampler.");
}
}
return absl::OkStatus();
}
void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) {
sample_buffer_.conservativeResize(
Eigen::NoChange, sample_buffer_.cols() + buffer_to_append.cols());
sample_buffer_.rightCols(buffer_to_append.cols()).swap(buffer_to_append);
}
absl::StatusOr<std::vector<Tensor>> AudioToTensorCalculator::ConvertToTensor(
const Matrix& frame_to_convert) {
Tensor tensor(Tensor::ElementType::kFloat32,
Tensor::Shape({num_channels_, num_samples_}));
auto buffer_view = tensor.GetCpuWriteView();
if (frame_to_convert.size() < num_channels_ * num_samples_) {
std::memset(buffer_view.buffer<float>(), 0, tensor.bytes());
}
std::memcpy(buffer_view.buffer<float>(), frame_to_convert.data(),
frame_to_convert.size() * sizeof(float));
std::vector<Tensor> tensor_vector;
tensor_vector.push_back(std::move(tensor));
return tensor_vector;
}
absl::Status AudioToTensorCalculator::OutputTensors(const Matrix& buffer,
bool should_flush,
CalculatorContext* cc) {
int next_frame_first_col = 0;
std::vector<Timestamp> timestamps;
while ((!streaming_mode_ || !should_flush) &&
next_frame_first_col + num_samples_ <= buffer.cols()) {
ASSIGN_OR_RETURN(auto output_tensor, ConvertToTensor(buffer.block(
0, next_frame_first_col,
num_channels_, num_samples_)));
kTensorsOut(cc).Send(std::move(output_tensor), next_output_timestamp_);
timestamps.push_back(next_output_timestamp_);
next_output_timestamp_ += round(frame_step_ / target_sample_rate_ *
Timestamp::kTimestampUnitsPerSecond);
next_frame_first_col += frame_step_;
}
if (should_flush && next_frame_first_col < buffer.cols()) {
ASSIGN_OR_RETURN(auto output_tensor,
ConvertToTensor(buffer.block(
0, next_frame_first_col, num_channels_,
std::min(num_samples_,
(int)buffer.cols() - next_frame_first_col))));
// In the streaming mode, the flush happens in Close() and a packet at
// Timestamp::Max() will be emitted. In the non-streaming mode, each
// Process() invocation will process the entire buffer completely.
Timestamp timestamp =
streaming_mode_ ? Timestamp::Max() : next_output_timestamp_;
timestamps.push_back(timestamp);
kTensorsOut(cc).Send(std::move(output_tensor), timestamp);
}
if (kTimestampsOut(cc).IsConnected()) {
Timestamp timestamp = timestamps.back();
kTimestampsOut(cc).Send(std::move(timestamps), timestamp);
}
processed_buffer_cols_ = next_frame_first_col - 1;
return absl::OkStatus();
}
MEDIAPIPE_REGISTER_NODE(AudioToTensorCalculator);
} // namespace api2
} // namespace mediapipe

View File

@ -0,0 +1,46 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message AudioToTensorCalculatorOptions {
extend mediapipe.CalculatorOptions {
optional AudioToTensorCalculatorOptions ext = 448635064;
}
// The required number of channels the output audio tensor has.
optional int64 num_channels = 1;
// The required number of samples per channel the output audio tensor has.
optional int64 num_samples = 2;
// The number of overlapping samples per channel the output audio tensor has.
optional int64 num_overlapping_samples = 3 [default = 0];
// The target number of samples per second (hertz) of the audio buffers that
// will be converted into tensors.
optional double target_sample_rate = 4;
// Whether to treat the input audio stream as a continous stream or a batch
// of unrelated audio buffers.
optional bool streaming_mode = 5 [default = true];
// Set to false to disable checks for jitter in timestamp values. Useful with
// live audio input.
optional bool check_inconsistent_timestamps = 6 [default = true];
}

View File

@ -0,0 +1,483 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cmath>
#include <memory>
#include <string>
#include <vector>
#include "absl/strings/substitute.h"
#include "audio/dsp/resampler_q.h"
#include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/timestamp.h"
namespace mediapipe {
namespace {
std::unique_ptr<Matrix> CreateTestMatrix(int num_channels, int num_samples,
int timestamp) {
auto matrix = std::make_unique<Matrix>(num_channels, num_samples);
for (int c = 0; c < num_channels; ++c) {
for (int i = 0; i < num_samples; ++i) {
// A float value with the sample, channel, and timestamp separated by a
// few orders of magnitude, for easy parsing by humans.
(*matrix)(c, i) = timestamp / 10000 + i + c / 100.0;
}
}
return matrix;
}
std::unique_ptr<Matrix> ResampleBuffer(const Matrix& input_matrix,
double resampling_factor) {
audio_dsp::QResamplerParams params;
std::vector<float> resampled;
int num_channels = input_matrix.rows();
std::vector<float> input_data(input_matrix.data(),
input_matrix.data() + input_matrix.size());
resampled = audio_dsp::QResampleSignal<float>(
1, resampling_factor, num_channels, params, input_data);
Matrix res = Eigen::Map<Matrix>(resampled.data(), num_channels,
resampled.size() / num_channels);
return std::make_unique<Matrix>(std::move(res));
}
class AudioToTensorCalculatorNonStreamingModeTest : public ::testing::Test {
protected:
void SetUp() override {}
void Run(int num_samples, int num_overlapping_samples,
double resampling_factor, const Matrix& input_matrix) {
double input_sample_rate = 10000;
double target_sample_rate = input_sample_rate * resampling_factor;
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "audio"
input_stream: "sample_rate"
output_stream: "tensors"
output_stream: "timestamps"
node {
calculator: "AudioToTensorCalculator"
input_stream: "AUDIO:audio"
input_stream: "SAMPLE_RATE:sample_rate"
output_stream: "TENSORS:tensors"
output_stream: "TIMESTAMPS:timestamps"
options {
[mediapipe.AudioToTensorCalculatorOptions.ext] {
num_channels: $0
num_samples: $1
num_overlapping_samples: $2
target_sample_rate: $3
streaming_mode: false
}
}
}
)",
/*$0=*/input_matrix.rows(),
/*$1=*/num_samples, /*$2=*/num_overlapping_samples,
/*$3=*/target_sample_rate));
tool::AddVectorSink("tensors", &graph_config, &tensors_packets_);
tool::AddVectorSink("timestamps", &graph_config, &timestamps_packets_);
// Run the graph.
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_ASSERT_OK(graph_.StartRun({}));
// Run with the input matrix multiple times.
for (int i = 0; i < num_iterations_; ++i) {
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"audio",
MakePacket<Matrix>(input_matrix)
.At(Timestamp(i * Timestamp::kTimestampUnitsPerSecond))));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"sample_rate",
MakePacket<double>(input_sample_rate)
.At(Timestamp(i * Timestamp::kTimestampUnitsPerSecond))));
}
MP_ASSERT_OK(graph_.CloseAllInputStreams());
MP_ASSERT_OK(graph_.WaitUntilIdle());
}
void CheckTensorsOutputPackets(const Matrix& expected_matrix,
int sample_offset, int num_tensors_per_input) {
ASSERT_EQ(num_iterations_ * num_tensors_per_input, tensors_packets_.size());
for (int i = 0; i < num_iterations_; ++i) {
for (int j = 0; j < num_tensors_per_input; ++j) {
CheckTensorsOutputPacket(
expected_matrix, tensors_packets_[i * num_tensors_per_input + j],
/*sample_offset*/ sample_offset * j, /*index=*/j);
}
}
}
void CheckTensorsOutputPacket(const Matrix& expected_matrix,
const Packet& packet, int sample_offset,
int index) {
MP_ASSERT_OK(packet.ValidateAsType<std::vector<Tensor>>());
ASSERT_EQ(1, packet.Get<std::vector<Tensor>>().size());
const Tensor& output_tensor = packet.Get<std::vector<Tensor>>()[0];
auto* buffer = output_tensor.GetCpuReadView().buffer<float>();
int num_values = output_tensor.shape().num_elements();
const std::vector<float> output_floats(buffer, buffer + num_values);
for (int i = 0; i < num_values; ++i) {
if (i + sample_offset >= expected_matrix.size()) {
EXPECT_FLOAT_EQ(output_floats[i], 0);
} else {
EXPECT_FLOAT_EQ(output_floats[i],
expected_matrix.coeff((i + sample_offset) % 2,
(i + sample_offset) / 2))
<< "i=" << i << ", sample_offset=" << sample_offset;
}
}
}
void CheckTimestampsOutputPackets(
std::vector<int64> expected_timestamp_values) {
ASSERT_EQ(num_iterations_, timestamps_packets_.size());
for (int i = 0; i < timestamps_packets_.size(); ++i) {
const auto& p = timestamps_packets_[i];
MP_ASSERT_OK(p.ValidateAsType<std::vector<Timestamp>>());
auto output_timestamps = p.Get<std::vector<Timestamp>>();
int64 base_timestamp = i * Timestamp::kTimestampUnitsPerSecond;
std::vector<Timestamp> expected_timestamps;
expected_timestamps.resize(expected_timestamp_values.size());
std::transform(
expected_timestamp_values.begin(), expected_timestamp_values.end(),
expected_timestamps.begin(), [base_timestamp](int64 v) -> Timestamp {
return Timestamp(v + base_timestamp);
});
EXPECT_EQ(expected_timestamps, output_timestamps);
EXPECT_EQ(p.Timestamp(), expected_timestamps.back());
}
}
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); }
private:
CalculatorGraph graph_;
int num_iterations_ = 10;
std::vector<Packet> tensors_packets_;
std::vector<Packet> timestamps_packets_;
};
TEST_F(AudioToTensorCalculatorNonStreamingModeTest,
ConvertToNoOverlappingFp32Tensors) {
auto input_matrix = CreateTestMatrix(2, 8, 0);
Run(/*num_samples=*/4, /*num_overlapping_samples=*/0,
/*resampling_factor=*/1.0f, *input_matrix);
CheckTensorsOutputPackets(*input_matrix, /*sample_offset=*/8,
/*num_tensors_per_input=*/2);
CheckTimestampsOutputPackets({0, 400});
CloseGraph();
}
TEST_F(AudioToTensorCalculatorNonStreamingModeTest,
ConvertToOverlappingFp32Tensors) {
auto input_matrix = CreateTestMatrix(2, 8, 0);
Run(/*num_samples=*/4, /*num_overlapping_samples=*/2,
/*resampling_factor=*/1.0f, *input_matrix);
CheckTensorsOutputPackets(*input_matrix, /*sample_offset=*/4,
/*num_tensors_per_input=*/4);
CheckTimestampsOutputPackets({0, 200, 400, 600});
CloseGraph();
}
TEST_F(AudioToTensorCalculatorNonStreamingModeTest, TensorsWithZeroPadding) {
auto input_matrix = CreateTestMatrix(2, 7, 0);
Run(/*num_samples=*/4, /*num_overlapping_samples=*/2,
/*resampling_factor=*/1.0f, *input_matrix);
CheckTensorsOutputPackets(*input_matrix, /*sample_offset=*/4,
/*num_tensors_per_input=*/3);
CheckTimestampsOutputPackets({0, 200, 400});
CloseGraph();
}
TEST_F(AudioToTensorCalculatorNonStreamingModeTest, Downsampling) {
auto input_matrix = CreateTestMatrix(2, 1024, 0);
Run(/*num_samples=*/256, /*num_overlapping_samples=*/0,
/*resampling_factor=*/0.5f, *input_matrix);
auto expected_matrix =
ResampleBuffer(*input_matrix, /*resampling_factor=*/0.5f);
CheckTensorsOutputPackets(*expected_matrix, /*sample_offset=*/512,
/*num_tensors_per_input=*/3);
CheckTimestampsOutputPackets({0, 51200, 102400});
CloseGraph();
}
TEST_F(AudioToTensorCalculatorNonStreamingModeTest,
DownsamplingWithOverlapping) {
auto input_matrix = CreateTestMatrix(2, 1024, 0);
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
/*resampling_factor=*/0.5f, *input_matrix);
auto expected_matrix =
ResampleBuffer(*input_matrix, /*resampling_factor=*/0.5f);
CheckTensorsOutputPackets(*expected_matrix, /*sample_offset=*/384,
/*num_tensors_per_input=*/3);
CheckTimestampsOutputPackets({0, 38400, 76800});
CloseGraph();
}
TEST_F(AudioToTensorCalculatorNonStreamingModeTest, Upsampling) {
auto input_matrix = CreateTestMatrix(2, 1024, 0);
Run(/*num_samples=*/256, /*num_overlapping_samples=*/0,
/*resampling_factor=*/2.0f, *input_matrix);
auto expected_matrix =
ResampleBuffer(*input_matrix, /*resampling_factor=*/2.0f);
CheckTensorsOutputPackets(*expected_matrix,
/*sample_offset=*/512,
/*num_tensors_per_input=*/9);
CheckTimestampsOutputPackets(
{0, 12800, 25600, 38400, 51200, 64000, 76800, 89600, 102400});
CloseGraph();
}
TEST_F(AudioToTensorCalculatorNonStreamingModeTest, UpsamplingWithOverlapping) {
auto input_matrix = CreateTestMatrix(2, 256, 0);
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
/*resampling_factor=*/2.0f, *input_matrix);
auto expected_matrix =
ResampleBuffer(*input_matrix, /*resampling_factor=*/2.0f);
CheckTensorsOutputPackets(*expected_matrix,
/*sample_offset=*/384,
/*num_tensors_per_input=*/3);
CheckTimestampsOutputPackets({0, 9600, 19200});
CloseGraph();
}
class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
protected:
void SetUp() override { sample_buffer_ = std::make_unique<Matrix>(2, 0); }
void SetInputBufferNumSamplesPerChannel(int num_samples) {
input_buffer_num_samples_ = num_samples;
}
void SetNumIterations(int num_iterations) {
num_iterations_ = num_iterations;
}
int GetExpectedNumOfSamples() {
Matrix* expected_matrix =
resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get();
return expected_matrix->cols();
}
void Run(int num_samples, int num_overlapping_samples,
double resampling_factor) {
double input_sample_rate = 10000;
double target_sample_rate = input_sample_rate * resampling_factor;
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
absl::Substitute(R"(
input_stream: "audio"
input_stream: "sample_rate"
output_stream: "tensors"
node {
calculator: "AudioToTensorCalculator"
input_stream: "AUDIO:audio"
input_stream: "SAMPLE_RATE:sample_rate"
output_stream: "TENSORS:tensors"
options {
[mediapipe.AudioToTensorCalculatorOptions.ext] {
num_channels: 2
num_samples: $0
num_overlapping_samples: $1
target_sample_rate: $2
streaming_mode:true
}
}
}
)",
/*$0=*/num_samples, /*$1=*/num_overlapping_samples,
/*$2=*/target_sample_rate));
tool::AddVectorSink("tensors", &graph_config, &tensors_packets_);
// Run the graph.
MP_ASSERT_OK(graph_.Initialize(graph_config));
MP_ASSERT_OK(graph_.StartRun({}));
for (int i = 0; i < num_iterations_; ++i) {
Timestamp input_timestamp(Timestamp::kTimestampUnitsPerSecond * i);
auto new_data = CreateTestMatrix(2, input_buffer_num_samples_,
input_timestamp.Value());
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"audio", MakePacket<Matrix>(*new_data).At(input_timestamp)));
MP_ASSERT_OK(graph_.AddPacketToInputStream(
"sample_rate",
MakePacket<double>(input_sample_rate).At(input_timestamp)));
sample_buffer_->conservativeResize(
Eigen::NoChange, sample_buffer_->cols() + new_data->cols());
sample_buffer_->rightCols(new_data->cols()).swap(*new_data);
}
MP_ASSERT_OK(graph_.CloseAllInputStreams());
MP_ASSERT_OK(graph_.WaitUntilIdle());
if (resampling_factor != 1) {
resampled_buffer_ = ResampleBuffer(*sample_buffer_, resampling_factor);
}
}
void CheckTensorsOutputPackets(int sample_offset, int num_packets,
int64 timestamp_interval,
bool output_last_at_close) {
ASSERT_EQ(num_packets, tensors_packets_.size());
for (int i = 0; i < num_packets; ++i) {
if (i == num_packets - 1 && output_last_at_close) {
CheckTensorsOutputPacket(sample_offset * i, i, Timestamp::Max());
} else {
CheckTensorsOutputPacket(sample_offset * i, i,
Timestamp(timestamp_interval * i));
}
}
}
void CheckTensorsOutputPacket(int sample_offset, int index,
Timestamp expected_timestamp) {
const Packet& p = tensors_packets_[index];
MP_ASSERT_OK(p.ValidateAsType<std::vector<Tensor>>());
const Tensor& output_tensor = p.Get<std::vector<Tensor>>()[0];
auto buffer = output_tensor.GetCpuReadView().buffer<float>();
int num_values = output_tensor.shape().num_elements();
std::vector<float> output_floats(buffer, buffer + num_values);
Matrix* expected_matrix =
resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get();
for (int i = 0; i < num_values; ++i) {
if (i + sample_offset >= expected_matrix->size()) {
EXPECT_FLOAT_EQ(output_floats[i], 0);
} else {
EXPECT_NEAR(output_floats[i],
expected_matrix->coeff((i + sample_offset) % 2,
(i + sample_offset) / 2),
0.001)
<< "i=" << i << ", sample_offset=" << sample_offset
<< ", packet index=" << index;
}
}
EXPECT_EQ(p.Timestamp(), expected_timestamp);
}
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); }
private:
int input_buffer_num_samples_ = 10;
int num_iterations_ = 10;
CalculatorGraph graph_;
std::vector<Packet> tensors_packets_;
std::unique_ptr<Matrix> sample_buffer_;
std::unique_ptr<Matrix> resampled_buffer_;
};
TEST_F(AudioToTensorCalculatorStreamingModeTest,
OutputNoOverlappingFp32Tensors) {
Run(/*num_samples=*/5, /*num_overlapping_samples=*/0,
/*resampling_factor=*/1.0f);
CheckTensorsOutputPackets(
/*sample_offset=*/10,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 5),
/*timestamp_interval=*/500,
/*output_last_at_close=*/false);
CloseGraph();
}
TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputRemainingInCloseMethod) {
Run(/*num_samples=*/6, /*num_overlapping_samples=*/0,
/*resampling_factor=*/1.0f);
CheckTensorsOutputPackets(
/*sample_offset=*/12,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 6),
/*timestamp_interval=*/600,
/*output_last_at_close=*/true);
CloseGraph();
}
TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputOverlappingFp32Tensors) {
SetInputBufferNumSamplesPerChannel(12);
Run(/*num_samples=*/10, /*num_overlapping_samples=*/2,
/*resampling_factor=*/1.0f);
CheckTensorsOutputPackets(
/*sample_offset=*/16,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 8),
/*timestamp_interval=*/800,
/*output_last_at_close=*/true);
CloseGraph();
}
TEST_F(AudioToTensorCalculatorStreamingModeTest, Downsampling) {
SetInputBufferNumSamplesPerChannel(1000);
Run(/*num_samples=*/256, /*num_overlapping_samples=*/0,
/*resampling_factor=*/0.5f);
CheckTensorsOutputPackets(
/*sample_offset=*/512,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256),
/*timestamp_interval=*/51200,
/*output_last_at_close=*/true);
CloseGraph();
}
TEST_F(AudioToTensorCalculatorStreamingModeTest, DownsamplingWithOverlapping) {
SetInputBufferNumSamplesPerChannel(1024);
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
/*resampling_factor=*/0.5f);
CheckTensorsOutputPackets(
/*sample_offset=*/384,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192),
/*timestamp_interval=*/38400,
/*output_last_at_close=*/true);
CloseGraph();
}
TEST_F(AudioToTensorCalculatorStreamingModeTest, Upsampling) {
SetInputBufferNumSamplesPerChannel(1000);
Run(/*num_samples=*/256, /*num_overlapping_samples=*/0,
/*resampling_factor=*/2.0f);
CheckTensorsOutputPackets(
/*sample_offset=*/512,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256),
/*timestamp_interval=*/12800,
/*output_last_at_close=*/true);
CloseGraph();
}
TEST_F(AudioToTensorCalculatorStreamingModeTest, UpsamplingWithOverlapping) {
SetInputBufferNumSamplesPerChannel(1024);
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
/*resampling_factor=*/2.0f);
CheckTensorsOutputPackets(
/*sample_offset=*/384,
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192),
/*timestamp_interval=*/9600,
/*output_last_at_close=*/true);
CloseGraph();
}
TEST_F(AudioToTensorCalculatorStreamingModeTest,
OnlyOutputInCloseIfNoSufficientSamples) {
SetNumIterations(1);
Run(/*num_samples=*/8, /*num_overlapping_samples=*/0,
/*resampling_factor=*/0.5f);
CheckTensorsOutputPackets(
/*sample_offset=*/0,
/*num_packets=*/1,
/*timestamp_interval=*/0,
/*output_last_at_close=*/true);
CloseGraph();
}
} // namespace
} // namespace mediapipe

View File

@ -19,8 +19,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
#include "mediapipe/framework/api2/packet.h" #include "mediapipe/framework/api2/packet.h"
#include "mediapipe/framework/tool/subgraph_expansion.h" #include "mediapipe/framework/tool/subgraph_expansion.h"
#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/api/op_resolver.h"
@ -43,8 +43,19 @@ class InferenceCalculatorSelectorImpl
!options.has_delegate() || // Use GPU delegate if not specified !options.has_delegate() || // Use GPU delegate if not specified
(options.has_delegate() && options.delegate().has_gpu()); (options.has_delegate() && options.delegate().has_gpu());
if (should_use_gpu) { if (should_use_gpu) {
const auto& api = options.delegate().gpu().api();
using Gpu = ::mediapipe::InferenceCalculatorOptions::Delegate::Gpu;
impls.emplace_back("Metal"); impls.emplace_back("Metal");
const bool prefer_gl_advanced =
options.delegate().gpu().use_advanced_gpu_api() &&
(api == Gpu::ANY || api == Gpu::OPENGL || api == Gpu::OPENCL);
if (prefer_gl_advanced) {
impls.emplace_back("GlAdvanced");
impls.emplace_back("Gl"); impls.emplace_back("Gl");
} else {
impls.emplace_back("Gl");
impls.emplace_back("GlAdvanced");
}
} }
impls.emplace_back("Cpu"); impls.emplace_back("Cpu");
for (const auto& suffix : impls) { for (const auto& suffix : impls) {

View File

@ -134,6 +134,10 @@ struct InferenceCalculatorGl : public InferenceCalculator {
static constexpr char kCalculatorName[] = "InferenceCalculatorGl"; static constexpr char kCalculatorName[] = "InferenceCalculatorGl";
}; };
struct InferenceCalculatorGlAdvanced : public InferenceCalculator {
static constexpr char kCalculatorName[] = "InferenceCalculatorGlAdvanced";
};
struct InferenceCalculatorMetal : public InferenceCalculator { struct InferenceCalculatorMetal : public InferenceCalculator {
static constexpr char kCalculatorName[] = "InferenceCalculatorMetal"; static constexpr char kCalculatorName[] = "InferenceCalculatorMetal";
}; };

View File

@ -75,9 +75,10 @@ const std::vector<Param>& GetParams() {
class InferenceCalculatorTest : public testing::TestWithParam<Param> { class InferenceCalculatorTest : public testing::TestWithParam<Param> {
protected: protected:
void SetDelegateForParam(mediapipe::CalculatorGraphConfig_Node* node) { void SetDelegateForParam(mediapipe::CalculatorGraphConfig_Node* node) {
*node->mutable_options() auto options_map = tool::MutableOptionsMap().Initialize(*node);
->MutableExtension(mediapipe::InferenceCalculatorOptions::ext) auto options = options_map.Get<mediapipe::InferenceCalculatorOptions>();
->mutable_delegate() = GetParam().delegate; *options.mutable_delegate() = GetParam().delegate;
options_map.Set(options);
} }
}; };

View File

@ -20,22 +20,8 @@
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/status/status.h" #include "absl/status/status.h"
#include "mediapipe/calculators/tensor/inference_calculator.h" #include "mediapipe/calculators/tensor/inference_calculator.h"
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/tflite/config.h"
#if MEDIAPIPE_TFLITE_GL_INFERENCE
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/util/tflite/tflite_gpu_runner.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
#if defined(MEDIAPIPE_ANDROID)
#include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/filesystem.h"
#include "mediapipe/util/android/file/base/helpers.h"
#endif // ANDROID
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
@ -50,42 +36,22 @@ class InferenceCalculatorGlImpl
absl::Status Close(CalculatorContext* cc) override; absl::Status Close(CalculatorContext* cc) override;
private: private:
absl::Status ReadGpuCaches();
absl::Status SaveGpuCaches();
absl::Status LoadModel(CalculatorContext* cc); absl::Status LoadModel(CalculatorContext* cc);
absl::Status LoadDelegate(CalculatorContext* cc); absl::Status LoadDelegate(CalculatorContext* cc);
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc); absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is. // TfLite requires us to keep the model alive as long as the interpreter is.
Packet<TfLiteModelPtr> model_packet_; Packet<TfLiteModelPtr> model_packet_;
#if MEDIAPIPE_TFLITE_GL_INFERENCE
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_;
bool allow_precision_loss_ = false; bool allow_precision_loss_ = false;
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api
tflite_gpu_runner_api_;
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::InferenceUsage
tflite_gpu_runner_usage_;
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
TfLiteDelegatePtr delegate_; TfLiteDelegatePtr delegate_;
std::unique_ptr<tflite::Interpreter> interpreter_; std::unique_ptr<tflite::Interpreter> interpreter_;
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
std::vector<Tensor::Shape> output_shapes_; std::vector<Tensor::Shape> output_shapes_;
std::vector<std::unique_ptr<Tensor>> gpu_buffers_in_; std::vector<std::unique_ptr<Tensor>> gpu_buffers_in_;
std::vector<std::unique_ptr<Tensor>> gpu_buffers_out_; std::vector<std::unique_ptr<Tensor>> gpu_buffers_out_;
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
bool use_advanced_gpu_api_ = false;
bool use_gpu_delegate_ = false;
bool use_kernel_caching_ = false;
std::string cached_kernel_filename_;
bool use_serialized_model_ = false;
std::string serialized_model_path_;
}; };
absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) { absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
@ -93,8 +59,7 @@ absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected()) RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
<< "Either model as side packet or model path in options is required."; << "Either model as side packet or model path in options is required.";
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); return mediapipe::GlCalculatorHelper::UpdateContract(cc);
return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) { absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
@ -110,46 +75,12 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
<< "for Gpu"; << "for Gpu";
delegate.MergeFrom(input_side_packet_delegate); delegate.MergeFrom(input_side_packet_delegate);
} }
const bool has_delegate = options.has_delegate() || !kDelegate(cc).IsEmpty();
use_advanced_gpu_api_ = has_delegate && delegate.has_gpu() &&
delegate.gpu().use_advanced_gpu_api();
allow_precision_loss_ = delegate.gpu().allow_precision_loss();
tflite_gpu_runner_api_ = delegate.gpu().api();
tflite_gpu_runner_usage_ = delegate.gpu().usage();
use_kernel_caching_ =
use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path();
use_serialized_model_ = use_advanced_gpu_api_ &&
delegate.gpu().has_serialized_model_dir() &&
delegate.gpu().has_model_token();
use_gpu_delegate_ = !use_advanced_gpu_api_;
if (use_kernel_caching_) {
#ifdef MEDIAPIPE_ANDROID
cached_kernel_filename_ = delegate.gpu().cached_kernel_path() +
mediapipe::File::Basename(options.model_path()) +
".ker";
#endif // MEDIAPIPE_ANDROID
}
if (use_serialized_model_) {
#ifdef MEDIAPIPE_ANDROID
serialized_model_path_ = mediapipe::file::JoinPath(
delegate.gpu().serialized_model_dir(), delegate.gpu().model_token());
#endif // MEDIAPIPE_ANDROID
}
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
// for everything.
if (!use_advanced_gpu_api_) {
MP_RETURN_IF_ERROR(LoadModel(cc)); MP_RETURN_IF_ERROR(LoadModel(cc));
}
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
MP_RETURN_IF_ERROR( return gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { return LoadDelegateAndAllocateTensors(cc);
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc) });
: LoadDelegateAndAllocateTensors(cc);
}));
return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) { absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) {
@ -160,23 +91,6 @@ absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) {
RET_CHECK(!input_tensors.empty()); RET_CHECK(!input_tensors.empty());
auto output_tensors = absl::make_unique<std::vector<Tensor>>(); auto output_tensors = absl::make_unique<std::vector<Tensor>>();
if (use_advanced_gpu_api_) {
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> ::mediapipe::Status {
for (int i = 0; i < input_tensors.size(); ++i) {
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
input_tensors[i].GetOpenGlBufferReadView().name(), i));
}
output_tensors->reserve(output_shapes_.size());
for (int i = 0; i < output_shapes_.size(); ++i) {
output_tensors->emplace_back(Tensor::ElementType::kFloat32,
output_shapes_[i]);
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor(
output_tensors->back().GetOpenGlBufferWriteView().name(), i));
}
return absl::OkStatus();
}));
} else {
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors]() -> ::mediapipe::Status { [this, &input_tensors]() -> ::mediapipe::Status {
// Explicitly copy input. // Explicitly copy input.
@ -190,16 +104,10 @@ absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) {
} }
return absl::OkStatus(); return absl::OkStatus();
})); }));
}
// Run inference. // Run inference.
if (use_advanced_gpu_api_) {
RET_CHECK(tflite_gpu_runner_->Invoke().ok());
} else {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
}
if (use_gpu_delegate_) {
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &output_tensors]() -> ::mediapipe::Status { [this, &output_tensors]() -> ::mediapipe::Status {
output_tensors->reserve(output_shapes_.size()); output_tensors->reserve(output_shapes_.size());
@ -216,149 +124,20 @@ absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) {
} }
return absl::OkStatus(); return absl::OkStatus();
})); }));
}
// Output tensors are already bound if use_advanced_gpu_api_ is true.
kOutTensors(cc).Send(std::move(output_tensors)); kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlImpl::SaveGpuCaches() {
#ifdef MEDIAPIPE_ANDROID
if (use_kernel_caching_) {
// Save kernel file.
auto kernel_cache = absl::make_unique<std::vector<uint8_t>>(
tflite_gpu_runner_->GetSerializedBinaryCache());
std::string cache_str(kernel_cache->begin(), kernel_cache->end());
MP_RETURN_IF_ERROR(
mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
}
if (use_serialized_model_) {
// Save serialized model file.
ASSIGN_OR_RETURN(std::vector<uint8_t> serialized_model_vec,
tflite_gpu_runner_->GetSerializedModel());
absl::string_view serialized_model(
reinterpret_cast<char*>(serialized_model_vec.data()),
serialized_model_vec.size());
MP_RETURN_IF_ERROR(
mediapipe::file::SetContents(serialized_model_path_, serialized_model));
}
#endif // MEDIAPIPE_ANDROID
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) { absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) {
MP_RETURN_IF_ERROR(SaveGpuCaches()); return gpu_helper_.RunInGlContext([this]() -> absl::Status {
if (use_gpu_delegate_) {
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
gpu_buffers_in_.clear(); gpu_buffers_in_.clear();
gpu_buffers_out_.clear(); gpu_buffers_out_.clear();
// Delegate must outlive the interpreter, hence the order is important. // Delegate must outlive the interpreter, hence the order is important.
interpreter_ = nullptr; interpreter_ = nullptr;
delegate_ = nullptr; delegate_ = nullptr;
return absl::OkStatus(); return absl::OkStatus();
})); });
} else {
// Delegate must outlive the interpreter, hence the order is important.
interpreter_ = nullptr;
delegate_ = nullptr;
}
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
#ifdef MEDIAPIPE_ANDROID
if (use_kernel_caching_ && File::Exists(cached_kernel_filename_)) {
// Load pre-compiled kernel file.
std::string cache_str;
MP_RETURN_IF_ERROR(
mediapipe::file::GetContents(cached_kernel_filename_, &cache_str));
std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end());
tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec));
}
if (use_serialized_model_ && File::Exists(serialized_model_path_)) {
// Load serialized model file.
std::string serialized_model_str;
MP_RETURN_IF_ERROR(
file::GetContents(serialized_model_path_, &serialized_model_str));
std::vector<uint8_t> serialized_model_vec(serialized_model_str.begin(),
serialized_model_str.end());
tflite_gpu_runner_->SetSerializedModel(std::move(serialized_model_vec));
}
#endif // MEDIAPIPE_ANDROID
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
// Create runner
tflite::gpu::InferenceOptions options;
options.priority1 = allow_precision_loss_
? tflite::gpu::InferencePriority::MIN_LATENCY
: tflite::gpu::InferencePriority::MAX_PRECISION;
options.priority2 = tflite::gpu::InferencePriority::AUTO;
options.priority3 = tflite::gpu::InferencePriority::AUTO;
switch (tflite_gpu_runner_usage_) {
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
FAST_SINGLE_ANSWER: {
options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER;
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
SUSTAINED_SPEED: {
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::UNSPECIFIED: {
return absl::InternalError("inference usage need to be specified.");
}
}
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
switch (tflite_gpu_runner_api_) {
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: {
// Do not need to force any specific API.
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: {
tflite_gpu_runner_->ForceOpenGL();
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENCL: {
tflite_gpu_runner_->ForceOpenCL();
break;
}
}
if (kSideInOpResolver(cc).IsConnected()) {
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true));
} else {
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true));
}
// Create and bind OpenGL buffers for outputs.
// The buffers are created once and their ids are passed to calculator outputs
output_shapes_.resize(tflite_gpu_runner_->outputs_size());
for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) {
output_shapes_[i] = {tflite_gpu_runner_->GetOutputShapes()[i].b,
tflite_gpu_runner_->GetOutputShapes()[i].h,
tflite_gpu_runner_->GetOutputShapes()[i].w,
tflite_gpu_runner_->GetOutputShapes()[i].c};
}
MP_RETURN_IF_ERROR(ReadGpuCaches());
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
return absl::OkStatus();
} }
absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) { absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
@ -375,12 +154,8 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
} }
RET_CHECK(interpreter_); RET_CHECK(interpreter_);
#if defined(__EMSCRIPTEN__)
interpreter_->SetNumThreads(1);
#else
interpreter_->SetNumThreads( interpreter_->SetNumThreads(
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread()); cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
#endif // __EMSCRIPTEN__
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -0,0 +1,285 @@
// Copyright 2022 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstring>
#include <memory>
#include <string>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "mediapipe/calculators/tensor/inference_calculator.h"
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/util/tflite/tflite_gpu_runner.h"
#if defined(MEDIAPIPE_ANDROID)
#include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/filesystem.h"
#include "mediapipe/util/android/file/base/helpers.h"
#endif // ANDROID
namespace mediapipe {
namespace api2 {
// Runs TFLite GPU delegate API2 directly, bypassing interpreter usage, and
// allows choosing specific API.
//
// To trigger this code path:
// [mediapipe.InferenceCalculatorOptions.ext] {
// delegate {
// gpu {
// use_advanced_gpu_api: true
// api: OPENCL # or OPENGL or ANY
// }
// }
// }
class InferenceCalculatorGlAdvancedImpl
: public NodeImpl<InferenceCalculatorGlAdvanced,
InferenceCalculatorGlAdvancedImpl> {
public:
static absl::Status UpdateContract(CalculatorContract* cc);
absl::Status Open(CalculatorContext* cc) override;
absl::Status Process(CalculatorContext* cc) override;
absl::Status Close(CalculatorContext* cc) override;
private:
absl::Status ReadGpuCaches();
absl::Status SaveGpuCaches();
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
// TfLite requires us to keep the model alive as long as the interpreter is.
Packet<TfLiteModelPtr> model_packet_;
mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_;
bool allow_precision_loss_ = false;
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api
tflite_gpu_runner_api_;
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::InferenceUsage
tflite_gpu_runner_usage_;
std::vector<Tensor::Shape> output_shapes_;
bool use_kernel_caching_ = false;
std::string cached_kernel_filename_;
bool use_serialized_model_ = false;
std::string serialized_model_path_;
};
absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract(
CalculatorContract* cc) {
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
<< "Either model as side packet or model path in options is required.";
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlAdvancedImpl::Open(CalculatorContext* cc) {
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
mediapipe::InferenceCalculatorOptions::Delegate delegate = options.delegate();
if (!kDelegate(cc).IsEmpty()) {
mediapipe::InferenceCalculatorOptions::Delegate input_side_packet_delegate =
kDelegate(cc).Get();
CHECK(input_side_packet_delegate.has_gpu() ||
input_side_packet_delegate.delegate_case() ==
mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET)
<< "inference_calculator_gl_advanced only supports delegate input side "
"packet for Gpu";
delegate.MergeFrom(input_side_packet_delegate);
}
allow_precision_loss_ = delegate.gpu().allow_precision_loss();
tflite_gpu_runner_api_ = delegate.gpu().api();
tflite_gpu_runner_usage_ = delegate.gpu().usage();
use_kernel_caching_ = delegate.gpu().has_cached_kernel_path();
use_serialized_model_ = delegate.gpu().has_serialized_model_dir() &&
delegate.gpu().has_model_token();
if (use_kernel_caching_) {
#ifdef MEDIAPIPE_ANDROID
cached_kernel_filename_ = delegate.gpu().cached_kernel_path() +
mediapipe::File::Basename(options.model_path()) +
".ker";
#endif // MEDIAPIPE_ANDROID
}
if (use_serialized_model_) {
#ifdef MEDIAPIPE_ANDROID
serialized_model_path_ = mediapipe::file::JoinPath(
delegate.gpu().serialized_model_dir(), delegate.gpu().model_token());
#endif // MEDIAPIPE_ANDROID
}
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
return gpu_helper_.RunInGlContext(
[this, &cc]() -> absl::Status { return InitTFLiteGPURunner(cc); });
}
absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) {
if (kInTensors(cc).IsEmpty()) {
return absl::OkStatus();
}
const auto& input_tensors = *kInTensors(cc);
RET_CHECK(!input_tensors.empty());
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors, &output_tensors]() -> absl::Status {
for (int i = 0; i < input_tensors.size(); ++i) {
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
input_tensors[i].GetOpenGlBufferReadView().name(), i));
}
output_tensors->reserve(output_shapes_.size());
for (int i = 0; i < output_shapes_.size(); ++i) {
output_tensors->emplace_back(Tensor::ElementType::kFloat32,
output_shapes_[i]);
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor(
output_tensors->back().GetOpenGlBufferWriteView().name(), i));
}
return absl::OkStatus();
}));
// Run inference.
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Invoke());
kOutTensors(cc).Send(std::move(output_tensors));
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlAdvancedImpl::SaveGpuCaches() {
#ifdef MEDIAPIPE_ANDROID
if (use_kernel_caching_) {
// Save kernel file.
auto kernel_cache = absl::make_unique<std::vector<uint8_t>>(
tflite_gpu_runner_->GetSerializedBinaryCache());
std::string cache_str(kernel_cache->begin(), kernel_cache->end());
MP_RETURN_IF_ERROR(
mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
}
if (use_serialized_model_) {
// Save serialized model file.
ASSIGN_OR_RETURN(std::vector<uint8_t> serialized_model_vec,
tflite_gpu_runner_->GetSerializedModel());
absl::string_view serialized_model(
reinterpret_cast<char*>(serialized_model_vec.data()),
serialized_model_vec.size());
MP_RETURN_IF_ERROR(
mediapipe::file::SetContents(serialized_model_path_, serialized_model));
}
#endif // MEDIAPIPE_ANDROID
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlAdvancedImpl::Close(CalculatorContext* cc) {
MP_RETURN_IF_ERROR(SaveGpuCaches());
return gpu_helper_.RunInGlContext([this]() -> absl::Status {
tflite_gpu_runner_.reset();
return absl::OkStatus();
});
}
absl::Status InferenceCalculatorGlAdvancedImpl::ReadGpuCaches() {
#ifdef MEDIAPIPE_ANDROID
if (use_kernel_caching_ && File::Exists(cached_kernel_filename_)) {
// Load pre-compiled kernel file.
std::string cache_str;
MP_RETURN_IF_ERROR(
mediapipe::file::GetContents(cached_kernel_filename_, &cache_str));
std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end());
tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec));
}
if (use_serialized_model_ && File::Exists(serialized_model_path_)) {
// Load serialized model file.
std::string serialized_model_str;
MP_RETURN_IF_ERROR(
file::GetContents(serialized_model_path_, &serialized_model_str));
std::vector<uint8_t> serialized_model_vec(serialized_model_str.begin(),
serialized_model_str.end());
tflite_gpu_runner_->SetSerializedModel(std::move(serialized_model_vec));
}
#endif // MEDIAPIPE_ANDROID
return absl::OkStatus();
}
absl::Status InferenceCalculatorGlAdvancedImpl::InitTFLiteGPURunner(
CalculatorContext* cc) {
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
const auto& model = *model_packet_.Get();
// Create runner
tflite::gpu::InferenceOptions options;
options.priority1 = allow_precision_loss_
? tflite::gpu::InferencePriority::MIN_LATENCY
: tflite::gpu::InferencePriority::MAX_PRECISION;
options.priority2 = tflite::gpu::InferencePriority::AUTO;
options.priority3 = tflite::gpu::InferencePriority::AUTO;
switch (tflite_gpu_runner_usage_) {
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
FAST_SINGLE_ANSWER: {
options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER;
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
SUSTAINED_SPEED: {
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::UNSPECIFIED: {
return absl::InternalError("inference usage need to be specified.");
}
}
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
switch (tflite_gpu_runner_api_) {
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: {
// Do not need to force any specific API.
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: {
tflite_gpu_runner_->ForceOpenGL();
break;
}
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENCL: {
tflite_gpu_runner_->ForceOpenCL();
break;
}
}
if (kSideInOpResolver(cc).IsConnected()) {
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true));
} else {
tflite::ops::builtin::BuiltinOpResolver op_resolver =
kSideInCustomOpResolver(cc).GetOr(
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
model, op_resolver, /*allow_quant_ops=*/true));
}
// Create and bind OpenGL buffers for outputs.
// The buffers are created once and their ids are passed to calculator outputs
output_shapes_.resize(tflite_gpu_runner_->outputs_size());
for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) {
output_shapes_[i] = {tflite_gpu_runner_->GetOutputShapes()[i].b,
tflite_gpu_runner_->GetOutputShapes()[i].h,
tflite_gpu_runner_->GetOutputShapes()[i].w,
tflite_gpu_runner_->GetOutputShapes()[i].c};
}
MP_RETURN_IF_ERROR(ReadGpuCaches());
return tflite_gpu_runner_->Build();
}
} // namespace api2
} // namespace mediapipe

View File

@ -38,61 +38,13 @@
#endif // defined(__APPLE__) #endif // defined(__APPLE__)
namespace mediapipe { namespace mediapipe {
namespace {
void DoSmokeTest(const std::string& graph_proto) { constexpr int kTensorWidth = 8;
const int width = 8; constexpr int kTensorHeight = 8;
const int height = 8; constexpr int kTensorChannels = 3;
const int channels = 3;
// Prepare input tensor.
auto input_vec = absl::make_unique<std::vector<Tensor>>();
input_vec->emplace_back(Tensor::ElementType::kFloat32,
Tensor::Shape{1, height, width, channels});
{
auto view1 = input_vec->back().GetCpuWriteView();
auto tensor_buffer = view1.buffer<float>();
ASSERT_NE(tensor_buffer, nullptr);
for (int i = 0; i < width * height * channels - 1; i++) {
tensor_buffer[i] = 1;
}
}
// Prepare single calculator graph to and wait for packets. constexpr char kGraphWithModelPathInOption[] = R"(
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
std::vector<Packet> output_packets;
tool::AddVectorSink("tensor_out", &graph_config, &output_packets);
CalculatorGraph graph(graph_config);
MP_ASSERT_OK(graph.StartRun({}));
// Push the tensor into the graph.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"tensor_in", Adopt(input_vec.release()).At(Timestamp(0))));
// Wait until the calculator done processing.
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, output_packets.size());
// Get and process results.
const std::vector<Tensor>& result_vec =
output_packets[0].Get<std::vector<Tensor>>();
ASSERT_EQ(1, result_vec.size());
const Tensor& result = result_vec[0];
auto view = result.GetCpuReadView();
auto result_buffer = view.buffer<float>();
ASSERT_NE(result_buffer, nullptr);
for (int i = 0; i < width * height * channels - 1; i++) {
ASSERT_EQ(3, result_buffer[i]);
}
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
MP_ASSERT_OK(graph.CloseInputStream("tensor_in"));
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Tests a simple add model that adds an input tensor to itself.
TEST(InferenceCalculatorTest, SmokeTest) {
std::string graph_proto = R"(
input_stream: "tensor_in" input_stream: "tensor_in"
node { node {
calculator: "InferenceCalculator" calculator: "InferenceCalculator"
@ -106,18 +58,7 @@ TEST(InferenceCalculatorTest, SmokeTest) {
} }
} }
)"; )";
// Test CPU inference only. constexpr char kGraphWithModelAsInputSidePacket[] = R"(
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
graph_proto, {{"$delegate", "delegate { tflite {} }"}}));
DoSmokeTest(absl::StrReplaceAll(graph_proto,
{{"$delegate", "delegate { xnnpack {} }"}}));
DoSmokeTest(absl::StrReplaceAll(
graph_proto,
{{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}}));
}
TEST(InferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) {
std::string graph_proto = R"(
input_stream: "tensor_in" input_stream: "tensor_in"
node { node {
@ -154,7 +95,84 @@ TEST(InferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) {
} }
} }
)"; )";
DoSmokeTest(graph_proto);
std::vector<Tensor> CreateInputs() {
std::vector<Tensor> input_vec;
// Prepare input tensor.
input_vec.emplace_back(
Tensor::ElementType::kFloat32,
Tensor::Shape{1, kTensorHeight, kTensorWidth, kTensorChannels});
{
auto view = input_vec.back().GetCpuWriteView();
auto num_elements = input_vec.back().shape().num_elements();
auto tensor_buffer = view.buffer<float>();
for (int i = 0; i < num_elements; i++) {
tensor_buffer[i] = 1;
}
} }
return input_vec;
}
void RunGraphThenClose(CalculatorGraph& graph, std::vector<Tensor> input_vec) {
MP_ASSERT_OK(graph.StartRun({}));
// Push the tensor into the graph.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"tensor_in",
MakePacket<std::vector<Tensor>>(std::move(input_vec)).At(Timestamp(0))));
// Wait until the calculator done processing.
MP_ASSERT_OK(graph.WaitUntilIdle());
// Fully close graph at end, otherwise calculator+tensors are destroyed
// after calling WaitUntilDone().
MP_ASSERT_OK(graph.CloseInputStream("tensor_in"));
MP_ASSERT_OK(graph.WaitUntilDone());
}
void DoSmokeTest(const std::string& graph_proto) {
auto input_vec = CreateInputs();
// Prepare single calculator graph to and wait for packets.
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
std::vector<Packet> output_packets;
tool::AddVectorSink("tensor_out", &graph_config, &output_packets);
CalculatorGraph graph(graph_config);
RunGraphThenClose(graph, std::move(input_vec));
ASSERT_EQ(1, output_packets.size());
// Get and process results.
const std::vector<Tensor>& result_vec =
output_packets[0].Get<std::vector<Tensor>>();
ASSERT_EQ(1, result_vec.size());
const Tensor& result = result_vec[0];
auto view = result.GetCpuReadView();
auto result_buffer = view.buffer<float>();
ASSERT_NE(result_buffer, nullptr);
for (int i = 0; i < result.shape().num_elements(); i++) {
ASSERT_EQ(3, result_buffer[i]);
}
}
// Tests a simple add model that adds an input tensor to itself.
TEST(InferenceCalculatorTest, SmokeTest) {
// Test CPU inference only.
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
kGraphWithModelPathInOption, {{"$delegate", "delegate { tflite {} }"}}));
DoSmokeTest(absl::StrReplaceAll(kGraphWithModelPathInOption,
{{"$delegate", "delegate { xnnpack {} }"}}));
DoSmokeTest(absl::StrReplaceAll(
kGraphWithModelPathInOption,
{{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}}));
}
TEST(InferenceCalculatorTest, ModelAsInputSidePacketSmokeTest) {
DoSmokeTest(kGraphWithModelAsInputSidePacket);
}
} // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -16,7 +16,6 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "absl/container/node_hash_map.h"
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h" #include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h"
@ -25,6 +24,7 @@
#include "mediapipe/framework/formats/classification.pb.h" #include "mediapipe/framework/formats/classification.pb.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
#if defined(MEDIAPIPE_MOBILE) #if defined(MEDIAPIPE_MOBILE)
#include "mediapipe/util/android/file/base/file.h" #include "mediapipe/util/android/file/base/file.h"
@ -35,6 +35,17 @@
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
namespace {
void SetClassificationLabel(const LabelMapItem label_map_item,
Classification* classification) {
classification->set_label(label_map_item.name());
if (label_map_item.has_display_name()) {
classification->set_display_name(label_map_item.display_name());
}
}
} // namespace
// Convert result tensors from classification models into MediaPipe // Convert result tensors from classification models into MediaPipe
// classifications. // classifications.
@ -54,7 +65,6 @@ namespace api2 {
// output_stream: "CLASSIFICATIONS:classifications" // output_stream: "CLASSIFICATIONS:classifications"
// options: { // options: {
// [mediapipe.TensorsToClassificationCalculatorOptions.ext] { // [mediapipe.TensorsToClassificationCalculatorOptions.ext] {
// num_classes: 1024
// min_score_threshold: 0.1 // min_score_threshold: 0.1
// label_map_path: "labelmap.txt" // label_map_path: "labelmap.txt"
// } // }
@ -72,22 +82,35 @@ class TensorsToClassificationCalculator : public Node {
absl::Status Close(CalculatorContext* cc) override; absl::Status Close(CalculatorContext* cc) override;
private: private:
::mediapipe::TensorsToClassificationCalculatorOptions options_;
int top_k_ = 0; int top_k_ = 0;
absl::node_hash_map<int, std::string> label_map_; bool sort_by_descending_score_ = false;
proto_ns::Map<int64, LabelMapItem> local_label_map_;
bool label_map_loaded_ = false; bool label_map_loaded_ = false;
bool is_binary_classification_ = false;
float min_score_threshold_ = std::numeric_limits<float>::lowest();
// Set of allowed or ignored class indices.
struct ClassIndexSet {
absl::flat_hash_set<int> values;
bool is_allowlist;
};
// Allowed or ignored class indices based on provided options.
// These are used to filter out the output classification results.
ClassIndexSet class_index_set_;
bool IsClassIndexAllowed(int class_index);
const proto_ns::Map<int64, LabelMapItem>& GetLabelMap(CalculatorContext* cc);
}; };
MEDIAPIPE_REGISTER_NODE(TensorsToClassificationCalculator); MEDIAPIPE_REGISTER_NODE(TensorsToClassificationCalculator);
absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) { absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
options_ = const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
cc->Options<::mediapipe::TensorsToClassificationCalculatorOptions>();
top_k_ = options_.top_k(); top_k_ = options.top_k();
if (options_.has_label_map_path()) { sort_by_descending_score_ = options.sort_by_descending_score();
if (options.has_label_map_path()) {
std::string string_path; std::string string_path;
ASSIGN_OR_RETURN(string_path, ASSIGN_OR_RETURN(string_path,
PathToResourceAsFile(options_.label_map_path())); PathToResourceAsFile(options.label_map_path()));
std::string label_map_string; std::string label_map_string;
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
mediapipe::GetResourceContents(string_path, &label_map_string)); mediapipe::GetResourceContents(string_path, &label_map_string));
@ -96,18 +119,45 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
std::string line; std::string line;
int i = 0; int i = 0;
while (std::getline(stream, line)) { while (std::getline(stream, line)) {
label_map_[i++] = line; LabelMapItem item;
item.set_name(line);
local_label_map_[i++] = item;
} }
label_map_loaded_ = true; label_map_loaded_ = true;
} else if (options_.has_label_map()) { } else if (!options.label_items().empty()) {
for (int i = 0; i < options_.label_map().entries_size(); ++i) { label_map_loaded_ = true;
const auto& entry = options_.label_map().entries(i); } else if (options.has_label_map()) {
RET_CHECK(!label_map_.contains(entry.id())) for (int i = 0; i < options.label_map().entries_size(); ++i) {
const auto& entry = options.label_map().entries(i);
RET_CHECK(!local_label_map_.contains(entry.id()))
<< "Duplicate id found: " << entry.id(); << "Duplicate id found: " << entry.id();
label_map_[entry.id()] = entry.label(); LabelMapItem item;
item.set_name(entry.label());
local_label_map_[entry.id()] = item;
} }
label_map_loaded_ = true; label_map_loaded_ = true;
} }
if (options.has_min_score_threshold()) {
min_score_threshold_ = options.min_score_threshold();
}
is_binary_classification_ = options.binary_classification();
if (is_binary_classification_) {
RET_CHECK(options.allow_classes().empty() &&
options.ignore_classes().empty());
}
if (!options.allow_classes().empty()) {
RET_CHECK(options.ignore_classes().empty());
class_index_set_.is_allowlist = true;
for (int i = 0; i < options.allow_classes_size(); ++i) {
class_index_set_.values.insert(options.allow_classes(i));
}
} else {
class_index_set_.is_allowlist = false;
for (int i = 0; i < options.ignore_classes_size(); ++i) {
class_index_set_.values.insert(options.ignore_classes(i));
}
}
return absl::OkStatus(); return absl::OkStatus();
} }
@ -118,19 +168,19 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
int num_classes = input_tensors[0].shape().num_elements(); int num_classes = input_tensors[0].shape().num_elements();
if (options_.binary_classification()) { if (is_binary_classification_) {
RET_CHECK_EQ(num_classes, 1); RET_CHECK_EQ(num_classes, 1);
// Number of classes for binary classification. // Number of classes for binary classification.
num_classes = 2; num_classes = 2;
} }
if (label_map_loaded_) { if (label_map_loaded_) {
RET_CHECK_EQ(num_classes, label_map_.size()); RET_CHECK_EQ(num_classes, GetLabelMap(cc).size());
} }
auto view = input_tensors[0].GetCpuReadView(); auto view = input_tensors[0].GetCpuReadView();
auto raw_scores = view.buffer<float>(); auto raw_scores = view.buffer<float>();
auto classification_list = absl::make_unique<ClassificationList>(); auto classification_list = absl::make_unique<ClassificationList>();
if (options_.binary_classification()) { if (is_binary_classification_) {
Classification* class_first = classification_list->add_classification(); Classification* class_first = classification_list->add_classification();
Classification* class_second = classification_list->add_classification(); Classification* class_second = classification_list->add_classification();
class_first->set_index(0); class_first->set_index(0);
@ -139,42 +189,49 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
class_second->set_score(1. - raw_scores[0]); class_second->set_score(1. - raw_scores[0]);
if (label_map_loaded_) { if (label_map_loaded_) {
class_first->set_label(label_map_[0]); SetClassificationLabel(GetLabelMap(cc).at(0), class_first);
class_second->set_label(label_map_[1]); SetClassificationLabel(GetLabelMap(cc).at(1), class_second);
} }
} else { } else {
for (int i = 0; i < num_classes; ++i) { for (int i = 0; i < num_classes; ++i) {
if (options_.has_min_score_threshold() && if (!IsClassIndexAllowed(i)) {
raw_scores[i] < options_.min_score_threshold()) { continue;
}
if (raw_scores[i] < min_score_threshold_) {
continue; continue;
} }
Classification* classification = Classification* classification =
classification_list->add_classification(); classification_list->add_classification();
classification->set_index(i); classification->set_index(i);
classification->set_score(raw_scores[i]); classification->set_score(raw_scores[i]);
if (label_map_loaded_) { if (label_map_loaded_) {
classification->set_label(label_map_[i]); SetClassificationLabel(GetLabelMap(cc).at(i), classification);
} }
} }
} }
// Note that partial_sort will raise error when top_k_ >
// classification_list->classification_size().
CHECK_GE(classification_list->classification_size(), top_k_);
auto raw_classification_list = classification_list->mutable_classification(); auto raw_classification_list = classification_list->mutable_classification();
if (top_k_ > 0 && classification_list->classification_size() >= top_k_) { if (top_k_ > 0) {
int desired_size =
std::min(classification_list->classification_size(), top_k_);
std::partial_sort(raw_classification_list->begin(), std::partial_sort(raw_classification_list->begin(),
raw_classification_list->begin() + top_k_, raw_classification_list->begin() + desired_size,
raw_classification_list->end(), raw_classification_list->end(),
[](const Classification a, const Classification b) { [](const Classification a, const Classification b) {
return a.score() > b.score(); return a.score() > b.score();
}); });
if (desired_size >= top_k_) {
// Resizes the underlying list to have only top_k_ classifications. // Resizes the underlying list to have only top_k_ classifications.
raw_classification_list->DeleteSubrange( raw_classification_list->DeleteSubrange(
top_k_, raw_classification_list->size() - top_k_); top_k_, raw_classification_list->size() - top_k_);
} }
} else if (sort_by_descending_score_) {
std::sort(raw_classification_list->begin(), raw_classification_list->end(),
[](const Classification a, const Classification b) {
return a.score() > b.score();
});
}
kOutClassificationList(cc).Send(std::move(classification_list)); kOutClassificationList(cc).Send(std::move(classification_list));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -183,5 +240,24 @@ absl::Status TensorsToClassificationCalculator::Close(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
bool TensorsToClassificationCalculator::IsClassIndexAllowed(int class_index) {
if (class_index_set_.values.empty()) {
return true;
}
if (class_index_set_.is_allowlist) {
return class_index_set_.values.contains(class_index);
} else {
return !class_index_set_.values.contains(class_index);
}
}
const proto_ns::Map<int64, LabelMapItem>&
TensorsToClassificationCalculator::GetLabelMap(CalculatorContext* cc) {
return !local_label_map_.empty()
? local_label_map_
: cc->Options<TensorsToClassificationCalculatorOptions>()
.label_items();
}
} // namespace api2 } // namespace api2
} // namespace mediapipe } // namespace mediapipe

View File

@ -19,6 +19,7 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/util/label_map.proto";
message TensorsToClassificationCalculatorOptions { message TensorsToClassificationCalculatorOptions {
extend .mediapipe.CalculatorOptions { extend .mediapipe.CalculatorOptions {
@ -38,16 +39,37 @@ message TensorsToClassificationCalculatorOptions {
// Number of highest scoring labels to output. If top_k is not positive then // Number of highest scoring labels to output. If top_k is not positive then
// all labels are used. // all labels are used.
optional int32 top_k = 2; optional int32 top_k = 2;
// Whether results should be sorted by descending score. By default, results
// may or may not be sorted: setting this to true guarantees that the returned
// results will be sorted by descending score.
optional bool sort_by_descending_score = 9;
// Path to a label map file for getting the actual name of class ids. // Path to a label map file for getting the actual name of class ids.
optional string label_map_path = 3; optional string label_map_path = 3;
// Label map. (Can be used instead of label_map_path.) // Label map. (Can be used instead of label_map_path.)
// NOTE: "label_map_path", if specified, takes precedence over "label_map". // NOTE: either "label_map_path" or "label_items", if specified, takes
// precedence over "label_map".
// Deprecated: please use `label_items` instead.
optional LabelMap label_map = 5; optional LabelMap label_map = 5;
// Label items. (Can be used instead of label_map_path.)
// NOTE: "label_map_path", if specified, takes precedence over "label_items".
map<int64, LabelMapItem> label_items = 6;
// Whether the input is a single float for binary classification. // Whether the input is a single float for binary classification.
// When true, only a single float is expected in the input tensor and the // When true, only a single float is expected in the input tensor and the
// label map, if provided, is expected to have exactly two labels. // label map, if provided, is expected to have exactly two labels.
// The single score(float) represent the probability of first label, and // The single score(float) represent the probability of first label, and
// 1 - score is the probabilility of the second label. // 1 - score is the probabilility of the second label.
optional bool binary_classification = 4; optional bool binary_classification = 4;
// The ids of classes that should be ignored during decoding the score for
// each classification. If `ignore_classes` is specified, all the other
// classes that are not in the `ignore_class` field will be considered during
// decoding. `ignore_classes` and `allow_classes` are mutually exclusive.
repeated int32 ignore_classes = 7 [packed = true];
// The ids of classes that will be allowed during decoding the score for
// each classification. If `allow_classes` is specified, all the other classes
// that are not in the `allow_classes` field will be completely ignored.
// `ignore_classes` and `allow_classes` are mutually exclusive.
repeated int32 allow_classes = 8 [packed = true];
} }

View File

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include <limits>
#include <vector> #include <vector>
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
@ -206,4 +208,119 @@ TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithTopK) {
} }
} }
TEST_F(TensorsToClassificationCalculatorTest,
CorrectOutputWithSortByDescendingScore) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToClassificationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "CLASSIFICATIONS:classifications"
options {
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
sort_by_descending_score: true
}
}
)pb"));
BuildGraph(&runner, {0, 0.5, 1});
MP_ASSERT_OK(runner.Run());
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
EXPECT_EQ(1, output_packets_.size());
const auto& classification_list =
output_packets_[0].Get<ClassificationList>();
// Verify results are sorted by descending score.
EXPECT_EQ(3, classification_list.classification_size());
float score = std::numeric_limits<float>::max();
for (int i = 0; i < classification_list.classification_size(); ++i) {
EXPECT_LE(classification_list.classification(i).score(), score);
score = classification_list.classification(i).score();
}
}
TEST_F(TensorsToClassificationCalculatorTest,
ClassNameAllowlistWithLabelItems) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToClassificationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "CLASSIFICATIONS:classifications"
options {
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
label_items {
key: 0
value { name: "ClassA" }
}
label_items {
key: 1
value { name: "ClassB" }
}
label_items {
key: 2
value { name: "ClassC" }
}
allow_classes: 1
}
}
)pb"));
BuildGraph(&runner, {0, 0.5, 1});
MP_ASSERT_OK(runner.Run());
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
EXPECT_EQ(1, output_packets_.size());
const auto& classification_list =
output_packets_[0].Get<ClassificationList>();
EXPECT_EQ(1, classification_list.classification_size());
EXPECT_EQ(1, classification_list.classification(0).index());
EXPECT_EQ(0.5, classification_list.classification(0).score());
ASSERT_TRUE(classification_list.classification(0).has_label());
}
TEST_F(TensorsToClassificationCalculatorTest,
ClassNameIgnorelistWithLabelItems) {
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
calculator: "TensorsToClassificationCalculator"
input_stream: "TENSORS:tensors"
output_stream: "CLASSIFICATIONS:classifications"
options {
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
label_items {
key: 0
value { name: "ClassA" }
}
label_items {
key: 1
value { name: "ClassB" }
}
label_items {
key: 2
value { name: "ClassC" }
}
ignore_classes: 1
}
}
)pb"));
BuildGraph(&runner, {0, 0.5, 1});
MP_ASSERT_OK(runner.Run());
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
EXPECT_EQ(1, output_packets_.size());
const auto& classification_list =
output_packets_[0].Get<ClassificationList>();
EXPECT_EQ(2, classification_list.classification_size());
EXPECT_EQ(0, classification_list.classification(0).index());
EXPECT_EQ(0, classification_list.classification(0).score());
ASSERT_TRUE(classification_list.classification(0).has_label());
EXPECT_EQ(2, classification_list.classification(1).index());
EXPECT_EQ(1, classification_list.classification(1).score());
ASSERT_TRUE(classification_list.classification(1).has_label());
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -20,10 +20,8 @@
#include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h" #include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/image_opencv.h"
#include "mediapipe/framework/formats/tensor.h" #include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "mediapipe/gpu/gpu_origin.pb.h" #include "mediapipe/gpu/gpu_origin.pb.h"
@ -37,6 +35,11 @@
#include "mediapipe/gpu/shader_util.h" #include "mediapipe/gpu/shader_util.h"
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
#if !MEDIAPIPE_DISABLE_OPENCV
#include "mediapipe/framework/formats/image_opencv.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#endif // !MEDIAPIPE_DISABLE_OPENCV
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
#include "tensorflow/lite/delegates/gpu/gl/converters/util.h" #include "tensorflow/lite/delegates/gpu/gl/converters/util.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
@ -159,9 +162,10 @@ class TensorsToSegmentationCalculator : public CalculatorBase {
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT; return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
} }
#if !MEDIAPIPE_DISABLE_OPENCV
template <class T> template <class T>
absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat); absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat);
#endif // !MEDIAPIPE_DISABLE_OPENCV
::mediapipe::TensorsToSegmentationCalculatorOptions options_; ::mediapipe::TensorsToSegmentationCalculatorOptions options_;
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
@ -283,7 +287,11 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) {
RET_CHECK_FAIL() << "GPU processing disabled."; RET_CHECK_FAIL() << "GPU processing disabled.";
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else {
#if !MEDIAPIPE_DISABLE_OPENCV
MP_RETURN_IF_ERROR(ProcessCpu(cc)); MP_RETURN_IF_ERROR(ProcessCpu(cc));
#else
RET_CHECK_FAIL() << "OpenCV processing disabled.";
#endif // !MEDIAPIPE_DISABLE_OPENCV
} }
return absl::OkStatus(); return absl::OkStatus();
@ -311,6 +319,7 @@ absl::Status TensorsToSegmentationCalculator::Close(CalculatorContext* cc) {
absl::Status TensorsToSegmentationCalculator::ProcessCpu( absl::Status TensorsToSegmentationCalculator::ProcessCpu(
CalculatorContext* cc) { CalculatorContext* cc) {
#if !MEDIAPIPE_DISABLE_OPENCV
// Get input streams, and dimensions. // Get input streams, and dimensions.
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>(); cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>();
@ -360,10 +369,12 @@ absl::Status TensorsToSegmentationCalculator::ProcessCpu(
cv::resize(small_mask_mat, *output_mat, cv::resize(small_mask_mat, *output_mat,
cv::Size(output_width, output_height)); cv::Size(output_width, output_height));
cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp()); cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp());
#endif // !MEDIAPIPE_DISABLE_OPENCV
return absl::OkStatus(); return absl::OkStatus();
} }
#if !MEDIAPIPE_DISABLE_OPENCV
template <class T> template <class T>
absl::Status TensorsToSegmentationCalculator::ApplyActivation( absl::Status TensorsToSegmentationCalculator::ApplyActivation(
cv::Mat& tensor_mat, cv::Mat* small_mask_mat) { cv::Mat& tensor_mat, cv::Mat* small_mask_mat) {
@ -411,6 +422,7 @@ absl::Status TensorsToSegmentationCalculator::ApplyActivation(
return absl::OkStatus(); return absl::OkStatus();
} }
#endif // !MEDIAPIPE_DISABLE_OPENCV
// Steps: // Steps:
// 1. receive tensor // 1. receive tensor

View File

@ -300,17 +300,26 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
RET_CHECK(options_.batch_size() == 1 || RET_CHECK(options_.batch_size() == 1 ||
options_.recurrent_tag_pair().empty()) options_.recurrent_tag_pair().empty())
<< "To use recurrent_tag_pairs, batch_size must be 1."; << "To use recurrent_tag_pairs, batch_size must be 1.";
// Helper for StrJoin. Prints key (tag) of map<string, string>.
auto TagFormatter =
absl::PairFormatter(absl::StreamFormatter(), "",
[](std::string* out, const std::string& second) {});
for (const auto& tag_pair : options_.recurrent_tag_pair()) { for (const auto& tag_pair : options_.recurrent_tag_pair()) {
const std::vector<std::string> tags = absl::StrSplit(tag_pair, ':'); const std::vector<std::string> tags = absl::StrSplit(tag_pair, ':');
RET_CHECK_EQ(tags.size(), 2) << "recurrent_tag_pair must be a colon " RET_CHECK_EQ(tags.size(), 2) << "recurrent_tag_pair must be a colon "
"separated string with two components: " "separated string with two components: "
<< tag_pair; << tag_pair;
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0])) RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0]))
<< "Can't find tag '" << tags[0] << "' in signature " << "Can't find tag '" << tags[0] << "' in signature "
<< options_.signature_name(); << options_.signature_name() << "; instead found tags "
<< absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1])) RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1]))
<< "Can't find tag '" << tags[1] << "' in signature " << "Can't find tag '" << tags[1] << "' in signature "
<< options_.signature_name(); << options_.signature_name() << " ; instead found tags "
<< absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
recurrent_feed_tags_.insert(tags[0]); recurrent_feed_tags_.insert(tags[0]);
recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0]; recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0];
} }
@ -319,12 +328,14 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
for (const std::string& tag : cc->Inputs().GetTags()) { for (const std::string& tag : cc->Inputs().GetTags()) {
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
<< "Can't find tag '" << tag << "' in signature " << "Can't find tag '" << tag << "' in signature "
<< options_.signature_name(); << options_.signature_name() << "; instead found tags "
<< absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
} }
for (const std::string& tag : cc->Outputs().GetTags()) { for (const std::string& tag : cc->Outputs().GetTags()) {
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag)) RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
<< "Can't find tag '" << tag << "' in signature " << "Can't find tag '" << tag << "' in signature "
<< options_.signature_name(); << options_.signature_name() << "; instead found tags "
<< absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
} }
{ {

View File

@ -38,6 +38,9 @@
namespace mediapipe { namespace mediapipe {
using ::testing::AllOf;
using ::testing::HasSubstr;
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
namespace { namespace {
@ -199,8 +202,8 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed) {
auto run_status = runner_->Run(); auto run_status = runner_->Run();
ASSERT_FALSE(run_status.ok()); ASSERT_FALSE(run_status.ok());
EXPECT_THAT(run_status.ToString(), EXPECT_THAT(run_status.ToString(),
testing::HasSubstr("TensorFlowInferenceCalculator")); HasSubstr("TensorFlowInferenceCalculator"));
EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Tag B")); EXPECT_THAT(run_status.ToString(), HasSubstr("Tag B"));
} }
TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) { TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) {
@ -238,8 +241,8 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) {
auto run_status = runner_->Run(); auto run_status = runner_->Run();
ASSERT_FALSE(run_status.ok()); ASSERT_FALSE(run_status.ok());
EXPECT_THAT(run_status.ToString(), EXPECT_THAT(run_status.ToString(),
testing::HasSubstr("TensorFlowInferenceCalculator")); HasSubstr("TensorFlowInferenceCalculator"));
EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Tag B")); EXPECT_THAT(run_status.ToString(), HasSubstr("Tag B"));
} }
TEST_F(TensorflowInferenceCalculatorTest, BadTag) { TEST_F(TensorflowInferenceCalculatorTest, BadTag) {
@ -255,7 +258,12 @@ TEST_F(TensorflowInferenceCalculatorTest, BadTag) {
runner_ = absl::make_unique<CalculatorRunner>(config); runner_ = absl::make_unique<CalculatorRunner>(config);
AddSessionInputSidePacket(); AddSessionInputSidePacket();
ASSERT_FALSE(runner_->Run().ok()); absl::Status status = runner_->Run();
ASSERT_FALSE(status.ok());
EXPECT_THAT(
status.message(),
AllOf(HasSubstr("Can't find tag 'BAD' in signature"),
HasSubstr("instead found tags A, B, EXPENSIVE, MULTIPLIED")));
} }
TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) { TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) {
@ -740,7 +748,7 @@ TEST_F(TensorflowInferenceCalculatorTest, BatchedInputTooBigBatch) {
ASSERT_FALSE(status.ok()); ASSERT_FALSE(status.ok());
EXPECT_THAT( EXPECT_THAT(
status.message(), status.message(),
::testing::HasSubstr( HasSubstr(
"has more packets than batch capacity. batch_size: 2 packets: 3")); "has more packets than batch capacity. batch_size: 2 packets: 3"));
} }

View File

@ -301,6 +301,8 @@ cc_library(
":detection_label_id_to_text_calculator_cc_proto", ":detection_label_id_to_text_calculator_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/container:node_hash_map",
"//mediapipe/framework/port:core_proto",
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",

View File

@ -16,6 +16,8 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/util/label_map.pb.h" #include "mediapipe/util/label_map.pb.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
@ -55,9 +57,9 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase {
private: private:
// Local label map built from the calculator options' `label_map_path` or // Local label map built from the calculator options' `label_map_path` or
// `label` field. // `label` field.
LabelMap local_label_map_; proto_ns::Map<int64, LabelMapItem> local_label_map_;
bool keep_label_id_; bool keep_label_id_;
const LabelMap& GetLabelMap(CalculatorContext* cc); const proto_ns::Map<int64, LabelMapItem>& GetLabelMap(CalculatorContext* cc);
}; };
REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator); REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator);
@ -72,13 +74,12 @@ absl::Status DetectionLabelIdToTextCalculator::GetContract(
absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) { absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
const auto& options = const auto& options = cc->Options<DetectionLabelIdToTextCalculatorOptions>();
cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>();
if (options.has_label_map_path()) { if (options.has_label_map_path()) {
RET_CHECK(!options.has_label_map() && options.label().empty()) RET_CHECK(options.label_items().empty() && options.label().empty())
<< "Only can set one of the following fields in the CalculatorOptions: " << "Only can set one of the following fields in the CalculatorOptions: "
"label_map_path, label, and label_map."; "label_map_path, label, and label_items.";
std::string string_path; std::string string_path;
ASSIGN_OR_RETURN(string_path, ASSIGN_OR_RETURN(string_path,
PathToResourceAsFile(options.label_map_path())); PathToResourceAsFile(options.label_map_path()));
@ -91,16 +92,16 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
while (std::getline(stream, line)) { while (std::getline(stream, line)) {
LabelMapItem item; LabelMapItem item;
item.set_name(line); item.set_name(line);
(*local_label_map_.mutable_index_to_item())[i++] = item; local_label_map_[i++] = item;
} }
} else if (!options.label().empty()) { } else if (!options.label().empty()) {
RET_CHECK(!options.has_label_map()) RET_CHECK(options.label_items().empty())
<< "Only can set one of the following fields in the CalculatorOptions: " << "Only can set one of the following fields in the CalculatorOptions: "
"label_map_path, label, and label_map."; "label_map_path, label, and label_items.";
for (int i = 0; i < options.label_size(); ++i) { for (int i = 0; i < options.label_size(); ++i) {
LabelMapItem item; LabelMapItem item;
item.set_name(options.label(i)); item.set_name(options.label(i));
(*local_label_map_.mutable_index_to_item())[i] = item; local_label_map_[i] = item;
} }
} }
keep_label_id_ = options.keep_label_id(); keep_label_id_ = options.keep_label_id();
@ -115,9 +116,8 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
Detection& output_detection = output_detections.back(); Detection& output_detection = output_detections.back();
bool has_text_label = false; bool has_text_label = false;
for (const int32 label_id : output_detection.label_id()) { for (const int32 label_id : output_detection.label_id()) {
if (GetLabelMap(cc).index_to_item().find(label_id) != if (GetLabelMap(cc).contains(label_id)) {
GetLabelMap(cc).index_to_item().end()) { auto item = GetLabelMap(cc).at(label_id);
auto item = GetLabelMap(cc).index_to_item().at(label_id);
output_detection.add_label(item.name()); output_detection.add_label(item.name());
if (item.has_display_name()) { if (item.has_display_name()) {
output_detection.add_display_name(item.display_name()); output_detection.add_display_name(item.display_name());
@ -136,13 +136,12 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
return absl::OkStatus(); return absl::OkStatus();
} }
const LabelMap& DetectionLabelIdToTextCalculator::GetLabelMap( const proto_ns::Map<int64, LabelMapItem>&
CalculatorContext* cc) { DetectionLabelIdToTextCalculator::GetLabelMap(CalculatorContext* cc) {
return !local_label_map_.index_to_item().empty() return !local_label_map_.empty()
? local_label_map_ ? local_label_map_
: cc->Options< : cc->Options<DetectionLabelIdToTextCalculatorOptions>()
::mediapipe::DetectionLabelIdToTextCalculatorOptions>() .label_items();
.label_map();
} }
} // namespace mediapipe } // namespace mediapipe

View File

@ -38,6 +38,6 @@ message DetectionLabelIdToTextCalculatorOptions {
// output detections. // output detections.
optional bool keep_label_id = 3; optional bool keep_label_id = 3;
// Label map. // Identifying information for each classification label.
optional LabelMap label_map = 4; map<int64, LabelMapItem> label_items = 4;
} }

View File

@ -426,6 +426,7 @@ cc_test(
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:opencv_core", "//mediapipe/framework/port:opencv_core",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:test_util",
"@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:flag",
], ],
) )
@ -451,6 +452,7 @@ cc_test(
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:opencv_video", "//mediapipe/framework/port:opencv_video",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/tool:test_util",
"@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:flag",
], ],
) )
@ -534,6 +536,7 @@ cc_test(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler", "//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
"//mediapipe/framework/tool:test_util",
"//mediapipe/util/tracking:box_tracker_cc_proto", "//mediapipe/util/tracking:box_tracker_cc_proto",
"//mediapipe/util/tracking:tracking_cc_proto", "//mediapipe/util/tracking:tracking_cc_proto",
"@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:flag",

View File

@ -120,7 +120,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
// back. To get correct image format, we read the first frame from the video // back. To get correct image format, we read the first frame from the video
// and get the number of channels. // and get the number of channels.
cv::Mat frame; cv::Mat frame;
cap_->read(frame); ReadFrame(frame);
if (frame.empty()) { if (frame.empty()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Fail to read any frames from the video file at " << "Fail to read any frames from the video file at "
@ -193,13 +193,13 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
Timestamp timestamp(cap_->get(cv::CAP_PROP_POS_MSEC) * 1000); Timestamp timestamp(cap_->get(cv::CAP_PROP_POS_MSEC) * 1000);
if (format_ == ImageFormat::GRAY8) { if (format_ == ImageFormat::GRAY8) {
cv::Mat frame = formats::MatView(image_frame.get()); cv::Mat frame = formats::MatView(image_frame.get());
cap_->read(frame); ReadFrame(frame);
if (frame.empty()) { if (frame.empty()) {
return tool::StatusStop(); return tool::StatusStop();
} }
} else { } else {
cv::Mat tmp_frame; cv::Mat tmp_frame;
cap_->read(tmp_frame); ReadFrame(tmp_frame);
if (tmp_frame.empty()) { if (tmp_frame.empty()) {
return tool::StatusStop(); return tool::StatusStop();
} }
@ -234,6 +234,14 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
return absl::OkStatus(); return absl::OkStatus();
} }
// Sometimes an empty frame is returned even though there are more frames.
void ReadFrame(cv::Mat& frame) {
cap_->read(frame);
if (frame.empty()) {
cap_->read(frame); // Try again.
}
}
private: private:
std::unique_ptr<cv::VideoCapture> cap_; std::unique_ptr<cv::VideoCapture> cap_;
int width_; int width_;

View File

@ -24,6 +24,7 @@
#include "mediapipe/framework/port/opencv_core_inc.h" #include "mediapipe/framework/port/opencv_core_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/test_util.h"
namespace mediapipe { namespace mediapipe {
@ -32,6 +33,7 @@ namespace {
constexpr char kVideoTag[] = "VIDEO"; constexpr char kVideoTag[] = "VIDEO";
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM"; constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH"; constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH";
constexpr char kTestPackageRoot[] = "mediapipe/calculators/video";
TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) { TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
CalculatorGraphConfig::Node node_config = CalculatorGraphConfig::Node node_config =
@ -41,10 +43,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
output_stream: "VIDEO:video" output_stream: "VIDEO:video"
output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>( runner.MutableSidePackets()->Tag(kInputFilePathTag) =
file::JoinPath("./", MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
"/mediapipe/calculators/video/" "format_MP4_AVC720P_AAC.video"));
"testdata/format_MP4_AVC720P_AAC.video"));
MP_EXPECT_OK(runner.Run()); MP_EXPECT_OK(runner.Run());
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
@ -87,10 +88,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) {
output_stream: "VIDEO:video" output_stream: "VIDEO:video"
output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>( runner.MutableSidePackets()->Tag(kInputFilePathTag) =
file::JoinPath("./", MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
"/mediapipe/calculators/video/" "format_FLV_H264_AAC.video"));
"testdata/format_FLV_H264_AAC.video"));
MP_EXPECT_OK(runner.Run()); MP_EXPECT_OK(runner.Run());
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
@ -131,10 +131,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) {
output_stream: "VIDEO:video" output_stream: "VIDEO:video"
output_stream: "VIDEO_PRESTREAM:video_prestream")pb"); output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
CalculatorRunner runner(node_config); CalculatorRunner runner(node_config);
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>( runner.MutableSidePackets()->Tag(kInputFilePathTag) =
file::JoinPath("./", MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
"/mediapipe/calculators/video/" "format_MKV_VP8_VORBIS.video"));
"testdata/format_MKV_VP8_VORBIS.video"));
MP_EXPECT_OK(runner.Run()); MP_EXPECT_OK(runner.Run());
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1); EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);

View File

@ -28,10 +28,14 @@
#include "mediapipe/framework/port/opencv_video_inc.h" #include "mediapipe/framework/port/opencv_video_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/test_util.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
constexpr char kTestPackageRoot[] = "mediapipe/calculators/video";
// Temporarily disable the test. // Temporarily disable the test.
// TODO: Investigate the “Could not open codec 'libx264'” error with // TODO: Investigate the “Could not open codec 'libx264'” error with
// opencv2. // opencv2.
@ -59,10 +63,9 @@ TEST(OpenCvVideoEncoderCalculatorTest, DISABLED_TestMp4Avc720pVideo) {
} }
)pb"); )pb");
std::map<std::string, Packet> input_side_packets; std::map<std::string, Packet> input_side_packets;
input_side_packets["input_file_path"] = MakePacket<std::string>( input_side_packets["input_file_path"] =
file::JoinPath("./", MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
"/mediapipe/calculators/video/" "format_MP4_AVC720P_AAC.video"));
"testdata/format_MP4_AVC720P_AAC.video"));
const std::string output_file_path = "/tmp/tmp_video.mp4"; const std::string output_file_path = "/tmp/tmp_video.mp4";
DeletingFile deleting_file(output_file_path, true); DeletingFile deleting_file(output_file_path, true);
input_side_packets["output_file_path"] = input_side_packets["output_file_path"] =
@ -120,10 +123,9 @@ TEST(OpenCvVideoEncoderCalculatorTest, TestFlvH264Video) {
} }
)pb"); )pb");
std::map<std::string, Packet> input_side_packets; std::map<std::string, Packet> input_side_packets;
input_side_packets["input_file_path"] = MakePacket<std::string>( input_side_packets["input_file_path"] =
file::JoinPath("./", MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
"/mediapipe/calculators/video/" "format_FLV_H264_AAC.video"));
"testdata/format_FLV_H264_AAC.video"));
const std::string output_file_path = "/tmp/tmp_video.avi"; const std::string output_file_path = "/tmp/tmp_video.avi";
DeletingFile deleting_file(output_file_path, true); DeletingFile deleting_file(output_file_path, true);
input_side_packets["output_file_path"] = input_side_packets["output_file_path"] =
@ -183,10 +185,9 @@ TEST(OpenCvVideoEncoderCalculatorTest, TestMkvVp8Video) {
} }
)pb"); )pb");
std::map<std::string, Packet> input_side_packets; std::map<std::string, Packet> input_side_packets;
input_side_packets["input_file_path"] = MakePacket<std::string>( input_side_packets["input_file_path"] =
file::JoinPath("./", MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
"/mediapipe/calculators/video/" "format_MKV_VP8_VORBIS.video"));
"testdata/format_MKV_VP8_VORBIS.video"));
const std::string output_file_path = "/tmp/tmp_video.mkv"; const std::string output_file_path = "/tmp/tmp_video.mkv";
DeletingFile deleting_file(output_file_path, true); DeletingFile deleting_file(output_file_path, true);
input_side_packets["output_file_path"] = input_side_packets["output_file_path"] =

View File

@ -33,39 +33,16 @@
#include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/tool/test_util.h"
#include "mediapipe/util/tracking/box_tracker.pb.h" #include "mediapipe/util/tracking/box_tracker.pb.h"
#include "mediapipe/util/tracking/tracking.pb.h" #include "mediapipe/util/tracking/tracking.pb.h"
#ifdef __APPLE__
#include <CoreFoundation/CoreFoundation.h>
#endif // defined(__APPLE__)
namespace mediapipe { namespace mediapipe {
namespace { namespace {
using ::testing::FloatNear; using ::testing::FloatNear;
using ::testing::Test; using ::testing::Test;
std::string GetTestDir() { constexpr char kTestPackageRoot[] = "mediapipe/calculators/video";
#ifdef __APPLE__
char path[1024];
CFURLRef bundle_url = CFBundleCopyBundleURL(CFBundleGetMainBundle());
CFURLGetFileSystemRepresentation(
bundle_url, true, reinterpret_cast<UInt8*>(path), sizeof(path));
CFRelease(bundle_url);
return mediapipe::file::JoinPath(path, "testdata");
#elif defined(__ANDROID__)
char path[1024];
getcwd(path, sizeof(path));
return mediapipe::file::JoinPath(path,
"mediapipe/calculators/video/testdata");
#else
return mediapipe::file::JoinPath(
"./",
// This should match the path of the output files
// of the genrule() that generates test model files.
"mediapipe/calculators/video/testdata");
#endif // defined(__APPLE__)
}
bool LoadBinaryTestGraph(const std::string& graph_path, bool LoadBinaryTestGraph(const std::string& graph_path,
CalculatorGraphConfig* config) { CalculatorGraphConfig* config) {
@ -85,7 +62,7 @@ class TrackingGraphTest : public Test {
TrackingGraphTest() {} TrackingGraphTest() {}
void SetUp() override { void SetUp() override {
test_dir_ = GetTestDir(); test_dir_ = mediapipe::GetTestDataDir(kTestPackageRoot);
const auto graph_path = file::JoinPath(test_dir_, "tracker.binarypb"); const auto graph_path = file::JoinPath(test_dir_, "tracker.binarypb");
ASSERT_TRUE(LoadBinaryTestGraph(graph_path, &config_)); ASSERT_TRUE(LoadBinaryTestGraph(graph_path, &config_));

View File

@ -15,10 +15,10 @@
package com.google.mediapipe.examples.facedetection; package com.google.mediapipe.examples.facedetection;
import android.opengl.GLES20; import android.opengl.GLES20;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import com.google.mediapipe.solutioncore.ResultGlRenderer; import com.google.mediapipe.solutioncore.ResultGlRenderer;
import com.google.mediapipe.solutions.facedetection.FaceDetectionResult; import com.google.mediapipe.solutions.facedetection.FaceDetectionResult;
import com.google.mediapipe.solutions.facedetection.FaceKeypoint; import com.google.mediapipe.solutions.facedetection.FaceKeypoint;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.nio.FloatBuffer; import java.nio.FloatBuffer;

View File

@ -23,9 +23,9 @@ import android.graphics.Color;
import android.graphics.Matrix; import android.graphics.Matrix;
import android.graphics.Paint; import android.graphics.Paint;
import androidx.appcompat.widget.AppCompatImageView; import androidx.appcompat.widget.AppCompatImageView;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
import com.google.mediapipe.solutions.facedetection.FaceDetectionResult; import com.google.mediapipe.solutions.facedetection.FaceDetectionResult;
import com.google.mediapipe.solutions.facedetection.FaceKeypoint; import com.google.mediapipe.solutions.facedetection.FaceKeypoint;
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
/** An ImageView implementation for displaying {@link FaceDetectionResult}. */ /** An ImageView implementation for displaying {@link FaceDetectionResult}. */
public class FaceDetectionResultImageView extends AppCompatImageView { public class FaceDetectionResultImageView extends AppCompatImageView {

View File

@ -279,35 +279,45 @@ mediapipe::autoflip::RectF ShiftDetection(
} }
absl::Status UpdateRanges(const SalientRegion& region, absl::Status UpdateRanges(const SalientRegion& region,
const float shift_vertical, const float shift_vertical,
const float shift_horizontal, float* xmin, const float shift_horizontal,
float* xmax, float* ymin, float* ymax) { const float pad_vertical, const float pad_horizontal,
float* xmin, float* xmax, float* ymin, float* ymax) {
if (!region.has_location_normalized()) { if (!region.has_location_normalized()) {
return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
<< "SalientRegion did not have location normalized set."; << "SalientRegion did not have location normalized set.";
} }
auto location = ShiftDetection(region.location_normalized(), shift_vertical, auto location = ShiftDetection(region.location_normalized(), shift_vertical,
shift_horizontal); shift_horizontal);
*xmin = fmin(*xmin, location.x());
*xmax = fmax(*xmax, location.x() + location.width()); const float x_padding = pad_horizontal * location.width();
*ymin = fmin(*ymin, location.y()); const float y_padding = pad_vertical * location.height();
*ymax = fmax(*ymax, location.y() + location.height());
*xmin = fmin(*xmin, location.x() - x_padding);
*xmax = fmax(*xmax, location.x() + location.width() + x_padding);
*ymin = fmin(*ymin, location.y() - y_padding);
*ymax = fmax(*ymax, location.y() + location.height() + y_padding);
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status UpdateRanges(const mediapipe::Detection& detection, absl::Status UpdateRanges(const mediapipe::Detection& detection,
const float shift_vertical, const float shift_vertical,
const float shift_horizontal, float* xmin, const float shift_horizontal,
float* xmax, float* ymin, float* ymax) { const float pad_vertical, const float pad_horizontal,
float* xmin, float* xmax, float* ymin, float* ymax) {
RET_CHECK(detection.location_data().format() == RET_CHECK(detection.location_data().format() ==
mediapipe::LocationData::RELATIVE_BOUNDING_BOX) mediapipe::LocationData::RELATIVE_BOUNDING_BOX)
<< "Face detection input is lacking required relative_bounding_box()"; << "Face detection input is lacking required relative_bounding_box()";
const auto& location = const auto& location =
ShiftDetection(detection.location_data().relative_bounding_box(), ShiftDetection(detection.location_data().relative_bounding_box(),
shift_vertical, shift_horizontal); shift_vertical, shift_horizontal);
*xmin = fmin(*xmin, location.xmin());
*xmax = fmax(*xmax, location.xmin() + location.width()); const float x_padding = pad_horizontal * location.width();
*ymin = fmin(*ymin, location.ymin()); const float y_padding = pad_vertical * location.height();
*ymax = fmax(*ymax, location.ymin() + location.height());
*xmin = fmin(*xmin, location.xmin() - x_padding);
*xmax = fmax(*xmax, location.xmin() + location.width() + x_padding);
*ymin = fmin(*ymin, location.ymin() - y_padding);
*ymax = fmax(*ymax, location.ymin() + location.height() + y_padding);
return absl::OkStatus(); return absl::OkStatus();
} }
@ -818,7 +828,9 @@ absl::Status ContentZoomingCalculator::GetDetectionsBox(
*only_required_found = true; *only_required_found = true;
MP_RETURN_IF_ERROR(UpdateRanges( MP_RETURN_IF_ERROR(UpdateRanges(
region, options_.detection_shift_vertical(), region, options_.detection_shift_vertical(),
options_.detection_shift_horizontal(), xmin, xmax, ymin, ymax)); options_.detection_shift_horizontal(),
options_.extra_vertical_padding(),
options_.extra_horizontal_padding(), xmin, xmax, ymin, ymax));
} }
} }
@ -864,7 +876,9 @@ absl::Status ContentZoomingCalculator::GetDetectionsBox(
*only_required_found = true; *only_required_found = true;
MP_RETURN_IF_ERROR(UpdateRanges( MP_RETURN_IF_ERROR(UpdateRanges(
detection, options_.detection_shift_vertical(), detection, options_.detection_shift_vertical(),
options_.detection_shift_horizontal(), xmin, xmax, ymin, ymax)); options_.detection_shift_horizontal(),
options_.extra_vertical_padding(),
options_.extra_horizontal_padding(), xmin, xmax, ymin, ymax));
} }
} }
} }

View File

@ -19,7 +19,7 @@ package mediapipe.autoflip;
import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto"; import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto";
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
// NextTag: 19 // NextTag: 21
message ContentZoomingCalculatorOptions { message ContentZoomingCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
optional ContentZoomingCalculatorOptions ext = 313091992; optional ContentZoomingCalculatorOptions ext = 313091992;
@ -45,12 +45,17 @@ message ContentZoomingCalculatorOptions {
optional int64 height = 2; optional int64 height = 2;
} }
optional Size target_size = 8; optional Size target_size = 8;
// Amount to shift an input detection as a ratio of the size (positive:
// Amount to shift an input detection, as a ratio of its size (positive:
// down/right, negative: up/left). Use a negative value to increase padding // down/right, negative: up/left). Use a negative value to increase padding
// above/left of an object, positive to increase padding below/right of an // above/left of an object, positive to increase padding below/right of an
// object. // object. (Applies to one side only)
optional float detection_shift_vertical = 11 [default = 0.0]; optional float detection_shift_vertical = 11 [default = 0.0];
optional float detection_shift_horizontal = 12 [default = 0.0]; optional float detection_shift_horizontal = 12 [default = 0.0];
// Amount to pad around an input detection, as a ratio of its size.
// (Applies to both sides)
optional float extra_vertical_padding = 19 [default = 0.0];
optional float extra_horizontal_padding = 20 [default = 0.0];
// Defines the smallest value in degrees the camera is permitted to zoom. // Defines the smallest value in degrees the camera is permitted to zoom.
optional float max_zoom_value_deg = 13 [default = 35]; optional float max_zoom_value_deg = 13 [default = 35];

View File

@ -35,7 +35,9 @@ objc_library(
"CoreMedia", "CoreMedia",
"UIKit", "UIKit",
], ],
visibility = ["//mediapipe:__subpackages__"], visibility = [
"//mediapipe:__subpackages__",
],
deps = [ deps = [
"//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_framework_ios",
"//mediapipe/objc:mediapipe_input_sources_ios", "//mediapipe/objc:mediapipe_input_sources_ios",

View File

@ -115,7 +115,10 @@ mediapipe_proto_library(
name = "packet_test_proto", name = "packet_test_proto",
testonly = 1, testonly = 1,
srcs = ["packet_test.proto"], srcs = ["packet_test.proto"],
visibility = ["//mediapipe/framework:__subpackages__"], visibility = [
":mediapipe_internal",
"//mediapipe/framework:__subpackages__",
],
) )
mediapipe_proto_library( mediapipe_proto_library(
@ -973,6 +976,7 @@ cc_library(
], ],
}), }),
visibility = [ visibility = [
"//fitbit/research/sensing/mobisense:__subpackages__",
"//mediapipe/calculators:__subpackages__", "//mediapipe/calculators:__subpackages__",
"//mediapipe/framework:__subpackages__", "//mediapipe/framework:__subpackages__",
"//mediapipe/framework/port:__pkg__", "//mediapipe/framework/port:__pkg__",
@ -1427,6 +1431,7 @@ cc_test(
"//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler", "//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler",
"//mediapipe/framework/tool:sink", "//mediapipe/framework/tool:sink",
"//mediapipe/framework/tool:status_util", "//mediapipe/framework/tool:status_util",
"//mediapipe/gpu:graph_support",
"@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:fixed_array",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -149,6 +149,7 @@ cc_library(
"//mediapipe/framework:calculator_contract", "//mediapipe/framework:calculator_contract",
"//mediapipe/framework:output_side_packet", "//mediapipe/framework:output_side_packet",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/tool:type_util",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )

View File

@ -27,41 +27,34 @@
#include "mediapipe/framework/calculator_contract.h" #include "mediapipe/framework/calculator_contract.h"
#include "mediapipe/framework/output_side_packet.h" #include "mediapipe/framework/output_side_packet.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/tool/type_util.h"
namespace mediapipe { namespace mediapipe {
namespace api2 { namespace api2 {
// typeid is not constexpr, but a pointer to this is.
template <typename T>
size_t get_type_hash() {
return typeid(T).hash_code();
}
using type_id_fptr = size_t (*)();
// This is a base class for various types of port. It is not meant to be used // This is a base class for various types of port. It is not meant to be used
// directly by node code. // directly by node code.
class PortBase { class PortBase {
public: public:
constexpr PortBase(std::size_t tag_size, const char* tag, constexpr PortBase(std::size_t tag_size, const char* tag, TypeId type_id,
type_id_fptr get_type_id, bool optional, bool multiple) bool optional, bool multiple)
: tag_(tag_size, tag), : tag_(tag_size, tag),
optional_(optional), optional_(optional),
multiple_(multiple), multiple_(multiple),
type_id_getter_(get_type_id) {} type_id_(type_id) {}
bool IsOptional() const { return optional_; } bool IsOptional() const { return optional_; }
bool IsMultiple() const { return multiple_; } bool IsMultiple() const { return multiple_; }
const char* Tag() const { return tag_.data(); } const char* Tag() const { return tag_.data(); }
size_t type_id() const { return type_id_getter_(); } TypeId type_id() const { return type_id_; }
const const_str tag_; const const_str tag_;
const bool optional_; const bool optional_;
const bool multiple_; const bool multiple_;
protected: protected:
type_id_fptr type_id_getter_; TypeId type_id_;
}; };
// These four base classes are used to distinguish between ports of different // These four base classes are used to distinguish between ports of different
@ -340,7 +333,7 @@ class PortCommon : public Base {
template <std::size_t N> template <std::size_t N>
explicit constexpr PortCommon(const char (&tag)[N]) explicit constexpr PortCommon(const char (&tag)[N])
: Base(N, tag, &get_type_hash<ValueT>, IsOptionalV, IsMultipleV) {} : Base(N, tag, kTypeId<ValueT>, IsOptionalV, IsMultipleV) {}
using PayloadT = ActualPayloadT<ValueT>; using PayloadT = ActualPayloadT<ValueT>;
@ -428,7 +421,7 @@ class SideFallbackT : public Base {
template <std::size_t N> template <std::size_t N>
explicit constexpr SideFallbackT(const char (&tag)[N]) explicit constexpr SideFallbackT(const char (&tag)[N])
: Base(N, tag, &get_type_hash<ValueT>, IsOptionalV, IsMultipleV), : Base(N, tag, kTypeId<ValueT>, IsOptionalV, IsMultipleV),
stream_port(tag), stream_port(tag),
side_port(tag) {} side_port(tag) {}

View File

@ -8,7 +8,7 @@ namespace {
TEST(PortTest, IntInput) { TEST(PortTest, IntInput) {
static constexpr auto port = Input<int>("FOO"); static constexpr auto port = Input<int>("FOO");
EXPECT_EQ(port.type_id(), typeid(int).hash_code()); EXPECT_EQ(port.type_id(), kTypeId<int>);
} }
TEST(PortTest, OptionalInput) { TEST(PortTest, OptionalInput) {

View File

@ -59,7 +59,7 @@ class CalculatorContract {
const CalculatorOptions& Options() const { return node_config_->options(); } const CalculatorOptions& Options() const { return node_config_->options(); }
// Returns the name given to this node. // Returns the name given to this node.
const std::string& GetNodeName() { return node_name_; } const std::string& GetNodeName() const { return node_name_; }
// Returns the options given to this calculator. Template argument T must // Returns the options given to this calculator. Template argument T must
// be the type of the protobuf extension message or the protobuf::Any // be the type of the protobuf extension message or the protobuf::Any

View File

@ -120,10 +120,10 @@ CalculatorGraph::CalculatorGraph()
counter_factory_ = absl::make_unique<BasicCounterFactory>(); counter_factory_ = absl::make_unique<BasicCounterFactory>();
} }
CalculatorGraph::CalculatorGraph(const CalculatorGraphConfig& config) CalculatorGraph::CalculatorGraph(CalculatorGraphConfig config)
: CalculatorGraph() { : CalculatorGraph() {
counter_factory_ = absl::make_unique<BasicCounterFactory>(); counter_factory_ = absl::make_unique<BasicCounterFactory>();
MEDIAPIPE_CHECK_OK(Initialize(config)); MEDIAPIPE_CHECK_OK(Initialize(std::move(config)));
} }
// Defining the destructor here lets us use incomplete types in the header; // Defining the destructor here lets us use incomplete types in the header;
@ -429,18 +429,17 @@ absl::Status CalculatorGraph::Initialize(
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status CalculatorGraph::Initialize( absl::Status CalculatorGraph::Initialize(CalculatorGraphConfig input_config) {
const CalculatorGraphConfig& input_config) { return Initialize(std::move(input_config), {});
return Initialize(input_config, {});
} }
absl::Status CalculatorGraph::Initialize( absl::Status CalculatorGraph::Initialize(
const CalculatorGraphConfig& input_config, CalculatorGraphConfig input_config,
const std::map<std::string, Packet>& side_packets) { const std::map<std::string, Packet>& side_packets) {
auto validated_graph = absl::make_unique<ValidatedGraphConfig>(); auto validated_graph = absl::make_unique<ValidatedGraphConfig>();
MP_RETURN_IF_ERROR(validated_graph->Initialize( MP_RETURN_IF_ERROR(validated_graph->Initialize(
input_config, /*graph_registry=*/nullptr, /*graph_options=*/nullptr, std::move(input_config), /*graph_registry=*/nullptr,
&service_manager_)); /*graph_options=*/nullptr, &service_manager_));
return Initialize(std::move(validated_graph), side_packets); return Initialize(std::move(validated_graph), side_packets);
} }
@ -675,6 +674,7 @@ absl::Status CalculatorGraph::PrepareForRun(
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(PrepareServices()); MP_RETURN_IF_ERROR(PrepareServices());
#if !MEDIAPIPE_DISABLE_GPU #if !MEDIAPIPE_DISABLE_GPU
// TODO: should we do this on each run, or only once?
MP_RETURN_IF_ERROR(PrepareGpu()); MP_RETURN_IF_ERROR(PrepareGpu());
additional_side_packets = MaybeCreateLegacyGpuSidePacket(legacy_sp); additional_side_packets = MaybeCreateLegacyGpuSidePacket(legacy_sp);
#endif // !MEDIAPIPE_DISABLE_GPU #endif // !MEDIAPIPE_DISABLE_GPU
@ -1251,7 +1251,9 @@ void CalculatorGraph::Resume() { scheduler_.Resume(); }
absl::Status CalculatorGraph::SetExecutorInternal( absl::Status CalculatorGraph::SetExecutorInternal(
const std::string& name, std::shared_ptr<Executor> executor) { const std::string& name, std::shared_ptr<Executor> executor) {
if (!executors_.emplace(name, executor).second) { auto [it, inserted] = executors_.emplace(name, executor);
if (!inserted) {
if (it->second == executor) return absl::OkStatus();
return mediapipe::AlreadyExistsErrorBuilder(MEDIAPIPE_LOC) return mediapipe::AlreadyExistsErrorBuilder(MEDIAPIPE_LOC)
<< "SetExecutor must be called only once for the executor \"" << name << "SetExecutor must be called only once for the executor \"" << name
<< "\""; << "\"";

View File

@ -119,17 +119,17 @@ class CalculatorGraph {
// Initializes the graph from its proto description (using Initialize()) // Initializes the graph from its proto description (using Initialize())
// and crashes if something goes wrong. // and crashes if something goes wrong.
explicit CalculatorGraph(const CalculatorGraphConfig& config); explicit CalculatorGraph(CalculatorGraphConfig config);
virtual ~CalculatorGraph(); virtual ~CalculatorGraph();
// Initializes the graph from a its proto description. // Initializes the graph from a its proto description.
// side_packets that are provided at this stage are common across all Run() // side_packets that are provided at this stage are common across all Run()
// invocations and could be used to execute PacketGenerators immediately. // invocations and could be used to execute PacketGenerators immediately.
absl::Status Initialize(const CalculatorGraphConfig& config, absl::Status Initialize(CalculatorGraphConfig config,
const std::map<std::string, Packet>& side_packets); const std::map<std::string, Packet>& side_packets);
// Convenience version which does not take side packets. // Convenience version which does not take side packets.
absl::Status Initialize(const CalculatorGraphConfig& config); absl::Status Initialize(CalculatorGraphConfig config);
// Initializes the CalculatorGraph from the specified graph and subgraph // Initializes the CalculatorGraph from the specified graph and subgraph
// configs. Template graph and subgraph configs can be specified through // configs. Template graph and subgraph configs can be specified through
@ -272,7 +272,6 @@ class CalculatorGraph {
absl::Status CloseInputStream(const std::string& stream_name); absl::Status CloseInputStream(const std::string& stream_name);
// Closes all the graph input streams. // Closes all the graph input streams.
// TODO: deprecate this function in favor of CloseAllPacketSources.
absl::Status CloseAllInputStreams(); absl::Status CloseAllInputStreams();
// Closes all the graph input streams and source calculator nodes. // Closes all the graph input streams and source calculator nodes.

View File

@ -60,6 +60,7 @@
#include "mediapipe/framework/tool/sink.h" #include "mediapipe/framework/tool/sink.h"
#include "mediapipe/framework/tool/status_util.h" #include "mediapipe/framework/tool/status_util.h"
#include "mediapipe/framework/type_map.h" #include "mediapipe/framework/type_map.h"
#include "mediapipe/gpu/graph_support.h"
namespace mediapipe { namespace mediapipe {
@ -2059,6 +2060,26 @@ TEST(CalculatorGraph, HandlersRun) {
input_side_packets.at("unavailable_input_counter2"))); input_side_packets.at("unavailable_input_counter2")));
} }
TEST(CalculatorGraph, CalculatorGraphConfigCopyElision) {
CalculatorGraph graph;
CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
input_stream: 'in'
node {
calculator: 'PassThroughCalculator'
input_stream: 'in'
output_stream: 'out'
}
)pb");
// config is consumed and never copied, which avoid copying data.
MP_ASSERT_OK(graph.Initialize(std::move(config)));
MP_EXPECT_OK(graph.StartRun({}));
MP_EXPECT_OK(
graph.AddPacketToInputStream("in", MakePacket<int>(1).At(Timestamp(1))));
MP_EXPECT_OK(graph.CloseInputStream("in"));
MP_EXPECT_OK(graph.WaitUntilDone());
}
// Test that calling SetOffset() in Calculator::Process() results in the // Test that calling SetOffset() in Calculator::Process() results in the
// absl::StatusCode::kFailedPrecondition error. // absl::StatusCode::kFailedPrecondition error.
TEST(CalculatorGraph, SetOffsetInProcess) { TEST(CalculatorGraph, SetOffsetInProcess) {

View File

@ -11,10 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
//
// Forked from mediapipe/framework/calculator_profile.proto.
// The forked proto must remain identical to the original proto and should be
// ONLY used by mediapipe open source project.
syntax = "proto2"; syntax = "proto2";
@ -24,6 +20,7 @@ import "mediapipe/framework/calculator.proto";
option java_package = "com.google.mediapipe.proto"; option java_package = "com.google.mediapipe.proto";
option java_outer_classname = "CalculatorProfileProto"; option java_outer_classname = "CalculatorProfileProto";
option objc_class_prefix = "MediaPipe";
// Stores the profiling information. // Stores the profiling information.
// //

View File

@ -88,7 +88,10 @@ cc_library(
testonly = True, testonly = True,
hdrs = ["message_matchers.h"], hdrs = ["message_matchers.h"],
# Use this library through "mediapipe/framework/port:gtest_main". # Use this library through "mediapipe/framework/port:gtest_main".
visibility = ["//mediapipe/framework/port:__pkg__"], visibility = [
"//mediapipe/framework/port:__pkg__",
"//third_party/visionai/algorithms/tracking:__pkg__",
],
deps = [ deps = [
"//mediapipe/framework/port:core_proto", "//mediapipe/framework/port:core_proto",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",
@ -137,7 +140,6 @@ cc_library(
hdrs = ["image_resizer.h"], hdrs = ["image_resizer.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework/port:integral_types",
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
], ],
) )

View File

@ -15,7 +15,6 @@
#ifndef MEDIAPIPE_DEPS_IMAGE_RESIZER_H_ #ifndef MEDIAPIPE_DEPS_IMAGE_RESIZER_H_
#define MEDIAPIPE_DEPS_IMAGE_RESIZER_H_ #define MEDIAPIPE_DEPS_IMAGE_RESIZER_H_
#include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h" #include "mediapipe/framework/port/opencv_imgproc_inc.h"
namespace mediapipe { namespace mediapipe {

View File

@ -140,9 +140,22 @@ _encode_binary_proto = rule(
) )
def encode_binary_proto(name, input, message_type, deps, **kwargs): def encode_binary_proto(name, input, message_type, deps, **kwargs):
if type(input) == type("string"):
input_label = input
textproto_srcs = [input]
elif type(input) == type(dict()):
# We cannot accept a select, as macros are unable to manipulate selects.
input_label = select(input)
srcs_dict = dict()
for k, v in input.items():
srcs_dict[k] = [v]
textproto_srcs = select(srcs_dict)
else:
fail("input should be a string or a dict, got %s" % input)
_encode_binary_proto( _encode_binary_proto(
name = name, name = name,
input = input, input = input_label,
message_type = message_type, message_type = message_type,
deps = deps, deps = deps,
**kwargs **kwargs

View File

@ -448,6 +448,7 @@ cc_library(
srcs = srcs =
[ [
"tensor.cc", "tensor.cc",
"tensor_ahwb.cc",
], ],
hdrs = ["tensor.h"], hdrs = ["tensor.h"],
copts = select({ copts = select({
@ -463,6 +464,9 @@ cc_library(
"-framework MetalKit", "-framework MetalKit",
], ],
"//conditions:default": [], "//conditions:default": [],
"//mediapipe:android": [
"-landroid",
],
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [

View File

@ -19,7 +19,9 @@ package mediapipe;
// Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of // Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of
// the joint and its visibility. // the joint and its visibility.
message Joint { message Joint {
// Joint rotation in 6D contineous representation. // Joint rotation in 6D contineous representation ordered as
// [a1, b1, a2, b2, a3, b3].
//
// Such representation is more sutable for NN model training and can be // Such representation is more sutable for NN model training and can be
// converted to quaternions and Euler angles if needed. Details can be found // converted to quaternions and Euler angles if needed. Details can be found
// in https://arxiv.org/abs/1812.07035. // in https://arxiv.org/abs/1812.07035.

View File

@ -20,6 +20,9 @@
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#include "mediapipe/gpu/gl_base.h"
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
#include <mach/mach_init.h> #include <mach/mach_init.h>
@ -319,28 +322,41 @@ void Tensor::AllocateOpenGlTexture2d() const {
Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const { Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const {
LOG_IF(FATAL, valid_ == kValidNone) LOG_IF(FATAL, valid_ == kValidNone)
<< "Tensor must be written prior to read from."; << "Tensor must be written prior to read from.";
LOG_IF(FATAL, !(valid_ & (kValidCpu | kValidOpenGlBuffer))) LOG_IF(FATAL, !(valid_ & (kValidCpu |
<< "Tensor conversion between different GPU resources is not supported " #ifdef MEDIAPIPE_TENSOR_USE_AHWB
"yet."; kValidAHardwareBuffer |
#endif // MEDIAPIPE_TENSOR_USE_AHWB
kValidOpenGlBuffer)))
<< "Tensor conversion between different GPU resources is not supported.";
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_)); auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
AllocateOpenGlBuffer(); AllocateOpenGlBuffer();
if (!(valid_ & kValidOpenGlBuffer)) { if (!(valid_ & kValidOpenGlBuffer)) {
// If the call succeds then AHWB -> SSBO are synchronized so any usage of
// the SSBO is correct after this call.
if (!InsertAhwbToSsboFence()) {
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
void* ptr = void* ptr =
glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(), glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(),
GL_MAP_INVALIDATE_BUFFER_BIT | GL_MAP_WRITE_BIT); GL_MAP_INVALIDATE_BUFFER_BIT | GL_MAP_WRITE_BIT);
std::memcpy(ptr, cpu_buffer_, bytes()); std::memcpy(ptr, cpu_buffer_, bytes());
glUnmapBuffer(GL_SHADER_STORAGE_BUFFER); glUnmapBuffer(GL_SHADER_STORAGE_BUFFER);
}
valid_ |= kValidOpenGlBuffer; valid_ |= kValidOpenGlBuffer;
} }
return {opengl_buffer_, std::move(lock)}; return {opengl_buffer_, std::move(lock),
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
&ssbo_read_
#else
nullptr
#endif // MEDIAPIPE_TENSOR_USE_AHWB
};
} }
Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView() const { Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView() const {
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_)); auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
AllocateOpenGlBuffer(); AllocateOpenGlBuffer();
valid_ = kValidOpenGlBuffer; valid_ = kValidOpenGlBuffer;
return {opengl_buffer_, std::move(lock)}; return {opengl_buffer_, std::move(lock), nullptr};
} }
void Tensor::AllocateOpenGlBuffer() const { void Tensor::AllocateOpenGlBuffer() const {
@ -349,8 +365,11 @@ void Tensor::AllocateOpenGlBuffer() const {
LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread."; LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread.";
glGenBuffers(1, &opengl_buffer_); glGenBuffers(1, &opengl_buffer_);
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_); glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
if (!AllocateAhwbMapToSsbo()) {
glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY); glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY);
} }
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
}
} }
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
@ -377,6 +396,8 @@ void Tensor::Move(Tensor* src) {
src->metal_buffer_ = nil; src->metal_buffer_ = nil;
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
MoveAhwbStuff(src);
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
gl_context_ = std::move(src->gl_context_); gl_context_ = std::move(src->gl_context_);
frame_buffer_ = src->frame_buffer_; frame_buffer_ = src->frame_buffer_;
@ -395,27 +416,31 @@ void Tensor::Move(Tensor* src) {
Tensor::Tensor(ElementType element_type, const Shape& shape) Tensor::Tensor(ElementType element_type, const Shape& shape)
: element_type_(element_type), shape_(shape) {} : element_type_(element_type), shape_(shape) {}
void Tensor::Invalidate() {
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
GLuint cleanup_gl_tex = GL_INVALID_INDEX;
GLuint cleanup_gl_fb = GL_INVALID_INDEX;
GLuint cleanup_gl_buf = GL_INVALID_INDEX;
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
{
absl::MutexLock lock(&view_mutex_);
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
void Tensor::Invalidate() {
absl::MutexLock lock(&view_mutex_);
// If memory is allocated and not owned by the metal buffer. // If memory is allocated and not owned by the metal buffer.
// TODO: Re-design cpu buffer memory management. // TODO: Re-design cpu buffer memory management.
if (cpu_buffer_ && !metal_buffer_) { if (cpu_buffer_ && !metal_buffer_) {
DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes())); DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes()));
} }
metal_buffer_ = nil; metal_buffer_ = nil;
#else
if (cpu_buffer_) {
free(cpu_buffer_);
}
#endif // MEDIAPIPE_METAL_ENABLED
cpu_buffer_ = nullptr; cpu_buffer_ = nullptr;
}
#else
void Tensor::Invalidate() {
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
GLuint cleanup_gl_tex = GL_INVALID_INDEX;
GLuint cleanup_gl_fb = GL_INVALID_INDEX;
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
GLuint cleanup_gl_buf = GL_INVALID_INDEX;
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
{
absl::MutexLock lock(&view_mutex_);
ReleaseAhwbStuff();
// Don't need to wait for the resource to be deleted bacause if will be // Don't need to wait for the resource to be deleted bacause if will be
// released on last reference deletion inside the OpenGL driver. // released on last reference deletion inside the OpenGL driver.
@ -429,28 +454,44 @@ void Tensor::Invalidate() {
} }
// Do not hold the view mutex while invoking GlContext::RunWithoutWaiting, // Do not hold the view mutex while invoking GlContext::RunWithoutWaiting,
// since that method may acquire the context's own lock. // since that method may acquire the context's own lock.
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX ||
cleanup_gl_buf != GL_INVALID_INDEX)
gl_context_->RunWithoutWaiting([cleanup_gl_tex, cleanup_gl_fb
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
, if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX ||
cleanup_gl_buf cleanup_gl_buf != GL_INVALID_INDEX) {
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31 gl_context_->RunWithoutWaiting(
]() { [cleanup_gl_tex, cleanup_gl_fb, cleanup_gl_buf]() {
glDeleteTextures(1, &cleanup_gl_tex); glDeleteTextures(1, &cleanup_gl_tex);
glDeleteFramebuffers(1, &cleanup_gl_fb); glDeleteFramebuffers(1, &cleanup_gl_fb);
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
glDeleteBuffers(1, &cleanup_gl_buf); glDeleteBuffers(1, &cleanup_gl_buf);
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
}); });
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
} }
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX) {
gl_context_->RunWithoutWaiting([cleanup_gl_tex, cleanup_gl_fb]() {
glDeleteTextures(1, &cleanup_gl_tex);
glDeleteFramebuffers(1, &cleanup_gl_fb);
});
}
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
if (cpu_buffer_) {
free(cpu_buffer_);
}
cpu_buffer_ = nullptr;
}
#endif // MEDIAPIPE_METAL_ENABLED
Tensor::CpuReadView Tensor::GetCpuReadView() const { Tensor::CpuReadView Tensor::GetCpuReadView() const {
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_); auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
LOG_IF(FATAL, valid_ == kValidNone) LOG_IF(FATAL, valid_ == kValidNone)
<< "Tensor must be written prior to read from."; << "Tensor must be written prior to read from.";
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
void* ptr = MapAhwbToCpuRead();
if (ptr) {
valid_ |= kValidCpu;
return {ptr, ahwb_, nullptr, std::move(lock)};
}
#endif // MEDIAPIPE_TENSOR_USE_AHWB
AllocateCpuBuffer(); AllocateCpuBuffer();
if (!(valid_ & kValidCpu)) { if (!(valid_ & kValidCpu)) {
// GPU-to-CPU synchronization and read-back. // GPU-to-CPU synchronization and read-back.
@ -512,18 +553,33 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
valid_ |= kValidCpu; valid_ |= kValidCpu;
} }
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
return {cpu_buffer_, nullptr, nullptr, std::move(lock)};
#else
return {cpu_buffer_, std::move(lock)}; return {cpu_buffer_, std::move(lock)};
#endif // MEDIAPIPE_TENSOR_USE_AHWB
} }
Tensor::CpuWriteView Tensor::GetCpuWriteView() const { Tensor::CpuWriteView Tensor::GetCpuWriteView() const {
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_); auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
AllocateCpuBuffer(); AllocateCpuBuffer();
valid_ = kValidCpu; valid_ = kValidCpu;
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
void* ptr = MapAhwbToCpuWrite();
if (ptr) {
return {ptr, ahwb_, &fence_fd_, std::move(lock)};
}
return {cpu_buffer_, nullptr, nullptr, std::move(lock)};
#else
return {cpu_buffer_, std::move(lock)}; return {cpu_buffer_, std::move(lock)};
#endif // MEDIAPIPE_TENSOR_USE_AHWB
} }
void Tensor::AllocateCpuBuffer() const { void Tensor::AllocateCpuBuffer() const {
if (!cpu_buffer_) { if (!cpu_buffer_) {
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
if (AllocateAHardwareBuffer()) return;
#endif // MEDIAPIPE_TENSOR_USE_AHWB
#if MEDIAPIPE_METAL_ENABLED #if MEDIAPIPE_METAL_ENABLED
cpu_buffer_ = AllocateVirtualMemory(bytes()); cpu_buffer_ = AllocateVirtualMemory(bytes());
#else #else
@ -532,4 +588,10 @@ void Tensor::AllocateCpuBuffer() const {
} }
} }
void Tensor::SetPreferredStorageType(StorageType type) {
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
use_ahwb_ = type == StorageType::kAhwb;
#endif // MEDIAPIPE_TENSOR_USE_AHWB
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -30,6 +30,16 @@
#import <Metal/Metal.h> #import <Metal/Metal.h>
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
#if __ANDROID_API__ < 26
#error MEDIAPIPE_TENSOR_USE_AHWB requires NDK version 26 or higher to be specified.
#endif // __ANDROID_API__ < 26
#include <android/hardware_buffer.h>
#include "third_party/GL/gl/include/EGL/egl.h"
#include "third_party/GL/gl/include/EGL/eglext.h"
#endif // MEDIAPIPE_TENSOR_USE_AHWB
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
#include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_base.h"
#include "mediapipe/gpu/gl_context.h" #include "mediapipe/gpu/gl_context.h"
@ -108,14 +118,37 @@ class Tensor {
return static_cast<typename std::tuple_element< return static_cast<typename std::tuple_element<
std::is_const<T>::value, std::tuple<P*, const P*> >::type>(buffer_); std::is_const<T>::value, std::tuple<P*, const P*> >::type>(buffer_);
} }
CpuView(CpuView&& src) : View(std::move(src)), buffer_(src.buffer_) { CpuView(CpuView&& src) : View(std::move(src)) {
src.buffer_ = nullptr; buffer_ = std::exchange(src.buffer_, nullptr);
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
ahwb_ = std::exchange(src.ahwb_, nullptr);
fence_fd_ = std::exchange(src.fence_fd_, nullptr);
#endif // MEDIAPIPE_TENSOR_USE_AHWB
} }
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
~CpuView() {
if (ahwb_) {
auto error = AHardwareBuffer_unlock(ahwb_, fence_fd_);
CHECK(error == 0) << "AHardwareBuffer_unlock " << error;
}
}
#endif // MEDIAPIPE_TENSOR_USE_AHWB
protected: protected:
friend class Tensor; friend class Tensor;
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
CpuView(T* buffer, AHardwareBuffer* ahwb, int* fence_fd,
std::unique_ptr<absl::MutexLock>&& lock)
: View(std::move(lock)),
buffer_(buffer),
fence_fd_(fence_fd),
ahwb_(ahwb) {}
AHardwareBuffer* ahwb_;
int* fence_fd_;
#else
CpuView(T* buffer, std::unique_ptr<absl::MutexLock>&& lock) CpuView(T* buffer, std::unique_ptr<absl::MutexLock>&& lock)
: View(std::move(lock)), buffer_(buffer) {} : View(std::move(lock)), buffer_(buffer) {}
#endif // MEDIAPIPE_TENSOR_USE_AHWB
T* buffer_; T* buffer_;
}; };
using CpuReadView = CpuView<const void>; using CpuReadView = CpuView<const void>;
@ -150,6 +183,60 @@ class Tensor {
MtlBufferView GetMtlBufferWriteView(id<MTLDevice> device) const; MtlBufferView GetMtlBufferWriteView(id<MTLDevice> device) const;
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
class AHardwareBufferView : public View {
public:
AHardwareBuffer* handle() const { return handle_; }
AHardwareBufferView(AHardwareBufferView&& src) : View(std::move(src)) {
handle_ = std::exchange(src.handle_, nullptr);
file_descriptor_ = src.file_descriptor_;
fence_fd_ = std::exchange(src.fence_fd_, nullptr);
ahwb_written_ = std::exchange(src.ahwb_written_, nullptr);
release_callback_ = std::exchange(src.release_callback_, nullptr);
}
int file_descriptor() const { return file_descriptor_; }
void SetReadingFinishedFunc(std::function<bool()>&& func) {
CHECK(ahwb_written_)
<< "AHWB write view can't accept 'reading finished callback'";
*ahwb_written_ = std::move(func);
}
void SetWritingFinishedFD(int fd) {
CHECK(fence_fd_)
<< "AHWB read view can't accept 'writing finished file descriptor'";
*fence_fd_ = fd;
}
// The function is called when the tensor is released.
void SetReleaseCallback(std::function<void()> callback) {
*release_callback_ = std::move(callback);
}
protected:
friend class Tensor;
AHardwareBufferView(AHardwareBuffer* handle, int file_descriptor,
int* fence_fd, std::function<bool()>* ahwb_written,
std::function<void()>* release_callback,
std::unique_ptr<absl::MutexLock>&& lock)
: View(std::move(lock)),
handle_(handle),
file_descriptor_(file_descriptor),
fence_fd_(fence_fd),
ahwb_written_(ahwb_written),
release_callback_(release_callback) {}
AHardwareBuffer* handle_;
int file_descriptor_;
// The view sets some Tensor's fields. The view is released prior to tensor.
int* fence_fd_;
std::function<bool()>* ahwb_written_;
std::function<void()>* release_callback_;
};
AHardwareBufferView GetAHardwareBufferReadView() const;
// size_alignment is an optional argument to tell the API to allocate
// a buffer that is padded to multiples of size_alignment bytes.
// size_alignment must be power of 2, i.e. 2, 4, 8, 16, 64, etc.
// If size_alignment is 0, then the buffer will not be padded.
AHardwareBufferView GetAHardwareBufferWriteView(int size_alignment = 0) const;
#endif // MEDIAPIPE_TENSOR_USE_AHWB
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
// TODO: Use GlTextureView instead. // TODO: Use GlTextureView instead.
// Only float32 textures are supported with 1/2/3/4 depths. // Only float32 textures are supported with 1/2/3/4 depths.
@ -188,16 +275,23 @@ class Tensor {
class OpenGlBufferView : public View { class OpenGlBufferView : public View {
public: public:
GLuint name() const { return name_; } GLuint name() const { return name_; }
OpenGlBufferView(OpenGlBufferView&& src) OpenGlBufferView(OpenGlBufferView&& src) : View(std::move(src)) {
: View(std::move(src)), name_(src.name_) { name_ = std::exchange(src.name_, GL_INVALID_INDEX);
src.name_ = GL_INVALID_INDEX; ssbo_read_ = std::exchange(src.ssbo_read_, nullptr);
}
~OpenGlBufferView() {
if (ssbo_read_) {
*ssbo_read_ = glFenceSync(GL_SYNC_GPU_COMMANDS_COMPLETE, 0);
}
} }
protected: protected:
friend class Tensor; friend class Tensor;
OpenGlBufferView(GLuint name, std::unique_ptr<absl::MutexLock>&& lock) OpenGlBufferView(GLuint name, std::unique_ptr<absl::MutexLock>&& lock,
: View(std::move(lock)), name_(name) {} GLsync* ssbo_read)
: View(std::move(lock)), name_(name), ssbo_read_(ssbo_read) {}
GLuint name_; GLuint name_;
GLsync* ssbo_read_;
}; };
// A valid OpenGL context must be bound to the calling thread due to possible // A valid OpenGL context must be bound to the calling thread due to possible
// GPU resource allocation. // GPU resource allocation.
@ -223,16 +317,26 @@ class Tensor {
} }
int bytes() const { return shape_.num_elements() * element_size(); } int bytes() const { return shape_.num_elements() * element_size(); }
bool ready_on_cpu() const { return valid_ & kValidCpu; } bool ready_on_cpu() const {
return valid_ & (kValidAHardwareBuffer | kValidCpu);
}
bool ready_on_gpu() const { bool ready_on_gpu() const {
return valid_ & return valid_ & (kValidMetalBuffer | kValidOpenGlBuffer |
(kValidMetalBuffer | kValidOpenGlBuffer | kValidOpenGlTexture2d); kValidAHardwareBuffer | kValidOpenGlTexture2d);
} }
bool ready_as_metal_buffer() const { return valid_ & kValidMetalBuffer; } bool ready_as_metal_buffer() const { return valid_ & kValidMetalBuffer; }
bool ready_as_opengl_buffer() const { return valid_ & kValidOpenGlBuffer; } bool ready_as_opengl_buffer() const {
return valid_ & (kValidAHardwareBuffer | kValidOpenGlBuffer);
}
bool ready_as_opengl_texture_2d() const { bool ready_as_opengl_texture_2d() const {
return valid_ & kValidOpenGlTexture2d; return valid_ & kValidOpenGlTexture2d;
} }
// Sets the type of underlying resource that is going to be allocated.
enum class StorageType {
kDefault,
kAhwb,
};
static void SetPreferredStorageType(StorageType type);
private: private:
void Move(Tensor*); void Move(Tensor*);
@ -248,6 +352,7 @@ class Tensor {
kValidMetalBuffer = 1 << 1, kValidMetalBuffer = 1 << 1,
kValidOpenGlBuffer = 1 << 2, kValidOpenGlBuffer = 1 << 2,
kValidOpenGlTexture2d = 1 << 3, kValidOpenGlTexture2d = 1 << 3,
kValidAHardwareBuffer = 1 << 5,
}; };
// A list of resource which are currently allocated and synchronized between // A list of resource which are currently allocated and synchronized between
// each-other: valid_ = kValidCpu | kValidMetalBuffer; // each-other: valid_ = kValidCpu | kValidMetalBuffer;
@ -264,6 +369,34 @@ class Tensor {
void AllocateMtlBuffer(id<MTLDevice> device) const; void AllocateMtlBuffer(id<MTLDevice> device) const;
#endif // MEDIAPIPE_METAL_ENABLED #endif // MEDIAPIPE_METAL_ENABLED
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
mutable AHardwareBuffer* ahwb_ = nullptr;
// Signals when GPU finished writing into SSBO so AHWB can be used then. Or
// signals when writing into AHWB has been finished so GPU can read from SSBO.
// Sync and FD are bound together.
mutable EGLSyncKHR fence_sync_ = EGL_NO_SYNC_KHR;
// This FD signals when the writing into the SSBO has been finished.
mutable int ssbo_written_ = -1;
// An externally set FD that is wrapped with the EGL sync then to synchronize
// AHWB -> OpenGL SSBO.
mutable int fence_fd_ = -1;
// Reading from SSBO has been finished so SSBO can be released.
mutable GLsync ssbo_read_ = 0;
// An externally set function that signals when it is safe to release AHWB.
mutable std::function<bool()> ahwb_written_;
mutable std::function<void()> release_callback_;
bool AllocateAHardwareBuffer(int size_alignment = 0) const;
void CreateEglSyncAndFd() const;
// Use Ahwb for other views: OpenGL / CPU buffer.
static inline bool use_ahwb_ = false;
#endif // MEDIAPIPE_TENSOR_USE_AHWB
bool AllocateAhwbMapToSsbo() const;
bool InsertAhwbToSsboFence() const;
void MoveAhwbStuff(Tensor* src);
void ReleaseAhwbStuff();
void* MapAhwbToCpuRead() const;
void* MapAhwbToCpuWrite() const;
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30 #if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
mutable std::shared_ptr<mediapipe::GlContext> gl_context_; mutable std::shared_ptr<mediapipe::GlContext> gl_context_;
mutable GLuint opengl_texture2d_ = GL_INVALID_INDEX; mutable GLuint opengl_texture2d_ = GL_INVALID_INDEX;

View File

@ -0,0 +1,382 @@
#include <cstdint>
#include <utility>
#include "mediapipe/framework/formats/tensor.h"
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
#include "absl/synchronization/mutex.h"
#include "mediapipe/framework/port.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/gpu/gl_base.h"
#include "third_party/GL/gl/include/EGL/egl.h"
#include "third_party/GL/gl/include/EGL/eglext.h"
#endif // MEDIAPIPE_TENSOR_USE_AHWB
namespace mediapipe {
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
namespace {
PFNGLBUFFERSTORAGEEXTERNALEXTPROC glBufferStorageExternalEXT;
PFNEGLGETNATIVECLIENTBUFFERANDROIDPROC eglGetNativeClientBufferANDROID;
PFNEGLDUPNATIVEFENCEFDANDROIDPROC eglDupNativeFenceFDANDROID;
PFNEGLCREATESYNCKHRPROC eglCreateSyncKHR;
PFNEGLWAITSYNCKHRPROC eglWaitSyncKHR;
PFNEGLCLIENTWAITSYNCKHRPROC eglClientWaitSyncKHR;
PFNEGLDESTROYSYNCKHRPROC eglDestroySyncKHR;
bool IsGlSupported() {
static const bool extensions_allowed = [] {
eglGetNativeClientBufferANDROID =
reinterpret_cast<PFNEGLGETNATIVECLIENTBUFFERANDROIDPROC>(
eglGetProcAddress("eglGetNativeClientBufferANDROID"));
glBufferStorageExternalEXT =
reinterpret_cast<PFNGLBUFFERSTORAGEEXTERNALEXTPROC>(
eglGetProcAddress("glBufferStorageExternalEXT"));
eglDupNativeFenceFDANDROID =
reinterpret_cast<PFNEGLDUPNATIVEFENCEFDANDROIDPROC>(
eglGetProcAddress("eglDupNativeFenceFDANDROID"));
eglCreateSyncKHR = reinterpret_cast<PFNEGLCREATESYNCKHRPROC>(
eglGetProcAddress("eglCreateSyncKHR"));
eglWaitSyncKHR = reinterpret_cast<PFNEGLWAITSYNCKHRPROC>(
eglGetProcAddress("eglWaitSyncKHR"));
eglClientWaitSyncKHR = reinterpret_cast<PFNEGLCLIENTWAITSYNCKHRPROC>(
eglGetProcAddress("eglClientWaitSyncKHR"));
eglDestroySyncKHR = reinterpret_cast<PFNEGLDESTROYSYNCKHRPROC>(
eglGetProcAddress("eglDestroySyncKHR"));
return eglClientWaitSyncKHR && eglWaitSyncKHR &&
eglGetNativeClientBufferANDROID && glBufferStorageExternalEXT &&
eglCreateSyncKHR && eglDupNativeFenceFDANDROID && eglDestroySyncKHR;
}();
return extensions_allowed;
}
absl::Status MapAHardwareBufferToGlBuffer(AHardwareBuffer* handle, size_t size,
GLuint name) {
if (!IsGlSupported()) {
return absl::UnknownError(
"No GL extension functions found to bind AHardwareBuffer and "
"OpenGL buffer");
}
EGLClientBuffer native_buffer = eglGetNativeClientBufferANDROID(handle);
if (!native_buffer) {
return absl::UnknownError("Can't get native buffer");
}
glBufferStorageExternalEXT(GL_SHADER_STORAGE_BUFFER, 0, size, native_buffer,
GL_MAP_READ_BIT | GL_MAP_WRITE_BIT |
GL_MAP_COHERENT_BIT_EXT |
GL_MAP_PERSISTENT_BIT_EXT);
if (glGetError() == GL_NO_ERROR) {
return absl::OkStatus();
} else {
return absl::InternalError("Error in glBufferStorageExternalEXT");
}
}
static inline int AlignedToPowerOf2(int value, int alignment) {
// alignment must be a power of 2
return ((value - 1) | (alignment - 1)) + 1;
}
// This class keeps tensor's resources while the tensor is in use on GPU or TPU
// but is already released on CPU. When a regular OpenGL buffer is bound to the
// GPU queue for execution and released on client side then the buffer is still
// not released because is being used by GPU. OpenGL driver keeps traking of
// that. When OpenGL buffer is build on top of AHWB then the traking is done
// with the DeleyedRelease which, actually, keeps record of all AHWBs allocated
// and releases each of them if already used. EGL/GL fences are used to check
// the status of a buffer.
class DelayedReleaser {
public:
// Non-copyable
DelayedReleaser(const DelayedReleaser&) = delete;
DelayedReleaser& operator=(const DelayedReleaser&) = delete;
// Non-movable
DelayedReleaser(DelayedReleaser&&) = delete;
DelayedReleaser& operator=(DelayedReleaser&&) = delete;
static void Add(AHardwareBuffer* ahwb, GLuint opengl_buffer,
EGLSyncKHR ssbo_sync, GLsync ssbo_read,
std::function<bool()>&& ahwb_written,
std::shared_ptr<mediapipe::GlContext> gl_context,
std::function<void()>&& callback) {
static absl::Mutex mutex;
absl::MutexLock lock(&mutex);
// Using `new` to access a non-public constructor.
to_release_.emplace_back(absl::WrapUnique(new DelayedReleaser(
ahwb, opengl_buffer, ssbo_sync, ssbo_read, std::move(ahwb_written),
gl_context, std::move(callback))));
for (auto it = to_release_.begin(); it != to_release_.end();) {
if ((*it)->IsSignaled()) {
it = to_release_.erase(it);
} else {
++it;
}
}
}
~DelayedReleaser() {
AHardwareBuffer_release(ahwb_);
if (release_callback_) release_callback_();
}
bool IsSignaled() {
CHECK(!(ssbo_read_ && ahwb_written_))
<< "ssbo_read_ and ahwb_written_ cannot both be set";
if (ahwb_written_) {
if (!ahwb_written_()) return false;
gl_context_->Run([this]() {
if (fence_sync_ != EGL_NO_SYNC_KHR && IsGlSupported()) {
auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY);
if (egl_display != EGL_NO_DISPLAY) {
eglDestroySyncKHR(egl_display, fence_sync_);
}
fence_sync_ = EGL_NO_SYNC_KHR;
}
glDeleteBuffers(1, &opengl_buffer_);
opengl_buffer_ = GL_INVALID_INDEX;
});
return true;
}
gl_context_->Run([this]() {
if (ssbo_read_ != 0) {
GLenum status = glClientWaitSync(ssbo_read_, 0,
/* timeout ns = */ 0);
if (status != GL_CONDITION_SATISFIED && status != GL_ALREADY_SIGNALED) {
return;
}
glDeleteSync(ssbo_read_);
ssbo_read_ = 0;
// Don't wait on ssbo_sync because it is ahead of ssbo_read_sync.
if (fence_sync_ != EGL_NO_SYNC_KHR && IsGlSupported()) {
auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY);
if (egl_display != EGL_NO_DISPLAY) {
eglDestroySyncKHR(egl_display, fence_sync_);
}
}
fence_sync_ = EGL_NO_SYNC_KHR;
glDeleteBuffers(1, &opengl_buffer_);
opengl_buffer_ = GL_INVALID_INDEX;
}
});
return opengl_buffer_ == GL_INVALID_INDEX;
}
protected:
AHardwareBuffer* ahwb_;
GLuint opengl_buffer_;
// TODO: use wrapper instead.
EGLSyncKHR fence_sync_;
// TODO: use wrapper instead.
GLsync ssbo_read_;
std::function<bool()> ahwb_written_;
std::shared_ptr<mediapipe::GlContext> gl_context_;
std::function<void()> release_callback_;
static inline std::deque<std::unique_ptr<DelayedReleaser>> to_release_;
DelayedReleaser(AHardwareBuffer* ahwb, GLuint opengl_buffer,
EGLSyncKHR fence_sync, GLsync ssbo_read,
std::function<bool()>&& ahwb_written,
std::shared_ptr<mediapipe::GlContext> gl_context,
std::function<void()>&& callback)
: ahwb_(ahwb),
opengl_buffer_(opengl_buffer),
fence_sync_(fence_sync),
ssbo_read_(ssbo_read),
ahwb_written_(std::move(ahwb_written)),
gl_context_(gl_context),
release_callback_(std::move(callback)) {}
};
} // namespace
Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const {
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
CHECK(valid_ != kValidNone) << "Tensor must be written prior to read from.";
CHECK(!(valid_ & kValidOpenGlTexture2d))
<< "Tensor conversion between OpenGL texture and AHardwareBuffer is not "
"supported.";
CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer))
<< "Interoperability bettween OpenGL buffer and AHardwareBuffer is not "
"supported on targe system.";
CHECK(AllocateAHardwareBuffer())
<< "AHardwareBuffer is not supported on the target system.";
valid_ |= kValidAHardwareBuffer;
if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd();
return {ahwb_,
ssbo_written_,
&fence_fd_, // The FD is created for SSBO -> AHWB synchronization.
&ahwb_written_, // Filled by SetReadingFinishedFunc.
&release_callback_,
std::move(lock)};
}
void Tensor::CreateEglSyncAndFd() const {
gl_context_->Run([this]() {
if (IsGlSupported()) {
auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY);
if (egl_display != EGL_NO_DISPLAY) {
fence_sync_ = eglCreateSyncKHR(egl_display,
EGL_SYNC_NATIVE_FENCE_ANDROID, nullptr);
if (fence_sync_ != EGL_NO_SYNC_KHR) {
ssbo_written_ = eglDupNativeFenceFDANDROID(egl_display, fence_sync_);
if (ssbo_written_ == -1) {
eglDestroySyncKHR(egl_display, fence_sync_);
fence_sync_ = EGL_NO_SYNC_KHR;
}
}
}
}
// Can't use Sync object.
if (fence_sync_ == EGL_NO_SYNC_KHR) glFinish();
});
}
Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
int size_alignment) const {
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
CHECK(AllocateAHardwareBuffer(size_alignment))
<< "AHardwareBuffer is not supported on the target system.";
valid_ = kValidAHardwareBuffer;
return {ahwb_,
/*ssbo_written=*/-1,
&fence_fd_, // For SetWritingFinishedFD.
/*ahwb_written=*/nullptr, // The lifetime is managed by SSBO.
&release_callback_,
std::move(lock)};
}
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
if (!use_ahwb_) return false;
if (ahwb_ == nullptr) {
AHardwareBuffer_Desc desc = {};
if (size_alignment == 0) {
desc.width = bytes();
} else {
// We expect allocations to be page-aligned, implicitly satisfying any
// requirements from Edge TPU. No need to add a check for this,
// since Edge TPU will check for us.
desc.width = AlignedToPowerOf2(bytes(), size_alignment);
}
desc.height = 1;
desc.layers = 1;
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
desc.usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN |
AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER;
return AHardwareBuffer_allocate(&desc, &ahwb_) == 0;
}
return true;
}
bool Tensor::AllocateAhwbMapToSsbo() const {
if (AllocateAHardwareBuffer()) {
if (MapAHardwareBufferToGlBuffer(ahwb_, bytes(), opengl_buffer_).ok()) {
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
return true;
}
// Unable to make OpenGL <-> AHWB binding. Use regular SSBO instead.
AHardwareBuffer_release(ahwb_);
ahwb_ = nullptr;
}
return false;
}
// SSBO is created on top of AHWB. A fence is inserted into the GPU queue before
// the GPU task that is going to read from the SSBO. When the writing into AHWB
// is finished then the GPU reads from the SSBO.
bool Tensor::InsertAhwbToSsboFence() const {
if (!ahwb_) return false;
if (fence_fd_ != -1) {
// Can't wait for FD to be signaled on GPU.
// TODO: wait on CPU instead.
if (!IsGlSupported()) return true;
// Server-side fence.
auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY);
if (egl_display == EGL_NO_DISPLAY) return true;
EGLint sync_attribs[] = {EGL_SYNC_NATIVE_FENCE_FD_ANDROID,
(EGLint)fence_fd_, EGL_NONE};
fence_sync_ = eglCreateSyncKHR(egl_display, EGL_SYNC_NATIVE_FENCE_ANDROID,
sync_attribs);
if (fence_sync_ != EGL_NO_SYNC_KHR) {
eglWaitSyncKHR(egl_display, fence_sync_, 0);
}
}
return true;
}
void Tensor::MoveAhwbStuff(Tensor* src) {
ahwb_ = std::exchange(src->ahwb_, nullptr);
fence_sync_ = std::exchange(src->fence_sync_, EGL_NO_SYNC_KHR);
ssbo_read_ = std::exchange(src->ssbo_read_, static_cast<GLsync>(0));
ssbo_written_ = std::exchange(src->ssbo_written_, -1);
fence_fd_ = std::exchange(src->fence_fd_, -1);
ahwb_written_ = std::move(src->ahwb_written_);
release_callback_ = std::move(src->release_callback_);
}
void Tensor::ReleaseAhwbStuff() {
if (fence_fd_ != -1) {
close(fence_fd_);
fence_fd_ = -1;
}
if (ahwb_) {
if (ssbo_read_ != 0 || fence_sync_ != EGL_NO_SYNC_KHR) {
if (ssbo_written_ != -1) close(ssbo_written_);
DelayedReleaser::Add(ahwb_, opengl_buffer_, fence_sync_, ssbo_read_,
std::move(ahwb_written_), gl_context_,
std::move(release_callback_));
opengl_buffer_ = GL_INVALID_INDEX;
} else {
AHardwareBuffer_release(ahwb_);
}
}
}
void* Tensor::MapAhwbToCpuRead() const {
if (ahwb_) {
if (!(valid_ & kValidCpu) && (valid_ & kValidOpenGlBuffer) &&
ssbo_written_ == -1) {
// EGLSync is failed. Use another synchronization method.
// TODO: Use tflite::gpu::GlBufferSync and GlActiveSync.
glFinish();
}
void* ptr;
auto error =
AHardwareBuffer_lock(ahwb_, AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN,
ssbo_written_, nullptr, &ptr);
CHECK(error == 0) << "AHardwareBuffer_lock " << error;
close(ssbo_written_);
ssbo_written_ = -1;
return ptr;
}
return nullptr;
}
void* Tensor::MapAhwbToCpuWrite() const {
if (ahwb_) {
// TODO: If previously acquired view is GPU write view then need to
// be sure that writing is finished. That's a warning: two consequent write
// views should be interleaved with read view.
void* ptr;
auto error = AHardwareBuffer_lock(
ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN, -1, nullptr, &ptr);
CHECK(error == 0) << "AHardwareBuffer_lock " << error;
return ptr;
}
return nullptr;
}
#else // MEDIAPIPE_TENSOR_USE_AHWB
bool Tensor::AllocateAhwbMapToSsbo() const { return false; }
bool Tensor::InsertAhwbToSsboFence() const { return false; }
void Tensor::MoveAhwbStuff(Tensor* src) {}
void Tensor::ReleaseAhwbStuff() {}
void* Tensor::MapAhwbToCpuRead() const { return nullptr; }
void* Tensor::MapAhwbToCpuWrite() const { return nullptr; }
#endif // MEDIAPIPE_TENSOR_USE_AHWB
} // namespace mediapipe

View File

@ -0,0 +1,59 @@
#include "mediapipe/gpu/gpu_test_base.h"
#include "testing/base/public/gmock.h"
#include "testing/base/public/gunit.h"
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
#include <android/hardware_buffer.h>
#include "mediapipe/framework/formats/tensor.h"
namespace mediapipe {
#if !MEDIAPIPE_DISABLE_GPU
class TensorAhwbTest : public mediapipe::GpuTestBase {
public:
};
TEST_F(TensorAhwbTest, TestCpuThenAHWB) {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
{
auto ptr = tensor.GetCpuWriteView().buffer<float>();
EXPECT_NE(ptr, nullptr);
}
{
auto ahwb = tensor.GetAHardwareBufferReadView().handle();
EXPECT_NE(ahwb, nullptr);
}
}
TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
{
auto ahwb = tensor.GetAHardwareBufferWriteView().handle();
EXPECT_NE(ahwb, nullptr);
}
{
auto ptr = tensor.GetCpuReadView().buffer<float>();
EXPECT_NE(ptr, nullptr);
}
}
TEST_F(TensorAhwbTest, TestCpuThenGl) {
RunInGlContext([] {
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
{
auto ptr = tensor.GetCpuWriteView().buffer<float>();
EXPECT_NE(ptr, nullptr);
}
{
auto ssbo = tensor.GetOpenGlBufferReadView().name();
EXPECT_GT(ssbo, 0);
}
});
}
} // namespace mediapipe
#endif // !MEDIAPIPE_DISABLE_GPU
#endif // MEDIAPIPE_TENSOR_USE_AHWB

View File

@ -21,6 +21,8 @@ syntax = "proto2";
package mediapipe; package mediapipe;
option objc_class_prefix = "MediaPipe";
// Header for a uniformly sampled time series stream. Each Packet in // Header for a uniformly sampled time series stream. Each Packet in
// the stream is a Matrix, and each column is a (vector-valued) sample of // the stream is a Matrix, and each column is a (vector-valued) sample of
// the series, i.e. each column corresponds to a distinct sample in time. // the series, i.e. each column corresponds to a distinct sample in time.

View File

@ -204,6 +204,8 @@ absl::Status InputStreamManager::SetNextTimestampBound(const Timestamp bound,
// untimed scheduling policies. // untimed scheduling policies.
if (bound > next_timestamp_bound_) { if (bound > next_timestamp_bound_) {
next_timestamp_bound_ = bound; next_timestamp_bound_ = bound;
VLOG(3) << "Next timestamp bound for input " << name_ << " is "
<< next_timestamp_bound_;
if (queue_.empty()) { if (queue_.empty()) {
// If the queue was not empty then a change to the next_timestamp_bound_ // If the queue was not empty then a change to the next_timestamp_bound_
// is not detectable by the consumer. // is not detectable by the consumer.

View File

@ -168,6 +168,8 @@ void OutputStreamManager::PropagateUpdatesToMirrors(
if (next_timestamp_bound != Timestamp::Unset()) { if (next_timestamp_bound != Timestamp::Unset()) {
absl::MutexLock lock(&stream_mutex_); absl::MutexLock lock(&stream_mutex_);
next_timestamp_bound_ = next_timestamp_bound; next_timestamp_bound_ = next_timestamp_bound;
VLOG(3) << "Next timestamp bound for output " << output_stream_spec_.name
<< " is " << next_timestamp_bound_;
} }
} }
std::list<Packet>* packets_to_propagate = output_stream_shard->OutputQueue(); std::list<Packet>* packets_to_propagate = output_stream_shard->OutputQueue();

View File

@ -106,19 +106,17 @@ std::string Packet::DebugString() const {
return result; return result;
} }
absl::Status Packet::ValidateAsType(const tool::TypeInfo& type_info) const { absl::Status Packet::ValidateAsType(TypeId type_id) const {
if (ABSL_PREDICT_FALSE(IsEmpty())) { if (ABSL_PREDICT_FALSE(IsEmpty())) {
return absl::InternalError( return absl::InternalError(absl::StrCat(
absl::StrCat("Expected a Packet of type: ", "Expected a Packet of type: ", MediaPipeTypeStringOrDemangled(type_id),
MediaPipeTypeStringOrDemangled(type_info),
", but received an empty Packet.")); ", but received an empty Packet."));
} }
bool holder_is_right_type = bool holder_is_right_type = holder_->GetTypeId() == type_id;
holder_->GetTypeInfo().hash_code() == type_info.hash_code();
if (ABSL_PREDICT_FALSE(!holder_is_right_type)) { if (ABSL_PREDICT_FALSE(!holder_is_right_type)) {
return absl::InvalidArgumentError(absl::StrCat( return absl::InvalidArgumentError(absl::StrCat(
"The Packet stores \"", holder_->DebugTypeName(), "\", but \"", "The Packet stores \"", holder_->DebugTypeName(), "\", but \"",
MediaPipeTypeStringOrDemangled(type_info), "\" was requested.")); MediaPipeTypeStringOrDemangled(type_id), "\" was requested."));
} }
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -21,7 +21,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <typeinfo>
#include "absl/base/macros.h" #include "absl/base/macros.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
@ -69,7 +68,7 @@ absl::StatusOr<Packet> PacketFromDynamicProto(const std::string& type_name,
// The preferred method of creating a Packet is with MakePacket<T>(). // The preferred method of creating a Packet is with MakePacket<T>().
// The Packet typically owns the object that it contains, but // The Packet typically owns the object that it contains, but
// PointToForeign allows a Packet to be constructed which does not // PointToForeign allows a Packet to be constructed which does not
// own it's data. // own its data.
// //
// This class is thread compatible. // This class is thread compatible.
class Packet { class Packet {
@ -180,7 +179,7 @@ class Packet {
// Returns an error if the packet does not contain data of type T. // Returns an error if the packet does not contain data of type T.
template <typename T> template <typename T>
absl::Status ValidateAsType() const { absl::Status ValidateAsType() const {
return ValidateAsType(tool::TypeInfo::Get<T>()); return ValidateAsType(kTypeId<T>);
} }
// Returns an error if the packet is not an instance of // Returns an error if the packet is not an instance of
@ -189,11 +188,7 @@ class Packet {
// Get the type id for the underlying type stored in the Packet. // Get the type id for the underlying type stored in the Packet.
// Crashes if IsEmpty() == true. // Crashes if IsEmpty() == true.
size_t GetTypeId() const { return GetTypeInfo().hash_code(); } TypeId GetTypeId() const;
// Get the type info for the underlying type stored in the Packet.
// Crashes if IsEmpty() == true.
const tool::TypeInfo& GetTypeInfo() const;
// Returns the timestamp. // Returns the timestamp.
class Timestamp Timestamp() const; class Timestamp Timestamp() const;
@ -225,7 +220,7 @@ class Packet {
packet_internal::GetHolderShared(Packet&& packet); packet_internal::GetHolderShared(Packet&& packet);
friend class PacketType; friend class PacketType;
absl::Status ValidateAsType(const tool::TypeInfo& type_info) const; absl::Status ValidateAsType(TypeId type_id) const;
std::shared_ptr<packet_internal::HolderBase> holder_; std::shared_ptr<packet_internal::HolderBase> holder_;
class Timestamp timestamp_; class Timestamp timestamp_;
@ -369,7 +364,7 @@ class HolderBase {
virtual ~HolderBase(); virtual ~HolderBase();
template <typename T> template <typename T>
bool PayloadIsOfType() const { bool PayloadIsOfType() const {
return GetTypeInfo().hash_code() == tool::GetTypeHash<T>(); return GetTypeId() == kTypeId<T>;
} }
// Returns a printable string identifying the type stored in the holder. // Returns a printable string identifying the type stored in the holder.
virtual const std::string DebugTypeName() const = 0; virtual const std::string DebugTypeName() const = 0;
@ -377,7 +372,7 @@ class HolderBase {
// empty string. // empty string.
virtual const std::string RegisteredTypeName() const = 0; virtual const std::string RegisteredTypeName() const = 0;
// Get the type id of the underlying data type. // Get the type id of the underlying data type.
virtual const tool::TypeInfo& GetTypeInfo() const = 0; virtual TypeId GetTypeId() const = 0;
// Downcasts this to Holder<T>. Returns nullptr if deserialization // Downcasts this to Holder<T>. Returns nullptr if deserialization
// failed or if the requested type is not what is stored. // failed or if the requested type is not what is stored.
template <typename T> template <typename T>
@ -428,7 +423,7 @@ StatusOr<std::vector<const proto_ns::MessageLite*>>
ConvertToVectorOfProtoMessageLitePtrs(const T* data, ConvertToVectorOfProtoMessageLitePtrs(const T* data,
/*is_proto_vector=*/std::false_type) { /*is_proto_vector=*/std::false_type) {
return absl::InvalidArgumentError(absl::StrCat( return absl::InvalidArgumentError(absl::StrCat(
"The Packet stores \"", tool::TypeInfo::Get<T>().name(), "\"", "The Packet stores \"", kTypeId<T>.name(), "\"",
"which is not convertible to vector<proto_ns::MessageLite*>.")); "which is not convertible to vector<proto_ns::MessageLite*>."));
} }
@ -510,9 +505,7 @@ class Holder : public HolderBase {
HolderSupport<T>::EnsureStaticInit(); HolderSupport<T>::EnsureStaticInit();
return *ptr_; return *ptr_;
} }
const tool::TypeInfo& GetTypeInfo() const final { TypeId GetTypeId() const final { return kTypeId<T>; }
return tool::TypeInfo::Get<T>();
}
// Releases the underlying data pointer and transfers the ownership to a // Releases the underlying data pointer and transfers the ownership to a
// unique pointer. // unique pointer.
// This method is dangerous and is only used by Packet::Consume() if the // This method is dangerous and is only used by Packet::Consume() if the
@ -748,9 +741,9 @@ inline Packet& Packet::operator=(Packet&& packet) {
inline bool Packet::IsEmpty() const { return holder_ == nullptr; } inline bool Packet::IsEmpty() const { return holder_ == nullptr; }
inline const tool::TypeInfo& Packet::GetTypeInfo() const { inline TypeId Packet::GetTypeId() const {
CHECK(holder_); CHECK(holder_);
return holder_->GetTypeInfo(); return holder_->GetTypeId();
} }
template <typename T> template <typename T>

View File

@ -18,6 +18,8 @@ syntax = "proto2";
package mediapipe; package mediapipe;
option objc_class_prefix = "MediaPipe";
message PacketTestProto { message PacketTestProto {
// Tests that the tags used to encode the timestamp do not interfere with // Tests that the tags used to encode the timestamp do not interfere with
// proto tags. // proto tags.

View File

@ -127,13 +127,13 @@ bool PacketType::IsOneOf() const {
} }
bool PacketType::IsExactType() const { bool PacketType::IsExactType() const {
return absl::holds_alternative<const tool::TypeInfo*>(type_spec_); return absl::holds_alternative<TypeId>(type_spec_);
} }
const std::string* PacketType::RegisteredTypeName() const { const std::string* PacketType::RegisteredTypeName() const {
if (auto* same_as = SameAsPtr()) return same_as->RegisteredTypeName(); if (auto* same_as = SameAsPtr()) return same_as->RegisteredTypeName();
if (auto* type_info = absl::get_if<const tool::TypeInfo*>(&type_spec_)) if (auto* type_id = absl::get_if<TypeId>(&type_spec_))
return MediaPipeTypeStringFromTypeId((**type_info).hash_code()); return MediaPipeTypeStringFromTypeId(*type_id);
if (auto* multi_type = absl::get_if<MultiType>(&type_spec_)) if (auto* multi_type = absl::get_if<MultiType>(&type_spec_))
return multi_type->registered_type_name; return multi_type->registered_type_name;
return nullptr; return nullptr;
@ -141,8 +141,8 @@ const std::string* PacketType::RegisteredTypeName() const {
namespace internal { namespace internal {
struct TypeInfoFormatter { struct TypeIdFormatter {
void operator()(std::string* out, const tool::TypeInfo& t) const { void operator()(std::string* out, TypeId t) const {
absl::StrAppend(out, MediaPipeTypeStringOrDemangled(t)); absl::StrAppend(out, MediaPipeTypeStringOrDemangled(t));
} }
}; };
@ -167,12 +167,9 @@ explicit QuoteFormatter(Formatter f) -> QuoteFormatter<Formatter>;
} // namespace internal } // namespace internal
std::string PacketType::TypeNameForOneOf(TypeInfoSpan types) { std::string PacketType::TypeNameForOneOf(TypeIdSpan types) {
return absl::StrCat( return absl::StrCat(
"OneOf<", "OneOf<", absl::StrJoin(types, ", ", internal::TypeIdFormatter()), ">");
absl::StrJoin(types, ", ",
absl::DereferenceFormatter(internal::TypeInfoFormatter())),
">");
} }
std::string PacketType::DebugTypeName() const { std::string PacketType::DebugTypeName() const {
@ -185,8 +182,8 @@ std::string PacketType::DebugTypeName() const {
if (auto* special = absl::get_if<SpecialType>(&type_spec_)) { if (auto* special = absl::get_if<SpecialType>(&type_spec_)) {
return special->name_; return special->name_;
} }
if (auto* type_info = absl::get_if<const tool::TypeInfo*>(&type_spec_)) { if (auto* type_id = absl::get_if<TypeId>(&type_spec_)) {
return MediaPipeTypeStringOrDemangled(**type_info); return MediaPipeTypeStringOrDemangled(*type_id);
} }
if (auto* multi_type = absl::get_if<MultiType>(&type_spec_)) { if (auto* multi_type = absl::get_if<MultiType>(&type_spec_)) {
return TypeNameForOneOf(multi_type->types); return TypeNameForOneOf(multi_type->types);
@ -194,11 +191,11 @@ std::string PacketType::DebugTypeName() const {
return "[Undefined Type]"; return "[Undefined Type]";
} }
static bool HaveCommonType(absl::Span<const tool::TypeInfo* const> types1, static bool HaveCommonType(absl::Span<const TypeId> types1,
absl::Span<const tool::TypeInfo* const> types2) { absl::Span<const TypeId> types2) {
for (const auto& first : types1) { for (const auto& first : types1) {
for (const auto& second : types2) { for (const auto& second : types2) {
if (first->hash_code() == second->hash_code()) { if (first == second) {
return true; return true;
} }
} }
@ -216,35 +213,34 @@ absl::Status PacketType::Validate(const Packet& packet) const {
// in SetSameAs(). // in SetSameAs().
return GetSameAs()->Validate(packet); return GetSameAs()->Validate(packet);
} }
if (auto* type_info = absl::get_if<const tool::TypeInfo*>(&type_spec_)) { if (auto* type_id = absl::get_if<TypeId>(&type_spec_)) {
return packet.ValidateAsType(**type_info); return packet.ValidateAsType(*type_id);
} }
if (packet.IsEmpty()) { if (packet.IsEmpty()) {
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Empty packets are not allowed for type: " << DebugTypeName(); << "Empty packets are not allowed for type: " << DebugTypeName();
} }
if (auto* multi_type = absl::get_if<MultiType>(&type_spec_)) { if (auto* multi_type = absl::get_if<MultiType>(&type_spec_)) {
auto* packet_type = &packet.GetTypeInfo(); auto packet_type = packet.GetTypeId();
if (HaveCommonType(multi_type->types, absl::MakeSpan(&packet_type, 1))) { if (HaveCommonType(multi_type->types, absl::MakeSpan(&packet_type, 1))) {
return absl::OkStatus(); return absl::OkStatus();
} else { } else {
return absl::InvalidArgumentError(absl::StrCat( return absl::InvalidArgumentError(absl::StrCat(
"The Packet stores \"", packet.DebugTypeName(), "\", but one of ", "The Packet stores \"", packet.DebugTypeName(), "\", but one of ",
absl::StrJoin(multi_type->types, ", ", absl::StrJoin(multi_type->types, ", ",
absl::DereferenceFormatter(internal::QuoteFormatter( internal::QuoteFormatter(internal::TypeIdFormatter())),
internal::TypeInfoFormatter()))),
" was requested.")); " was requested."));
} }
} }
if (auto* special = absl::get_if<SpecialType>(&type_spec_)) { if (auto* special = absl::get_if<SpecialType>(&type_spec_)) {
return special->accept_fn_(&packet.GetTypeInfo()); return special->accept_fn_(packet.GetTypeId());
} }
return absl::OkStatus(); return absl::OkStatus();
} }
PacketType::TypeInfoSpan PacketType::GetTypeSpan(const TypeSpec& type_spec) { PacketType::TypeIdSpan PacketType::GetTypeSpan(const TypeSpec& type_spec) {
if (auto* type_info = absl::get_if<const tool::TypeInfo*>(&type_spec)) if (auto* type_id = absl::get_if<TypeId>(&type_spec))
return absl::MakeSpan(type_info, 1); return absl::MakeSpan(type_id, 1);
if (auto* multi_type = absl::get_if<MultiType>(&type_spec)) if (auto* multi_type = absl::get_if<MultiType>(&type_spec))
return multi_type->types; return multi_type->types;
return {}; return {};
@ -254,8 +250,8 @@ bool PacketType::IsConsistentWith(const PacketType& other) const {
const PacketType* type1 = GetSameAs(); const PacketType* type1 = GetSameAs();
const PacketType* type2 = other.GetSameAs(); const PacketType* type2 = other.GetSameAs();
TypeInfoSpan types1 = GetTypeSpan(type1->type_spec_); TypeIdSpan types1 = GetTypeSpan(type1->type_spec_);
TypeInfoSpan types2 = GetTypeSpan(type2->type_spec_); TypeIdSpan types2 = GetTypeSpan(type2->type_spec_);
if (!types1.empty() && !types2.empty()) { if (!types1.empty() && !types2.empty()) {
return HaveCommonType(types1, types2); return HaveCommonType(types1, types2);
} }

View File

@ -121,15 +121,15 @@ class PacketType {
// We don't do union-find optimizations in order to avoid a mutex. // We don't do union-find optimizations in order to avoid a mutex.
const PacketType* other; const PacketType* other;
}; };
using TypeInfoSpan = absl::Span<const tool::TypeInfo* const>; using TypeIdSpan = absl::Span<const TypeId>;
struct MultiType { struct MultiType {
TypeInfoSpan types; TypeIdSpan types;
// TODO: refactor RegisteredTypeName, remove. // TODO: refactor RegisteredTypeName, remove.
const std::string* registered_type_name; const std::string* registered_type_name;
}; };
struct SpecialType; struct SpecialType;
using TypeSpec = absl::variant<absl::monostate, const tool::TypeInfo*, using TypeSpec =
MultiType, SameAs, SpecialType>; absl::variant<absl::monostate, TypeId, MultiType, SameAs, SpecialType>;
typedef absl::Status (*AcceptsTypeFn)(const TypeSpec& type); typedef absl::Status (*AcceptsTypeFn)(const TypeSpec& type);
struct SpecialType { struct SpecialType {
std::string name_; std::string name_;
@ -140,8 +140,8 @@ class PacketType {
static absl::Status AcceptNone(const TypeSpec& type); static absl::Status AcceptNone(const TypeSpec& type);
const PacketType* SameAsPtr() const; const PacketType* SameAsPtr() const;
static TypeInfoSpan GetTypeSpan(const TypeSpec& type_spec); static TypeIdSpan GetTypeSpan(const TypeSpec& type_spec);
static std::string TypeNameForOneOf(TypeInfoSpan types); static std::string TypeNameForOneOf(TypeIdSpan types);
TypeSpec type_spec_; TypeSpec type_spec_;
@ -259,14 +259,13 @@ absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set);
template <typename T> template <typename T>
PacketType& PacketType::Set() { PacketType& PacketType::Set() {
type_spec_ = &tool::TypeInfo::Get<T>(); type_spec_ = kTypeId<T>;
return *this; return *this;
} }
template <typename... T> template <typename... T>
PacketType& PacketType::SetOneOf() { PacketType& PacketType::SetOneOf() {
static const NoDestructor<std::vector<const tool::TypeInfo*>> types{ static const NoDestructor<std::vector<TypeId>> types{{kTypeId<T>...}};
{&tool::TypeInfo::Get<T>()...}};
static const NoDestructor<std::string> name{TypeNameForOneOf(*types)}; static const NoDestructor<std::string> name{TypeNameForOneOf(*types)};
type_spec_ = MultiType{*types, &*name}; type_spec_ = MultiType{*types, &*name};
return *this; return *this;

View File

@ -43,7 +43,7 @@ const int kDefaultLogFileCount = 2;
const char kDefaultLogFilePrefix[] = "mediapipe_trace_"; const char kDefaultLogFilePrefix[] = "mediapipe_trace_";
// The number of recent timestamps tracked for each input stream. // The number of recent timestamps tracked for each input stream.
const int kPacketInfoRecentCount = 100; const int kPacketInfoRecentCount = 400;
std::string PacketIdToString(const PacketId& packet_id) { std::string PacketIdToString(const PacketId& packet_id) {
return absl::Substitute("stream_name: $0, timestamp_usec: $1", return absl::Substitute("stream_name: $0, timestamp_usec: $1",
@ -507,7 +507,7 @@ int64 GraphProfiler::AddInputStreamTimeSamples(
// This is a condition rather than a failure CHECK because // This is a condition rather than a failure CHECK because
// under certain conditions the consumer calculator's Process() // under certain conditions the consumer calculator's Process()
// can start before the producer calculator's Process() is finished. // can start before the producer calculator's Process() is finished.
LOG_EVERY_N(WARNING, 100) << "Expected packet info is missing for: " LOG_FIRST_N(WARNING, 10) << "Expected packet info is missing for: "
<< PacketIdToString(packet_id); << PacketIdToString(packet_id);
continue; continue;
} }

View File

@ -36,7 +36,7 @@ class SubgraphContext {
public: public:
SubgraphContext() : SubgraphContext(nullptr, nullptr) {} SubgraphContext() : SubgraphContext(nullptr, nullptr) {}
// @node and/or @service_manager can be nullptr. // @node and/or @service_manager can be nullptr.
SubgraphContext(const CalculatorGraphConfig::Node* node, SubgraphContext(CalculatorGraphConfig::Node* node,
const GraphServiceManager* service_manager) const GraphServiceManager* service_manager)
: default_node_(node ? absl::nullopt : default_node_(node ? absl::nullopt
: absl::optional<CalculatorGraphConfig::Node>( : absl::optional<CalculatorGraphConfig::Node>(
@ -48,14 +48,19 @@ class SubgraphContext {
: absl::optional<GraphServiceManager>(GraphServiceManager())), : absl::optional<GraphServiceManager>(GraphServiceManager())),
service_manager_(service_manager ? *service_manager service_manager_(service_manager ? *service_manager
: default_service_manager_.value()), : default_service_manager_.value()),
options_map_(std::move(tool::OptionsMap().Initialize(original_node_))) { options_map_(
} std::move(tool::MutableOptionsMap().Initialize(original_node_))) {}
template <typename T> template <typename T>
const T& Options() { const T& Options() {
return options_map_.Get<T>(); return options_map_.Get<T>();
} }
template <typename T>
T* MutableOptions() {
return options_map_.GetMutable<T>();
}
const CalculatorGraphConfig::Node& OriginalNode() const { const CalculatorGraphConfig::Node& OriginalNode() const {
return original_node_; return original_node_;
} }
@ -67,16 +72,16 @@ class SubgraphContext {
private: private:
// Populated if node is not provided during construction. // Populated if node is not provided during construction.
const absl::optional<CalculatorGraphConfig::Node> default_node_; absl::optional<CalculatorGraphConfig::Node> default_node_;
const CalculatorGraphConfig::Node& original_node_; CalculatorGraphConfig::Node& original_node_;
// Populated if service manager is not provided during construction. // Populated if service manager is not provided during construction.
const absl::optional<GraphServiceManager> default_service_manager_; const absl::optional<GraphServiceManager> default_service_manager_;
const GraphServiceManager& service_manager_; const GraphServiceManager& service_manager_;
tool::OptionsMap options_map_; tool::MutableOptionsMap options_map_;
}; };
// Instances of this class are responsible for providing a subgraph config. // Instances of this class are responsible for providing a subgraph config.

View File

@ -22,6 +22,8 @@ package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
option objc_class_prefix = "MediaPipe";
message RandomMatrixCalculatorOptions { message RandomMatrixCalculatorOptions {
extend CalculatorOptions { extend CalculatorOptions {
optional RandomMatrixCalculatorOptions ext = 52056136; optional RandomMatrixCalculatorOptions ext = 52056136;

View File

@ -198,6 +198,7 @@ cc_library(
":name_util", ":name_util",
":options_registry", ":options_registry",
":proto_util_lite", ":proto_util_lite",
":type_util",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework:packet", "//mediapipe/framework:packet",
"//mediapipe/framework:packet_type", "//mediapipe/framework:packet_type",
@ -277,9 +278,12 @@ cc_library(
hdrs = ["options_registry.h"], hdrs = ["options_registry.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":field_data_cc_proto",
":proto_util_lite",
"//mediapipe/framework/deps:registration", "//mediapipe/framework/deps:registration",
"//mediapipe/framework/port:advanced_proto", "//mediapipe/framework/port:advanced_proto",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:ret_check",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
], ],
@ -334,6 +338,7 @@ cc_library(
hdrs = ["proto_util_lite.h"], hdrs = ["proto_util_lite.h"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":field_data_cc_proto",
"//mediapipe/framework:type_map", "//mediapipe/framework:type_map",
"//mediapipe/framework/port:advanced_proto_lite", "//mediapipe/framework/port:advanced_proto_lite",
"//mediapipe/framework/port:integral_types", "//mediapipe/framework/port:integral_types",
@ -518,9 +523,11 @@ cc_library(
cc_library( cc_library(
name = "type_util", name = "type_util",
hdrs = ["type_util.h"], hdrs = ["type_util.h"],
visibility = ["//mediapipe/framework:mediapipe_internal"], visibility = ["//visibility:public"],
deps = [ deps = [
"//mediapipe/framework:demangle",
"//mediapipe/framework:port", "//mediapipe/framework:port",
"@com_google_absl//absl/base:core_headers",
], ],
) )

View File

@ -3,6 +3,7 @@ syntax = "proto2";
package mediapipe; package mediapipe;
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/framework/calculator_options.proto";
import "mediapipe/framework/deps/proto_descriptor.proto"; import "mediapipe/framework/deps/proto_descriptor.proto";
option java_package = "com.google.mediapipe.proto"; option java_package = "com.google.mediapipe.proto";

View File

@ -7,6 +7,7 @@
#include <vector> #include <vector>
#include "absl/status/status.h" #include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/packet.h" #include "mediapipe/framework/packet.h"
@ -18,6 +19,7 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/name_util.h" #include "mediapipe/framework/tool/name_util.h"
#include "mediapipe/framework/tool/proto_util_lite.h" #include "mediapipe/framework/tool/proto_util_lite.h"
#include "mediapipe/framework/tool/type_util.h"
namespace mediapipe { namespace mediapipe {
namespace tool { namespace tool {
@ -41,165 +43,39 @@ FieldType AsFieldType(proto_ns::FieldDescriptorProto::Type type) {
return static_cast<FieldType>(type); return static_cast<FieldType>(type);
} }
absl::Status WriteValue(const FieldData& value, FieldType field_type,
std::string* field_bytes) {
StringOutputStream sos(field_bytes);
CodedOutputStream out(&sos);
switch (field_type) {
case WireFormatLite::TYPE_INT32:
WireFormatLite::WriteInt32NoTag(value.int32_value(), &out);
break;
case WireFormatLite::TYPE_SINT32:
WireFormatLite::WriteSInt32NoTag(value.int32_value(), &out);
break;
case WireFormatLite::TYPE_INT64:
WireFormatLite::WriteInt64NoTag(value.int64_value(), &out);
break;
case WireFormatLite::TYPE_SINT64:
WireFormatLite::WriteSInt64NoTag(value.int64_value(), &out);
break;
case WireFormatLite::TYPE_UINT32:
WireFormatLite::WriteUInt32NoTag(value.uint32_value(), &out);
break;
case WireFormatLite::TYPE_UINT64:
WireFormatLite::WriteUInt64NoTag(value.uint64_value(), &out);
break;
case WireFormatLite::TYPE_DOUBLE:
WireFormatLite::WriteDoubleNoTag(value.uint64_value(), &out);
break;
case WireFormatLite::TYPE_FLOAT:
WireFormatLite::WriteFloatNoTag(value.float_value(), &out);
break;
case WireFormatLite::TYPE_BOOL:
WireFormatLite::WriteBoolNoTag(value.bool_value(), &out);
break;
case WireFormatLite::TYPE_ENUM:
WireFormatLite::WriteEnumNoTag(value.enum_value(), &out);
break;
case WireFormatLite::TYPE_STRING:
out.WriteString(value.string_value());
break;
case WireFormatLite::TYPE_MESSAGE:
out.WriteString(value.message_value().value());
break;
default:
return absl::UnimplementedError(
absl::StrCat("Cannot write type: ", field_type));
}
return absl::OkStatus();
}
// Serializes a packet value. // Serializes a packet value.
absl::Status WriteField(const FieldData& packet, const FieldDescriptor* field, absl::Status WriteField(const FieldData& packet, const FieldDescriptor* field,
std::string* result) { std::string* result) {
FieldType field_type = AsFieldType(field->type()); return ProtoUtilLite::WriteValue(packet, field->type(), result);
return WriteValue(packet, field_type, result);
}
template <typename ValueT, FieldType kFieldType>
static ValueT ReadValue(absl::string_view field_bytes, absl::Status* status) {
ArrayInputStream ais(field_bytes.data(), field_bytes.size());
CodedInputStream input(&ais);
ValueT result;
if (!WireFormatLite::ReadPrimitive<ValueT, kFieldType>(&input, &result)) {
status->Update(mediapipe::InvalidArgumentError(absl::StrCat(
"Bad serialized value: ", MediaPipeTypeStringOrDemangled<ValueT>(),
".")));
}
return result;
}
absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type,
absl::string_view message_type, FieldData* result) {
absl::Status status;
result->Clear();
switch (field_type) {
case WireFormatLite::TYPE_INT32:
result->set_int32_value(
ReadValue<int32, WireFormatLite::TYPE_INT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_SINT32:
result->set_int32_value(
ReadValue<int32, WireFormatLite::TYPE_SINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_INT64:
result->set_int64_value(
ReadValue<int64, WireFormatLite::TYPE_INT64>(field_bytes, &status));
break;
case WireFormatLite::TYPE_SINT64:
result->set_int64_value(
ReadValue<int64, WireFormatLite::TYPE_SINT64>(field_bytes, &status));
break;
case WireFormatLite::TYPE_UINT32:
result->set_uint32_value(
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_UINT64:
result->set_uint64_value(
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_DOUBLE:
result->set_double_value(
ReadValue<double, WireFormatLite::TYPE_DOUBLE>(field_bytes, &status));
break;
case WireFormatLite::TYPE_FLOAT:
result->set_float_value(
ReadValue<float, WireFormatLite::TYPE_FLOAT>(field_bytes, &status));
break;
case WireFormatLite::TYPE_BOOL:
result->set_bool_value(
ReadValue<bool, WireFormatLite::TYPE_BOOL>(field_bytes, &status));
break;
case WireFormatLite::TYPE_ENUM:
result->set_enum_value(
ReadValue<int32, WireFormatLite::TYPE_ENUM>(field_bytes, &status));
break;
case WireFormatLite::TYPE_STRING:
result->set_string_value(std::string(field_bytes));
break;
case WireFormatLite::TYPE_MESSAGE:
result->mutable_message_value()->set_value(std::string(field_bytes));
result->mutable_message_value()->set_type_url(TypeUrl(message_type));
break;
default:
status = absl::UnimplementedError(
absl::StrCat("Cannot read type: ", field_type));
break;
}
return status;
} }
// Deserializes a packet from a protobuf field. // Deserializes a packet from a protobuf field.
absl::Status ReadField(absl::string_view bytes, const FieldDescriptor* field, absl::Status ReadField(absl::string_view bytes, const FieldDescriptor& field,
FieldData* result) { FieldData* result) {
RET_CHECK_NE(field, nullptr); std::string message_type = (field.type() == WireFormatLite::TYPE_MESSAGE)
FieldType field_type = AsFieldType(field->type()); ? field.message_type()->full_name()
std::string message_type = (field_type == WireFormatLite::TYPE_MESSAGE)
? field->message_type()->full_name()
: ""; : "";
return ReadValue(bytes, field_type, message_type, result); return ProtoUtilLite::ReadValue(bytes, field.type(), message_type, result);
} }
// Reads all values from a repeated field. // Reads all values from a repeated field.
absl::Status GetFieldValues(const FieldData& message_data, absl::StatusOr<std::vector<FieldData>> GetFieldValues(
const FieldDescriptor& field, const FieldData& message_data, const FieldDescriptor& field) {
std::vector<FieldData>* result) { std::vector<FieldData> result;
const std::string& message_bytes = message_data.message_value().value(); const std::string& message_bytes = message_data.message_value().value();
FieldType field_type = AsFieldType(field.type());
ProtoUtilLite proto_util;
ProtoUtilLite::ProtoPath proto_path = {{field.number(), 0}}; ProtoUtilLite::ProtoPath proto_path = {{field.number(), 0}};
int count; int count;
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount(message_bytes, proto_path,
proto_util.GetFieldCount(message_bytes, proto_path, field_type, &count)); field.type(), &count));
std::vector<std::string> field_values; std::vector<std::string> field_values;
MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, count, MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(
field_type, &field_values)); message_bytes, proto_path, count, field.type(), &field_values));
for (int i = 0; i < count; ++i) { for (int i = 0; i < field_values.size(); ++i) {
FieldData r; FieldData r;
MP_RETURN_IF_ERROR(ReadField(field_values[i], &field, &r)); MP_RETURN_IF_ERROR(ReadField(field_values[i], field, &r));
result->push_back(std::move(r)); result.push_back(std::move(r));
} }
return absl::OkStatus(); return result;
} }
// Reads one value from a field. // Reads one value from a field.
@ -207,42 +83,70 @@ absl::Status GetFieldValue(const FieldData& message_data,
const FieldPathEntry& entry, FieldData* result) { const FieldPathEntry& entry, FieldData* result) {
RET_CHECK_NE(entry.field, nullptr); RET_CHECK_NE(entry.field, nullptr);
const std::string& message_bytes = message_data.message_value().value(); const std::string& message_bytes = message_data.message_value().value();
FieldType field_type = AsFieldType(entry.field->type()); FieldType field_type = entry.field->type();
ProtoUtilLite proto_util; int index = std::max(0, entry.index);
ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), entry.index}}; ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), index}};
std::vector<std::string> field_values; std::vector<std::string> field_values;
MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, 1, MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(message_bytes, proto_path, 1,
field_type, &field_values)); field_type, &field_values));
MP_RETURN_IF_ERROR(ReadField(field_values[0], entry.field, result)); MP_RETURN_IF_ERROR(ReadField(field_values[0], *entry.field, result));
return absl::OkStatus(); return absl::OkStatus();
} }
// Writes one value to a field. // Writes one value to a field.
absl::Status SetFieldValue(const FieldPathEntry& entry, const FieldData& value, absl::Status SetFieldValue(FieldData& result, const FieldPathEntry& entry,
FieldData* result) { const FieldData& value) {
std::vector<FieldData> field_values; int index = std::max(0, entry.index);
ProtoUtilLite proto_util; ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), index}};
FieldType field_type = AsFieldType(entry.field->type()); std::string* message_bytes = result.mutable_message_value()->mutable_value();
ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), entry.index}};
std::string* message_bytes = result->mutable_message_value()->mutable_value();
int field_count; int field_count;
MP_RETURN_IF_ERROR(proto_util.GetFieldCount(*message_bytes, proto_path, MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount(
field_type, &field_count)); *message_bytes, proto_path, entry.field->type(), &field_count));
if (entry.index > field_count) { if (index > field_count) {
return absl::OutOfRangeError( return absl::OutOfRangeError(
absl::StrCat("Option field index out of range: ", entry.index)); absl::StrCat("Option field index out of range: ", index));
} }
int replace_length = entry.index < field_count ? 1 : 0; int replace_length = index < field_count ? 1 : 0;
std::string field_value; std::string field_value;
MP_RETURN_IF_ERROR(WriteField(value, entry.field, &field_value)); MP_RETURN_IF_ERROR(WriteField(value, entry.field, &field_value));
MP_RETURN_IF_ERROR(proto_util.ReplaceFieldRange( MP_RETURN_IF_ERROR(ProtoUtilLite::ReplaceFieldRange(
message_bytes, proto_path, replace_length, field_type, {field_value})); message_bytes, proto_path, replace_length, entry.field->type(),
{field_value}));
return absl::OkStatus();
}
// Writes several values to a repeated field.
// The specified |values| replace the specified |entry| index,
// or if no index is specified all field values are replaced.
absl::Status SetFieldValues(FieldData& result, const FieldPathEntry& entry,
const std::vector<FieldData>& values) {
if (entry.field == nullptr) {
return absl::InvalidArgumentError("Field not found.");
}
FieldType field_type = entry.field->type();
ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), 0}};
std::string* message_bytes = result.mutable_message_value()->mutable_value();
int field_count;
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount(*message_bytes, proto_path,
field_type, &field_count));
int replace_start = 0, replace_length = field_count;
if (entry.index > -1) {
replace_start = entry.index;
replace_length = 1;
}
std::vector<std::string> field_values(values.size());
for (int i = 0; i < values.size(); ++i) {
MP_RETURN_IF_ERROR(WriteField(values[i], entry.field, &field_values[i]));
}
proto_path = {{entry.field->number(), replace_start}};
MP_RETURN_IF_ERROR(ProtoUtilLite::ReplaceFieldRange(
message_bytes, proto_path, replace_length, field_type, field_values));
return absl::OkStatus(); return absl::OkStatus();
} }
// Returns true for a field of type "google.protobuf.Any". // Returns true for a field of type "google.protobuf.Any".
bool IsProtobufAny(const FieldDescriptor* field) { bool IsProtobufAny(const FieldDescriptor* field) {
return AsFieldType(field->type()) == FieldType::TYPE_MESSAGE && return field->type() == FieldType::TYPE_MESSAGE &&
field->message_type()->full_name() == kGoogleProtobufAny; field->message_type()->full_name() == kGoogleProtobufAny;
} }
@ -275,9 +179,7 @@ StatusOr<int> FindExtensionIndex(const FieldData& message_data,
} }
std::string& extension_type = entry->extension_type; std::string& extension_type = entry->extension_type;
std::vector<FieldData> field_values; std::vector<FieldData> field_values;
RET_CHECK_NE(entry->field, nullptr); ASSIGN_OR_RETURN(field_values, GetFieldValues(message_data, *entry->field));
MP_RETURN_IF_ERROR(
GetFieldValues(message_data, *entry->field, &field_values));
for (int i = 0; i < field_values.size(); ++i) { for (int i = 0; i < field_values.size(); ++i) {
FieldData extension = ParseProtobufAny(field_values[i]); FieldData extension = ParseProtobufAny(field_values[i]);
if (extension_type == "*" || if (extension_type == "*" ||
@ -290,9 +192,9 @@ StatusOr<int> FindExtensionIndex(const FieldData& message_data,
// Returns true if the value of a field is available. // Returns true if the value of a field is available.
bool HasField(const FieldPath& field_path, const FieldData& message_data) { bool HasField(const FieldPath& field_path, const FieldData& message_data) {
FieldData value; auto value = GetField(message_data, field_path);
return GetField(field_path, message_data, &value).ok() && return value.ok() &&
value.value_case() != mediapipe::FieldData::VALUE_NOT_SET; value->value_case() != mediapipe::FieldData::VALUE_NOT_SET;
} }
// Returns the extension field containing the specified extension-type. // Returns the extension field containing the specified extension-type.
@ -330,43 +232,24 @@ void SetOptionsMessage(
*options_any->mutable_value() = node_options.message_value().value(); *options_any->mutable_value() = node_options.message_value().value();
} }
// Returns the count of values in a repeated field.
int FieldCount(const FieldData& message_data, const FieldDescriptor* field) {
const std::string& message_bytes = message_data.message_value().value();
FieldType field_type = AsFieldType(field->type());
ProtoUtilLite proto_util;
ProtoUtilLite::ProtoPath proto_path = {{field->number(), 0}};
int count;
if (proto_util.GetFieldCount(message_bytes, proto_path, field_type, &count)
.ok()) {
return count;
}
return 0;
}
} // anonymous namespace } // anonymous namespace
// Deserializes a packet containing a MessageLite value. // Deserializes a packet containing a MessageLite value.
absl::Status ReadMessage(const std::string& value, const std::string& type_name, absl::StatusOr<Packet> ReadMessage(const std::string& value,
Packet* result) { const std::string& type_name) {
auto packet = packet_internal::PacketFromDynamicProto(type_name, value); return packet_internal::PacketFromDynamicProto(type_name, value);
if (packet.ok()) {
*result = *packet;
}
return packet.status();
} }
// Merge two options FieldData values. // Merge two options FieldData values.
absl::Status MergeMessages(const FieldData& base, const FieldData& over, absl::StatusOr<FieldData> MergeMessages(const FieldData& base,
FieldData* result) { const FieldData& over) {
FieldData result;
absl::Status status; absl::Status status;
if (over.value_case() == FieldData::VALUE_NOT_SET) { if (over.value_case() == FieldData::VALUE_NOT_SET) {
*result = base; return base;
return status;
} }
if (base.value_case() == FieldData::VALUE_NOT_SET) { if (base.value_case() == FieldData::VALUE_NOT_SET) {
*result = over; return over;
return status;
} }
if (over.value_case() != base.value_case()) { if (over.value_case() != base.value_case()) {
return absl::InvalidArgumentError(absl::StrCat( return absl::InvalidArgumentError(absl::StrCat(
@ -382,10 +265,9 @@ absl::Status MergeMessages(const FieldData& base, const FieldData& over,
absl::Cord merged_value; absl::Cord merged_value;
merged_value.Append(base.message_value().value()); merged_value.Append(base.message_value().value());
merged_value.Append(over.message_value().value()); merged_value.Append(over.message_value().value());
result->mutable_message_value()->set_type_url( result.mutable_message_value()->set_type_url(base.message_value().type_url());
base.message_value().type_url()); result.mutable_message_value()->set_value(std::string(merged_value));
result->mutable_message_value()->set_value(std::string(merged_value)); return result;
return status;
} }
// Returns either the extension field or the repeated protobuf.Any field index // Returns either the extension field or the repeated protobuf.Any field index
@ -439,51 +321,48 @@ FieldPath GetExtensionPath(const std::string& parent_type,
} }
// Returns the requested options protobuf for a graph node. // Returns the requested options protobuf for a graph node.
absl::Status GetNodeOptions(const FieldData& message_data, absl::StatusOr<FieldData> GetNodeOptions(const FieldData& message_data,
const std::string& extension_type, const std::string& extension_type) {
FieldData* result) {
constexpr char kOptionsName[] = "options"; constexpr char kOptionsName[] = "options";
constexpr char kNodeOptionsName[] = "node_options"; constexpr char kNodeOptionsName[] = "node_options";
std::string parent_type = options_field_util::ParseTypeUrl( std::string parent_type = options_field_util::ParseTypeUrl(
std::string(message_data.message_value().type_url())); std::string(message_data.message_value().type_url()));
FieldPath path; FieldPath path;
Status status; absl::Status status;
path = GetExtensionPath(parent_type, extension_type, kOptionsName, false); path = GetExtensionPath(parent_type, extension_type, kOptionsName, false);
status = GetField(path, message_data, result); auto result = GetField(message_data, path);
if (status.ok()) { if (result.ok()) {
return status; return result;
} }
path = GetExtensionPath(parent_type, extension_type, kNodeOptionsName, true); path = GetExtensionPath(parent_type, extension_type, kNodeOptionsName, true);
status = GetField(path, message_data, result); return GetField(message_data, path);
return status;
} }
// Returns the requested options protobuf for a graph. // Returns the requested options protobuf for a graph.
absl::Status GetGraphOptions(const FieldData& message_data, absl::StatusOr<FieldData> GetGraphOptions(const FieldData& message_data,
const std::string& extension_type, const std::string& extension_type) {
FieldData* result) {
constexpr char kOptionsName[] = "options"; constexpr char kOptionsName[] = "options";
constexpr char kGraphOptionsName[] = "graph_options"; constexpr char kGraphOptionsName[] = "graph_options";
std::string parent_type = options_field_util::ParseTypeUrl( std::string parent_type = options_field_util::ParseTypeUrl(
std::string(message_data.message_value().type_url())); std::string(message_data.message_value().type_url()));
FieldPath path; FieldPath path;
Status status; absl::Status status;
path = GetExtensionPath(parent_type, extension_type, kOptionsName, false); path = GetExtensionPath(parent_type, extension_type, kOptionsName, false);
status = GetField(path, message_data, result); auto result = GetField(message_data, path);
if (status.ok()) { if (result.ok()) {
return status; return result;
} }
path = GetExtensionPath(parent_type, extension_type, kGraphOptionsName, true); path = GetExtensionPath(parent_type, extension_type, kGraphOptionsName, true);
status = GetField(path, message_data, result); return GetField(message_data, path);
return status;
} }
// Reads a FieldData value from a protobuf field. // Reads the FieldData values from a protobuf field.
absl::Status GetField(const FieldPath& field_path, absl::StatusOr<std::vector<FieldData>> GetFieldValues(
const FieldData& message_data, FieldData* result) { const FieldData& message_data, const FieldPath& field_path) {
std::vector<FieldData> results;
if (field_path.empty()) { if (field_path.empty()) {
*result->mutable_message_value() = message_data.message_value(); results.push_back(message_data);
return absl::OkStatus(); return results;
} }
FieldPathEntry head = field_path.front(); FieldPathEntry head = field_path.front();
FieldPath tail = field_path; FieldPath tail = field_path;
@ -491,65 +370,101 @@ absl::Status GetField(const FieldPath& field_path,
if (!head.extension_type.empty()) { if (!head.extension_type.empty()) {
MP_RETURN_IF_ERROR(FindExtension(message_data, &head)); MP_RETURN_IF_ERROR(FindExtension(message_data, &head));
} }
if (tail.empty() && FieldCount(message_data, head.field) == 0) { RET_CHECK_NE(head.field, nullptr);
return absl::OkStatus(); ASSIGN_OR_RETURN(results, GetFieldValues(message_data, *head.field));
}
MP_RETURN_IF_ERROR(GetFieldValue(message_data, head, result));
if (IsProtobufAny(head.field)) { if (IsProtobufAny(head.field)) {
*result = ParseProtobufAny(*result); for (int i = 0; i < results.size(); ++i) {
results[i] = ParseProtobufAny(results[i]);
}
}
int index = tail.empty() ? head.index : std::max(0, head.index);
if ((int)results.size() <= index) {
return absl::OutOfRangeError(absl::StrCat(
"Missing feild value: ", head.field ? head.field->name() : "#",
" at index: ", index));
} }
if (!tail.empty()) { if (!tail.empty()) {
FieldData child = *result; FieldData child = results.at(index);
MP_RETURN_IF_ERROR(GetField(tail, child, result)); ASSIGN_OR_RETURN(results, GetFieldValues(child, tail));
} else if (index > -1) {
FieldData child = results.at(index);
results.clear();
results.push_back(child);
} }
return results;
}
// Reads a FieldData value from a protobuf field.
absl::StatusOr<FieldData> GetField(const FieldData& message_data,
const FieldPath& field_path) {
std::vector<FieldData> results;
ASSIGN_OR_RETURN(results, GetFieldValues(message_data, field_path));
if (results.empty()) {
FieldPathEntry tail = field_path.back();
return absl::OutOfRangeError(absl::StrCat(
"Missing feild value: ", tail.field ? tail.field->name() : "##",
" at index: ", tail.index));
}
return results[0];
}
// Writes FieldData values into protobuf field.
absl::Status SetFieldValues(FieldData& message_data,
const FieldPath& field_path,
const std::vector<FieldData>& values) {
if (field_path.empty()) {
if (values.empty()) {
return absl::InvalidArgumentError("Missing feild value.");
}
message_data = values[0];
return absl::OkStatus(); return absl::OkStatus();
} }
// Writes a FieldData value into protobuf field.
absl::Status SetField(const FieldPath& field_path, const FieldData& value,
FieldData* message_data) {
if (field_path.empty()) {
*message_data->mutable_message_value() = value.message_value();
return absl::OkStatus();
}
FieldPathEntry head = field_path.front(); FieldPathEntry head = field_path.front();
FieldPath tail = field_path; FieldPath tail = field_path;
tail.erase(tail.begin()); tail.erase(tail.begin());
if (!head.extension_type.empty()) { if (!head.extension_type.empty()) {
MP_RETURN_IF_ERROR(FindExtension(*message_data, &head)); MP_RETURN_IF_ERROR(FindExtension(message_data, &head));
} }
if (tail.empty()) { if (tail.empty()) {
MP_RETURN_IF_ERROR(SetFieldValue(head, value, message_data)); MP_RETURN_IF_ERROR(SetFieldValues(message_data, head, values));
} else { return absl::OkStatus();
}
FieldData child; FieldData child;
MP_RETURN_IF_ERROR(GetFieldValue(*message_data, head, &child)); MP_RETURN_IF_ERROR(GetFieldValue(message_data, head, &child));
MP_RETURN_IF_ERROR(SetField(tail, value, &child)); MP_RETURN_IF_ERROR(SetFieldValues(child, tail, values));
if (IsProtobufAny(head.field)) { if (IsProtobufAny(head.field)) {
child = SerializeProtobufAny(child); child = SerializeProtobufAny(child);
} }
MP_RETURN_IF_ERROR(SetFieldValue(head, child, message_data)); MP_RETURN_IF_ERROR(SetFieldValue(message_data, head, child));
}
return absl::OkStatus(); return absl::OkStatus();
} }
// Merges a packet value into nested protobuf Message. // Writes a FieldData value into protobuf field.
absl::Status MergeField(const FieldPath& field_path, const FieldData& value, absl::Status SetField(FieldData& message_data, const FieldPath& field_path,
FieldData* message_data) { const FieldData& value) {
absl::Status status; return SetFieldValues(message_data, field_path, {value});
FieldType field_type = field_path.empty()
? FieldType::TYPE_MESSAGE
: AsFieldType(field_path.back().field->type());
std::string message_type =
(value.has_message_value())
? ParseTypeUrl(std::string(value.message_value().type_url()))
: "";
FieldData v = value;
if (field_type == FieldType::TYPE_MESSAGE) {
FieldData b;
status.Update(GetField(field_path, *message_data, &b));
status.Update(MergeMessages(b, v, &v));
} }
status.Update(SetField(field_path, v, message_data));
// Merges FieldData values into nested protobuf Message.
// For each new field index, any previous value is merged with the new value.
absl::Status MergeFieldValues(FieldData& message_data,
const FieldPath& field_path,
const std::vector<FieldData>& values) {
absl::Status status;
FieldType field_type = field_path.empty() ? FieldType::TYPE_MESSAGE
: field_path.back().field->type();
std::vector<FieldData> results = values;
std::vector<FieldData> prevs;
ASSIGN_OR_RETURN(prevs, GetFieldValues(message_data, field_path));
if (field_type == FieldType::TYPE_MESSAGE) {
for (int i = 0; i < std::min(values.size(), prevs.size()); ++i) {
FieldData& v = results[i];
FieldData& b = prevs[i];
ASSIGN_OR_RETURN(v, MergeMessages(b, v));
}
}
status.Update(SetFieldValues(message_data, field_path, results));
return status; return status;
} }
@ -576,34 +491,35 @@ struct ProtoEnum {
int32 value; int32 value;
}; };
absl::Status AsPacket(const FieldData& data, Packet* result) { absl::StatusOr<Packet> AsPacket(const FieldData& data) {
Packet result;
switch (data.value_case()) { switch (data.value_case()) {
case FieldData::ValueCase::kInt32Value: case FieldData::ValueCase::kInt32Value:
*result = MakePacket<int32>(data.int32_value()); result = MakePacket<int32>(data.int32_value());
break; break;
case FieldData::ValueCase::kInt64Value: case FieldData::ValueCase::kInt64Value:
*result = MakePacket<int64>(data.int64_value()); result = MakePacket<int64>(data.int64_value());
break; break;
case FieldData::ValueCase::kUint32Value: case FieldData::ValueCase::kUint32Value:
*result = MakePacket<uint32>(data.uint32_value()); result = MakePacket<uint32>(data.uint32_value());
break; break;
case FieldData::ValueCase::kUint64Value: case FieldData::ValueCase::kUint64Value:
*result = MakePacket<uint64>(data.uint64_value()); result = MakePacket<uint64>(data.uint64_value());
break; break;
case FieldData::ValueCase::kDoubleValue: case FieldData::ValueCase::kDoubleValue:
*result = MakePacket<double>(data.double_value()); result = MakePacket<double>(data.double_value());
break; break;
case FieldData::ValueCase::kFloatValue: case FieldData::ValueCase::kFloatValue:
*result = MakePacket<float>(data.float_value()); result = MakePacket<float>(data.float_value());
break; break;
case FieldData::ValueCase::kBoolValue: case FieldData::ValueCase::kBoolValue:
*result = MakePacket<bool>(data.bool_value()); result = MakePacket<bool>(data.bool_value());
break; break;
case FieldData::ValueCase::kEnumValue: case FieldData::ValueCase::kEnumValue:
*result = MakePacket<ProtoEnum>(data.enum_value()); result = MakePacket<ProtoEnum>(data.enum_value());
break; break;
case FieldData::ValueCase::kStringValue: case FieldData::ValueCase::kStringValue:
*result = MakePacket<std::string>(data.string_value()); result = MakePacket<std::string>(data.string_value());
break; break;
case FieldData::ValueCase::kMessageValue: { case FieldData::ValueCase::kMessageValue: {
auto r = packet_internal::PacketFromDynamicProto( auto r = packet_internal::PacketFromDynamicProto(
@ -612,32 +528,33 @@ absl::Status AsPacket(const FieldData& data, Packet* result) {
if (!r.ok()) { if (!r.ok()) {
return r.status(); return r.status();
} }
*result = r.value(); result = r.value();
break; break;
} }
case FieldData::VALUE_NOT_SET: case FieldData::VALUE_NOT_SET:
*result = Packet(); result = Packet();
} }
return absl::OkStatus(); return result;
} }
absl::Status AsFieldData(Packet packet, FieldData* result) { absl::StatusOr<FieldData> AsFieldData(Packet packet) {
static const auto* kTypeIds = new std::map<size_t, int32>{ static const auto* kTypeIds = new std::map<TypeId, int32>{
{tool::GetTypeHash<int32>(), WireFormatLite::CPPTYPE_INT32}, {kTypeId<int32>, WireFormatLite::CPPTYPE_INT32},
{tool::GetTypeHash<int64>(), WireFormatLite::CPPTYPE_INT64}, {kTypeId<int64>, WireFormatLite::CPPTYPE_INT64},
{tool::GetTypeHash<uint32>(), WireFormatLite::CPPTYPE_UINT32}, {kTypeId<uint32>, WireFormatLite::CPPTYPE_UINT32},
{tool::GetTypeHash<uint64>(), WireFormatLite::CPPTYPE_UINT64}, {kTypeId<uint64>, WireFormatLite::CPPTYPE_UINT64},
{tool::GetTypeHash<double>(), WireFormatLite::CPPTYPE_DOUBLE}, {kTypeId<double>, WireFormatLite::CPPTYPE_DOUBLE},
{tool::GetTypeHash<float>(), WireFormatLite::CPPTYPE_FLOAT}, {kTypeId<float>, WireFormatLite::CPPTYPE_FLOAT},
{tool::GetTypeHash<bool>(), WireFormatLite::CPPTYPE_BOOL}, {kTypeId<bool>, WireFormatLite::CPPTYPE_BOOL},
{tool::GetTypeHash<ProtoEnum>(), WireFormatLite::CPPTYPE_ENUM}, {kTypeId<ProtoEnum>, WireFormatLite::CPPTYPE_ENUM},
{tool::GetTypeHash<std::string>(), WireFormatLite::CPPTYPE_STRING}, {kTypeId<std::string>, WireFormatLite::CPPTYPE_STRING},
}; };
FieldData result;
if (packet.ValidateAsProtoMessageLite().ok()) { if (packet.ValidateAsProtoMessageLite().ok()) {
result->mutable_message_value()->set_value( result.mutable_message_value()->set_value(
packet.GetProtoMessageLite().SerializeAsString()); packet.GetProtoMessageLite().SerializeAsString());
result->mutable_message_value()->set_type_url( result.mutable_message_value()->set_type_url(
TypeUrl(packet.GetProtoMessageLite().GetTypeName())); TypeUrl(packet.GetProtoMessageLite().GetTypeName()));
return absl::OkStatus(); return absl::OkStatus();
} }
@ -649,48 +566,42 @@ absl::Status AsFieldData(Packet packet, FieldData* result) {
switch (kTypeIds->at(packet.GetTypeId())) { switch (kTypeIds->at(packet.GetTypeId())) {
case WireFormatLite::CPPTYPE_INT32: case WireFormatLite::CPPTYPE_INT32:
result->set_int32_value(packet.Get<int32>()); result.set_int32_value(packet.Get<int32>());
break; break;
case WireFormatLite::CPPTYPE_INT64: case WireFormatLite::CPPTYPE_INT64:
result->set_int64_value(packet.Get<int64>()); result.set_int64_value(packet.Get<int64>());
break; break;
case WireFormatLite::CPPTYPE_UINT32: case WireFormatLite::CPPTYPE_UINT32:
result->set_uint32_value(packet.Get<uint32>()); result.set_uint32_value(packet.Get<uint32>());
break; break;
case WireFormatLite::CPPTYPE_UINT64: case WireFormatLite::CPPTYPE_UINT64:
result->set_uint64_value(packet.Get<uint64>()); result.set_uint64_value(packet.Get<uint64>());
break; break;
case WireFormatLite::CPPTYPE_DOUBLE: case WireFormatLite::CPPTYPE_DOUBLE:
result->set_double_value(packet.Get<double>()); result.set_double_value(packet.Get<double>());
break; break;
case WireFormatLite::CPPTYPE_FLOAT: case WireFormatLite::CPPTYPE_FLOAT:
result->set_float_value(packet.Get<float>()); result.set_float_value(packet.Get<float>());
break; break;
case WireFormatLite::CPPTYPE_BOOL: case WireFormatLite::CPPTYPE_BOOL:
result->set_bool_value(packet.Get<bool>()); result.set_bool_value(packet.Get<bool>());
break; break;
case WireFormatLite::CPPTYPE_ENUM: case WireFormatLite::CPPTYPE_ENUM:
result->set_enum_value(packet.Get<ProtoEnum>().value); result.set_enum_value(packet.Get<ProtoEnum>().value);
break; break;
case WireFormatLite::CPPTYPE_STRING: case WireFormatLite::CPPTYPE_STRING:
result->set_string_value(packet.Get<std::string>()); result.set_string_value(packet.Get<std::string>());
break; break;
} }
return absl::OkStatus(); return result;
} }
std::string TypeUrl(absl::string_view type_name) { std::string TypeUrl(absl::string_view type_name) {
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/"; return ProtoUtilLite::TypeUrl(type_name);
return absl::StrCat(std::string(kTypeUrlPrefix), std::string(type_name));
} }
std::string ParseTypeUrl(absl::string_view type_url) { std::string ParseTypeUrl(absl::string_view type_url) {
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/"; return ProtoUtilLite::ParseTypeUrl(type_url);
if (std::string(type_url).rfind(kTypeUrlPrefix, 0) == 0) {
return std::string(
type_url.substr(kTypeUrlPrefix.length(), std::string::npos));
}
return std::string(type_url);
} }
} // namespace options_field_util } // namespace options_field_util

View File

@ -34,30 +34,38 @@ absl::Status SetField(const FieldPath& field_path, const FieldData& value,
FieldData* message_data); FieldData* message_data);
// Reads a field value from a protobuf field. // Reads a field value from a protobuf field.
absl::Status GetField(const FieldPath& field_path, absl::StatusOr<FieldData> GetField(const FieldData& message_data,
const FieldData& message_data, FieldData* result); const FieldPath& field_path);
// Merges a field value into nested protobuf Message. // Reads one or all FieldData values from a protobuf field.
absl::Status MergeField(const FieldPath& field_path, const FieldData& value, absl::StatusOr<std::vector<FieldData>> GetFieldValues(
FieldData* message_data); const FieldData& message_data, const FieldPath& field_path);
// Writes FieldData values into a protobuf field.
absl::Status SetFieldValues(FieldData& message_data,
const FieldPath& field_path,
const std::vector<FieldData>& values);
// Merges FieldData values into a protobuf field.
absl::Status MergeFieldValues(FieldData& message_data,
const FieldPath& field_path,
const std::vector<FieldData>& values);
// Deserializes a packet containing a MessageLite value. // Deserializes a packet containing a MessageLite value.
absl::Status ReadMessage(const std::string& value, const std::string& type_name, absl::StatusOr<Packet> ReadMessage(const std::string& value,
Packet* result); const std::string& type_name);
// Merge two options protobuf field values. // Merge two options protobuf field values.
absl::Status MergeMessages(const FieldData& base, const FieldData& over, absl::StatusOr<FieldData> MergeMessages(const FieldData& base,
FieldData* result); const FieldData& over);
// Returns the requested options protobuf for a graph. // Returns the requested options protobuf for a graph.
absl::Status GetNodeOptions(const FieldData& message_data, absl::StatusOr<FieldData> GetNodeOptions(const FieldData& message_data,
const std::string& extension_type, const std::string& extension_type);
FieldData* result);
// Returns the requested options protobuf for a graph node. // Returns the requested options protobuf for a graph node.
absl::Status GetGraphOptions(const FieldData& message_data, absl::StatusOr<FieldData> GetGraphOptions(const FieldData& message_data,
const std::string& extension_type, const std::string& extension_type);
FieldData* result);
// Sets the node_options field in a Node, and clears the options field. // Sets the node_options field in a Node, and clears the options field.
void SetOptionsMessage(const FieldData& node_options, void SetOptionsMessage(const FieldData& node_options,
@ -67,10 +75,10 @@ void SetOptionsMessage(const FieldData& node_options,
FieldData AsFieldData(const proto_ns::MessageLite& message); FieldData AsFieldData(const proto_ns::MessageLite& message);
// Constructs a Packet for a FieldData proto. // Constructs a Packet for a FieldData proto.
absl::Status AsPacket(const FieldData& data, Packet* result); absl::StatusOr<Packet> AsPacket(const FieldData& data);
// Constructs a FieldData proto for a Packet. // Constructs a FieldData proto for a Packet.
absl::Status AsFieldData(Packet packet, FieldData* result); absl::StatusOr<FieldData> AsFieldData(Packet packet);
// Returns the protobuf type-url for a protobuf type-name. // Returns the protobuf type-url for a protobuf type-name.
std::string TypeUrl(absl::string_view type_name); std::string TypeUrl(absl::string_view type_name);

View File

@ -25,11 +25,12 @@ constexpr char kDescriptorContents[] =
#include "{{DESCRIPTOR_INC_FILE_PATH}}" #include "{{DESCRIPTOR_INC_FILE_PATH}}"
; // NOLINT(whitespace/semicolon) ; // NOLINT(whitespace/semicolon)
mediapipe::proto_ns::FileDescriptorSet ParseFileDescriptorSet( mediapipe::FieldData ReadFileDescriptorSet(const std::string& pb) {
const std::string& pb) { mediapipe::FieldData result;
mediapipe::proto_ns::FileDescriptorSet files; *result.mutable_message_value()->mutable_type_url() =
files.ParseFromString(pb); "proto2.FileDescriptorSet";
return files; *result.mutable_message_value()->mutable_value() = pb;
return result;
} }
} // namespace } // namespace
@ -39,6 +40,6 @@ namespace mediapipe {
template <> template <>
const RegistrationToken tool::OptionsRegistry::registration_token< const RegistrationToken tool::OptionsRegistry::registration_token<
MP_OPTION_TYPE_NS::MP_OPTION_TYPE_NAME> = MP_OPTION_TYPE_NS::MP_OPTION_TYPE_NAME> =
tool::OptionsRegistry::Register(ParseFileDescriptorSet( tool::OptionsRegistry::Register(ReadFileDescriptorSet(
std::string(kDescriptorContents, sizeof(kDescriptorContents) - 1))); std::string(kDescriptorContents, sizeof(kDescriptorContents) - 1)));
} // namespace mediapipe } // namespace mediapipe

View File

@ -30,15 +30,26 @@ struct IsExtension {
template <class T, template <class T,
typename std::enable_if<IsExtension<T>::value, int>::type = 0> typename std::enable_if<IsExtension<T>::value, int>::type = 0>
void GetExtension(const CalculatorOptions& options, T* result) { T* GetExtension(CalculatorOptions& options) {
if (options.HasExtension(T::ext)) { if (options.HasExtension(T::ext)) {
*result = options.GetExtension(T::ext); return options.MutableExtension(T::ext);
} }
return nullptr;
} }
template <class T, template <class T,
typename std::enable_if<!IsExtension<T>::value, int>::type = 0> typename std::enable_if<!IsExtension<T>::value, int>::type = 0>
void GetExtension(const CalculatorOptions& options, T* result) {} T* GetExtension(const CalculatorOptions& options) {
return nullptr;
}
template <class T>
void GetExtension(const CalculatorOptions& options, T* result) {
T* r = GetExtension<T>(*const_cast<CalculatorOptions*>(&options));
if (r) {
*result = *r;
}
}
template <class T> template <class T>
void GetNodeOptions(const CalculatorGraphConfig::Node& node_config, T* result) { void GetNodeOptions(const CalculatorGraphConfig::Node& node_config, T* result) {
@ -53,23 +64,39 @@ void GetNodeOptions(const CalculatorGraphConfig::Node& node_config, T* result) {
#endif #endif
} }
template <class T>
void SetNodeOptions(CalculatorGraphConfig::Node& node_config, const T& value) {
#if defined(MEDIAPIPE_PROTO_LITE) && defined(MEDIAPIPE_PROTO_THIRD_PARTY)
// protobuf::Any is unavailable with third_party/protobuf:protobuf-lite.
#else
for (mediapipe::protobuf::Any& options :
*node_config.mutable_node_options()) {
if (options.Is<T>()) {
options.PackFrom(value);
return;
}
}
node_config.add_node_options()->PackFrom(value);
#endif
}
// A map from object type to object. // A map from object type to object.
class TypeMap { class TypeMap {
public: public:
template <class T> template <class T>
bool Has() const { bool Has() const {
return content_.count(TypeInfo::Get<T>()) > 0; return content_.count(kTypeId<T>) > 0;
} }
template <class T> template <class T>
T* Get() const { T* Get() const {
if (!Has<T>()) { if (!Has<T>()) {
content_[TypeInfo::Get<T>()] = std::make_shared<T>(); content_[kTypeId<T>] = std::make_shared<T>();
} }
return static_cast<T*>(content_[TypeInfo::Get<T>()].get()); return static_cast<T*>(content_[kTypeId<T>].get());
} }
private: private:
mutable std::map<TypeIndex, std::shared_ptr<void>> content_; mutable std::map<TypeId, std::shared_ptr<void>> content_;
}; };
// Extracts the options message of a specified type from a // Extracts the options message of a specified type from a
@ -77,7 +104,7 @@ class TypeMap {
class OptionsMap { class OptionsMap {
public: public:
OptionsMap& Initialize(const CalculatorGraphConfig::Node& node_config) { OptionsMap& Initialize(const CalculatorGraphConfig::Node& node_config) {
node_config_ = &node_config; node_config_ = const_cast<CalculatorGraphConfig::Node*>(&node_config);
return *this; return *this;
} }
@ -97,10 +124,40 @@ class OptionsMap {
return *result; return *result;
} }
const CalculatorGraphConfig::Node* node_config_; CalculatorGraphConfig::Node* node_config_;
TypeMap options_; TypeMap options_;
}; };
class MutableOptionsMap : public OptionsMap {
public:
MutableOptionsMap& Initialize(CalculatorGraphConfig::Node& node_config) {
node_config_ = &node_config;
return *this;
}
template <class T>
void Set(const T& value) const {
*options_.Get<T>() = value;
if (node_config_->has_options()) {
*GetExtension<T>(*node_config_->mutable_options()) = value;
} else {
SetNodeOptions(*node_config_, value);
}
}
template <class T>
T* GetMutable() const {
if (options_.Has<T>()) {
return options_.Get<T>();
}
if (node_config_->has_options()) {
return GetExtension<T>(*node_config_->mutable_options());
}
T* result = options_.Get<T>();
GetNodeOptions(*node_config_, result);
return result;
}
};
} // namespace tool } // namespace tool
} // namespace mediapipe } // namespace mediapipe

View File

@ -1,6 +1,11 @@
#include "mediapipe/framework/tool/options_registry.h" #include "mediapipe/framework/tool/options_registry.h"
#include <string>
#include <vector>
#include "absl/synchronization/mutex.h" #include "absl/synchronization/mutex.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/tool/proto_util_lite.h"
namespace mediapipe { namespace mediapipe {
namespace tool { namespace tool {
@ -9,37 +14,135 @@ namespace {
// Returns a canonical message type name, with any leading "." removed. // Returns a canonical message type name, with any leading "." removed.
std::string CanonicalTypeName(const std::string& type_name) { std::string CanonicalTypeName(const std::string& type_name) {
return (type_name.rfind('.', 0) == 0) ? type_name.substr(1) : type_name; return (absl::StartsWith(type_name, ".")) ? type_name.substr(1) : type_name;
}
// Returns the values from a protobuf field as typed FieldData.
absl::StatusOr<std::vector<FieldData>> GetFieldValues(
const FieldData& message_data, std::string field_name) {
std::string type_name =
ProtoUtilLite::ParseTypeUrl(message_data.message_value().type_url());
const Descriptor* descriptor =
OptionsRegistry::GetProtobufDescriptor(type_name);
RET_CHECK_NE(descriptor, nullptr);
const FieldDescriptor* field = descriptor->FindFieldByName(field_name);
if (field == nullptr) {
return std::vector<FieldData>();
}
ProtoUtilLite::ProtoPath proto_path = {{field->number(), 0}};
ProtoUtilLite::FieldValue mesage_bytes = message_data.message_value().value();
int count;
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount(mesage_bytes, proto_path,
field->type(), &count));
std::vector<std::string> field_values;
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(
mesage_bytes, proto_path, count, field->type(), &field_values));
std::vector<FieldData> result;
for (int i = 0; i < field_values.size(); ++i) {
FieldData r;
std::string message_type =
field->message_type() ? field->message_type()->full_name() : "";
MP_RETURN_IF_ERROR(ProtoUtilLite::ReadValue(field_values[i], field->type(),
message_type, &r));
result.push_back(std::move(r));
}
return result;
}
// Returns a single value from a protobuf string field.
std::string GetFieldString(const FieldData& message_data,
std::string field_name) {
auto values = GetFieldValues(message_data, field_name);
if (!values->empty()) {
return values->front().string_value();
}
return "";
}
// Registers the descriptors for the descriptor protobufs. These four
// descriptors are required to deserialize descriptors for other protobufs.
// This implementation avoids a code size problem introduced by
// proto_ns::DescriptorProto.
void RegisterDescriptorProtos(
absl::flat_hash_map<std::string, Descriptor>& result) {
std::vector<Descriptor> descriptors = {
{"proto2.FileDescriptorSet",
{
{"file", 1, FieldType::TYPE_MESSAGE, "proto2.FileDescriptorProto"},
}},
{"proto2.FileDescriptorProto",
{
{"package", 2, FieldType::TYPE_STRING, ""},
{"message_type", 4, FieldType::TYPE_MESSAGE,
"proto2.DescriptorProto"},
}},
{"proto2.DescriptorProto",
{
{"name", 1, FieldType::TYPE_STRING, ""},
{"field", 2, FieldType::TYPE_MESSAGE, "proto2.FieldDescriptorProto"},
{"extension", 6, FieldType::TYPE_MESSAGE,
"proto2.FieldDescriptorProto"},
{"nested_type", 3, FieldType::TYPE_MESSAGE,
"proto2.DescriptorProto"},
}},
{"proto2.FieldDescriptorProto",
{
{"name", 1, FieldType::TYPE_STRING, ""},
{"number", 3, FieldType::TYPE_INT32, ""},
{"type", 5, FieldType::TYPE_ENUM, ""},
{"type_name", 6, FieldType::TYPE_STRING, ""},
{"extendee", 2, FieldType::TYPE_STRING, ""},
}},
};
for (const auto& descriptor : descriptors) {
result[descriptor.full_name()] = descriptor;
}
} }
} // namespace } // namespace
RegistrationToken OptionsRegistry::Register( RegistrationToken OptionsRegistry::Register(
const proto_ns::FileDescriptorSet& files) { const FieldData& file_descriptor_set) {
absl::MutexLock lock(&mutex()); auto files = GetFieldValues(file_descriptor_set, "file");
for (auto& file : files.file()) { for (auto& file : *files) {
for (auto& message_type : file.message_type()) { std::string package_name = GetFieldString(file, "package");
Register(message_type, file.package()); auto message_types = GetFieldValues(file, "message_type");
for (auto& message_type : *message_types) {
Register(message_type, package_name);
} }
} }
return RegistrationToken([]() {}); return RegistrationToken([]() {});
} }
void OptionsRegistry::Register(const proto_ns::DescriptorProto& message_type, void OptionsRegistry::Register(const FieldData& message_type,
const std::string& parent_name) { const std::string& parent_name) {
auto full_name = absl::StrCat(parent_name, ".", message_type.name()); std::string name = GetFieldString(message_type, "name");
descriptors()[full_name] = Descriptor(message_type, full_name); std::string full_name = absl::StrCat(parent_name, ".", name);
for (auto& nested : message_type.nested_type()) { Descriptor descriptor(full_name, message_type);
{
absl::MutexLock lock(&mutex());
descriptors()[full_name] = descriptor;
}
auto nested_types = GetFieldValues(message_type, "nested_type");
for (auto& nested : *nested_types) {
Register(nested, full_name); Register(nested, full_name);
} }
for (auto& extension : message_type.extension()) { auto exts = GetFieldValues(message_type, "extension");
extensions()[CanonicalTypeName(extension.extendee())].push_back( for (auto& extension : *exts) {
FieldDescriptor(extension)); FieldDescriptor field(extension);
std::string extendee = GetFieldString(extension, "extendee");
{
absl::MutexLock lock(&mutex());
extensions()[CanonicalTypeName(extendee)].push_back(field);
}
} }
} }
const Descriptor* OptionsRegistry::GetProtobufDescriptor( const Descriptor* OptionsRegistry::GetProtobufDescriptor(
const std::string& type_name) { const std::string& type_name) {
if (descriptors().count("proto2.DescriptorProto") == 0) {
RegisterDescriptorProtos(descriptors());
}
absl::ReaderMutexLock lock(&mutex()); absl::ReaderMutexLock lock(&mutex());
auto it = descriptors().find(CanonicalTypeName(type_name)); auto it = descriptors().find(CanonicalTypeName(type_name));
return (it == descriptors().end()) ? nullptr : &it->second; return (it == descriptors().end()) ? nullptr : &it->second;
@ -73,11 +176,21 @@ absl::Mutex& OptionsRegistry::mutex() {
return *mutex; return *mutex;
} }
Descriptor::Descriptor(const proto_ns::DescriptorProto& proto, Descriptor::Descriptor(const std::string& full_name,
const std::string& full_name) const FieldData& descriptor_proto)
: full_name_(full_name) { : full_name_(full_name) {
for (auto& field : proto.field()) { auto fields = GetFieldValues(descriptor_proto, "field");
fields_[field.name()] = FieldDescriptor(field); for (const auto& field : *fields) {
FieldDescriptor f(field);
fields_[f.name()] = f;
}
}
Descriptor::Descriptor(const std::string& full_name,
const std::vector<FieldDescriptor>& fields)
: full_name_(full_name) {
for (const auto& field : fields) {
fields_[field.name()] = field;
} }
} }
@ -89,20 +202,22 @@ const FieldDescriptor* Descriptor::FindFieldByName(
return (it != fields_.end()) ? &it->second : nullptr; return (it != fields_.end()) ? &it->second : nullptr;
} }
FieldDescriptor::FieldDescriptor(const proto_ns::FieldDescriptorProto& proto) { FieldDescriptor::FieldDescriptor(const FieldData& field_proto) {
name_ = proto.name(); name_ = GetFieldString(field_proto, "name");
message_type_ = CanonicalTypeName(proto.type_name()); number_ = GetFieldValues(field_proto, "number")->front().int32_value();
type_ = proto.type(); type_ = (FieldType)GetFieldValues(field_proto, "type")->front().enum_value();
number_ = proto.number(); message_type_ = CanonicalTypeName(GetFieldString(field_proto, "type_name"));
} }
FieldDescriptor::FieldDescriptor(std::string name, int number, FieldType type,
std::string message_type)
: name_(name), number_(number), type_(type), message_type_(message_type) {}
const std::string& FieldDescriptor::name() const { return name_; } const std::string& FieldDescriptor::name() const { return name_; }
int FieldDescriptor::number() const { return number_; } int FieldDescriptor::number() const { return number_; }
proto_ns::FieldDescriptorProto::Type FieldDescriptor::type() const { FieldType FieldDescriptor::type() const { return type_; }
return type_;
}
const Descriptor* FieldDescriptor::message_type() const { const Descriptor* FieldDescriptor::message_type() const {
return OptionsRegistry::GetProtobufDescriptor(message_type_); return OptionsRegistry::GetProtobufDescriptor(message_type_);

View File

@ -1,15 +1,20 @@
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_ #ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_
#define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_ #define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_
#include <string>
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "mediapipe/framework/deps/registration.h" #include "mediapipe/framework/deps/registration.h"
#include "mediapipe/framework/port/advanced_proto_inc.h" #include "mediapipe/framework/port/advanced_proto_inc.h"
#include "mediapipe/framework/tool/field_data.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tool { namespace tool {
class Descriptor; class Descriptor;
class FieldDescriptor; class FieldDescriptor;
using FieldType = mediapipe::proto_ns::internal::WireFormatLite::FieldType;
using mediapipe::FieldData;
// A static registry that stores descriptors for protobufs used in MediaPipe // A static registry that stores descriptors for protobufs used in MediaPipe
// calculator options. Lite-proto builds do not normally include descriptors. // calculator options. Lite-proto builds do not normally include descriptors.
@ -17,8 +22,8 @@ class FieldDescriptor;
// referenced and specified separately within CalculatorGraphConfigs. // referenced and specified separately within CalculatorGraphConfigs.
class OptionsRegistry { class OptionsRegistry {
public: public:
// Registers the protobuf descriptors for a MessageLite. // Registers the protobuf descriptors for a FileDescriptorSet.
static RegistrationToken Register(const proto_ns::FileDescriptorSet& files); static RegistrationToken Register(const FieldData& file_descriptor_set);
// Finds the descriptor for a protobuf. // Finds the descriptor for a protobuf.
static const Descriptor* GetProtobufDescriptor(const std::string& type_name); static const Descriptor* GetProtobufDescriptor(const std::string& type_name);
@ -28,8 +33,8 @@ class OptionsRegistry {
std::vector<const FieldDescriptor*>* result); std::vector<const FieldDescriptor*>* result);
private: private:
// Registers protobuf descriptors a MessageLite and nested types. // Registers protobuf descriptors for a message type and nested types.
static void Register(const proto_ns::DescriptorProto& message_type, static void Register(const FieldData& message_type,
const std::string& parent_name); const std::string& parent_name);
static absl::flat_hash_map<std::string, Descriptor>& descriptors(); static absl::flat_hash_map<std::string, Descriptor>& descriptors();
@ -46,9 +51,10 @@ class OptionsRegistry {
// avoids a code size problem introduced by proto_ns::FieldDescriptor. // avoids a code size problem introduced by proto_ns::FieldDescriptor.
class Descriptor { class Descriptor {
public: public:
Descriptor() {} Descriptor() = default;
Descriptor(const proto_ns::DescriptorProto& proto, Descriptor(const std::string& full_name, const FieldData& descriptor_proto);
const std::string& full_name); Descriptor(const std::string& full_name,
const std::vector<FieldDescriptor>& fields);
const std::string& full_name() const; const std::string& full_name() const;
const FieldDescriptor* FindFieldByName(const std::string& name) const; const FieldDescriptor* FindFieldByName(const std::string& name) const;
@ -61,18 +67,20 @@ class Descriptor {
// avoids a code size problem introduced by proto_ns::FieldDescriptor. // avoids a code size problem introduced by proto_ns::FieldDescriptor.
class FieldDescriptor { class FieldDescriptor {
public: public:
FieldDescriptor() {} FieldDescriptor() = default;
FieldDescriptor(const proto_ns::FieldDescriptorProto& proto); FieldDescriptor(const FieldData& field_proto);
FieldDescriptor(std::string name, int number, FieldType type,
std::string message_type);
const std::string& name() const; const std::string& name() const;
int number() const; int number() const;
proto_ns::FieldDescriptorProto::Type type() const; FieldType type() const;
const Descriptor* message_type() const; const Descriptor* message_type() const;
private: private:
std::string name_; std::string name_;
std::string message_type_;
proto_ns::FieldDescriptorProto::Type type_;
int number_; int number_;
FieldType type_;
std::string message_type_;
}; };
} // namespace tool } // namespace tool

View File

@ -91,8 +91,7 @@ class OptionsSyntaxUtil::OptionsSyntaxHelper {
int index; int index;
if (absl::SimpleAtoi(option_name, &index)) { if (absl::SimpleAtoi(option_name, &index)) {
result.back().index = index; result.back().index = index;
} } else if (!ExtensionType(option_name).empty()) {
if (!ExtensionType(option_name).empty()) {
std::string extension_type = std::string(ExtensionType(option_name)); std::string extension_type = std::string(ExtensionType(option_name));
result.push_back({nullptr, 0, extension_type}); result.push_back({nullptr, 0, extension_type});
descriptor = OptionsRegistry::GetProtobufDescriptor(extension_type); descriptor = OptionsRegistry::GetProtobufDescriptor(extension_type);
@ -102,7 +101,7 @@ class OptionsSyntaxUtil::OptionsSyntaxHelper {
} }
auto field = descriptor->FindFieldByName(std::string(option_name)); auto field = descriptor->FindFieldByName(std::string(option_name));
descriptor = field ? field->message_type() : nullptr; descriptor = field ? field->message_type() : nullptr;
result.push_back({std::move(field), 0}); result.push_back({std::move(field), -1});
} }
} }
return result; return result;

View File

@ -26,10 +26,9 @@ namespace mediapipe {
namespace tool { namespace tool {
using options_field_util::FieldPath; using options_field_util::FieldPath;
using options_field_util::GetField;
using options_field_util::GetGraphOptions; using options_field_util::GetGraphOptions;
using options_field_util::GetNodeOptions; using options_field_util::GetNodeOptions;
using options_field_util::MergeField; using options_field_util::MergeFieldValues;
using options_field_util::MergeMessages; using options_field_util::MergeMessages;
// Returns the type for the root options message if specified. // Returns the type for the root options message if specified.
@ -56,10 +55,19 @@ std::string MessageType(FieldData message) {
std::string(message.message_value().type_url())); std::string(message.message_value().type_url()));
} }
// Assigns the value from a StatusOr if avialable.
#define ASSIGN_IF_OK(lhs, rexpr) \
{ \
auto statusor = (rexpr); \
if (statusor.ok()) { \
lhs = statusor.value(); \
} \
}
// Copy literal options from graph_options to node_options. // Copy literal options from graph_options to node_options.
absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node, absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node,
CalculatorGraphConfig* config) { CalculatorGraphConfig* config) {
Status status; absl::Status status;
FieldData graph_data = options_field_util::AsFieldData(*config); FieldData graph_data = options_field_util::AsFieldData(*config);
FieldData parent_data = options_field_util::AsFieldData(parent_node); FieldData parent_data = options_field_util::AsFieldData(parent_node);
@ -75,25 +83,26 @@ absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node,
std::string node_tag = syntax_util.OptionFieldsTag(tag_and_name[0]); std::string node_tag = syntax_util.OptionFieldsTag(tag_and_name[0]);
std::string node_extension_type = ExtensionType(node_tag); std::string node_extension_type = ExtensionType(node_tag);
FieldData graph_options; FieldData graph_options;
GetGraphOptions(graph_data, graph_extension_type, &graph_options) ASSIGN_IF_OK(graph_options,
.IgnoreError(); GetGraphOptions(graph_data, graph_extension_type));
FieldData parent_options; FieldData parent_options;
GetNodeOptions(parent_data, graph_extension_type, &parent_options) ASSIGN_IF_OK(parent_options,
.IgnoreError(); GetNodeOptions(parent_data, graph_extension_type));
status.Update( ASSIGN_OR_RETURN(graph_options,
MergeMessages(graph_options, parent_options, &graph_options)); MergeMessages(graph_options, parent_options));
FieldData node_options; FieldData node_options;
status.Update( ASSIGN_OR_RETURN(node_options,
GetNodeOptions(node_data, node_extension_type, &node_options)); GetNodeOptions(node_data, node_extension_type));
if (!node_options.has_message_value() || if (!node_options.has_message_value() ||
!graph_options.has_message_value()) { !graph_options.has_message_value()) {
continue; continue;
} }
FieldPath graph_path = GetPath(graph_tag, MessageType(graph_options)); FieldPath graph_path = GetPath(graph_tag, MessageType(graph_options));
FieldPath node_path = GetPath(node_tag, MessageType(node_options)); FieldPath node_path = GetPath(node_tag, MessageType(node_options));
FieldData packet_data; std::vector<FieldData> packet_data;
status.Update(GetField(graph_path, graph_options, &packet_data)); ASSIGN_OR_RETURN(packet_data, GetFieldValues(graph_options, graph_path));
status.Update(MergeField(node_path, packet_data, &node_options)); MP_RETURN_IF_ERROR(
MergeFieldValues(node_options, node_path, packet_data));
options_field_util::SetOptionsMessage(node_options, &node); options_field_util::SetOptionsMessage(node_options, &node);
} }
node.clear_option_value(); node.clear_option_value();
@ -105,7 +114,7 @@ absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node,
absl::Status DefineGraphOptions(const CalculatorGraphConfig::Node& parent_node, absl::Status DefineGraphOptions(const CalculatorGraphConfig::Node& parent_node,
CalculatorGraphConfig* config) { CalculatorGraphConfig* config) {
MP_RETURN_IF_ERROR(CopyLiteralOptions(parent_node, config)); MP_RETURN_IF_ERROR(CopyLiteralOptions(parent_node, config));
return mediapipe::OkStatus(); return absl::OkStatus();
} }
} // namespace tool } // namespace tool

View File

@ -13,8 +13,10 @@
// limitations under the License. // limitations under the License.
#include <memory> #include <memory>
#include <sstream>
#include <vector> #include <vector>
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/gmock.h" #include "mediapipe/framework/port/gmock.h"
@ -30,23 +32,27 @@
namespace mediapipe { namespace mediapipe {
namespace { namespace {
using ::mediapipe::proto_ns::FieldDescriptorProto;
using FieldType = ::mediapipe::proto_ns::FieldDescriptorProto::Type; using FieldType = ::mediapipe::proto_ns::FieldDescriptorProto::Type;
using ::testing::HasSubstr;
// Assigns the value from a StatusOr if avialable.
#define ASSERT_AND_ASSIGN(lhs, rexpr) \
{ \
auto statusor = (rexpr); \
MP_ASSERT_OK(statusor); \
lhs = statusor.value(); \
}
// A test Calculator using DeclareOptions and DefineOptions. // A test Calculator using DeclareOptions and DefineOptions.
class NightLightCalculator : public CalculatorBase { class NightLightCalculator : public CalculatorBase {
public: public:
static absl::Status GetContract(CalculatorContract* cc) { static absl::Status GetContract(CalculatorContract* cc) {
return mediapipe::OkStatus(); return absl::OkStatus();
} }
absl::Status Open(CalculatorContext* cc) final { absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
return mediapipe::OkStatus();
}
absl::Status Process(CalculatorContext* cc) final { absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
return mediapipe::OkStatus();
}
private: private:
NightLightCalculatorOptions options_; NightLightCalculatorOptions options_;
@ -124,7 +130,7 @@ TEST_F(OptionsUtilTest, CopyLiteralOptions) {
CalculatorGraph graph; CalculatorGraph graph;
graph_config.set_num_threads(4); graph_config.set_num_threads(4);
MP_EXPECT_OK(graph.Initialize({subgraph_config, graph_config}, {}, {})); MP_ASSERT_OK(graph.Initialize({subgraph_config, graph_config}, {}, {}));
CalculatorGraphConfig expanded_config = graph.Config(); CalculatorGraphConfig expanded_config = graph.Config();
expanded_config.clear_executor(); expanded_config.clear_executor();
@ -236,8 +242,8 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) {
tool::options_field_util::FieldPath field_path = tool::options_field_util::FieldPath field_path =
syntax_util.OptionFieldPath(split[1], descriptor); syntax_util.OptionFieldPath(split[1], descriptor);
EXPECT_EQ(field_path.size(), 2); EXPECT_EQ(field_path.size(), 2);
EXPECT_TRUE(Equals(field_path[0], "sub_options", 0, "")); EXPECT_TRUE(Equals(field_path[0], "sub_options", -1, ""));
EXPECT_TRUE(Equals(field_path[1], "num_lights", 0, "")); EXPECT_TRUE(Equals(field_path[1], "num_lights", -1, ""));
{ {
// NightLightCalculatorOptions in Node.options. // NightLightCalculatorOptions in Node.options.
@ -252,11 +258,11 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) {
auto path = field_path; auto path = field_path;
std::string node_extension_type = ExtensionType(std::string(split[1])); std::string node_extension_type = ExtensionType(std::string(split[1]));
FieldData node_options; FieldData node_options;
MP_EXPECT_OK(tool::options_field_util::GetNodeOptions( ASSERT_AND_ASSIGN(node_options, tool::options_field_util::GetNodeOptions(
node_data, node_extension_type, &node_options)); node_data, node_extension_type));
FieldData packet_data; FieldData packet_data;
MP_EXPECT_OK(tool::options_field_util::GetField(field_path, node_options, ASSERT_AND_ASSIGN(packet_data, tool::options_field_util::GetField(
&packet_data)); node_options, field_path));
EXPECT_EQ(packet_data.value_case(), FieldData::kInt32Value); EXPECT_EQ(packet_data.value_case(), FieldData::kInt32Value);
EXPECT_EQ(packet_data.int32_value(), 33); EXPECT_EQ(packet_data.int32_value(), 33);
} }
@ -273,11 +279,11 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) {
auto path = field_path; auto path = field_path;
std::string node_extension_type = ExtensionType(std::string(split[1])); std::string node_extension_type = ExtensionType(std::string(split[1]));
FieldData node_options; FieldData node_options;
MP_EXPECT_OK(tool::options_field_util::GetNodeOptions( ASSERT_AND_ASSIGN(node_options, tool::options_field_util::GetNodeOptions(
node_data, node_extension_type, &node_options)); node_data, node_extension_type));
FieldData packet_data; FieldData packet_data;
MP_EXPECT_OK(tool::options_field_util::GetField(field_path, node_options, ASSERT_AND_ASSIGN(packet_data, tool::options_field_util::GetField(
&packet_data)); node_options, field_path));
EXPECT_EQ(packet_data.value_case(), FieldData::kInt32Value); EXPECT_EQ(packet_data.value_case(), FieldData::kInt32Value);
EXPECT_EQ(packet_data.int32_value(), 33); EXPECT_EQ(packet_data.int32_value(), 33);
} }
@ -285,5 +291,333 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) {
// TODO: Test with specified extension_type. // TODO: Test with specified extension_type.
} }
// Constructs the field path for a string of field names.
FieldPath MakeFieldPath(std::string tag, FieldData message_data) {
tool::OptionsSyntaxUtil syntax_util;
const tool::Descriptor* descriptor =
tool::OptionsRegistry::GetProtobufDescriptor(
tool::options_field_util::ParseTypeUrl(
message_data.message_value().type_url()));
return syntax_util.OptionFieldPath(tag, descriptor);
}
// Returns the field path addressing the entire specified field.
FieldPath EntireField(FieldPath field_path) {
field_path.back().index = -1;
return field_path;
}
// Converts an int to a FieldData record.
FieldData AsFieldData(int v) {
return tool::options_field_util::AsFieldData(MakePacket<int>(v)).value();
}
// Equality comparison for field contents.
template <typename T>
absl::Status Equals(const T& v1, const T& v2) {
RET_CHECK_EQ(v1, v2);
return absl::OkStatus();
}
// Equality comparison for protobuf field contents.
// The generic Equals() fails because MessageLite lacks operator==().
// The protobuf comparison is performed using testing::EqualsProto.
using LightBundle = NightLightCalculatorOptions::LightBundle;
template <>
absl::Status Equals<LightBundle>(const LightBundle& v1, const LightBundle& v2) {
std::string s_1, s_2;
v1.SerializeToString(&s_1);
v2.SerializeToString(&s_2);
RET_CHECK(s_1 == s_2);
return absl::OkStatus();
}
// Equality comparison for FieldData vectors.
template <typename FieldType>
absl::Status Equals(std::vector<FieldData> b1, std::vector<FieldData> b2) {
using tool::options_field_util::AsPacket;
RET_CHECK_EQ(b1.size(), b2.size());
for (int i = 0; i < b1.size(); ++i) {
ASSIGN_OR_RETURN(Packet p1, AsPacket(b1.at(i)));
ASSIGN_OR_RETURN(Packet p2, AsPacket(b2.at(i)));
MP_RETURN_IF_ERROR(Equals(p1.Get<FieldType>(), p2.Get<FieldType>()));
}
return absl::OkStatus();
}
// Unit-tests for graph options feild accessors from options_field_util.
class OptionsFieldUtilTest : public ::testing::Test {
protected:
void SetUp() override {}
void TearDown() override {}
};
// Tests empty FieldPaths applied to empty options.
TEST_F(OptionsFieldUtilTest, EmptyFieldPaths) {
FieldData graph_options;
FieldData node_options;
FieldPath graph_path;
FieldPath node_path;
std::vector<FieldData> packet_data;
ASSERT_AND_ASSIGN(packet_data, GetFieldValues(graph_options, graph_path));
MP_EXPECT_OK(MergeFieldValues(node_options, node_path, packet_data));
}
// Tests GetFieldValues applied to an int field.
TEST_F(OptionsFieldUtilTest, GetFieldValuesInt) {
NightLightCalculatorOptions node_proto;
node_proto.mutable_sub_options();
node_proto.mutable_sub_options()->add_num_lights(33);
node_proto.mutable_sub_options()->add_num_lights(44);
FieldData node_data = tool::options_field_util::AsFieldData(node_proto);
// Read an entire populated repeated field.
FieldPath path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data);
MP_EXPECT_OK(Equals<int>(GetFieldValues(node_data, path).value(),
{AsFieldData(33), AsFieldData(44)}));
// Read a specific populated repeated field index.
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
MP_EXPECT_OK(
Equals<int>(GetFieldValues(node_data, path).value(), {AsFieldData(44)}));
}
// Tests GetFieldValues applied to a protobuf field.
TEST_F(OptionsFieldUtilTest, GetFieldValuesProtobuf) {
using tool::options_field_util::AsFieldData;
using LightBundle = NightLightCalculatorOptions::LightBundle;
NightLightCalculatorOptions node_proto;
node_proto.mutable_sub_options();
node_proto.mutable_sub_options()->add_bundle();
*node_proto.mutable_sub_options()->mutable_bundle(0)->mutable_room_id() =
"111";
node_proto.mutable_sub_options()
->mutable_bundle(0)
->add_room_lights()
->set_frame_rate(11.1);
node_proto.mutable_sub_options()
->mutable_bundle(0)
->add_room_lights()
->set_frame_rate(22.1);
FieldData node_data = AsFieldData(node_proto);
// Read all values from a repeated protobuf field.
LightBundle expected_proto;
*expected_proto.mutable_room_id() = "111";
expected_proto.add_room_lights()->set_frame_rate(11.1);
expected_proto.add_room_lights()->set_frame_rate(22.1);
FieldData expected_data = AsFieldData(expected_proto);
FieldPath path = MakeFieldPath("OPTIONS/sub_options/bundle", node_data);
MP_EXPECT_OK(Equals<LightBundle>(GetFieldValues(node_data, path).value(),
{expected_data}));
// Read a specific index from a repeated protobuf field.
path = MakeFieldPath("OPTIONS/sub_options/bundle/0", node_data);
MP_EXPECT_OK(Equals<LightBundle>(GetFieldValues(node_data, path).value(),
{expected_data}));
}
// Tests SetFieldValues applied to an int field.
TEST_F(OptionsFieldUtilTest, SetFieldValuesInt) {
NightLightCalculatorOptions node_proto;
node_proto.mutable_sub_options();
FieldData node_data = tool::options_field_util::AsFieldData(node_proto);
// Replace an entire empty repeated field.
FieldPath path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data);
MP_ASSERT_OK(SetFieldValues(node_data, path, {AsFieldData(33)}));
MP_EXPECT_OK(
Equals<int>(GetFieldValues(node_data, path).value(), {AsFieldData(33)}));
// Replace an entire populated repeated field.
MP_ASSERT_OK(SetFieldValues(node_data, path, {AsFieldData(44)}));
MP_EXPECT_OK(
Equals<int>(GetFieldValues(node_data, path).value(), {AsFieldData(44)}));
// Replace an entire repeated field with a new list of values.
MP_ASSERT_OK(
SetFieldValues(node_data, path, {AsFieldData(33), AsFieldData(44)}));
MP_EXPECT_OK(Equals<int>(GetFieldValues(node_data, path).value(),
{AsFieldData(33), AsFieldData(44)}));
// Replace a single field index with a new list of values.
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
MP_ASSERT_OK(
SetFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)}));
MP_EXPECT_OK(
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(),
{AsFieldData(33), AsFieldData(55), AsFieldData(66)}));
// Replace a single field middle index with a new list of values.
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
MP_ASSERT_OK(
SetFieldValues(node_data, path, {AsFieldData(11), AsFieldData(12)}));
MP_EXPECT_OK(Equals<int>(
GetFieldValues(node_data, EntireField(path)).value(),
{AsFieldData(33), AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
// Replace field index 0 with a new value.
path = MakeFieldPath("OPTIONS/sub_options/num_lights/0", node_data);
MP_ASSERT_OK(SetFieldValues(node_data, path, {AsFieldData(77)}));
MP_EXPECT_OK(Equals<int>(
GetFieldValues(node_data, EntireField(path)).value(),
{AsFieldData(77), AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
// Replace field index 0 with an empty list of values.
MP_ASSERT_OK(SetFieldValues(node_data, path, {}));
MP_EXPECT_OK(
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(),
{AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
// Replace an entire populated field with an empty list of values.
path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data);
MP_ASSERT_OK(SetFieldValues(node_data, path, {}));
MP_ASSERT_OK(
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(), {}));
// Replace a missing field index with new values.
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
absl::Status status =
SetFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)});
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
// TODO: status.message() appears empty on KokoroGCPDocker.
// EXPECT_THAT(status.message(),
// HasSubstr("index >= 0 && index <= v.size()"));
}
// Tests SetFieldValues applied to a protobuf field.
TEST_F(OptionsFieldUtilTest, SetFieldValuesProtobuf) {
using tool::options_field_util::AsFieldData;
using LightBundle = NightLightCalculatorOptions::LightBundle;
NightLightCalculatorOptions node_proto;
node_proto.mutable_sub_options();
FieldData node_data = AsFieldData(node_proto);
// Replace an empty repeated protobuf field.
LightBundle bundle_proto;
*bundle_proto.mutable_room_id() = "222";
bundle_proto.add_room_lights()->set_frame_rate(22.1);
FieldData bundle_data = AsFieldData(bundle_proto);
FieldData expected_data = bundle_data;
FieldPath path = MakeFieldPath("OPTIONS/sub_options/bundle", node_data);
MP_ASSERT_OK(SetFieldValues(node_data, path, {bundle_data}));
MP_EXPECT_OK(Equals<LightBundle>(
GetFieldValues(node_data, EntireField(path)).value(), {expected_data}));
// Replace a populated repeated protobuf field.
*bundle_proto.mutable_room_id() = "333";
bundle_proto.mutable_room_lights(0)->set_frame_rate(33.1);
bundle_data = AsFieldData(bundle_proto);
LightBundle expected_proto;
*expected_proto.mutable_room_id() = "333";
expected_proto.add_room_lights()->set_frame_rate(33.1);
expected_data = AsFieldData(expected_proto);
MP_ASSERT_OK(SetFieldValues(node_data, path, {bundle_data}));
MP_EXPECT_OK(Equals<LightBundle>(
GetFieldValues(node_data, EntireField(path)).value(), {expected_data}));
}
// Tests MergeFieldValues applied to an int field.
TEST_F(OptionsFieldUtilTest, MergeFieldValuesInt) {
NightLightCalculatorOptions node_proto;
node_proto.mutable_sub_options();
FieldData node_data = tool::options_field_util::AsFieldData(node_proto);
// Replace an entire empty repeated field.
FieldPath path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data);
MP_ASSERT_OK(MergeFieldValues(node_data, path, {AsFieldData(33)}));
MP_EXPECT_OK(
Equals<int>(GetFieldValues(node_data, path).value(), {AsFieldData(33)}));
// Replace an entire populated repeated field.
MP_ASSERT_OK(MergeFieldValues(node_data, path, {AsFieldData(44)}));
MP_EXPECT_OK(
Equals<int>(GetFieldValues(node_data, path).value(), {AsFieldData(44)}));
// Replace an entire repeated field with a new list of values.
MP_ASSERT_OK(
MergeFieldValues(node_data, path, {AsFieldData(33), AsFieldData(44)}));
MP_EXPECT_OK(Equals<int>(GetFieldValues(node_data, path).value(),
{AsFieldData(33), AsFieldData(44)}));
// Replace a singe field index with a new list of values.
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
MP_ASSERT_OK(
MergeFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)}));
MP_EXPECT_OK(
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(),
{AsFieldData(33), AsFieldData(55), AsFieldData(66)}));
// Replace a single field middle index with a new list of values.
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
MP_ASSERT_OK(
MergeFieldValues(node_data, path, {AsFieldData(11), AsFieldData(12)}));
MP_EXPECT_OK(Equals<int>(
GetFieldValues(node_data, EntireField(path)).value(),
{AsFieldData(33), AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
// Replace field index 0 with a new value.
path = MakeFieldPath("OPTIONS/sub_options/num_lights/0", node_data);
MP_ASSERT_OK(MergeFieldValues(node_data, path, {AsFieldData(77)}));
MP_EXPECT_OK(Equals<int>(
GetFieldValues(node_data, EntireField(path)).value(),
{AsFieldData(77), AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
// Replace field index 0 with an empty list of values.
MP_ASSERT_OK(MergeFieldValues(node_data, path, {}));
MP_EXPECT_OK(
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(),
{AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
// Replace an entire populated field with an empty list of values.
path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data);
MP_ASSERT_OK(MergeFieldValues(node_data, path, {}));
MP_EXPECT_OK(
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(), {}));
// Replace a missing field index with new values.
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
absl::Status status =
MergeFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)});
EXPECT_EQ(status.code(), absl::StatusCode::kOutOfRange);
EXPECT_THAT(status.message(),
HasSubstr("Missing feild value: num_lights at index: 1"));
}
// Tests MergeFieldValues applied to a protobuf field.
TEST_F(OptionsFieldUtilTest, MergeFieldValuesProtobuf) {
using tool::options_field_util::AsFieldData;
using LightBundle = NightLightCalculatorOptions::LightBundle;
NightLightCalculatorOptions node_proto;
node_proto.mutable_sub_options();
FieldData node_data = AsFieldData(node_proto);
// Merge an empty repeated protobuf field.
LightBundle bundle_proto;
*bundle_proto.mutable_room_id() = "222";
bundle_proto.add_room_lights()->set_frame_rate(22.1);
FieldData bundle_data = AsFieldData(bundle_proto);
FieldData expected_data = bundle_data;
FieldPath path = MakeFieldPath("OPTIONS/sub_options/bundle", node_data);
MP_ASSERT_OK(MergeFieldValues(node_data, path, {bundle_data}));
MP_EXPECT_OK(Equals<LightBundle>(
GetFieldValues(node_data, EntireField(path)).value(), {expected_data}));
// Merge a populated repeated protobuf field.
// "LightBundle.room_id" merges to "333".
// "LightBundle.room_lights" merges to {{22.1}, {33.1}}.
*bundle_proto.mutable_room_id() = "333";
bundle_proto.mutable_room_lights(0)->set_frame_rate(33.1);
bundle_data = AsFieldData(bundle_proto);
LightBundle expected_proto;
*expected_proto.mutable_room_id() = "333";
expected_proto.add_room_lights()->set_frame_rate(22.1);
expected_proto.add_room_lights()->set_frame_rate(33.1);
expected_data = AsFieldData(expected_proto);
MP_ASSERT_OK(MergeFieldValues(node_data, path, {bundle_data}));
MP_EXPECT_OK(Equals<LightBundle>(
GetFieldValues(node_data, EntireField(path)).value(), {expected_data}));
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -16,11 +16,13 @@
#include <tuple> #include <tuple>
#include "absl/strings/match.h"
#include "absl/strings/numbers.h" #include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "mediapipe/framework/port/canonical_errors.h" #include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/logging.h" #include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/tool/field_data.pb.h"
#include "mediapipe/framework/type_map.h" #include "mediapipe/framework/type_map.h"
#define RET_CHECK_NO_LOG(cond) RET_CHECK(cond).SetNoLogging() #define RET_CHECK_NO_LOG(cond) RET_CHECK(cond).SetNoLogging()
@ -37,6 +39,7 @@ using FieldAccess = ProtoUtilLite::FieldAccess;
using FieldValue = ProtoUtilLite::FieldValue; using FieldValue = ProtoUtilLite::FieldValue;
using ProtoPath = ProtoUtilLite::ProtoPath; using ProtoPath = ProtoUtilLite::ProtoPath;
using FieldType = ProtoUtilLite::FieldType; using FieldType = ProtoUtilLite::FieldType;
using mediapipe::FieldData;
// Returns true if a wire type includes a length indicator. // Returns true if a wire type includes a length indicator.
bool IsLengthDelimited(WireFormatLite::WireType wire_type) { bool IsLengthDelimited(WireFormatLite::WireType wire_type) {
@ -408,5 +411,149 @@ absl::Status ProtoUtilLite::Deserialize(
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status ProtoUtilLite::WriteValue(const FieldData& value,
FieldType field_type,
std::string* field_bytes) {
StringOutputStream sos(field_bytes);
CodedOutputStream out(&sos);
switch (field_type) {
case WireFormatLite::TYPE_INT32:
WireFormatLite::WriteInt32NoTag(value.int32_value(), &out);
break;
case WireFormatLite::TYPE_SINT32:
WireFormatLite::WriteSInt32NoTag(value.int32_value(), &out);
break;
case WireFormatLite::TYPE_INT64:
WireFormatLite::WriteInt64NoTag(value.int64_value(), &out);
break;
case WireFormatLite::TYPE_SINT64:
WireFormatLite::WriteSInt64NoTag(value.int64_value(), &out);
break;
case WireFormatLite::TYPE_UINT32:
WireFormatLite::WriteUInt32NoTag(value.uint32_value(), &out);
break;
case WireFormatLite::TYPE_UINT64:
WireFormatLite::WriteUInt64NoTag(value.uint64_value(), &out);
break;
case WireFormatLite::TYPE_DOUBLE:
WireFormatLite::WriteDoubleNoTag(value.uint64_value(), &out);
break;
case WireFormatLite::TYPE_FLOAT:
WireFormatLite::WriteFloatNoTag(value.float_value(), &out);
break;
case WireFormatLite::TYPE_BOOL:
WireFormatLite::WriteBoolNoTag(value.bool_value(), &out);
break;
case WireFormatLite::TYPE_ENUM:
WireFormatLite::WriteEnumNoTag(value.enum_value(), &out);
break;
case WireFormatLite::TYPE_STRING:
out.WriteString(value.string_value());
break;
case WireFormatLite::TYPE_MESSAGE:
out.WriteString(value.message_value().value());
break;
default:
return absl::UnimplementedError(
absl::StrCat("Cannot write type: ", field_type));
}
return absl::OkStatus();
}
template <typename ValueT, FieldType kFieldType>
static ValueT ReadValue(absl::string_view field_bytes, absl::Status* status) {
ArrayInputStream ais(field_bytes.data(), field_bytes.size());
CodedInputStream input(&ais);
ValueT result;
if (!WireFormatLite::ReadPrimitive<ValueT, kFieldType>(&input, &result)) {
status->Update(absl::InvalidArgumentError(absl::StrCat(
"Bad serialized value: ", MediaPipeTypeStringOrDemangled<ValueT>(),
".")));
}
return result;
}
absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type,
absl::string_view message_type, FieldData* result) {
absl::Status status;
result->Clear();
switch (field_type) {
case WireFormatLite::TYPE_INT32:
result->set_int32_value(
ReadValue<int32, WireFormatLite::TYPE_INT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_SINT32:
result->set_int32_value(
ReadValue<int32, WireFormatLite::TYPE_SINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_INT64:
result->set_int64_value(
ReadValue<int64, WireFormatLite::TYPE_INT64>(field_bytes, &status));
break;
case WireFormatLite::TYPE_SINT64:
result->set_int64_value(
ReadValue<int64, WireFormatLite::TYPE_SINT64>(field_bytes, &status));
break;
case WireFormatLite::TYPE_UINT32:
result->set_uint32_value(
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_UINT64:
result->set_uint64_value(
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
break;
case WireFormatLite::TYPE_DOUBLE:
result->set_double_value(
ReadValue<double, WireFormatLite::TYPE_DOUBLE>(field_bytes, &status));
break;
case WireFormatLite::TYPE_FLOAT:
result->set_float_value(
ReadValue<float, WireFormatLite::TYPE_FLOAT>(field_bytes, &status));
break;
case WireFormatLite::TYPE_BOOL:
result->set_bool_value(
ReadValue<bool, WireFormatLite::TYPE_BOOL>(field_bytes, &status));
break;
case WireFormatLite::TYPE_ENUM:
result->set_enum_value(
ReadValue<int32, WireFormatLite::TYPE_ENUM>(field_bytes, &status));
break;
case WireFormatLite::TYPE_STRING:
result->set_string_value(std::string(field_bytes));
break;
case WireFormatLite::TYPE_MESSAGE:
result->mutable_message_value()->set_value(std::string(field_bytes));
result->mutable_message_value()->set_type_url(
ProtoUtilLite::TypeUrl(message_type));
break;
default:
status = absl::UnimplementedError(
absl::StrCat("Cannot read type: ", field_type));
break;
}
return status;
}
absl::Status ProtoUtilLite::ReadValue(absl::string_view field_bytes,
FieldType field_type,
absl::string_view message_type,
FieldData* result) {
return mediapipe::tool::ReadValue(field_bytes, field_type, message_type,
result);
}
std::string ProtoUtilLite::TypeUrl(absl::string_view type_name) {
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/";
return absl::StrCat(std::string(kTypeUrlPrefix), std::string(type_name));
}
std::string ProtoUtilLite::ParseTypeUrl(absl::string_view type_url) {
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/";
if (absl::StartsWith(std::string(type_url), std::string(kTypeUrlPrefix))) {
return std::string(type_url.substr(kTypeUrlPrefix.length()));
}
return std::string(type_url);
}
} // namespace tool } // namespace tool
} // namespace mediapipe } // namespace mediapipe

View File

@ -23,10 +23,12 @@
#include "mediapipe/framework/port/integral_types.h" #include "mediapipe/framework/port/integral_types.h"
#include "mediapipe/framework/port/proto_ns.h" #include "mediapipe/framework/port/proto_ns.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/tool/field_data.pb.h"
namespace mediapipe { namespace mediapipe {
namespace tool { namespace tool {
// TODO: Replace this class with a namespace following Google style.
class ProtoUtilLite { class ProtoUtilLite {
public: public:
// Defines field types and tag formats. // Defines field types and tag formats.
@ -89,6 +91,23 @@ class ProtoUtilLite {
static absl::Status Deserialize(const std::vector<FieldValue>& field_values, static absl::Status Deserialize(const std::vector<FieldValue>& field_values,
FieldType field_type, FieldType field_type,
std::vector<std::string>* result); std::vector<std::string>* result);
// Write a protobuf field value from a typed FieldData value.
static absl::Status WriteValue(const mediapipe::FieldData& value,
FieldType field_type,
std::string* field_bytes);
// Read a protobuf field value into a typed FieldData value.
static absl::Status ReadValue(absl::string_view field_bytes,
FieldType field_type,
absl::string_view message_type,
mediapipe::FieldData* result);
// Returns the protobuf type-url for a protobuf type-name.
static std::string TypeUrl(absl::string_view type_name);
// Returns the protobuf type-name for a protobuf type-url.
static std::string ParseTypeUrl(absl::string_view type_url);
}; };
} // namespace tool } // namespace tool

View File

@ -59,7 +59,8 @@ absl::Status CombinedStatus(const std::string& general_comment,
} }
} }
if (error_code == StatusCode::kOk) return OkStatus(); if (error_code == StatusCode::kOk) return OkStatus();
Status combined = absl::Status( Status combined;
combined = absl::Status(
error_code, error_code,
absl::StrCat(general_comment, "\n", absl::StrJoin(errors, "\n"))); absl::StrCat(general_comment, "\n", absl::StrJoin(errors, "\n")));
return combined; return combined;

View File

@ -28,8 +28,11 @@ namespace mediapipe {
namespace { namespace {
using testing::ContainerEq; using testing::ContainerEq;
using testing::Eq;
using testing::HasSubstr; using testing::HasSubstr;
using testing::IsEmpty; using testing::IsEmpty;
using testing::Matches;
using testing::Pointwise;
TEST(StatusTest, StatusStopIsNotOk) { EXPECT_FALSE(tool::StatusStop().ok()); } TEST(StatusTest, StatusStopIsNotOk) { EXPECT_FALSE(tool::StatusStop().ok()); }

View File

@ -293,7 +293,7 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
if (subgraph_nodes_start == nodes->end()) break; if (subgraph_nodes_start == nodes->end()) break;
std::vector<CalculatorGraphConfig> subgraphs; std::vector<CalculatorGraphConfig> subgraphs;
for (auto it = subgraph_nodes_start; it != nodes->end(); ++it) { for (auto it = subgraph_nodes_start; it != nodes->end(); ++it) {
const auto& node = *it; auto& node = *it;
int node_id = it - nodes->begin(); int node_id = it - nodes->begin();
std::string node_name = CanonicalNodeName(*config, node_id); std::string node_name = CanonicalNodeName(*config, node_id);
MP_RETURN_IF_ERROR(ValidateSubgraphFields(node)); MP_RETURN_IF_ERROR(ValidateSubgraphFields(node));

View File

@ -16,79 +16,129 @@
#define MEDIAPIPE_FRAMEWORK_TOOL_TYPE_UTIL_H_ #define MEDIAPIPE_FRAMEWORK_TOOL_TYPE_UTIL_H_
#include <cstddef> #include <cstddef>
#include <string>
#include <typeinfo> #include <typeinfo>
#include <utility>
#include "absl/base/attributes.h"
#include "mediapipe/framework/demangle.h"
#include "mediapipe/framework/port.h" #include "mediapipe/framework/port.h"
namespace mediapipe { namespace mediapipe {
// An identifier for a type. This class is lightweight and is meant to be passed
// by value.
// To get the TypeId for SomeType, write kTypeId<SomeType>.
class TypeId {
public:
size_t hash_code() const { return impl_.hash_code(); }
std::string name() const { return impl_.name(); }
bool operator==(const TypeId& other) const { return impl_ == other.impl_; }
bool operator<(const TypeId& other) const { return impl_ < other.impl_; }
template <typename H>
friend H AbslHashValue(H h, const TypeId& r) {
return H::combine(std::move(h), r.hash_code());
}
template <class T>
static constexpr inline TypeId Of() {
return TypeId{Impl::Get<T>()};
}
private:
// This implementation uses no RTTI. It distinguishes types, but does not
// know their names.
// TODO: record compile-time type string for (some or all) types.
template <class T>
struct TypeTag {
static constexpr char dummy = 0;
};
struct NoRttiImpl {
template <class T>
static constexpr inline NoRttiImpl Get() {
return {&TypeTag<T>::dummy};
}
size_t hash_code() const { return reinterpret_cast<uintptr_t>(tag_); }
std::string name() const { return "<type name missing>"; }
bool operator==(const NoRttiImpl& other) const {
return tag_ == other.tag_;
}
bool operator<(const NoRttiImpl& other) const { return tag_ < other.tag_; }
const void* tag_;
};
#if MEDIAPIPE_HAS_RTTI
template <class T>
static const std::type_info& GetTypeInfo() {
return typeid(T);
}
// This implementation uses RTTI, and delegates all operations to
// std::type_info. In order to support constexpr construction, we don't store
// a type_info directly (which is not constexpr), but a pointer to a function
// returning it (which is). This implementation is a bit slower than the
// others. The only potential advantage would be the ability to match types
// across multiple dynamic libraries, but we don't support that setup anyway.
// This is provided for completeness.
struct FullRttiImpl {
template <class T>
static constexpr inline FullRttiImpl Get() {
return {GetTypeInfo<T>};
}
size_t hash_code() const { return get_().hash_code(); }
std::string name() const { return Demangle(get_().name()); }
bool operator==(const FullRttiImpl& other) const {
return get_ == other.get_ || get_() == other.get_();
}
bool operator<(const FullRttiImpl& other) const {
return get_().before(other.get_());
}
decltype(&GetTypeInfo<void>) get_;
};
// This implementation also stores a pointer to a std::type_info getter
// function, but it only invokes it to get the type's name. It's equivalent to
// NoRttiImpl for most operations, but it allows getting the type's name.
struct FastRttiImpl {
template <class T>
static constexpr inline FastRttiImpl Get() {
return {GetTypeInfo<T>};
}
size_t hash_code() const { return reinterpret_cast<uintptr_t>(get_); }
std::string name() const { return Demangle(get_().name()); }
bool operator==(const FastRttiImpl& other) const {
return get_ == other.get_;
}
bool operator<(const FastRttiImpl& other) const {
return reinterpret_cast<uintptr_t>(get_) <
reinterpret_cast<uintptr_t>(other.get_);
}
decltype(&GetTypeInfo<void>) get_;
};
using Impl = FastRttiImpl;
#else
using Impl = NoRttiImpl;
#endif // MEDIAPIPE_HAS_RTTI
constexpr explicit TypeId(Impl impl) : impl_(impl) {}
Impl impl_;
};
template <class T>
static constexpr TypeId kTypeId = TypeId::Of<T>();
namespace tool { namespace tool {
#if !MEDIAPIPE_HAS_RTTI // Helper method that returns a hash code of the given type.
// A unique identifier for type T. // Superseded by TypeId.
class TypeInfo {
public:
size_t hash_code() const { return reinterpret_cast<size_t>(this); }
bool operator==(const TypeInfo& other) const { return &other == this; }
bool operator<(const TypeInfo& other) const { return &other < this; }
const char* name() const { return "<unknown>"; }
template <typename T>
static const TypeInfo& Get() {
static TypeInfo* static_type_info = new TypeInfo;
return *static_type_info;
}
private:
TypeInfo() {}
TypeInfo(const TypeInfo&) = delete;
};
#else // MEDIAPIPE_HAS_RTTI
// The std unique identifier for type T.
class TypeInfo {
public:
size_t hash_code() const { return info_.hash_code(); }
bool operator==(const TypeInfo& o) const { return info_ == o.info_; }
bool operator<(const TypeInfo& o) const { return info_.before(o.info_); }
const char* name() const { return info_.name(); }
template <typename T>
static const TypeInfo& Get() {
static TypeInfo* static_type_info = new TypeInfo(typeid(T));
return *static_type_info;
}
private:
TypeInfo(const std::type_info& info) : info_(info) {}
TypeInfo(const TypeInfo&) = delete;
private:
const std::type_info& info_;
friend class TypeIndex;
};
#endif
// An associative key for TypeInfo.
class TypeIndex {
public:
TypeIndex(const TypeInfo& info) : info_(info) {}
size_t hash_code() const { return info_.hash_code(); }
bool operator==(const TypeIndex& other) const { return info_ == other.info_; }
bool operator<(const TypeIndex& other) const { return info_ < other.info_; }
private:
const TypeInfo& info_;
};
// Helper method that returns a hash code of the given type. This allows for
// typeid testing across multiple binaries, unlike FastTypeId which used a
// memory location that only works within the same binary. Moreover, we use this
// for supporting multiple .so binaries in a single Android app built using the
// same compiler and C++ libraries.
// Note that std::type_info may still generate the same hash code for different
// types, although the c++ standard recommends that implementations avoid this
// as much as possible.
template <typename T> template <typename T>
ABSL_DEPRECATED("Use TypeId directly instead.")
size_t GetTypeHash() { size_t GetTypeHash() {
return TypeInfo::Get<T>().hash_code(); return kTypeId<T>.hash_code();
} }
} // namespace tool } // namespace tool

View File

@ -361,32 +361,30 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string);
// End define MEDIAPIPE_REGISTER_TYPE_WITH_PROXY. // End define MEDIAPIPE_REGISTER_TYPE_WITH_PROXY.
// Helper functions's to retrieve registration data. // Helper functions's to retrieve registration data.
inline const std::string* MediaPipeTypeStringFromTypeId(const size_t type_id) { inline const std::string* MediaPipeTypeStringFromTypeId(TypeId type_id) {
const MediaPipeTypeData* value = const MediaPipeTypeData* value =
PacketTypeIdToMediaPipeTypeData::GetValue(type_id); PacketTypeIdToMediaPipeTypeData::GetValue(type_id.hash_code());
return (value) ? &value->type_string : nullptr; return (value) ? &value->type_string : nullptr;
} }
// Returns string identifier of type or NULL if not registered. // Returns string identifier of type or NULL if not registered.
template <typename T> template <typename T>
inline const std::string* MediaPipeTypeString() { inline const std::string* MediaPipeTypeString() {
return MediaPipeTypeStringFromTypeId(tool::GetTypeHash<T>()); return MediaPipeTypeStringFromTypeId(kTypeId<T>);
} }
inline std::string MediaPipeTypeStringOrDemangled( inline std::string MediaPipeTypeStringOrDemangled(TypeId type_id) {
const tool::TypeInfo& type_info) { const std::string* type_string = MediaPipeTypeStringFromTypeId(type_id);
const std::string* type_string =
MediaPipeTypeStringFromTypeId(type_info.hash_code());
if (type_string) { if (type_string) {
return *type_string; return *type_string;
} else { } else {
return mediapipe::Demangle(type_info.name()); return type_id.name();
} }
} }
template <typename T> template <typename T>
std::string MediaPipeTypeStringOrDemangled() { std::string MediaPipeTypeStringOrDemangled() {
return MediaPipeTypeStringOrDemangled(tool::TypeInfo::Get<T>()); return MediaPipeTypeStringOrDemangled(kTypeId<T>);
} }
// Returns type hash id of type identified by type_string or NULL if not // Returns type hash id of type identified by type_string or NULL if not

View File

@ -14,6 +14,8 @@
#include "mediapipe/framework/validated_graph_config.h" #include "mediapipe/framework/validated_graph_config.h"
#include <memory>
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
@ -140,35 +142,6 @@ absl::Status AddPredefinedExecutorConfigs(CalculatorGraphConfig* graph_config) {
return absl::OkStatus(); return absl::OkStatus();
} }
absl::Status PerformBasicTransforms(
const CalculatorGraphConfig& input_graph_config,
const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager,
CalculatorGraphConfig* output_graph_config) {
*output_graph_config = input_graph_config;
MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(output_graph_config, graph_registry,
graph_options, service_manager));
MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(output_graph_config));
// Populate each node with the graph level input stream handler if a
// stream handler wasn't explicitly provided.
// TODO Instead of pre-populating, handle the graph level
// default appropriately within CalculatorGraph.
if (output_graph_config->has_input_stream_handler()) {
const auto& graph_level_input_stream_handler =
output_graph_config->input_stream_handler();
for (auto& node : *output_graph_config->mutable_node()) {
if (!node.has_input_stream_handler()) {
*node.mutable_input_stream_handler() = graph_level_input_stream_handler;
}
}
}
return absl::OkStatus();
}
} // namespace } // namespace
// static // static
@ -346,8 +319,7 @@ absl::Status NodeTypeInfo::Initialize(
} }
absl::Status ValidatedGraphConfig::Initialize( absl::Status ValidatedGraphConfig::Initialize(
const CalculatorGraphConfig& input_config, CalculatorGraphConfig input_config, const GraphRegistry* graph_registry,
const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options, const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager) { const GraphServiceManager* service_manager) {
RET_CHECK(!initialized_) RET_CHECK(!initialized_)
@ -358,9 +330,9 @@ absl::Status ValidatedGraphConfig::Initialize(
<< input_config.DebugString(); << input_config.DebugString();
#endif #endif
MP_RETURN_IF_ERROR(PerformBasicTransforms( config_ = std::move(input_config);
input_config, graph_registry, graph_options, service_manager, &config_)); MP_RETURN_IF_ERROR(
PerformBasicTransforms(graph_registry, graph_options, service_manager));
// Initialize the basic node information. // Initialize the basic node information.
MP_RETURN_IF_ERROR(InitializeGeneratorInfo()); MP_RETURN_IF_ERROR(InitializeGeneratorInfo());
MP_RETURN_IF_ERROR(InitializeCalculatorInfo()); MP_RETURN_IF_ERROR(InitializeCalculatorInfo());
@ -441,7 +413,12 @@ absl::Status ValidatedGraphConfig::Initialize(
const GraphServiceManager* service_manager) { const GraphServiceManager* service_manager) {
graph_registry = graph_registry =
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry; graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
SubgraphContext subgraph_context(graph_options, service_manager); Subgraph::SubgraphOptions local_graph_options;
if (graph_options) {
local_graph_options = *graph_options;
}
SubgraphContext subgraph_context =
SubgraphContext(&local_graph_options, service_manager);
auto status_or_config = auto status_or_config =
graph_registry->CreateByName("", graph_type, &subgraph_context); graph_registry->CreateByName("", graph_type, &subgraph_context);
MP_RETURN_IF_ERROR(status_or_config.status()); MP_RETURN_IF_ERROR(status_or_config.status());
@ -466,6 +443,32 @@ absl::Status ValidatedGraphConfig::Initialize(
service_manager); service_manager);
} }
absl::Status ValidatedGraphConfig::PerformBasicTransforms(
const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager) {
MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(&config_, graph_registry,
graph_options, service_manager));
MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(&config_));
// Populate each node with the graph level input stream handler if a
// stream handler wasn't explicitly provided.
// TODO Instead of pre-populating, handle the graph level
// default appropriately within CalculatorGraph.
if (config_.has_input_stream_handler()) {
const auto& graph_level_input_stream_handler =
config_.input_stream_handler();
for (auto& node : *config_.mutable_node()) {
if (!node.has_input_stream_handler()) {
*node.mutable_input_stream_handler() = graph_level_input_stream_handler;
}
}
}
return absl::OkStatus();
}
absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() { absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() {
std::vector<absl::Status> statuses; std::vector<absl::Status> statuses;
calculators_.reserve(config_.node_size()); calculators_.reserve(config_.node_size());
@ -690,6 +693,7 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode(
if (!need_sorting_ptr) { if (!need_sorting_ptr) {
LOG(WARNING) << "Input Stream \"" << name LOG(WARNING) << "Input Stream \"" << name
<< "\" for node with sorted index " << node_index << "\" for node with sorted index " << node_index
<< " name " << node_type_info->Contract().GetNodeName()
<< " is marked as a back edge, but its output stream is " << " is marked as a back edge, but its output stream is "
"already available. This means it was not necessary " "already available. This means it was not necessary "
"to mark it as a back edge."; "to mark it as a back edge.";
@ -701,6 +705,7 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode(
if (edge_info.back_edge) { if (edge_info.back_edge) {
VLOG(1) << "Encountered expected behavior: the back edge \"" << name VLOG(1) << "Encountered expected behavior: the back edge \"" << name
<< "\" for node with (possibly sorted) index " << node_index << "\" for node with (possibly sorted) index " << node_index
<< " name " << node_type_info->Contract().GetNodeName()
<< " has an output stream which we have not yet seen."; << " has an output stream which we have not yet seen.";
} else if (need_sorting_ptr) { } else if (need_sorting_ptr) {
*need_sorting_ptr = true; *need_sorting_ptr = true;
@ -709,7 +714,9 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode(
} else { } else {
return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
<< "Input Stream \"" << name << "\" for node with sorted index " << "Input Stream \"" << name << "\" for node with sorted index "
<< node_index << " does not have a corresponding output stream."; << node_index << " name "
<< node_type_info->Contract().GetNodeName()
<< " does not have a corresponding output stream.";
} }
} }

View File

@ -195,7 +195,7 @@ class ValidatedGraphConfig {
// before any other functions. Subgraphs are specified through the // before any other functions. Subgraphs are specified through the
// global graph registry or an optional local graph registry. // global graph registry or an optional local graph registry.
absl::Status Initialize( absl::Status Initialize(
const CalculatorGraphConfig& input_config, CalculatorGraphConfig input_config,
const GraphRegistry* graph_registry = nullptr, const GraphRegistry* graph_registry = nullptr,
const Subgraph::SubgraphOptions* graph_options = nullptr, const Subgraph::SubgraphOptions* graph_options = nullptr,
const GraphServiceManager* service_manager = nullptr); const GraphServiceManager* service_manager = nullptr);
@ -302,6 +302,13 @@ class ValidatedGraphConfig {
} }
private: private:
// Perform transforms such as converting legacy features, expanding
// subgraphs, and popluting input stream handler.
absl::Status PerformBasicTransforms(
const GraphRegistry* graph_registry,
const Subgraph::SubgraphOptions* graph_options,
const GraphServiceManager* service_manager);
// Initialize the PacketGenerator information. // Initialize the PacketGenerator information.
absl::Status InitializeGeneratorInfo(); absl::Status InitializeGeneratorInfo();
// Initialize the Calculator information. // Initialize the Calculator information.

View File

@ -53,6 +53,12 @@ cc_library(
deps = ["//mediapipe/framework:graph_service"], deps = ["//mediapipe/framework:graph_service"],
) )
cc_library(
name = "attachments",
hdrs = ["attachments.h"],
visibility = ["//visibility:public"],
)
GL_BASE_LINK_OPTS = select({ GL_BASE_LINK_OPTS = select({
"//conditions:default": [], "//conditions:default": [],
"//mediapipe:android": [ "//mediapipe:android": [
@ -172,6 +178,7 @@ cc_library(
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":attachments",
":gl_base", ":gl_base",
":gl_thread_collector", ":gl_thread_collector",
":gpu_buffer_format", ":gpu_buffer_format",

View File

@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_ #ifndef MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_
#define MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_ #define MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_
#import <CoreVideo/CVMetalTextureCache.h> #import <CoreVideo/CVMetalTextureCache.h>
#import <CoreVideo/CoreVideo.h> #import <CoreVideo/CoreVideo.h>
@ -68,4 +68,4 @@ class GpuBufferMultiPool;
@end @end
#endif // MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_ #endif // MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_

View File

@ -0,0 +1,64 @@
#ifndef MEDIAPIPE_GPU_ATTACHMENTS_H_
#define MEDIAPIPE_GPU_ATTACHMENTS_H_
#include <functional>
#include <memory>
namespace mediapipe {
namespace internal {
// Unique pointer with a type-erased destructor.
template <class T>
using AttachmentPtr = std::unique_ptr<T, std::function<void(void*)>>;
// Like make_unique.
template <class T, class... Args>
static std::enable_if_t<!std::is_array<T>::value, AttachmentPtr<T>>
MakeAttachmentPtr(Args&&... args) {
return {new T(std::forward<Args>(args)...),
[](void* ptr) { delete static_cast<T*>(ptr); }};
}
template <class Context>
class AttachmentBase {};
// An cacheable resource that can be associated with a context.
// Attachments are defined as constants.
// When access to an attachment is requested, it will be retrieved from the
// context if already created, or the factory function will be invoked to create
// it. The factory function for a given attachment is invoked at most once per
// context. The lifetime of the object it returns is managed by the context.
template <class Context, class T>
class Attachment : public AttachmentBase<Context> {
public:
using FactoryT = std::function<AttachmentPtr<T>(Context&)>;
Attachment(FactoryT factory) : factory_(factory) {}
Attachment(const Attachment&) = delete;
Attachment(Attachment&&) = delete;
Attachment& operator=(const Attachment&) = delete;
Attachment& operator=(Attachment&&) = delete;
T& Get(Context& ctx) const { return ctx.GetCachedAttachment(*this); }
const FactoryT& factory() const { return factory_; }
// Ptr and MakePtr here make it more convenient to define new types of
// attachment contexts, since you only need a using declaration for Attachment
// and can refer to Ptr from it.
using Ptr = AttachmentPtr<T>;
template <class... Args>
inline static std::enable_if_t<!std::is_array<T>::value, AttachmentPtr<T>>
MakePtr(Args&&... args) {
return MakeAttachmentPtr<T>(std::forward<Args>(args)...);
}
private:
FactoryT factory_;
};
} // namespace internal
} // namespace mediapipe
#endif // MEDIAPIPE_GPU_ATTACHMENTS_H_

View File

@ -29,6 +29,7 @@
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "mediapipe/framework/port/threadpool.h" #include "mediapipe/framework/port/threadpool.h"
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
#include "mediapipe/gpu/attachments.h"
#include "mediapipe/gpu/gl_base.h" #include "mediapipe/gpu/gl_base.h"
#include "mediapipe/gpu/gpu_buffer_format.h" #include "mediapipe/gpu/gpu_buffer_format.h"
@ -286,42 +287,15 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
// Sets default texture filtering parameters. // Sets default texture filtering parameters.
void SetStandardTextureParams(GLenum target, GLint internal_format); void SetStandardTextureParams(GLenum target, GLint internal_format);
using AttachmentBase = internal::AttachmentBase<GlContext>;
template <class T> template <class T>
using AttachmentPtr = std::unique_ptr<T, std::function<void(void*)>>; using Attachment = internal::Attachment<GlContext, T>;
template <class T, class... Args>
static std::enable_if_t<!std::is_array<T>::value, AttachmentPtr<T>>
MakeAttachmentPtr(Args&&... args) {
return {new T(std::forward<Args>(args)...),
[](void* ptr) { delete static_cast<T*>(ptr); }};
}
class AttachmentBase {};
template <class T>
class Attachment : public AttachmentBase {
public:
using FactoryT = std::function<AttachmentPtr<T>(GlContext&)>;
Attachment(FactoryT factory) : factory_(factory) {}
Attachment(const Attachment&) = delete;
Attachment(Attachment&&) = delete;
Attachment& operator=(const Attachment&) = delete;
Attachment& operator=(Attachment&&) = delete;
T& Get(GlContext& ctx) const { return ctx.GetCachedAttachment(*this); }
const FactoryT& factory() const { return factory_; }
private:
FactoryT factory_;
};
// TOOD: const result? // TOOD: const result?
template <class T> template <class T>
T& GetCachedAttachment(const Attachment<T>& attachment) { T& GetCachedAttachment(const Attachment<T>& attachment) {
DCHECK(IsCurrent()); DCHECK(IsCurrent());
AttachmentPtr<void>& entry = attachments_[&attachment]; internal::AttachmentPtr<void>& entry = attachments_[&attachment];
if (entry == nullptr) { if (entry == nullptr) {
entry = attachment.factory()(*this); entry = attachment.factory()(*this);
} }
@ -454,7 +428,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
// better mechanism? // better mechanism?
bool can_linear_filter_float_textures_; bool can_linear_filter_float_textures_;
absl::flat_hash_map<const AttachmentBase*, AttachmentPtr<void>> attachments_; absl::flat_hash_map<const AttachmentBase*, internal::AttachmentPtr<void>>
attachments_;
// Number of glFinish calls completed on the GL thread. // Number of glFinish calls completed on the GL thread.
// Changes should be guarded by mutex_. However, we use simple atomic // Changes should be guarded by mutex_. However, we use simple atomic

Some files were not shown because too many files have changed in this diff Show More