Project import generated by Copybara.
GitOrigin-RevId: 6e5aa035cd1f6a9333962df5d3ab97a05bd5744e
This commit is contained in:
parent
4a20e9909d
commit
c688862570
|
@ -1 +1 @@
|
|||
5.0.0
|
||||
5.2.0
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM ubuntu:18.04
|
||||
FROM ubuntu:20.04
|
||||
|
||||
MAINTAINER <mediapipe@google.com>
|
||||
|
||||
|
@ -42,6 +42,8 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
software-properties-common && \
|
||||
add-apt-repository -y ppa:openjdk-r/ppa && \
|
||||
apt-get update && apt-get install -y openjdk-8-jdk && \
|
||||
apt-get install -y mesa-common-dev libegl1-mesa-dev libgles2-mesa-dev && \
|
||||
apt-get install -y mesa-utils && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
@ -50,13 +52,13 @@ RUN pip3 install --upgrade setuptools
|
|||
RUN pip3 install wheel
|
||||
RUN pip3 install future
|
||||
RUN pip3 install six==1.14.0
|
||||
RUN pip3 install tensorflow==1.14.0
|
||||
RUN pip3 install tensorflow==2.2.0
|
||||
RUN pip3 install tf_slim
|
||||
|
||||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||
|
||||
# Install bazel
|
||||
ARG BAZEL_VERSION=5.0.0
|
||||
ARG BAZEL_VERSION=5.2.0
|
||||
RUN mkdir /bazel && \
|
||||
wget --no-check-certificate -O /bazel/installer.sh "https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/b\
|
||||
azel-${BAZEL_VERSION}-installer-linux-x86_64.sh" && \
|
||||
|
|
|
@ -35,8 +35,9 @@ http_archive(
|
|||
|
||||
http_archive(
|
||||
name = "rules_cc",
|
||||
strip_prefix = "rules_cc-main",
|
||||
urls = ["https://github.com/bazelbuild/rules_cc/archive/main.zip"],
|
||||
strip_prefix = "rules_cc-2f8c04c04462ab83c545ab14c0da68c3b4c96191",
|
||||
# The commit can be updated if the build passes. Last updated 6/23/22.
|
||||
urls = ["https://github.com/bazelbuild/rules_cc/archive/2f8c04c04462ab83c545ab14c0da68c3b4c96191.zip"],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
|
|
|
@ -244,6 +244,7 @@ cc_test(
|
|||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:test_util",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -20,8 +20,12 @@
|
|||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestPackageRoot[] = "mediapipe/calculators/audio";
|
||||
|
||||
TEST(AudioDecoderCalculatorTest, TestWAV) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
|
@ -37,9 +41,8 @@ TEST(AudioDecoderCalculatorTest, TestWAV) {
|
|||
})pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_44100_mono_2_sec_wav.audio"));
|
||||
file::JoinPath(GetTestDataDir(kTestPackageRoot),
|
||||
"sine_wave_1k_44100_mono_2_sec_wav.audio"));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
MP_EXPECT_OK(runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
|
@ -68,9 +71,8 @@ TEST(AudioDecoderCalculatorTest, Test48KWAV) {
|
|||
})pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_48000_stereo_2_sec_wav.audio"));
|
||||
file::JoinPath(GetTestDataDir(kTestPackageRoot),
|
||||
"sine_wave_1k_48000_stereo_2_sec_wav.audio"));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
MP_EXPECT_OK(runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
|
@ -99,9 +101,8 @@ TEST(AudioDecoderCalculatorTest, TestMP3) {
|
|||
})pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_44100_stereo_2_sec_mp3.audio"));
|
||||
file::JoinPath(GetTestDataDir(kTestPackageRoot),
|
||||
"sine_wave_1k_44100_stereo_2_sec_mp3.audio"));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
MP_EXPECT_OK(runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
|
@ -130,9 +131,8 @@ TEST(AudioDecoderCalculatorTest, TestAAC) {
|
|||
})pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag("INPUT_FILE_PATH") = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/audio/"
|
||||
"testdata/sine_wave_1k_44100_stereo_2_sec_aac.audio"));
|
||||
file::JoinPath(GetTestDataDir(kTestPackageRoot),
|
||||
"sine_wave_1k_44100_stereo_2_sec_aac.audio"));
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
MP_EXPECT_OK(runner.Outputs()
|
||||
.Tag("AUDIO_HEADER")
|
||||
|
@ -147,4 +147,5 @@ TEST(AudioDecoderCalculatorTest, TestAAC) {
|
|||
std::ceil(44100.0 * 2 / 1024));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -20,24 +20,22 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "audio/dsp/spectrogram/spectrogram.h"
|
||||
#include "audio/dsp/window_functions.h"
|
||||
#include "mediapipe/calculators/audio/spectrogram_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/framework/port/core_proto_inc.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/source_location.h"
|
||||
#include "mediapipe/framework/port/status_builder.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
constexpr char kFrameDurationTag[] = "FRAME_DURATION";
|
||||
constexpr char kFrameOverlapTag[] = "FRAME_OVERLAP";
|
||||
} // namespace
|
||||
// MediaPipe Calculator for computing the "spectrogram" (short-time Fourier
|
||||
// transform squared-magnitude, by default) of a multichannel input
|
||||
// time series, including optionally overlapping frames. Options are
|
||||
|
@ -46,11 +44,14 @@ namespace mediapipe {
|
|||
//
|
||||
// Result is a MatrixData record (for single channel input and when the
|
||||
// allow_multichannel_input flag is false), or a vector of MatrixData records,
|
||||
// one for each channel (when the allow_multichannel_input flag is set). The
|
||||
// rows of each spectrogram matrix correspond to the n_fft/2+1 unique complex
|
||||
// values, or squared/linear/dB magnitudes, depending on the output_type option.
|
||||
// Each input packet will result in zero or one output packets, each containing
|
||||
// one Matrix for each channel of the input, where each Matrix has one or more
|
||||
// one for each channel (when the allow_multichannel_input flag is set). Each
|
||||
// waveform frame is converted to frequency by a fast Fourier transform whose
|
||||
// size, n_fft, is the smallest power of two large enough to enclose the frame
|
||||
// length of round(frame_duration_seconds * sample_rate).The rows of each
|
||||
// spectrogram matrix(result) correspond to the n_fft/2+1 unique complex values,
|
||||
// or squared/linear/dB magnitudes, depending on the output_type option. Each
|
||||
// input packet will result in zero or one output packets, each containing one
|
||||
// Matrix for each channel of the input, where each Matrix has one or more
|
||||
// columns of spectral values, one for each complete frame of input samples. If
|
||||
// the input packet contains too few samples to trigger a new output frame, no
|
||||
// output packet is generated (since zero-length packets are not legal since
|
||||
|
@ -71,6 +72,22 @@ class SpectrogramCalculator : public CalculatorBase {
|
|||
// Input stream with TimeSeriesHeader.
|
||||
);
|
||||
|
||||
if (cc->InputSidePackets().HasTag(kFrameDurationTag)) {
|
||||
cc->InputSidePackets()
|
||||
.Tag(kFrameDurationTag)
|
||||
.Set<double>(
|
||||
// Optional side packet for frame_duration_seconds if provided.
|
||||
);
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag(kFrameOverlapTag)) {
|
||||
cc->InputSidePackets()
|
||||
.Tag(kFrameOverlapTag)
|
||||
.Set<double>(
|
||||
// Optional side packet for frame_overlap_seconds if provided.
|
||||
);
|
||||
}
|
||||
|
||||
SpectrogramCalculatorOptions spectrogram_options =
|
||||
cc->Options<SpectrogramCalculatorOptions>();
|
||||
if (!spectrogram_options.allow_multichannel_input()) {
|
||||
|
@ -184,27 +201,47 @@ class SpectrogramCalculator : public CalculatorBase {
|
|||
// Fixed scale factor applied to output values (regardless of type).
|
||||
double output_scale_;
|
||||
|
||||
static const float kLnPowerToDb;
|
||||
static const float kLnSquaredMagnitudeToDb;
|
||||
};
|
||||
REGISTER_CALCULATOR(SpectrogramCalculator);
|
||||
|
||||
// Factor to convert ln(magnitude_squared) to deciBels = 10.0/ln(10.0).
|
||||
const float SpectrogramCalculator::kLnPowerToDb = 4.342944819032518;
|
||||
// DECIBELS = 20*log10(LINEAR_MAGNITUDE) = 10*Log10(SQUARED_MAGNITUDE)
|
||||
// =10/ln(10)*ln(SQUARED_MAGNITUDE).
|
||||
// Factor to convert ln(SQUARED_MAGNITUDE) to deciBels = 10.0/ln(10.0).
|
||||
const float SpectrogramCalculator::kLnSquaredMagnitudeToDb = 4.342944819032518;
|
||||
|
||||
absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
|
||||
SpectrogramCalculatorOptions spectrogram_options =
|
||||
cc->Options<SpectrogramCalculatorOptions>();
|
||||
// Provide frame_duration_seconds and frame_overlap_seconds either from static
|
||||
// options, or dynamically from a side packet, the side packet one will
|
||||
// override the options one if provided.
|
||||
|
||||
double frame_duration_seconds = 0;
|
||||
double frame_overlap_seconds = 0;
|
||||
if (cc->InputSidePackets().HasTag(kFrameDurationTag)) {
|
||||
frame_duration_seconds =
|
||||
cc->InputSidePackets().Tag(kFrameDurationTag).Get<double>();
|
||||
} else {
|
||||
frame_duration_seconds = spectrogram_options.frame_duration_seconds();
|
||||
}
|
||||
|
||||
if (cc->InputSidePackets().HasTag(kFrameOverlapTag)) {
|
||||
frame_overlap_seconds =
|
||||
cc->InputSidePackets().Tag(kFrameOverlapTag).Get<double>();
|
||||
} else {
|
||||
frame_overlap_seconds = spectrogram_options.frame_overlap_seconds();
|
||||
}
|
||||
|
||||
use_local_timestamp_ = spectrogram_options.use_local_timestamp();
|
||||
|
||||
if (spectrogram_options.frame_duration_seconds() <= 0.0) {
|
||||
if (frame_duration_seconds <= 0.0) {
|
||||
// TODO: return an error.
|
||||
}
|
||||
if (spectrogram_options.frame_overlap_seconds() >=
|
||||
spectrogram_options.frame_duration_seconds()) {
|
||||
if (frame_overlap_seconds >= frame_duration_seconds) {
|
||||
// TODO: return an error.
|
||||
}
|
||||
if (spectrogram_options.frame_overlap_seconds() < 0.0) {
|
||||
if (frame_overlap_seconds < 0.0) {
|
||||
// TODO: return an error.
|
||||
}
|
||||
|
||||
|
@ -220,10 +257,8 @@ absl::Status SpectrogramCalculator::Open(CalculatorContext* cc) {
|
|||
// TODO: return an error.
|
||||
}
|
||||
|
||||
frame_duration_samples_ =
|
||||
round(spectrogram_options.frame_duration_seconds() * input_sample_rate_);
|
||||
frame_overlap_samples_ =
|
||||
round(spectrogram_options.frame_overlap_seconds() * input_sample_rate_);
|
||||
frame_duration_samples_ = round(frame_duration_seconds * input_sample_rate_);
|
||||
frame_overlap_samples_ = round(frame_overlap_seconds * input_sample_rate_);
|
||||
|
||||
pad_final_packet_ = spectrogram_options.pad_final_packet();
|
||||
output_type_ = spectrogram_options.output_type();
|
||||
|
@ -419,7 +454,7 @@ absl::Status SpectrogramCalculator::ProcessVector(const Matrix& input_stream,
|
|||
return ProcessVectorToOutput(
|
||||
input_stream,
|
||||
+[](const Matrix& col) -> const Matrix {
|
||||
return kLnPowerToDb * col.array().log().matrix();
|
||||
return kLnSquaredMagnitudeToDb * col.array().log().matrix();
|
||||
}, cc);
|
||||
}
|
||||
// clang-format on
|
||||
|
|
|
@ -32,7 +32,11 @@ message SpectrogramCalculatorOptions {
|
|||
|
||||
// Duration of overlap between adjacent windows.
|
||||
// Hence, frame_rate = 1/(frame_duration_seconds - frame_overlap_seconds).
|
||||
// Required that 0 <= frame_overlap_seconds < frame_duration_seconds.
|
||||
// Note the frame_rate here is not the MediaPipe packet rate, the frame here
|
||||
// means each Fourier transform analysis waveform frame, the output MediaPipe
|
||||
// packet rate will the the same as input, if frame rate is lower than input
|
||||
// packet rate, will result in intermittent empty output packets. Required
|
||||
// that 0 <= frame_overlap_seconds < frame_duration_seconds.
|
||||
optional double frame_overlap_seconds = 2 [default = 0.0];
|
||||
|
||||
// Whether to pad the final packet with zeros. If true, guarantees that
|
||||
|
@ -42,6 +46,11 @@ message SpectrogramCalculatorOptions {
|
|||
|
||||
// Output value type can be squared-magnitude, linear-magnitude,
|
||||
// deciBels (dB, = 20*log10(linear_magnitude)), or std::complex.
|
||||
// Their relationship:
|
||||
// COMPLEX c = Re + Im*i;
|
||||
// SQUARED_MAGNITUDE = Re^2 + Im^2;
|
||||
// LINEAR_MAGNITUDE = sqrt(SQUARED_MAGNITUDE);
|
||||
// DECIBELS = 20*log10(LINEAR_MAGNITUDE) = 10*log10(SQUARED_MAGNITUDE);
|
||||
enum OutputType {
|
||||
SQUARED_MAGNITUDE = 0;
|
||||
LINEAR_MAGNITUDE = 1;
|
||||
|
|
|
@ -557,6 +557,22 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "packet_cloner_calculator_test",
|
||||
srcs = ["packet_cloner_calculator_test.cc"],
|
||||
deps = [
|
||||
":packet_cloner_calculator",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/stream_handler:immediate_input_stream_handler",
|
||||
"//mediapipe/framework/tool:sink",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "packet_inner_join_calculator",
|
||||
srcs = ["packet_inner_join_calculator.cc"],
|
||||
|
|
|
@ -73,8 +73,17 @@ typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmark>
|
|||
ConcatenateLandmarkVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkVectorCalculator);
|
||||
|
||||
typedef ConcatenateVectorCalculator<::mediapipe::LandmarkList>
|
||||
ConcatenateLandmarkListVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarkListVectorCalculator);
|
||||
|
||||
typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmarkList>
|
||||
ConcatenateLandmarListVectorCalculator;
|
||||
ConcatenateNormalizedLandmarkListVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateNormalizedLandmarkListVectorCalculator);
|
||||
|
||||
// For backwards compatibility, keep the version with the typo.
|
||||
using ConcatenateLandmarListVectorCalculator =
|
||||
ConcatenateNormalizedLandmarkListVectorCalculator;
|
||||
MEDIAPIPE_REGISTER_NODE(ConcatenateLandmarListVectorCalculator);
|
||||
|
||||
typedef ConcatenateVectorCalculator<mediapipe::ClassificationList>
|
||||
|
|
|
@ -32,8 +32,8 @@ constexpr char kOptionsTag[] = "OPTIONS";
|
|||
// FlowLimiterCalculator is used to limit the number of frames in flight
|
||||
// by dropping input frames when necessary.
|
||||
//
|
||||
// The input stream "FINISH" is used to signal the FlowLimiterCalculator
|
||||
// when a frame is finished processing. Either a non-empty "FINISH" packet
|
||||
// The input stream "FINISHED" is used to signal the FlowLimiterCalculator
|
||||
// when a frame is finished processing. Either a non-empty "FINISHED" packet
|
||||
// or a timestamp bound should be received for each processed frame.
|
||||
//
|
||||
// The combination of `max_in_flight: 1` and `max_in_queue: 1` generally gives
|
||||
|
|
|
@ -16,9 +16,10 @@
|
|||
// For every packet that appears in B, outputs the most recent packet from each
|
||||
// of the A_i on a separate stream.
|
||||
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/core/packet_cloner_calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
||||
|
@ -34,7 +35,18 @@ namespace mediapipe {
|
|||
// calculator: "PacketClonerCalculator"
|
||||
// input_stream: "first_base_signal"
|
||||
// input_stream: "second_base_signal"
|
||||
// input_stream: "tick_signal"
|
||||
// input_stream: "tick_signal" # or input_stream: "TICK:tick_signal"
|
||||
// output_stream: "cloned_first_base_signal"
|
||||
// output_stream: "cloned_second_base_signal"
|
||||
// }
|
||||
//
|
||||
// Or you can use "TICK" tag and put corresponding input stream at any location,
|
||||
// for example at the very beginning:
|
||||
// node {
|
||||
// calculator: "PacketClonerCalculator"
|
||||
// input_stream: "TICK:tick_signal"
|
||||
// input_stream: "first_base_signal"
|
||||
// input_stream: "second_base_signal"
|
||||
// output_stream: "cloned_first_base_signal"
|
||||
// output_stream: "cloned_second_base_signal"
|
||||
// }
|
||||
|
@ -46,12 +58,13 @@ namespace mediapipe {
|
|||
class PacketClonerCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
const int tick_signal_index = cc->Inputs().NumEntries() - 1;
|
||||
for (int i = 0; i < tick_signal_index; ++i) {
|
||||
cc->Inputs().Index(i).SetAny();
|
||||
cc->Outputs().Index(i).SetSameAs(&cc->Inputs().Index(i));
|
||||
const Ids ids = GetIds(*cc);
|
||||
for (const auto& in_out : ids.inputs_outputs) {
|
||||
auto& input = cc->Inputs().Get(in_out.in);
|
||||
input.SetAny();
|
||||
cc->Outputs().Get(in_out.out).SetSameAs(&input);
|
||||
}
|
||||
cc->Inputs().Index(tick_signal_index).SetAny();
|
||||
cc->Inputs().Get(ids.tick_id).SetAny();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -65,13 +78,15 @@ class PacketClonerCalculator : public CalculatorBase {
|
|||
output_empty_packets_before_all_inputs_received_ =
|
||||
calculator_options.output_packets_only_when_all_inputs_received();
|
||||
|
||||
// Parse input streams.
|
||||
tick_signal_index_ = cc->Inputs().NumEntries() - 1;
|
||||
current_.resize(tick_signal_index_);
|
||||
// Prepare input and output ids.
|
||||
ids_ = GetIds(*cc);
|
||||
current_.resize(ids_.inputs_outputs.size());
|
||||
|
||||
// Pass along the header for each stream if present.
|
||||
for (int i = 0; i < tick_signal_index_; ++i) {
|
||||
if (!cc->Inputs().Index(i).Header().IsEmpty()) {
|
||||
cc->Outputs().Index(i).SetHeader(cc->Inputs().Index(i).Header());
|
||||
for (const auto& in_out : ids_.inputs_outputs) {
|
||||
auto& input = cc->Inputs().Get(in_out.in);
|
||||
if (!input.Header().IsEmpty()) {
|
||||
cc->Outputs().Get(in_out.out).SetHeader(input.Header());
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
|
@ -79,17 +94,18 @@ class PacketClonerCalculator : public CalculatorBase {
|
|||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
// Store input signals.
|
||||
for (int i = 0; i < tick_signal_index_; ++i) {
|
||||
if (!cc->Inputs().Index(i).Value().IsEmpty()) {
|
||||
current_[i] = cc->Inputs().Index(i).Value();
|
||||
for (int i = 0; i < ids_.inputs_outputs.size(); ++i) {
|
||||
const auto& input = cc->Inputs().Get(ids_.inputs_outputs[i].in);
|
||||
if (!input.IsEmpty()) {
|
||||
current_[i] = input.Value();
|
||||
}
|
||||
}
|
||||
|
||||
// Output according to the TICK signal.
|
||||
if (!cc->Inputs().Index(tick_signal_index_).Value().IsEmpty()) {
|
||||
if (!cc->Inputs().Get(ids_.tick_id).IsEmpty()) {
|
||||
if (output_only_when_all_inputs_received_) {
|
||||
// Return if one of the input is null.
|
||||
for (int i = 0; i < tick_signal_index_; ++i) {
|
||||
for (int i = 0; i < ids_.inputs_outputs.size(); ++i) {
|
||||
if (current_[i].IsEmpty()) {
|
||||
if (output_empty_packets_before_all_inputs_received_) {
|
||||
SetAllNextTimestampBounds(cc);
|
||||
|
@ -99,12 +115,12 @@ class PacketClonerCalculator : public CalculatorBase {
|
|||
}
|
||||
}
|
||||
// Output each stream.
|
||||
for (int i = 0; i < tick_signal_index_; ++i) {
|
||||
for (int i = 0; i < ids_.inputs_outputs.size(); ++i) {
|
||||
auto& output = cc->Outputs().Get(ids_.inputs_outputs[i].out);
|
||||
if (!current_[i].IsEmpty()) {
|
||||
cc->Outputs().Index(i).AddPacket(
|
||||
current_[i].At(cc->InputTimestamp()));
|
||||
output.AddPacket(current_[i].At(cc->InputTimestamp()));
|
||||
} else {
|
||||
cc->Outputs().Index(i).SetNextTimestampBound(
|
||||
output.SetNextTimestampBound(
|
||||
cc->InputTimestamp().NextAllowedInStream());
|
||||
}
|
||||
}
|
||||
|
@ -113,15 +129,44 @@ class PacketClonerCalculator : public CalculatorBase {
|
|||
}
|
||||
|
||||
private:
|
||||
struct Ids {
|
||||
struct InputOutput {
|
||||
CollectionItemId in;
|
||||
CollectionItemId out;
|
||||
};
|
||||
CollectionItemId tick_id;
|
||||
std::vector<InputOutput> inputs_outputs;
|
||||
};
|
||||
|
||||
template <typename CC>
|
||||
static Ids GetIds(CC& cc) {
|
||||
Ids ids;
|
||||
static constexpr absl::string_view kEmptyTag = "";
|
||||
int num_inputs_to_clone = cc.Inputs().NumEntries(kEmptyTag);
|
||||
static constexpr absl::string_view kTickTag = "TICK";
|
||||
if (cc.Inputs().HasTag(kTickTag)) {
|
||||
ids.tick_id = cc.Inputs().GetId(kTickTag, 0);
|
||||
} else {
|
||||
--num_inputs_to_clone;
|
||||
ids.tick_id = cc.Inputs().GetId(kEmptyTag, num_inputs_to_clone);
|
||||
}
|
||||
for (int i = 0; i < num_inputs_to_clone; ++i) {
|
||||
ids.inputs_outputs.push_back({.in = cc.Inputs().GetId(kEmptyTag, i),
|
||||
.out = cc.Outputs().GetId(kEmptyTag, i)});
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
void SetAllNextTimestampBounds(CalculatorContext* cc) {
|
||||
for (int j = 0; j < tick_signal_index_; ++j) {
|
||||
cc->Outputs().Index(j).SetNextTimestampBound(
|
||||
cc->InputTimestamp().NextAllowedInStream());
|
||||
for (const auto& in_out : ids_.inputs_outputs) {
|
||||
cc->Outputs()
|
||||
.Get(in_out.out)
|
||||
.SetNextTimestampBound(cc->InputTimestamp().NextAllowedInStream());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Packet> current_;
|
||||
int tick_signal_index_;
|
||||
Ids ids_;
|
||||
bool output_only_when_all_inputs_received_;
|
||||
bool output_empty_packets_before_all_inputs_received_;
|
||||
};
|
||||
|
|
349
mediapipe/calculators/core/packet_cloner_calculator_test.cc
Normal file
349
mediapipe/calculators/core/packet_cloner_calculator_test.cc
Normal file
|
@ -0,0 +1,349 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/framework/tool/sink.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::Eq;
|
||||
using ::testing::Value;
|
||||
|
||||
MATCHER_P2(IntPacket, value, ts, "") {
|
||||
return Value(arg.template Get<int>(), Eq(value)) &&
|
||||
Value(arg.Timestamp(), Eq(Timestamp(ts)));
|
||||
}
|
||||
|
||||
MATCHER_P2(FloatPacket, value, ts, "") {
|
||||
return Value(arg.template Get<float>(), Eq(value)) &&
|
||||
Value(arg.Timestamp(), Eq(Timestamp(ts)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
absl::Status SendPacket(const std::string& input_name, T value, int ts,
|
||||
CalculatorGraph& graph) {
|
||||
return graph.AddPacketToInputStream(input_name,
|
||||
MakePacket<T>(value).At(Timestamp(ts)));
|
||||
}
|
||||
|
||||
struct Params {
|
||||
bool use_tick_tag = false;
|
||||
};
|
||||
|
||||
class PacketClonerCalculatorTest : public testing::TestWithParam<Params> {};
|
||||
|
||||
TEST_P(PacketClonerCalculatorTest, ClonesSingleInputSameTimestamps) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>([&]() {
|
||||
if (GetParam().use_tick_tag) {
|
||||
return R"pb(
|
||||
input_stream: 'in1'
|
||||
input_stream: 'tick'
|
||||
node {
|
||||
calculator: 'PacketClonerCalculator'
|
||||
input_stream: 'in1'
|
||||
input_stream: 'TICK:tick'
|
||||
output_stream: 'out1'
|
||||
})pb";
|
||||
}
|
||||
return R"pb(
|
||||
input_stream: 'in1'
|
||||
input_stream: 'tick'
|
||||
node {
|
||||
calculator: 'PacketClonerCalculator'
|
||||
input_stream: 'in1'
|
||||
input_stream: 'tick'
|
||||
output_stream: 'out1'
|
||||
})pb";
|
||||
}());
|
||||
std::vector<Packet> out1;
|
||||
tool::AddVectorSink("out1", &graph_config, &out1);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(SendPacket("in1", 1, /*ts=*/10000, graph));
|
||||
MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(out1, ElementsAre(IntPacket(1, 10000)));
|
||||
}
|
||||
|
||||
TEST_P(PacketClonerCalculatorTest, ClonesSingleInputEarlierTimestamps) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>([&]() {
|
||||
if (GetParam().use_tick_tag) {
|
||||
return R"pb(
|
||||
input_stream: 'in1'
|
||||
input_stream: 'tick'
|
||||
node {
|
||||
calculator: 'PacketClonerCalculator'
|
||||
input_stream: 'in1'
|
||||
input_stream: 'TICK:tick'
|
||||
output_stream: 'out1'
|
||||
})pb";
|
||||
}
|
||||
return R"pb(
|
||||
input_stream: 'in1'
|
||||
input_stream: 'tick'
|
||||
node {
|
||||
calculator: 'PacketClonerCalculator'
|
||||
input_stream: 'in1'
|
||||
input_stream: 'tick'
|
||||
output_stream: 'out1'
|
||||
})pb";
|
||||
}());
|
||||
std::vector<Packet> out1;
|
||||
tool::AddVectorSink("out1", &graph_config, &out1);
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
// PacketClonerCalculator is non-ImmediateInputStreamHandler
|
||||
// PacketClonerCalculator waits for "in1" to arrive for ts=5000
|
||||
MP_ASSERT_OK(SendPacket("in1", 1, /*ts=*/5000, graph));
|
||||
// Newer tick at ts=10000, should NOT trigger output for ts=5000
|
||||
// PacketClonerCalculator waits for "in1" to arrive for ts=10000
|
||||
MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph));
|
||||
MP_ASSERT_OK(SendPacket("tick", 1001, /*ts=*/10001, graph));
|
||||
MP_ASSERT_OK(SendPacket("tick", 1002, /*ts=*/10002, graph));
|
||||
// Newer "in1" at ts=15000, should trigger output for ts=10000
|
||||
MP_ASSERT_OK(SendPacket("in1", 2, /*ts=*/15000, graph));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(out1, ElementsAre(IntPacket(1, 10000), IntPacket(1, 10001),
|
||||
IntPacket(1, 10002)));
|
||||
}
|
||||
|
||||
TEST_P(PacketClonerCalculatorTest, ClonesFiveInputs) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>([&]() {
|
||||
if (GetParam().use_tick_tag) {
|
||||
return R"pb(
|
||||
input_stream: 'in1'
|
||||
input_stream: 'in2'
|
||||
input_stream: 'in3'
|
||||
input_stream: 'in4'
|
||||
input_stream: 'in5'
|
||||
input_stream: 'tick'
|
||||
node {
|
||||
calculator: 'PacketClonerCalculator'
|
||||
input_stream: 'in1'
|
||||
input_stream: 'in2'
|
||||
input_stream: 'in3'
|
||||
input_stream: 'in4'
|
||||
input_stream: 'in5'
|
||||
output_stream: 'out1'
|
||||
output_stream: 'out2'
|
||||
output_stream: 'out3'
|
||||
input_stream: 'TICK:tick' # arbitrary location
|
||||
output_stream: 'out4'
|
||||
output_stream: 'out5'
|
||||
}
|
||||
)pb";
|
||||
}
|
||||
return R"pb(
|
||||
input_stream: 'in1'
|
||||
input_stream: 'in2'
|
||||
input_stream: 'in3'
|
||||
input_stream: 'in4'
|
||||
input_stream: 'in5'
|
||||
input_stream: 'tick'
|
||||
node {
|
||||
calculator: 'PacketClonerCalculator'
|
||||
input_stream: 'in1'
|
||||
input_stream: 'in2'
|
||||
input_stream: 'in3'
|
||||
input_stream: 'in4'
|
||||
input_stream: 'in5'
|
||||
input_stream: 'tick'
|
||||
output_stream: 'out1'
|
||||
output_stream: 'out2'
|
||||
output_stream: 'out3'
|
||||
output_stream: 'out4'
|
||||
output_stream: 'out5'
|
||||
}
|
||||
)pb";
|
||||
}());
|
||||
constexpr int kNumToClone = 5;
|
||||
std::array<std::vector<Packet>, kNumToClone> outs;
|
||||
for (int i = 0; i < kNumToClone; ++i) {
|
||||
tool::AddVectorSink(absl::StrCat("out", i + 1), &graph_config, &outs[i]);
|
||||
}
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
MP_ASSERT_OK(SendPacket("in1", 10, /*ts=*/10000, graph));
|
||||
MP_ASSERT_OK(SendPacket("in2", 20.0f, /*ts=*/10000, graph));
|
||||
MP_ASSERT_OK(SendPacket("in3", 30, /*ts=*/10000, graph));
|
||||
MP_ASSERT_OK(SendPacket("in4", 40.0f, /*ts=*/10000, graph));
|
||||
MP_ASSERT_OK(SendPacket("in5", 50, /*ts=*/10000, graph));
|
||||
MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph));
|
||||
// Below "tick" packets won't trigger output, until newer inputs are sent,
|
||||
// because inputs are missing and ImmediateInputStreamHandler is not
|
||||
// configured.
|
||||
MP_ASSERT_OK(SendPacket("tick", 1001, /*ts=*/10001, graph));
|
||||
MP_ASSERT_OK(SendPacket("tick", 1002, /*ts=*/10002, graph));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(outs, ElementsAre(ElementsAre(IntPacket(10, 10000)),
|
||||
ElementsAre(FloatPacket(20.0f, 10000)),
|
||||
ElementsAre(IntPacket(30, 10000)),
|
||||
ElementsAre(FloatPacket(40.0f, 10000)),
|
||||
ElementsAre(IntPacket(50, 10000))));
|
||||
|
||||
MP_ASSERT_OK(SendPacket("in1", 100, /*ts=*/20000, graph));
|
||||
MP_ASSERT_OK(SendPacket("in2", 200.0f, /*ts=*/20000, graph));
|
||||
MP_ASSERT_OK(SendPacket("in3", 300, /*ts=*/20000, graph));
|
||||
MP_ASSERT_OK(SendPacket("in4", 400.0f, /*ts=*/20000, graph));
|
||||
MP_ASSERT_OK(SendPacket("in5", 500, /*ts=*/20000, graph));
|
||||
MP_ASSERT_OK(SendPacket("tick", 2000, /*ts=*/20000, graph));
|
||||
// Below "tick" packets won't trigger output, because inputs are missing and
|
||||
// ImmediateInputStreamHandler is not configured.
|
||||
MP_ASSERT_OK(SendPacket("tick", 2001, /*ts=*/20001, graph));
|
||||
MP_ASSERT_OK(SendPacket("tick", 2002, /*ts=*/20002, graph));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(
|
||||
outs,
|
||||
ElementsAre(
|
||||
ElementsAre(IntPacket(10, 10000), IntPacket(10, 10001),
|
||||
IntPacket(10, 10002), IntPacket(100, 20000)),
|
||||
ElementsAre(FloatPacket(20.0f, 10000), FloatPacket(20.0f, 10001),
|
||||
FloatPacket(20.0f, 10002), FloatPacket(200.0f, 20000)),
|
||||
ElementsAre(IntPacket(30, 10000), IntPacket(30, 10001),
|
||||
IntPacket(30, 10002), IntPacket(300, 20000)),
|
||||
ElementsAre(FloatPacket(40.0f, 10000), FloatPacket(40.0f, 10001),
|
||||
FloatPacket(40.0f, 10002), FloatPacket(400.0f, 20000)),
|
||||
ElementsAre(IntPacket(50, 10000), IntPacket(50, 10001),
|
||||
IntPacket(50, 10002), IntPacket(500, 20000))));
|
||||
}
|
||||
|
||||
TEST_P(PacketClonerCalculatorTest,
|
||||
ClonesTwoInputsWithImmediateInputStreamHandler) {
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>([&]() {
|
||||
if (GetParam().use_tick_tag) {
|
||||
return R"pb(
|
||||
input_stream: 'in1'
|
||||
input_stream: 'in2'
|
||||
input_stream: 'tick'
|
||||
node {
|
||||
calculator: 'PacketClonerCalculator'
|
||||
input_stream: 'TICK:tick'
|
||||
input_stream: 'in1'
|
||||
input_stream: 'in2'
|
||||
output_stream: 'out1'
|
||||
output_stream: 'out2'
|
||||
input_stream_handler {
|
||||
input_stream_handler: "ImmediateInputStreamHandler"
|
||||
}
|
||||
})pb";
|
||||
}
|
||||
return R"pb(
|
||||
input_stream: 'in1'
|
||||
input_stream: 'in2'
|
||||
input_stream: 'tick'
|
||||
node {
|
||||
calculator: 'PacketClonerCalculator'
|
||||
input_stream: 'in1'
|
||||
input_stream: 'in2'
|
||||
input_stream: 'tick'
|
||||
output_stream: 'out1'
|
||||
output_stream: 'out2'
|
||||
input_stream_handler {
|
||||
input_stream_handler: "ImmediateInputStreamHandler"
|
||||
}
|
||||
})pb";
|
||||
}());
|
||||
constexpr int kNumToClone = 2;
|
||||
std::array<std::vector<Packet>, kNumToClone> outs;
|
||||
for (int i = 0; i < kNumToClone; ++i) {
|
||||
tool::AddVectorSink(absl::StrCat("out", i + 1), &graph_config, &outs[i]);
|
||||
}
|
||||
|
||||
CalculatorGraph graph;
|
||||
MP_ASSERT_OK(graph.Initialize(graph_config, {}));
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
// No packets to clone.
|
||||
MP_ASSERT_OK(SendPacket("tick", 0, /*ts=*/0, graph));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Cloning current packets.
|
||||
MP_ASSERT_OK(SendPacket("in1", 1, /*ts=*/10000, graph));
|
||||
MP_ASSERT_OK(SendPacket("in2", 10.0f, /*ts=*/10000, graph));
|
||||
MP_ASSERT_OK(SendPacket("tick", 1000, /*ts=*/10000, graph));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Cloning past packets.
|
||||
MP_ASSERT_OK(SendPacket("tick", 1500, /*ts=*/15000, graph));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Cloning past packets.
|
||||
MP_ASSERT_OK(SendPacket("in1", 2, /*ts=*/10001, graph));
|
||||
MP_ASSERT_OK(SendPacket("in2", 20.0f, /*ts=*/10001, graph));
|
||||
MP_ASSERT_OK(SendPacket("tick", 2000, /*ts=*/20000, graph));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Cloning future packets.
|
||||
MP_ASSERT_OK(SendPacket("in1", 3, /*ts=*/30000, graph));
|
||||
MP_ASSERT_OK(SendPacket("in2", 30.0f, /*ts=*/30000, graph));
|
||||
// Waiting to ensure newer packets (ts=30000) to clone would get into the
|
||||
// cloner before tick (ts=25000) does.
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
MP_ASSERT_OK(SendPacket("tick", 3000, /*ts=*/25000, graph));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Cloning packets having different timestamps.
|
||||
MP_ASSERT_OK(SendPacket("in1", 4, /*ts=*/38000, graph));
|
||||
MP_ASSERT_OK(SendPacket("in2", 40.0f, /*ts=*/39000, graph));
|
||||
MP_ASSERT_OK(SendPacket("tick", 4000, /*ts=*/40000, graph));
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
EXPECT_THAT(
|
||||
outs,
|
||||
ElementsAre(
|
||||
ElementsAre(IntPacket(1, 10000), IntPacket(1, 15000),
|
||||
IntPacket(2, 20000), IntPacket(3, 25000),
|
||||
IntPacket(4, 40000)),
|
||||
ElementsAre(FloatPacket(10.0f, 10000), FloatPacket(10.0f, 15000),
|
||||
FloatPacket(20.0f, 20000), FloatPacket(30.0f, 25000),
|
||||
FloatPacket(40.0f, 40000))));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(PacketClonerCalculator, PacketClonerCalculatorTest,
|
||||
testing::ValuesIn({Params{.use_tick_tag = false},
|
||||
Params{.use_tick_tag = true}}));
|
||||
} // anonymous namespace
|
||||
} // namespace mediapipe
|
|
@ -157,9 +157,7 @@ absl::Status PacketResamplerCalculator::Process(CalculatorContext* cc) {
|
|||
}
|
||||
}
|
||||
|
||||
if (absl::Status status = strategy_->Process(cc); !status.ok()) {
|
||||
return status; // Avoid MP_RETURN_IF_ERROR macro for external release.
|
||||
}
|
||||
MP_RETURN_IF_ERROR(strategy_->Process(cc));
|
||||
|
||||
last_packet_ = cc->Inputs().Get(input_data_id_).Value();
|
||||
|
||||
|
|
|
@ -626,11 +626,8 @@ cc_library(
|
|||
"//mediapipe/framework/formats:image_format_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:image_opencv",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/port:vector",
|
||||
] + select({
|
||||
|
@ -641,6 +638,13 @@ cc_library(
|
|||
"//mediapipe/gpu:gl_quad_renderer",
|
||||
"//mediapipe/gpu:shader_util",
|
||||
],
|
||||
}) + select({
|
||||
"//mediapipe/framework/port:disable_opencv": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/framework/formats:image_frame_opencv",
|
||||
"//mediapipe/framework/formats:image_opencv",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -727,7 +731,6 @@ cc_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":affine_transformation",
|
||||
":affine_transformation_runner_opencv",
|
||||
":warp_affine_calculator_cc_proto",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
|
@ -745,6 +748,9 @@ cc_library(
|
|||
"//mediapipe/gpu:gpu_buffer",
|
||||
":affine_transformation_runner_gl",
|
||||
],
|
||||
}) + select({
|
||||
"//mediapipe/framework/port:disable_opencv": [],
|
||||
"//conditions:default": [":affine_transformation_runner_opencv"],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -799,3 +805,21 @@ cc_test(
|
|||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "yuv_to_image_calculator",
|
||||
srcs = ["yuv_to_image_calculator.cc"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:yuv_image",
|
||||
"//third_party/libyuv",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
|
@ -21,10 +21,7 @@
|
|||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_format.pb.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/formats/image_opencv.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/vector.h"
|
||||
|
||||
|
@ -34,6 +31,12 @@
|
|||
#include "mediapipe/gpu/shader_util.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
#include "mediapipe/framework/formats/image_frame_opencv.h"
|
||||
#include "mediapipe/framework/formats/image_opencv.h"
|
||||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
@ -163,7 +166,11 @@ absl::Status SegmentationSmoothingCalculator::Process(CalculatorContext* cc) {
|
|||
return absl::InternalError("GPU processing is disabled.");
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
} else {
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
MP_RETURN_IF_ERROR(RenderCpu(cc));
|
||||
#else
|
||||
return absl::InternalError("OpenCV processing is disabled.");
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -181,6 +188,7 @@ absl::Status SegmentationSmoothingCalculator::Close(CalculatorContext* cc) {
|
|||
}
|
||||
|
||||
absl::Status SegmentationSmoothingCalculator::RenderCpu(CalculatorContext* cc) {
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
// Setup source images.
|
||||
const auto& current_frame = cc->Inputs().Tag(kCurrentMaskTag).Get<Image>();
|
||||
auto current_mat = mediapipe::formats::MatView(¤t_frame);
|
||||
|
@ -245,6 +253,7 @@ absl::Status SegmentationSmoothingCalculator::RenderCpu(CalculatorContext* cc) {
|
|||
cc->Outputs()
|
||||
.Tag(kOutputMaskTag)
|
||||
.AddPacket(MakePacket<Image>(output_frame).At(cc->InputTimestamp()));
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -24,7 +24,9 @@
|
|||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
#include "mediapipe/calculators/image/affine_transformation_runner_opencv.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
#include "mediapipe/calculators/image/warp_affine_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
|
@ -54,6 +56,7 @@ AffineTransformation::BorderMode GetBorderMode(
|
|||
template <typename ImageT>
|
||||
class WarpAffineRunnerHolder {};
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
template <>
|
||||
class WarpAffineRunnerHolder<ImageFrame> {
|
||||
public:
|
||||
|
@ -69,6 +72,7 @@ class WarpAffineRunnerHolder<ImageFrame> {
|
|||
private:
|
||||
std::unique_ptr<RunnerType> runner_;
|
||||
};
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
template <>
|
||||
|
@ -113,7 +117,9 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
|
|||
mediapipe::Image> {
|
||||
public:
|
||||
absl::Status Open(CalculatorContext* cc) {
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
MP_RETURN_IF_ERROR(cpu_holder_.Open(cc));
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
MP_RETURN_IF_ERROR(gpu_holder_.Open(cc));
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
@ -133,20 +139,26 @@ class WarpAffineRunnerHolder<mediapipe::Image> {
|
|||
return absl::UnavailableError("GPU support is disabled");
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
}
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
ASSIGN_OR_RETURN(auto* runner, cpu_holder_.GetRunner());
|
||||
const auto& frame_ptr = input.GetImageFrameSharedPtr();
|
||||
// Wrap image into image frame.
|
||||
const ImageFrame image_frame(frame_ptr->Format(), frame_ptr->Width(),
|
||||
frame_ptr->Height(), frame_ptr->WidthStep(),
|
||||
const_cast<uint8_t*>(frame_ptr->PixelData()),
|
||||
[](uint8* data) {});
|
||||
[](uint8* data){});
|
||||
ASSIGN_OR_RETURN(auto result,
|
||||
runner->Run(image_frame, matrix, size, border_mode));
|
||||
return mediapipe::Image(std::make_shared<ImageFrame>(std::move(result)));
|
||||
#else
|
||||
return absl::UnavailableError("OpenCV support is disabled");
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
}
|
||||
|
||||
private:
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
WarpAffineRunnerHolder<ImageFrame> cpu_holder_;
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
WarpAffineRunnerHolder<mediapipe::GpuBuffer> gpu_holder_;
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
@ -200,8 +212,10 @@ class WarpAffineCalculatorImpl : public mediapipe::api2::NodeImpl<InterfaceT> {
|
|||
|
||||
} // namespace
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
MEDIAPIPE_NODE_IMPLEMENTATION(
|
||||
WarpAffineCalculatorImpl<WarpAffineCalculatorCpu>);
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
MEDIAPIPE_NODE_IMPLEMENTATION(
|
||||
WarpAffineCalculatorImpl<WarpAffineCalculatorGpu>);
|
||||
|
|
|
@ -70,11 +70,13 @@ class WarpAffineCalculatorIntf : public mediapipe::api2::NodeIntf {
|
|||
static constexpr mediapipe::api2::Output<ImageT> kOutImage{"IMAGE"};
|
||||
};
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
class WarpAffineCalculatorCpu : public WarpAffineCalculatorIntf<ImageFrame> {
|
||||
public:
|
||||
MEDIAPIPE_NODE_INTERFACE(WarpAffineCalculatorCpu, kInImage, kMatrix,
|
||||
kOutputSize, kOutImage);
|
||||
};
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
class WarpAffineCalculatorGpu
|
||||
: public WarpAffineCalculatorIntf<mediapipe::GpuBuffer> {
|
||||
|
|
123
mediapipe/calculators/image/yuv_to_image_calculator.cc
Normal file
123
mediapipe/calculators/image/yuv_to_image_calculator.cc
Normal file
|
@ -0,0 +1,123 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "libyuv/convert_argb.h"
|
||||
#include "libyuv/video_common.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_frame.h"
|
||||
#include "mediapipe/framework/formats/yuv_image.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
namespace {
|
||||
|
||||
// Utility function to convert FourCC enum to string, for error messages.
|
||||
std::string FourCCToString(libyuv::FourCC fourcc) {
|
||||
char buf[5];
|
||||
buf[0] = (fourcc >> 24) & 0xff;
|
||||
buf[1] = (fourcc >> 16) & 0xff;
|
||||
buf[2] = (fourcc >> 8) & 0xff;
|
||||
buf[3] = (fourcc)&0xff;
|
||||
buf[4] = 0;
|
||||
return std::string(buf);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Converts a `YUVImage` into an RGB `Image` using libyuv.
|
||||
//
|
||||
// The input `YUVImage` is expected to be in the NV12, NV21, YV12 or I420 (aka
|
||||
// YV21) format (as per the `fourcc()` property). This covers the most commonly
|
||||
// used YUV image formats used on mobile devices. Other formats are not
|
||||
// supported and wil result in an `InvalidArgumentError`.
|
||||
class YUVToImageCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<YUVImage> kInput{"YUV_IMAGE"};
|
||||
static constexpr Output<Image> kOutput{"IMAGE"};
|
||||
|
||||
MEDIAPIPE_NODE_CONTRACT(kInput, kOutput);
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) override {
|
||||
const auto& yuv_image = *kInput(cc);
|
||||
// Check that the format is supported.
|
||||
auto format = yuv_image.fourcc();
|
||||
if (format != libyuv::FOURCC_NV12 && format != libyuv::FOURCC_NV21 &&
|
||||
format != libyuv::FOURCC_YV12 && format != libyuv::FOURCC_I420) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Unsupported YUVImage format: %s. Only NV12, NV21, "
|
||||
"YV12 and I420 (aka YV21) are supported.",
|
||||
FourCCToString(format)));
|
||||
}
|
||||
// Build a transient ImageFrameSharedPtr with default alignment to host
|
||||
// conversion results.
|
||||
ImageFrameSharedPtr image_frame = std::make_shared<ImageFrame>(
|
||||
ImageFormat::SRGB, yuv_image.width(), yuv_image.height());
|
||||
// Perform actual conversion.
|
||||
switch (format) {
|
||||
case libyuv::FOURCC_NV12:
|
||||
// 8-bit Y plane followed by an interleaved 8-bit U/V plane with 2×2
|
||||
// subsampling.
|
||||
libyuv::NV12ToRAW(
|
||||
yuv_image.data(0), yuv_image.stride(0), yuv_image.data(1),
|
||||
yuv_image.stride(1), image_frame->MutablePixelData(),
|
||||
image_frame->WidthStep(), yuv_image.width(), yuv_image.height());
|
||||
break;
|
||||
case libyuv::FOURCC_NV21:
|
||||
// 8-bit Y plane followed by an interleaved 8-bit V/U plane with 2×2
|
||||
// subsampling.
|
||||
libyuv::NV21ToRAW(
|
||||
yuv_image.data(0), yuv_image.stride(0), yuv_image.data(1),
|
||||
yuv_image.stride(1), image_frame->MutablePixelData(),
|
||||
image_frame->WidthStep(), yuv_image.width(), yuv_image.height());
|
||||
break;
|
||||
case libyuv::FOURCC_I420:
|
||||
// Also known as YV21.
|
||||
// 8-bit Y plane followed by 8-bit 2×2 subsampled U and V planes.
|
||||
libyuv::I420ToRAW(
|
||||
yuv_image.data(0), yuv_image.stride(0), yuv_image.data(1),
|
||||
yuv_image.stride(1), yuv_image.data(2), yuv_image.stride(2),
|
||||
image_frame->MutablePixelData(), image_frame->WidthStep(),
|
||||
yuv_image.width(), yuv_image.height());
|
||||
break;
|
||||
case libyuv::FOURCC_YV12:
|
||||
// 8-bit Y plane followed by 8-bit 2×2 subsampled V and U planes.
|
||||
libyuv::I420ToRAW(
|
||||
yuv_image.data(0), yuv_image.stride(0), yuv_image.data(2),
|
||||
yuv_image.stride(2), yuv_image.data(1), yuv_image.stride(1),
|
||||
image_frame->MutablePixelData(), image_frame->WidthStep(),
|
||||
yuv_image.width(), yuv_image.height());
|
||||
break;
|
||||
default:
|
||||
// This should never happen (caught by checks above).
|
||||
return absl::InternalError("Unsupported YUVImage format.");
|
||||
}
|
||||
// Finally, build and send an Image object that takes ownership of the
|
||||
// transient ImageFrameSharedPtr object.
|
||||
kOutput(cc).Send(std::make_unique<Image>(std::move(image_frame)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(YUVToImageCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -40,6 +40,63 @@ selects.config_setting_group(
|
|||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "audio_to_tensor_calculator_proto",
|
||||
srcs = ["audio_to_tensor_calculator.proto"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "audio_to_tensor_calculator",
|
||||
srcs = ["audio_to_tensor_calculator.cc"],
|
||||
visibility = [
|
||||
"//mediapipe/framework:mediapipe_internal",
|
||||
],
|
||||
deps = [
|
||||
":audio_to_tensor_calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/api2:port",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/formats:time_series_header_cc_proto",
|
||||
"//mediapipe/util:time_series_util",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_audio_tools//audio/dsp:resampler_q",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "audio_to_tensor_calculator_test",
|
||||
srcs = ["audio_to_tensor_calculator_test.cc"],
|
||||
deps = [
|
||||
":audio_to_tensor_calculator",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:timestamp",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
"//mediapipe/framework/formats:matrix",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:gtest_main",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_audio_tools//audio/dsp:resampler_q",
|
||||
"@org_tensorflow//tensorflow/lite/c:common",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
name = "inference_calculator_proto",
|
||||
srcs = ["inference_calculator.proto"],
|
||||
|
@ -50,6 +107,14 @@ mediapipe_proto_library(
|
|||
],
|
||||
)
|
||||
|
||||
# This target defines the "InferenceCalculator" component, which looks for the available concrete
|
||||
# implementations linked into the current binary and picks the one to use.
|
||||
# You can depend on :inference_calculator instead if you want to automatically include a default
|
||||
# set of implementations tailored for the current build configuration.
|
||||
# If you want to have precise control of which implementations to include (e.g. for strict binary
|
||||
# size concerns), depend on those implementations directly, and do not depend on
|
||||
# :inference_calculator.
|
||||
# In all cases, use "InferenceCalulator" in your graphs.
|
||||
cc_library(
|
||||
name = "inference_calculator_interface",
|
||||
srcs = ["inference_calculator.cc"],
|
||||
|
@ -62,8 +127,9 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_calculator_cc_proto",
|
||||
":inference_calculator_options_lib",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework/api2:node",
|
||||
"//mediapipe/framework/api2:packet",
|
||||
|
@ -85,18 +151,31 @@ cc_library(
|
|||
name = "inference_calculator_gl",
|
||||
srcs = ["inference_calculator_gl.cc"],
|
||||
tags = ["nomac"], # config problem with cpuinfo via TF
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"inference_calculator_interface",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
":inference_calculator_interface",
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//mediapipe/gpu:gpu_buffer",
|
||||
"//mediapipe/util/tflite:config",
|
||||
"//mediapipe/util/tflite:tflite_gpu_runner",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "inference_calculator_gl_advanced",
|
||||
srcs = ["inference_calculator_gl_advanced.cc"],
|
||||
tags = ["nomac"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_calculator_interface",
|
||||
"//mediapipe/framework/deps:file_path",
|
||||
"//mediapipe/gpu:gl_calculator_helper",
|
||||
"//mediapipe/util/tflite:tflite_gpu_runner",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@org_tensorflow//tensorflow/lite:framework_stable",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -113,6 +192,7 @@ cc_library(
|
|||
"-framework MetalKit",
|
||||
],
|
||||
tags = ["ios"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"inference_calculator_interface",
|
||||
"//mediapipe/gpu:MPPMetalHelper",
|
||||
|
@ -142,6 +222,7 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":inference_calculator_interface",
|
||||
"@com_google_absl//absl/memory",
|
||||
|
@ -161,9 +242,13 @@ cc_library(
|
|||
|
||||
cc_library(
|
||||
name = "inference_calculator_gl_if_compute_shader_available",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = selects.with_or({
|
||||
":compute_shader_unavailable": [],
|
||||
"//conditions:default": [":inference_calculator_gl"],
|
||||
"//conditions:default": [
|
||||
":inference_calculator_gl",
|
||||
":inference_calculator_gl_advanced",
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
|
@ -484,6 +569,7 @@ cc_library(
|
|||
"//mediapipe/framework/formats:location",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/util:label_map_cc_proto",
|
||||
"//mediapipe/util:resource_util",
|
||||
] + select({
|
||||
"//mediapipe:android": [
|
||||
|
@ -506,6 +592,7 @@ mediapipe_proto_library(
|
|||
deps = [
|
||||
"//mediapipe/framework:calculator_options_proto",
|
||||
"//mediapipe/framework:calculator_proto",
|
||||
"//mediapipe/util:label_map_proto",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -672,6 +759,7 @@ cc_library(
|
|||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":image_to_tensor_converter",
|
||||
":image_to_tensor_utils",
|
||||
|
@ -858,9 +946,7 @@ cc_library(
|
|||
"@com_google_absl//absl/types:span",
|
||||
"//mediapipe/framework/formats:image",
|
||||
"//mediapipe/framework/formats:image_frame",
|
||||
"//mediapipe/framework/formats:image_opencv",
|
||||
"//mediapipe/framework/formats:tensor",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"//mediapipe/framework:calculator_context",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
|
@ -890,6 +976,12 @@ cc_library(
|
|||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture",
|
||||
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl/converters:util",
|
||||
],
|
||||
}) + select({
|
||||
"//mediapipe/framework/port:disable_opencv": [],
|
||||
"//conditions:default": [
|
||||
"//mediapipe/framework/formats:image_opencv",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
],
|
||||
}),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
401
mediapipe/calculators/tensor/audio_to_tensor_calculator.cc
Normal file
401
mediapipe/calculators/tensor/audio_to_tensor_calculator.cc
Normal file
|
@ -0,0 +1,401 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "audio/dsp/resampler_q.h"
|
||||
#include "mediapipe/calculators/tensor/audio_to_tensor_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/node.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/api2/port.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/formats/time_series_header.pb.h"
|
||||
#include "mediapipe/util/time_series_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Converts audio buffers into tensors, possibly with resampling, buffering
|
||||
// and framing, according to specified inputs and options. All input audio
|
||||
// buffers will be first resampled from the input sample rate to the target
|
||||
// sample rate if they are not equal. The resampled audio data (with the
|
||||
// buffered samples from the previous runs in the streaming mode) will be broken
|
||||
// into fixed-sized, possibly overlapping frames. Finally, all frames will be
|
||||
// converted to and outputted as MediaPipe Tensors. The last output tensor will
|
||||
// be zero-padding if the remaining samples are insufficient.
|
||||
//
|
||||
// This calculator assumes that the input timestamps refer to the first
|
||||
// sample in each Matrix. The output timestamps follow this same convention.
|
||||
// One Process() call may output multiple tensors packets. The timestamps of
|
||||
// the output packets are determined by the timestamp of the previous output
|
||||
// packet, the target sample rate, and the number of samples advanced after the
|
||||
// previous output.
|
||||
//
|
||||
// The calculator has two running modes:
|
||||
// Streaming mode: when "streaming_mode" is set to true in the calculator
|
||||
// options, the calculator treats the input audio stream as a continuous
|
||||
// stream. Thus, any samples that are not consumed in the previous runs will
|
||||
// be cached in a global sample buffer. The audio data resampled from the
|
||||
// current raw audio input will be appended to the global sample buffer.
|
||||
// The calculator will process the global sample buffer and output as many
|
||||
// tensors as possible.
|
||||
// Non-streaming mode: when "streaming_mode" is set to false in the calculator
|
||||
// options, the calculators treats the packets in the input audio stream as
|
||||
// a batch of unrelated audio buffers. In each Process() call, the input
|
||||
// buffer will be frist resampled, and framed as fixed-sized, possibly
|
||||
// overlapping tensors. The last tensor produced by a Process() invocation
|
||||
// will be zero-padding if the remaining samples are insufficient. As the
|
||||
// calculator treats the input packets as unrelated, all samples will be
|
||||
// processed immediately and no samples will be cached in the global sample
|
||||
// buffer.
|
||||
//
|
||||
// Inputs:
|
||||
// AUDIO - mediapipe::Matrix
|
||||
// The audio data represented as mediapipe::Matrix.
|
||||
// SAMPLE_RATE - double @Optional
|
||||
// The sample rate of the corresponding audio data in the "AUDIO" stream.
|
||||
// If a sample rate packet is provided at Timestamp::PreStream(), the sample
|
||||
// rate will be used as the sample rate of every audio packets in the
|
||||
// "AUDIO" stream. Note that one and only one of the "AUDIO" stream's time
|
||||
// series header or the "SAMPLE_RATE" stream can exist.
|
||||
//
|
||||
// Outputs:
|
||||
// TENSORS - std::vector<Tensor>
|
||||
// Vector containing a single Tensor that represents a fix-sized audio
|
||||
// frame.
|
||||
// TIMESTAMPS - std::vector<Timestamp> @Optional
|
||||
// Vector containing the output timestamps emitted by the current Process()
|
||||
// invocation. In the non-streaming mode, the vector contains all of the
|
||||
// output timestamps for an input audio buffer.
|
||||
//
|
||||
// Example:
|
||||
// node {
|
||||
// calculator: "AudioToTensorCalculator"
|
||||
// input_stream: "AUDIO:audio"
|
||||
// output_stream: "TENSORS:tensors"
|
||||
// output_stream: "TIMESTAMPS:timestamps"
|
||||
// options {
|
||||
// [mediapipe.AudioToTensorCalculatorOptions.ext] {
|
||||
// num_channels: 2
|
||||
// num_samples: 512
|
||||
// num_overlapping_samples: 64
|
||||
// target_sample_rate: 16000
|
||||
// streaming_mode: true # or false
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class AudioToTensorCalculator : public Node {
|
||||
public:
|
||||
static constexpr Input<Matrix> kAudioIn{"AUDIO"};
|
||||
// TODO: Removes this optional input stream when the "AUDIO" stream
|
||||
// uses the new mediapipe audio data containers that carry audio metatdata,
|
||||
// such as sample rate.
|
||||
static constexpr Input<double>::Optional kAudioSampleRateIn{"SAMPLE_RATE"};
|
||||
static constexpr Output<std::vector<Tensor>> kTensorsOut{"TENSORS"};
|
||||
// A vector of the output timestamps emitted by the current Process()
|
||||
// invocation. The packet timestamp is the last emitted timestamp.
|
||||
static constexpr Output<std::vector<Timestamp>>::Optional kTimestampsOut{
|
||||
"TIMESTAMPS"};
|
||||
MEDIAPIPE_NODE_CONTRACT(kAudioIn, kAudioSampleRateIn, kTensorsOut,
|
||||
kTimestampsOut);
|
||||
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
absl::Status Open(CalculatorContext* cc);
|
||||
absl::Status Process(CalculatorContext* cc);
|
||||
absl::Status Close(CalculatorContext* cc);
|
||||
|
||||
private:
|
||||
// The target number of channels.
|
||||
int num_channels_;
|
||||
// The target number of samples per channel.
|
||||
int num_samples_;
|
||||
// The number of samples per channel to advance after the current frame is
|
||||
// processed.
|
||||
int frame_step_;
|
||||
bool streaming_mode_;
|
||||
bool check_inconsistent_timestamps_;
|
||||
Timestamp initial_timestamp_ = Timestamp::Unstarted();
|
||||
int64 cumulative_input_samples_ = 0;
|
||||
Timestamp next_output_timestamp_ = Timestamp::Unstarted();
|
||||
|
||||
double source_sample_rate_ = -1;
|
||||
double target_sample_rate_ = -1;
|
||||
// TODO: Configures QResamplerParams through calculator options.
|
||||
audio_dsp::QResamplerParams params_;
|
||||
// A QResampler instance to resample an audio stream.
|
||||
std::unique_ptr<audio_dsp::QResampler<float>> resampler_;
|
||||
Matrix sample_buffer_;
|
||||
int processed_buffer_cols_ = 0;
|
||||
|
||||
absl::Status ProcessStreamingData(CalculatorContext* cc);
|
||||
absl::Status ProcessNonStreamingData(CalculatorContext* cc);
|
||||
|
||||
absl::Status SetupStreamingResampler(double input_sample_rate_);
|
||||
void AppendToSampleBuffer(Matrix buffer_to_append);
|
||||
|
||||
absl::StatusOr<std::vector<Tensor>> ConvertToTensor(
|
||||
const Matrix& frame_to_convert);
|
||||
absl::Status OutputTensors(const Matrix& buffer, bool should_flush,
|
||||
CalculatorContext* cc);
|
||||
};
|
||||
|
||||
absl::Status AudioToTensorCalculator::UpdateContract(CalculatorContract* cc) {
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::AudioToTensorCalculatorOptions>();
|
||||
if (!options.has_num_channels() || !options.has_num_samples() ||
|
||||
!options.has_target_sample_rate()) {
|
||||
return absl::InvalidArgumentError(
|
||||
"AudioToTensorCalculatorOptions must specifiy "
|
||||
"`num_channels`, `num_samples`, and `target_sample_rate`.");
|
||||
}
|
||||
if (options.streaming_mode()) {
|
||||
// Explicitly disables tiemstamp offset to disallow the timestamp bound
|
||||
// from the input streams to be propagated to the output streams.
|
||||
// In the streaming mode, the output timestamp bound is based on
|
||||
// next_output_timestamp_, which can be smaller than the current input
|
||||
// timestamps.
|
||||
cc->SetTimestampOffset(TimestampDiff::Unset());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AudioToTensorCalculator::Open(CalculatorContext* cc) {
|
||||
const auto& options =
|
||||
cc->Options<mediapipe::AudioToTensorCalculatorOptions>();
|
||||
num_channels_ = options.num_channels();
|
||||
num_samples_ = options.num_samples();
|
||||
if (options.has_num_overlapping_samples()) {
|
||||
RET_CHECK_GE(options.num_overlapping_samples(), 0);
|
||||
RET_CHECK_LT(options.num_overlapping_samples(), num_samples_);
|
||||
frame_step_ = num_samples_ - options.num_overlapping_samples();
|
||||
} else {
|
||||
frame_step_ = num_samples_;
|
||||
}
|
||||
target_sample_rate_ = options.target_sample_rate();
|
||||
streaming_mode_ = options.streaming_mode();
|
||||
if (streaming_mode_) {
|
||||
check_inconsistent_timestamps_ = options.check_inconsistent_timestamps();
|
||||
sample_buffer_.resize(num_channels_, Eigen::NoChange);
|
||||
}
|
||||
|
||||
RET_CHECK(kAudioSampleRateIn(cc).IsConnected() ^
|
||||
!kAudioIn(cc).Header().IsEmpty())
|
||||
<< "Must either specify the time series header of the \"AUDIO\" stream "
|
||||
"or have the \"SAMPLE_RATE\" stream connected.";
|
||||
if (!kAudioIn(cc).Header().IsEmpty()) {
|
||||
mediapipe::TimeSeriesHeader input_header;
|
||||
MP_RETURN_IF_ERROR(mediapipe::time_series_util::FillTimeSeriesHeaderIfValid(
|
||||
kAudioIn(cc).Header(), &input_header));
|
||||
if (streaming_mode_) {
|
||||
MP_RETURN_IF_ERROR(SetupStreamingResampler(input_header.sample_rate()));
|
||||
} else {
|
||||
source_sample_rate_ = input_header.sample_rate();
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AudioToTensorCalculator::Process(CalculatorContext* cc) {
|
||||
if (cc->InputTimestamp() == Timestamp::PreStream()) {
|
||||
double current_source_sample_rate = kAudioSampleRateIn(cc).Get();
|
||||
if (cc->Options<mediapipe::AudioToTensorCalculatorOptions>()
|
||||
.streaming_mode()) {
|
||||
return SetupStreamingResampler(current_source_sample_rate);
|
||||
} else {
|
||||
source_sample_rate_ = current_source_sample_rate;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
}
|
||||
// Sanity checks.
|
||||
const auto& input_frame = kAudioIn(cc).Get();
|
||||
if (input_frame.rows() != num_channels_) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"Audio input has %d channel(s) but the model requires %d channel(s).",
|
||||
input_frame.rows(), num_channels_));
|
||||
}
|
||||
if (num_channels_ > 1 && input_frame.IsRowMajor) {
|
||||
return absl::InvalidArgumentError(
|
||||
"The audio data should be stored in column-major.");
|
||||
}
|
||||
return streaming_mode_ ? ProcessStreamingData(cc)
|
||||
: ProcessNonStreamingData(cc);
|
||||
}
|
||||
|
||||
absl::Status AudioToTensorCalculator::Close(CalculatorContext* cc) {
|
||||
if (!streaming_mode_) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
if (resampler_) {
|
||||
Matrix resampled_buffer(num_channels_, 0);
|
||||
resampler_->Flush(&resampled_buffer);
|
||||
AppendToSampleBuffer(std::move(resampled_buffer));
|
||||
}
|
||||
return OutputTensors(sample_buffer_, /*should_flush=*/true, cc);
|
||||
}
|
||||
|
||||
absl::Status AudioToTensorCalculator::ProcessStreamingData(
|
||||
CalculatorContext* cc) {
|
||||
const auto& input_buffer = kAudioIn(cc).Get();
|
||||
if (initial_timestamp_ == Timestamp::Unstarted()) {
|
||||
initial_timestamp_ = cc->InputTimestamp();
|
||||
next_output_timestamp_ = initial_timestamp_;
|
||||
}
|
||||
if (source_sample_rate_ != -1 && check_inconsistent_timestamps_) {
|
||||
mediapipe::time_series_util::LogWarningIfTimestampIsInconsistent(
|
||||
cc->InputTimestamp(), initial_timestamp_, cumulative_input_samples_,
|
||||
source_sample_rate_);
|
||||
cumulative_input_samples_ += input_buffer.cols();
|
||||
}
|
||||
if (!kAudioSampleRateIn(cc).IsEmpty()) {
|
||||
double current_source_sample_rate = kAudioSampleRateIn(cc).Get();
|
||||
if (resampler_) {
|
||||
RET_CHECK_EQ(current_source_sample_rate, source_sample_rate_);
|
||||
} else {
|
||||
MP_RETURN_IF_ERROR(SetupStreamingResampler(current_source_sample_rate));
|
||||
}
|
||||
}
|
||||
|
||||
if (resampler_) {
|
||||
Matrix resampled_buffer(num_channels_, 0);
|
||||
resampler_->ProcessSamples(input_buffer, &resampled_buffer);
|
||||
AppendToSampleBuffer(std::move(resampled_buffer));
|
||||
} else {
|
||||
// Tries to consume the input matrix first to avoid extra data copy.
|
||||
auto status_or_matrix = kAudioIn(cc).packet().Consume<Matrix>();
|
||||
if (status_or_matrix.ok()) {
|
||||
Matrix local_matrix(num_channels_, 0);
|
||||
local_matrix.swap(*status_or_matrix.value());
|
||||
AppendToSampleBuffer(std::move(local_matrix));
|
||||
} else {
|
||||
AppendToSampleBuffer(input_buffer);
|
||||
}
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(OutputTensors(sample_buffer_, /*should_flush=*/false, cc));
|
||||
// Removes the processed samples from the global sample buffer.
|
||||
sample_buffer_ = Matrix(sample_buffer_.rightCols(sample_buffer_.cols() -
|
||||
processed_buffer_cols_ - 1));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AudioToTensorCalculator::ProcessNonStreamingData(
|
||||
CalculatorContext* cc) {
|
||||
initial_timestamp_ = cc->InputTimestamp();
|
||||
next_output_timestamp_ = initial_timestamp_;
|
||||
const auto& input_frame = kAudioIn(cc).Get();
|
||||
double source_sample_rate = kAudioSampleRateIn(cc).GetOr(source_sample_rate_);
|
||||
|
||||
if (source_sample_rate != -1 && source_sample_rate != target_sample_rate_) {
|
||||
std::vector<float> resampled = audio_dsp::QResampleSignal<float>(
|
||||
source_sample_rate, target_sample_rate_, num_channels_, params_,
|
||||
input_frame);
|
||||
Eigen::Map<const Matrix> matrix_mapping(resampled.data(), num_channels_,
|
||||
resampled.size() / num_channels_);
|
||||
return OutputTensors(matrix_mapping, /*should_flush=*/true, cc);
|
||||
}
|
||||
return OutputTensors(input_frame, /*should_flush=*/true, cc);
|
||||
}
|
||||
|
||||
absl::Status AudioToTensorCalculator::SetupStreamingResampler(
|
||||
double input_sample_rate) {
|
||||
if (input_sample_rate == source_sample_rate_) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
source_sample_rate_ = input_sample_rate;
|
||||
if (source_sample_rate_ != target_sample_rate_) {
|
||||
resampler_ = absl::make_unique<audio_dsp::QResampler<float>>(
|
||||
source_sample_rate_, target_sample_rate_, num_channels_, params_);
|
||||
if (!resampler_) {
|
||||
return absl::InternalError("Failed to initialize resampler.");
|
||||
}
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void AudioToTensorCalculator::AppendToSampleBuffer(Matrix buffer_to_append) {
|
||||
sample_buffer_.conservativeResize(
|
||||
Eigen::NoChange, sample_buffer_.cols() + buffer_to_append.cols());
|
||||
sample_buffer_.rightCols(buffer_to_append.cols()).swap(buffer_to_append);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<Tensor>> AudioToTensorCalculator::ConvertToTensor(
|
||||
const Matrix& frame_to_convert) {
|
||||
Tensor tensor(Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape({num_channels_, num_samples_}));
|
||||
auto buffer_view = tensor.GetCpuWriteView();
|
||||
if (frame_to_convert.size() < num_channels_ * num_samples_) {
|
||||
std::memset(buffer_view.buffer<float>(), 0, tensor.bytes());
|
||||
}
|
||||
std::memcpy(buffer_view.buffer<float>(), frame_to_convert.data(),
|
||||
frame_to_convert.size() * sizeof(float));
|
||||
std::vector<Tensor> tensor_vector;
|
||||
tensor_vector.push_back(std::move(tensor));
|
||||
return tensor_vector;
|
||||
}
|
||||
|
||||
absl::Status AudioToTensorCalculator::OutputTensors(const Matrix& buffer,
|
||||
bool should_flush,
|
||||
CalculatorContext* cc) {
|
||||
int next_frame_first_col = 0;
|
||||
std::vector<Timestamp> timestamps;
|
||||
while ((!streaming_mode_ || !should_flush) &&
|
||||
next_frame_first_col + num_samples_ <= buffer.cols()) {
|
||||
ASSIGN_OR_RETURN(auto output_tensor, ConvertToTensor(buffer.block(
|
||||
0, next_frame_first_col,
|
||||
num_channels_, num_samples_)));
|
||||
kTensorsOut(cc).Send(std::move(output_tensor), next_output_timestamp_);
|
||||
timestamps.push_back(next_output_timestamp_);
|
||||
next_output_timestamp_ += round(frame_step_ / target_sample_rate_ *
|
||||
Timestamp::kTimestampUnitsPerSecond);
|
||||
next_frame_first_col += frame_step_;
|
||||
}
|
||||
if (should_flush && next_frame_first_col < buffer.cols()) {
|
||||
ASSIGN_OR_RETURN(auto output_tensor,
|
||||
ConvertToTensor(buffer.block(
|
||||
0, next_frame_first_col, num_channels_,
|
||||
std::min(num_samples_,
|
||||
(int)buffer.cols() - next_frame_first_col))));
|
||||
// In the streaming mode, the flush happens in Close() and a packet at
|
||||
// Timestamp::Max() will be emitted. In the non-streaming mode, each
|
||||
// Process() invocation will process the entire buffer completely.
|
||||
Timestamp timestamp =
|
||||
streaming_mode_ ? Timestamp::Max() : next_output_timestamp_;
|
||||
timestamps.push_back(timestamp);
|
||||
kTensorsOut(cc).Send(std::move(output_tensor), timestamp);
|
||||
}
|
||||
if (kTimestampsOut(cc).IsConnected()) {
|
||||
Timestamp timestamp = timestamps.back();
|
||||
kTimestampsOut(cc).Send(std::move(timestamps), timestamp);
|
||||
}
|
||||
processed_buffer_cols_ = next_frame_first_col - 1;
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
MEDIAPIPE_REGISTER_NODE(AudioToTensorCalculator);
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
message AudioToTensorCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional AudioToTensorCalculatorOptions ext = 448635064;
|
||||
}
|
||||
|
||||
// The required number of channels the output audio tensor has.
|
||||
optional int64 num_channels = 1;
|
||||
|
||||
// The required number of samples per channel the output audio tensor has.
|
||||
optional int64 num_samples = 2;
|
||||
|
||||
// The number of overlapping samples per channel the output audio tensor has.
|
||||
optional int64 num_overlapping_samples = 3 [default = 0];
|
||||
|
||||
// The target number of samples per second (hertz) of the audio buffers that
|
||||
// will be converted into tensors.
|
||||
optional double target_sample_rate = 4;
|
||||
|
||||
// Whether to treat the input audio stream as a continous stream or a batch
|
||||
// of unrelated audio buffers.
|
||||
optional bool streaming_mode = 5 [default = true];
|
||||
|
||||
// Set to false to disable checks for jitter in timestamp values. Useful with
|
||||
// live audio input.
|
||||
optional bool check_inconsistent_timestamps = 6 [default = true];
|
||||
}
|
483
mediapipe/calculators/tensor/audio_to_tensor_calculator_test.cc
Normal file
483
mediapipe/calculators/tensor/audio_to_tensor_calculator_test.cc
Normal file
|
@ -0,0 +1,483 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "audio/dsp/resampler_q.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/calculator.pb.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/matrix.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/gtest.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
std::unique_ptr<Matrix> CreateTestMatrix(int num_channels, int num_samples,
|
||||
int timestamp) {
|
||||
auto matrix = std::make_unique<Matrix>(num_channels, num_samples);
|
||||
for (int c = 0; c < num_channels; ++c) {
|
||||
for (int i = 0; i < num_samples; ++i) {
|
||||
// A float value with the sample, channel, and timestamp separated by a
|
||||
// few orders of magnitude, for easy parsing by humans.
|
||||
(*matrix)(c, i) = timestamp / 10000 + i + c / 100.0;
|
||||
}
|
||||
}
|
||||
return matrix;
|
||||
}
|
||||
|
||||
std::unique_ptr<Matrix> ResampleBuffer(const Matrix& input_matrix,
|
||||
double resampling_factor) {
|
||||
audio_dsp::QResamplerParams params;
|
||||
std::vector<float> resampled;
|
||||
int num_channels = input_matrix.rows();
|
||||
std::vector<float> input_data(input_matrix.data(),
|
||||
input_matrix.data() + input_matrix.size());
|
||||
resampled = audio_dsp::QResampleSignal<float>(
|
||||
1, resampling_factor, num_channels, params, input_data);
|
||||
Matrix res = Eigen::Map<Matrix>(resampled.data(), num_channels,
|
||||
resampled.size() / num_channels);
|
||||
return std::make_unique<Matrix>(std::move(res));
|
||||
}
|
||||
|
||||
class AudioToTensorCalculatorNonStreamingModeTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {}
|
||||
void Run(int num_samples, int num_overlapping_samples,
|
||||
double resampling_factor, const Matrix& input_matrix) {
|
||||
double input_sample_rate = 10000;
|
||||
double target_sample_rate = input_sample_rate * resampling_factor;
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
absl::Substitute(R"(
|
||||
input_stream: "audio"
|
||||
input_stream: "sample_rate"
|
||||
output_stream: "tensors"
|
||||
output_stream: "timestamps"
|
||||
node {
|
||||
calculator: "AudioToTensorCalculator"
|
||||
input_stream: "AUDIO:audio"
|
||||
input_stream: "SAMPLE_RATE:sample_rate"
|
||||
output_stream: "TENSORS:tensors"
|
||||
output_stream: "TIMESTAMPS:timestamps"
|
||||
options {
|
||||
[mediapipe.AudioToTensorCalculatorOptions.ext] {
|
||||
num_channels: $0
|
||||
num_samples: $1
|
||||
num_overlapping_samples: $2
|
||||
target_sample_rate: $3
|
||||
streaming_mode: false
|
||||
}
|
||||
}
|
||||
}
|
||||
)",
|
||||
/*$0=*/input_matrix.rows(),
|
||||
/*$1=*/num_samples, /*$2=*/num_overlapping_samples,
|
||||
/*$3=*/target_sample_rate));
|
||||
tool::AddVectorSink("tensors", &graph_config, &tensors_packets_);
|
||||
tool::AddVectorSink("timestamps", &graph_config, ×tamps_packets_);
|
||||
|
||||
// Run the graph.
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
// Run with the input matrix multiple times.
|
||||
for (int i = 0; i < num_iterations_; ++i) {
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"audio",
|
||||
MakePacket<Matrix>(input_matrix)
|
||||
.At(Timestamp(i * Timestamp::kTimestampUnitsPerSecond))));
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"sample_rate",
|
||||
MakePacket<double>(input_sample_rate)
|
||||
.At(Timestamp(i * Timestamp::kTimestampUnitsPerSecond))));
|
||||
}
|
||||
MP_ASSERT_OK(graph_.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
}
|
||||
|
||||
void CheckTensorsOutputPackets(const Matrix& expected_matrix,
|
||||
int sample_offset, int num_tensors_per_input) {
|
||||
ASSERT_EQ(num_iterations_ * num_tensors_per_input, tensors_packets_.size());
|
||||
for (int i = 0; i < num_iterations_; ++i) {
|
||||
for (int j = 0; j < num_tensors_per_input; ++j) {
|
||||
CheckTensorsOutputPacket(
|
||||
expected_matrix, tensors_packets_[i * num_tensors_per_input + j],
|
||||
/*sample_offset*/ sample_offset * j, /*index=*/j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckTensorsOutputPacket(const Matrix& expected_matrix,
|
||||
const Packet& packet, int sample_offset,
|
||||
int index) {
|
||||
MP_ASSERT_OK(packet.ValidateAsType<std::vector<Tensor>>());
|
||||
ASSERT_EQ(1, packet.Get<std::vector<Tensor>>().size());
|
||||
const Tensor& output_tensor = packet.Get<std::vector<Tensor>>()[0];
|
||||
auto* buffer = output_tensor.GetCpuReadView().buffer<float>();
|
||||
int num_values = output_tensor.shape().num_elements();
|
||||
const std::vector<float> output_floats(buffer, buffer + num_values);
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
if (i + sample_offset >= expected_matrix.size()) {
|
||||
EXPECT_FLOAT_EQ(output_floats[i], 0);
|
||||
} else {
|
||||
EXPECT_FLOAT_EQ(output_floats[i],
|
||||
expected_matrix.coeff((i + sample_offset) % 2,
|
||||
(i + sample_offset) / 2))
|
||||
<< "i=" << i << ", sample_offset=" << sample_offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckTimestampsOutputPackets(
|
||||
std::vector<int64> expected_timestamp_values) {
|
||||
ASSERT_EQ(num_iterations_, timestamps_packets_.size());
|
||||
for (int i = 0; i < timestamps_packets_.size(); ++i) {
|
||||
const auto& p = timestamps_packets_[i];
|
||||
MP_ASSERT_OK(p.ValidateAsType<std::vector<Timestamp>>());
|
||||
auto output_timestamps = p.Get<std::vector<Timestamp>>();
|
||||
int64 base_timestamp = i * Timestamp::kTimestampUnitsPerSecond;
|
||||
std::vector<Timestamp> expected_timestamps;
|
||||
expected_timestamps.resize(expected_timestamp_values.size());
|
||||
std::transform(
|
||||
expected_timestamp_values.begin(), expected_timestamp_values.end(),
|
||||
expected_timestamps.begin(), [base_timestamp](int64 v) -> Timestamp {
|
||||
return Timestamp(v + base_timestamp);
|
||||
});
|
||||
EXPECT_EQ(expected_timestamps, output_timestamps);
|
||||
EXPECT_EQ(p.Timestamp(), expected_timestamps.back());
|
||||
}
|
||||
}
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); }
|
||||
|
||||
private:
|
||||
CalculatorGraph graph_;
|
||||
int num_iterations_ = 10;
|
||||
std::vector<Packet> tensors_packets_;
|
||||
std::vector<Packet> timestamps_packets_;
|
||||
};
|
||||
|
||||
TEST_F(AudioToTensorCalculatorNonStreamingModeTest,
|
||||
ConvertToNoOverlappingFp32Tensors) {
|
||||
auto input_matrix = CreateTestMatrix(2, 8, 0);
|
||||
Run(/*num_samples=*/4, /*num_overlapping_samples=*/0,
|
||||
/*resampling_factor=*/1.0f, *input_matrix);
|
||||
CheckTensorsOutputPackets(*input_matrix, /*sample_offset=*/8,
|
||||
/*num_tensors_per_input=*/2);
|
||||
CheckTimestampsOutputPackets({0, 400});
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorNonStreamingModeTest,
|
||||
ConvertToOverlappingFp32Tensors) {
|
||||
auto input_matrix = CreateTestMatrix(2, 8, 0);
|
||||
Run(/*num_samples=*/4, /*num_overlapping_samples=*/2,
|
||||
/*resampling_factor=*/1.0f, *input_matrix);
|
||||
CheckTensorsOutputPackets(*input_matrix, /*sample_offset=*/4,
|
||||
/*num_tensors_per_input=*/4);
|
||||
CheckTimestampsOutputPackets({0, 200, 400, 600});
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorNonStreamingModeTest, TensorsWithZeroPadding) {
|
||||
auto input_matrix = CreateTestMatrix(2, 7, 0);
|
||||
Run(/*num_samples=*/4, /*num_overlapping_samples=*/2,
|
||||
/*resampling_factor=*/1.0f, *input_matrix);
|
||||
CheckTensorsOutputPackets(*input_matrix, /*sample_offset=*/4,
|
||||
/*num_tensors_per_input=*/3);
|
||||
CheckTimestampsOutputPackets({0, 200, 400});
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorNonStreamingModeTest, Downsampling) {
|
||||
auto input_matrix = CreateTestMatrix(2, 1024, 0);
|
||||
Run(/*num_samples=*/256, /*num_overlapping_samples=*/0,
|
||||
/*resampling_factor=*/0.5f, *input_matrix);
|
||||
auto expected_matrix =
|
||||
ResampleBuffer(*input_matrix, /*resampling_factor=*/0.5f);
|
||||
CheckTensorsOutputPackets(*expected_matrix, /*sample_offset=*/512,
|
||||
/*num_tensors_per_input=*/3);
|
||||
CheckTimestampsOutputPackets({0, 51200, 102400});
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorNonStreamingModeTest,
|
||||
DownsamplingWithOverlapping) {
|
||||
auto input_matrix = CreateTestMatrix(2, 1024, 0);
|
||||
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
|
||||
/*resampling_factor=*/0.5f, *input_matrix);
|
||||
auto expected_matrix =
|
||||
ResampleBuffer(*input_matrix, /*resampling_factor=*/0.5f);
|
||||
CheckTensorsOutputPackets(*expected_matrix, /*sample_offset=*/384,
|
||||
/*num_tensors_per_input=*/3);
|
||||
CheckTimestampsOutputPackets({0, 38400, 76800});
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorNonStreamingModeTest, Upsampling) {
|
||||
auto input_matrix = CreateTestMatrix(2, 1024, 0);
|
||||
Run(/*num_samples=*/256, /*num_overlapping_samples=*/0,
|
||||
/*resampling_factor=*/2.0f, *input_matrix);
|
||||
auto expected_matrix =
|
||||
ResampleBuffer(*input_matrix, /*resampling_factor=*/2.0f);
|
||||
CheckTensorsOutputPackets(*expected_matrix,
|
||||
/*sample_offset=*/512,
|
||||
/*num_tensors_per_input=*/9);
|
||||
CheckTimestampsOutputPackets(
|
||||
{0, 12800, 25600, 38400, 51200, 64000, 76800, 89600, 102400});
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorNonStreamingModeTest, UpsamplingWithOverlapping) {
|
||||
auto input_matrix = CreateTestMatrix(2, 256, 0);
|
||||
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
|
||||
/*resampling_factor=*/2.0f, *input_matrix);
|
||||
auto expected_matrix =
|
||||
ResampleBuffer(*input_matrix, /*resampling_factor=*/2.0f);
|
||||
CheckTensorsOutputPackets(*expected_matrix,
|
||||
/*sample_offset=*/384,
|
||||
/*num_tensors_per_input=*/3);
|
||||
CheckTimestampsOutputPackets({0, 9600, 19200});
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
class AudioToTensorCalculatorStreamingModeTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override { sample_buffer_ = std::make_unique<Matrix>(2, 0); }
|
||||
|
||||
void SetInputBufferNumSamplesPerChannel(int num_samples) {
|
||||
input_buffer_num_samples_ = num_samples;
|
||||
}
|
||||
|
||||
void SetNumIterations(int num_iterations) {
|
||||
num_iterations_ = num_iterations;
|
||||
}
|
||||
|
||||
int GetExpectedNumOfSamples() {
|
||||
Matrix* expected_matrix =
|
||||
resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get();
|
||||
return expected_matrix->cols();
|
||||
}
|
||||
|
||||
void Run(int num_samples, int num_overlapping_samples,
|
||||
double resampling_factor) {
|
||||
double input_sample_rate = 10000;
|
||||
double target_sample_rate = input_sample_rate * resampling_factor;
|
||||
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
|
||||
absl::Substitute(R"(
|
||||
input_stream: "audio"
|
||||
input_stream: "sample_rate"
|
||||
output_stream: "tensors"
|
||||
node {
|
||||
calculator: "AudioToTensorCalculator"
|
||||
input_stream: "AUDIO:audio"
|
||||
input_stream: "SAMPLE_RATE:sample_rate"
|
||||
output_stream: "TENSORS:tensors"
|
||||
options {
|
||||
[mediapipe.AudioToTensorCalculatorOptions.ext] {
|
||||
num_channels: 2
|
||||
num_samples: $0
|
||||
num_overlapping_samples: $1
|
||||
target_sample_rate: $2
|
||||
streaming_mode:true
|
||||
}
|
||||
}
|
||||
}
|
||||
)",
|
||||
/*$0=*/num_samples, /*$1=*/num_overlapping_samples,
|
||||
/*$2=*/target_sample_rate));
|
||||
tool::AddVectorSink("tensors", &graph_config, &tensors_packets_);
|
||||
|
||||
// Run the graph.
|
||||
MP_ASSERT_OK(graph_.Initialize(graph_config));
|
||||
MP_ASSERT_OK(graph_.StartRun({}));
|
||||
for (int i = 0; i < num_iterations_; ++i) {
|
||||
Timestamp input_timestamp(Timestamp::kTimestampUnitsPerSecond * i);
|
||||
auto new_data = CreateTestMatrix(2, input_buffer_num_samples_,
|
||||
input_timestamp.Value());
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"audio", MakePacket<Matrix>(*new_data).At(input_timestamp)));
|
||||
MP_ASSERT_OK(graph_.AddPacketToInputStream(
|
||||
"sample_rate",
|
||||
MakePacket<double>(input_sample_rate).At(input_timestamp)));
|
||||
sample_buffer_->conservativeResize(
|
||||
Eigen::NoChange, sample_buffer_->cols() + new_data->cols());
|
||||
sample_buffer_->rightCols(new_data->cols()).swap(*new_data);
|
||||
}
|
||||
MP_ASSERT_OK(graph_.CloseAllInputStreams());
|
||||
MP_ASSERT_OK(graph_.WaitUntilIdle());
|
||||
if (resampling_factor != 1) {
|
||||
resampled_buffer_ = ResampleBuffer(*sample_buffer_, resampling_factor);
|
||||
}
|
||||
}
|
||||
|
||||
void CheckTensorsOutputPackets(int sample_offset, int num_packets,
|
||||
int64 timestamp_interval,
|
||||
bool output_last_at_close) {
|
||||
ASSERT_EQ(num_packets, tensors_packets_.size());
|
||||
for (int i = 0; i < num_packets; ++i) {
|
||||
if (i == num_packets - 1 && output_last_at_close) {
|
||||
CheckTensorsOutputPacket(sample_offset * i, i, Timestamp::Max());
|
||||
} else {
|
||||
CheckTensorsOutputPacket(sample_offset * i, i,
|
||||
Timestamp(timestamp_interval * i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckTensorsOutputPacket(int sample_offset, int index,
|
||||
Timestamp expected_timestamp) {
|
||||
const Packet& p = tensors_packets_[index];
|
||||
MP_ASSERT_OK(p.ValidateAsType<std::vector<Tensor>>());
|
||||
const Tensor& output_tensor = p.Get<std::vector<Tensor>>()[0];
|
||||
auto buffer = output_tensor.GetCpuReadView().buffer<float>();
|
||||
int num_values = output_tensor.shape().num_elements();
|
||||
std::vector<float> output_floats(buffer, buffer + num_values);
|
||||
Matrix* expected_matrix =
|
||||
resampled_buffer_ ? resampled_buffer_.get() : sample_buffer_.get();
|
||||
for (int i = 0; i < num_values; ++i) {
|
||||
if (i + sample_offset >= expected_matrix->size()) {
|
||||
EXPECT_FLOAT_EQ(output_floats[i], 0);
|
||||
} else {
|
||||
EXPECT_NEAR(output_floats[i],
|
||||
expected_matrix->coeff((i + sample_offset) % 2,
|
||||
(i + sample_offset) / 2),
|
||||
0.001)
|
||||
<< "i=" << i << ", sample_offset=" << sample_offset
|
||||
<< ", packet index=" << index;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(p.Timestamp(), expected_timestamp);
|
||||
}
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
void CloseGraph() { MP_EXPECT_OK(graph_.WaitUntilDone()); }
|
||||
|
||||
private:
|
||||
int input_buffer_num_samples_ = 10;
|
||||
int num_iterations_ = 10;
|
||||
CalculatorGraph graph_;
|
||||
std::vector<Packet> tensors_packets_;
|
||||
std::unique_ptr<Matrix> sample_buffer_;
|
||||
std::unique_ptr<Matrix> resampled_buffer_;
|
||||
};
|
||||
|
||||
TEST_F(AudioToTensorCalculatorStreamingModeTest,
|
||||
OutputNoOverlappingFp32Tensors) {
|
||||
Run(/*num_samples=*/5, /*num_overlapping_samples=*/0,
|
||||
/*resampling_factor=*/1.0f);
|
||||
CheckTensorsOutputPackets(
|
||||
/*sample_offset=*/10,
|
||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 5),
|
||||
/*timestamp_interval=*/500,
|
||||
/*output_last_at_close=*/false);
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputRemainingInCloseMethod) {
|
||||
Run(/*num_samples=*/6, /*num_overlapping_samples=*/0,
|
||||
/*resampling_factor=*/1.0f);
|
||||
CheckTensorsOutputPackets(
|
||||
/*sample_offset=*/12,
|
||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 6),
|
||||
/*timestamp_interval=*/600,
|
||||
/*output_last_at_close=*/true);
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorStreamingModeTest, OutputOverlappingFp32Tensors) {
|
||||
SetInputBufferNumSamplesPerChannel(12);
|
||||
Run(/*num_samples=*/10, /*num_overlapping_samples=*/2,
|
||||
/*resampling_factor=*/1.0f);
|
||||
CheckTensorsOutputPackets(
|
||||
/*sample_offset=*/16,
|
||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 8),
|
||||
/*timestamp_interval=*/800,
|
||||
/*output_last_at_close=*/true);
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorStreamingModeTest, Downsampling) {
|
||||
SetInputBufferNumSamplesPerChannel(1000);
|
||||
Run(/*num_samples=*/256, /*num_overlapping_samples=*/0,
|
||||
/*resampling_factor=*/0.5f);
|
||||
CheckTensorsOutputPackets(
|
||||
/*sample_offset=*/512,
|
||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256),
|
||||
/*timestamp_interval=*/51200,
|
||||
/*output_last_at_close=*/true);
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorStreamingModeTest, DownsamplingWithOverlapping) {
|
||||
SetInputBufferNumSamplesPerChannel(1024);
|
||||
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
|
||||
/*resampling_factor=*/0.5f);
|
||||
CheckTensorsOutputPackets(
|
||||
/*sample_offset=*/384,
|
||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192),
|
||||
/*timestamp_interval=*/38400,
|
||||
/*output_last_at_close=*/true);
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorStreamingModeTest, Upsampling) {
|
||||
SetInputBufferNumSamplesPerChannel(1000);
|
||||
Run(/*num_samples=*/256, /*num_overlapping_samples=*/0,
|
||||
/*resampling_factor=*/2.0f);
|
||||
CheckTensorsOutputPackets(
|
||||
/*sample_offset=*/512,
|
||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 256),
|
||||
/*timestamp_interval=*/12800,
|
||||
/*output_last_at_close=*/true);
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorStreamingModeTest, UpsamplingWithOverlapping) {
|
||||
SetInputBufferNumSamplesPerChannel(1024);
|
||||
Run(/*num_samples=*/256, /*num_overlapping_samples=*/64,
|
||||
/*resampling_factor=*/2.0f);
|
||||
CheckTensorsOutputPackets(
|
||||
/*sample_offset=*/384,
|
||||
/*num_packets=*/std::ceil((float)GetExpectedNumOfSamples() / 192),
|
||||
/*timestamp_interval=*/9600,
|
||||
/*output_last_at_close=*/true);
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
TEST_F(AudioToTensorCalculatorStreamingModeTest,
|
||||
OnlyOutputInCloseIfNoSufficientSamples) {
|
||||
SetNumIterations(1);
|
||||
Run(/*num_samples=*/8, /*num_overlapping_samples=*/0,
|
||||
/*resampling_factor=*/0.5f);
|
||||
CheckTensorsOutputPackets(
|
||||
/*sample_offset=*/0,
|
||||
/*num_packets=*/1,
|
||||
/*timestamp_interval=*/0,
|
||||
/*output_last_at_close=*/true);
|
||||
CloseGraph();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
|
@ -19,8 +19,8 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.pb.h"
|
||||
#include "mediapipe/framework/api2/packet.h"
|
||||
#include "mediapipe/framework/tool/subgraph_expansion.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
|
@ -43,8 +43,19 @@ class InferenceCalculatorSelectorImpl
|
|||
!options.has_delegate() || // Use GPU delegate if not specified
|
||||
(options.has_delegate() && options.delegate().has_gpu());
|
||||
if (should_use_gpu) {
|
||||
const auto& api = options.delegate().gpu().api();
|
||||
using Gpu = ::mediapipe::InferenceCalculatorOptions::Delegate::Gpu;
|
||||
impls.emplace_back("Metal");
|
||||
impls.emplace_back("Gl");
|
||||
const bool prefer_gl_advanced =
|
||||
options.delegate().gpu().use_advanced_gpu_api() &&
|
||||
(api == Gpu::ANY || api == Gpu::OPENGL || api == Gpu::OPENCL);
|
||||
if (prefer_gl_advanced) {
|
||||
impls.emplace_back("GlAdvanced");
|
||||
impls.emplace_back("Gl");
|
||||
} else {
|
||||
impls.emplace_back("Gl");
|
||||
impls.emplace_back("GlAdvanced");
|
||||
}
|
||||
}
|
||||
impls.emplace_back("Cpu");
|
||||
for (const auto& suffix : impls) {
|
||||
|
|
|
@ -134,6 +134,10 @@ struct InferenceCalculatorGl : public InferenceCalculator {
|
|||
static constexpr char kCalculatorName[] = "InferenceCalculatorGl";
|
||||
};
|
||||
|
||||
struct InferenceCalculatorGlAdvanced : public InferenceCalculator {
|
||||
static constexpr char kCalculatorName[] = "InferenceCalculatorGlAdvanced";
|
||||
};
|
||||
|
||||
struct InferenceCalculatorMetal : public InferenceCalculator {
|
||||
static constexpr char kCalculatorName[] = "InferenceCalculatorMetal";
|
||||
};
|
||||
|
|
|
@ -75,9 +75,10 @@ const std::vector<Param>& GetParams() {
|
|||
class InferenceCalculatorTest : public testing::TestWithParam<Param> {
|
||||
protected:
|
||||
void SetDelegateForParam(mediapipe::CalculatorGraphConfig_Node* node) {
|
||||
*node->mutable_options()
|
||||
->MutableExtension(mediapipe::InferenceCalculatorOptions::ext)
|
||||
->mutable_delegate() = GetParam().delegate;
|
||||
auto options_map = tool::MutableOptionsMap().Initialize(*node);
|
||||
auto options = options_map.Get<mediapipe::InferenceCalculatorOptions>();
|
||||
*options.mutable_delegate() = GetParam().delegate;
|
||||
options_map.Set(options);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -20,22 +20,8 @@
|
|||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/util/tflite/config.h"
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/gpu/gpu_buffer.h"
|
||||
#include "mediapipe/util/tflite/tflite_gpu_runner.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/shape.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl_delegate.h"
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
#if defined(MEDIAPIPE_ANDROID)
|
||||
#include "mediapipe/util/android/file/base/file.h"
|
||||
#include "mediapipe/util/android/file/base/filesystem.h"
|
||||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#endif // ANDROID
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
@ -50,42 +36,22 @@ class InferenceCalculatorGlImpl
|
|||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
absl::Status ReadGpuCaches();
|
||||
absl::Status SaveGpuCaches();
|
||||
absl::Status LoadModel(CalculatorContext* cc);
|
||||
absl::Status LoadDelegate(CalculatorContext* cc);
|
||||
absl::Status LoadDelegateAndAllocateTensors(CalculatorContext* cc);
|
||||
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
|
||||
|
||||
// TfLite requires us to keep the model alive as long as the interpreter is.
|
||||
Packet<TfLiteModelPtr> model_packet_;
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_;
|
||||
bool allow_precision_loss_ = false;
|
||||
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api
|
||||
tflite_gpu_runner_api_;
|
||||
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::InferenceUsage
|
||||
tflite_gpu_runner_usage_;
|
||||
#endif // MEDIAPIPE_TFLITE_GL_INFERENCE
|
||||
|
||||
TfLiteDelegatePtr delegate_;
|
||||
std::unique_ptr<tflite::Interpreter> interpreter_;
|
||||
|
||||
#if MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
std::vector<Tensor::Shape> output_shapes_;
|
||||
std::vector<std::unique_ptr<Tensor>> gpu_buffers_in_;
|
||||
std::vector<std::unique_ptr<Tensor>> gpu_buffers_out_;
|
||||
#endif // MEDIAPIPE_TFLITE_GPU_SUPPORTED
|
||||
|
||||
bool use_advanced_gpu_api_ = false;
|
||||
bool use_gpu_delegate_ = false;
|
||||
|
||||
bool use_kernel_caching_ = false;
|
||||
std::string cached_kernel_filename_;
|
||||
bool use_serialized_model_ = false;
|
||||
std::string serialized_model_path_;
|
||||
};
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
|
||||
|
@ -93,8 +59,7 @@ absl::Status InferenceCalculatorGlImpl::UpdateContract(CalculatorContract* cc) {
|
|||
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
|
||||
<< "Either model as side packet or model path in options is required.";
|
||||
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
return absl::OkStatus();
|
||||
return mediapipe::GlCalculatorHelper::UpdateContract(cc);
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
|
||||
|
@ -110,46 +75,12 @@ absl::Status InferenceCalculatorGlImpl::Open(CalculatorContext* cc) {
|
|||
<< "for Gpu";
|
||||
delegate.MergeFrom(input_side_packet_delegate);
|
||||
}
|
||||
const bool has_delegate = options.has_delegate() || !kDelegate(cc).IsEmpty();
|
||||
use_advanced_gpu_api_ = has_delegate && delegate.has_gpu() &&
|
||||
delegate.gpu().use_advanced_gpu_api();
|
||||
allow_precision_loss_ = delegate.gpu().allow_precision_loss();
|
||||
tflite_gpu_runner_api_ = delegate.gpu().api();
|
||||
tflite_gpu_runner_usage_ = delegate.gpu().usage();
|
||||
use_kernel_caching_ =
|
||||
use_advanced_gpu_api_ && delegate.gpu().has_cached_kernel_path();
|
||||
use_serialized_model_ = use_advanced_gpu_api_ &&
|
||||
delegate.gpu().has_serialized_model_dir() &&
|
||||
delegate.gpu().has_model_token();
|
||||
use_gpu_delegate_ = !use_advanced_gpu_api_;
|
||||
|
||||
if (use_kernel_caching_) {
|
||||
#ifdef MEDIAPIPE_ANDROID
|
||||
cached_kernel_filename_ = delegate.gpu().cached_kernel_path() +
|
||||
mediapipe::File::Basename(options.model_path()) +
|
||||
".ker";
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
}
|
||||
if (use_serialized_model_) {
|
||||
#ifdef MEDIAPIPE_ANDROID
|
||||
serialized_model_path_ = mediapipe::file::JoinPath(
|
||||
delegate.gpu().serialized_model_dir(), delegate.gpu().model_token());
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
}
|
||||
|
||||
// When use_advanced_gpu_api_, model loading is handled in InitTFLiteGPURunner
|
||||
// for everything.
|
||||
if (!use_advanced_gpu_api_) {
|
||||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(LoadModel(cc));
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
MP_RETURN_IF_ERROR(
|
||||
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
|
||||
return use_advanced_gpu_api_ ? InitTFLiteGPURunner(cc)
|
||||
: LoadDelegateAndAllocateTensors(cc);
|
||||
}));
|
||||
return absl::OkStatus();
|
||||
return gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
|
||||
return LoadDelegateAndAllocateTensors(cc);
|
||||
});
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) {
|
||||
|
@ -160,205 +91,53 @@ absl::Status InferenceCalculatorGlImpl::Process(CalculatorContext* cc) {
|
|||
RET_CHECK(!input_tensors.empty());
|
||||
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
|
||||
|
||||
if (use_advanced_gpu_api_) {
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors, &output_tensors]() -> ::mediapipe::Status {
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
|
||||
input_tensors[i].GetOpenGlBufferReadView().name(), i));
|
||||
}
|
||||
output_tensors->reserve(output_shapes_.size());
|
||||
for (int i = 0; i < output_shapes_.size(); ++i) {
|
||||
output_tensors->emplace_back(Tensor::ElementType::kFloat32,
|
||||
output_shapes_[i]);
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor(
|
||||
output_tensors->back().GetOpenGlBufferWriteView().name(), i));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
} else {
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors]() -> ::mediapipe::Status {
|
||||
// Explicitly copy input.
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
glBindBuffer(GL_COPY_READ_BUFFER,
|
||||
input_tensors[i].GetOpenGlBufferReadView().name());
|
||||
glBindBuffer(GL_COPY_WRITE_BUFFER,
|
||||
gpu_buffers_in_[i]->GetOpenGlBufferWriteView().name());
|
||||
glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0,
|
||||
input_tensors[i].bytes());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
}
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors]() -> ::mediapipe::Status {
|
||||
// Explicitly copy input.
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
glBindBuffer(GL_COPY_READ_BUFFER,
|
||||
input_tensors[i].GetOpenGlBufferReadView().name());
|
||||
glBindBuffer(GL_COPY_WRITE_BUFFER,
|
||||
gpu_buffers_in_[i]->GetOpenGlBufferWriteView().name());
|
||||
glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0,
|
||||
input_tensors[i].bytes());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
|
||||
// Run inference.
|
||||
if (use_advanced_gpu_api_) {
|
||||
RET_CHECK(tflite_gpu_runner_->Invoke().ok());
|
||||
} else {
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
}
|
||||
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
|
||||
|
||||
if (use_gpu_delegate_) {
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &output_tensors]() -> ::mediapipe::Status {
|
||||
output_tensors->reserve(output_shapes_.size());
|
||||
for (int i = 0; i < output_shapes_.size(); ++i) {
|
||||
const auto& t = gpu_buffers_out_[i];
|
||||
output_tensors->emplace_back(Tensor::ElementType::kFloat32,
|
||||
gpu_buffers_out_[i]->shape());
|
||||
auto read_view = t->GetOpenGlBufferReadView();
|
||||
glBindBuffer(GL_COPY_READ_BUFFER, read_view.name());
|
||||
auto write_view = output_tensors->back().GetOpenGlBufferWriteView();
|
||||
glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name());
|
||||
glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0,
|
||||
t->bytes());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
}
|
||||
// Output tensors are already bound if use_advanced_gpu_api_ is true.
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &output_tensors]() -> ::mediapipe::Status {
|
||||
output_tensors->reserve(output_shapes_.size());
|
||||
for (int i = 0; i < output_shapes_.size(); ++i) {
|
||||
const auto& t = gpu_buffers_out_[i];
|
||||
output_tensors->emplace_back(Tensor::ElementType::kFloat32,
|
||||
gpu_buffers_out_[i]->shape());
|
||||
auto read_view = t->GetOpenGlBufferReadView();
|
||||
glBindBuffer(GL_COPY_READ_BUFFER, read_view.name());
|
||||
auto write_view = output_tensors->back().GetOpenGlBufferWriteView();
|
||||
glBindBuffer(GL_COPY_WRITE_BUFFER, write_view.name());
|
||||
glCopyBufferSubData(GL_COPY_READ_BUFFER, GL_COPY_WRITE_BUFFER, 0, 0,
|
||||
t->bytes());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
|
||||
kOutTensors(cc).Send(std::move(output_tensors));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::SaveGpuCaches() {
|
||||
#ifdef MEDIAPIPE_ANDROID
|
||||
if (use_kernel_caching_) {
|
||||
// Save kernel file.
|
||||
auto kernel_cache = absl::make_unique<std::vector<uint8_t>>(
|
||||
tflite_gpu_runner_->GetSerializedBinaryCache());
|
||||
std::string cache_str(kernel_cache->begin(), kernel_cache->end());
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
|
||||
}
|
||||
if (use_serialized_model_) {
|
||||
// Save serialized model file.
|
||||
ASSIGN_OR_RETURN(std::vector<uint8_t> serialized_model_vec,
|
||||
tflite_gpu_runner_->GetSerializedModel());
|
||||
absl::string_view serialized_model(
|
||||
reinterpret_cast<char*>(serialized_model_vec.data()),
|
||||
serialized_model_vec.size());
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::file::SetContents(serialized_model_path_, serialized_model));
|
||||
}
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::Close(CalculatorContext* cc) {
|
||||
MP_RETURN_IF_ERROR(SaveGpuCaches());
|
||||
if (use_gpu_delegate_) {
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
|
||||
gpu_buffers_in_.clear();
|
||||
gpu_buffers_out_.clear();
|
||||
// Delegate must outlive the interpreter, hence the order is important.
|
||||
interpreter_ = nullptr;
|
||||
delegate_ = nullptr;
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
} else {
|
||||
return gpu_helper_.RunInGlContext([this]() -> absl::Status {
|
||||
gpu_buffers_in_.clear();
|
||||
gpu_buffers_out_.clear();
|
||||
// Delegate must outlive the interpreter, hence the order is important.
|
||||
interpreter_ = nullptr;
|
||||
delegate_ = nullptr;
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::ReadGpuCaches() {
|
||||
#ifdef MEDIAPIPE_ANDROID
|
||||
if (use_kernel_caching_ && File::Exists(cached_kernel_filename_)) {
|
||||
// Load pre-compiled kernel file.
|
||||
std::string cache_str;
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::file::GetContents(cached_kernel_filename_, &cache_str));
|
||||
std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end());
|
||||
tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec));
|
||||
}
|
||||
if (use_serialized_model_ && File::Exists(serialized_model_path_)) {
|
||||
// Load serialized model file.
|
||||
std::string serialized_model_str;
|
||||
MP_RETURN_IF_ERROR(
|
||||
file::GetContents(serialized_model_path_, &serialized_model_str));
|
||||
std::vector<uint8_t> serialized_model_vec(serialized_model_str.begin(),
|
||||
serialized_model_str.end());
|
||||
tflite_gpu_runner_->SetSerializedModel(std::move(serialized_model_vec));
|
||||
}
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::InitTFLiteGPURunner(
|
||||
CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||
const auto& model = *model_packet_.Get();
|
||||
|
||||
// Create runner
|
||||
tflite::gpu::InferenceOptions options;
|
||||
options.priority1 = allow_precision_loss_
|
||||
? tflite::gpu::InferencePriority::MIN_LATENCY
|
||||
: tflite::gpu::InferencePriority::MAX_PRECISION;
|
||||
options.priority2 = tflite::gpu::InferencePriority::AUTO;
|
||||
options.priority3 = tflite::gpu::InferencePriority::AUTO;
|
||||
switch (tflite_gpu_runner_usage_) {
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
|
||||
FAST_SINGLE_ANSWER: {
|
||||
options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER;
|
||||
break;
|
||||
}
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
|
||||
SUSTAINED_SPEED: {
|
||||
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
|
||||
break;
|
||||
}
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::UNSPECIFIED: {
|
||||
return absl::InternalError("inference usage need to be specified.");
|
||||
}
|
||||
}
|
||||
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
|
||||
switch (tflite_gpu_runner_api_) {
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: {
|
||||
// Do not need to force any specific API.
|
||||
break;
|
||||
}
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: {
|
||||
tflite_gpu_runner_->ForceOpenGL();
|
||||
break;
|
||||
}
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENCL: {
|
||||
tflite_gpu_runner_->ForceOpenCL();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (kSideInOpResolver(cc).IsConnected()) {
|
||||
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
|
||||
model, op_resolver, /*allow_quant_ops=*/true));
|
||||
} else {
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
|
||||
model, op_resolver, /*allow_quant_ops=*/true));
|
||||
}
|
||||
|
||||
// Create and bind OpenGL buffers for outputs.
|
||||
// The buffers are created once and their ids are passed to calculator outputs
|
||||
output_shapes_.resize(tflite_gpu_runner_->outputs_size());
|
||||
for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) {
|
||||
output_shapes_[i] = {tflite_gpu_runner_->GetOutputShapes()[i].b,
|
||||
tflite_gpu_runner_->GetOutputShapes()[i].h,
|
||||
tflite_gpu_runner_->GetOutputShapes()[i].w,
|
||||
tflite_gpu_runner_->GetOutputShapes()[i].c};
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(ReadGpuCaches());
|
||||
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Build());
|
||||
|
||||
return absl::OkStatus();
|
||||
return absl::OkStatus();
|
||||
});
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
|
||||
|
@ -375,12 +154,8 @@ absl::Status InferenceCalculatorGlImpl::LoadModel(CalculatorContext* cc) {
|
|||
}
|
||||
RET_CHECK(interpreter_);
|
||||
|
||||
#if defined(__EMSCRIPTEN__)
|
||||
interpreter_->SetNumThreads(1);
|
||||
#else
|
||||
interpreter_->SetNumThreads(
|
||||
cc->Options<mediapipe::InferenceCalculatorOptions>().cpu_num_thread());
|
||||
#endif // __EMSCRIPTEN__
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
285
mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc
Normal file
285
mediapipe/calculators/tensor/inference_calculator_gl_advanced.cc
Normal file
|
@ -0,0 +1,285 @@
|
|||
// Copyright 2022 The MediaPipe Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "mediapipe/calculators/tensor/inference_calculator.h"
|
||||
#include "mediapipe/gpu/gl_calculator_helper.h"
|
||||
#include "mediapipe/util/tflite/tflite_gpu_runner.h"
|
||||
|
||||
#if defined(MEDIAPIPE_ANDROID)
|
||||
#include "mediapipe/framework/deps/file_path.h"
|
||||
#include "mediapipe/util/android/file/base/file.h"
|
||||
#include "mediapipe/util/android/file/base/filesystem.h"
|
||||
#include "mediapipe/util/android/file/base/helpers.h"
|
||||
#endif // ANDROID
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// Runs TFLite GPU delegate API2 directly, bypassing interpreter usage, and
|
||||
// allows choosing specific API.
|
||||
//
|
||||
// To trigger this code path:
|
||||
// [mediapipe.InferenceCalculatorOptions.ext] {
|
||||
// delegate {
|
||||
// gpu {
|
||||
// use_advanced_gpu_api: true
|
||||
// api: OPENCL # or OPENGL or ANY
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
class InferenceCalculatorGlAdvancedImpl
|
||||
: public NodeImpl<InferenceCalculatorGlAdvanced,
|
||||
InferenceCalculatorGlAdvancedImpl> {
|
||||
public:
|
||||
static absl::Status UpdateContract(CalculatorContract* cc);
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) override;
|
||||
absl::Status Process(CalculatorContext* cc) override;
|
||||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
absl::Status ReadGpuCaches();
|
||||
absl::Status SaveGpuCaches();
|
||||
absl::Status InitTFLiteGPURunner(CalculatorContext* cc);
|
||||
|
||||
// TfLite requires us to keep the model alive as long as the interpreter is.
|
||||
Packet<TfLiteModelPtr> model_packet_;
|
||||
|
||||
mediapipe::GlCalculatorHelper gpu_helper_;
|
||||
std::unique_ptr<tflite::gpu::TFLiteGPURunner> tflite_gpu_runner_;
|
||||
bool allow_precision_loss_ = false;
|
||||
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::Api
|
||||
tflite_gpu_runner_api_;
|
||||
mediapipe::InferenceCalculatorOptions::Delegate::Gpu::InferenceUsage
|
||||
tflite_gpu_runner_usage_;
|
||||
|
||||
std::vector<Tensor::Shape> output_shapes_;
|
||||
|
||||
bool use_kernel_caching_ = false;
|
||||
std::string cached_kernel_filename_;
|
||||
bool use_serialized_model_ = false;
|
||||
std::string serialized_model_path_;
|
||||
};
|
||||
|
||||
absl::Status InferenceCalculatorGlAdvancedImpl::UpdateContract(
|
||||
CalculatorContract* cc) {
|
||||
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
|
||||
RET_CHECK(!options.model_path().empty() ^ kSideInModel(cc).IsConnected())
|
||||
<< "Either model as side packet or model path in options is required.";
|
||||
|
||||
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlAdvancedImpl::Open(CalculatorContext* cc) {
|
||||
const auto& options = cc->Options<::mediapipe::InferenceCalculatorOptions>();
|
||||
mediapipe::InferenceCalculatorOptions::Delegate delegate = options.delegate();
|
||||
if (!kDelegate(cc).IsEmpty()) {
|
||||
mediapipe::InferenceCalculatorOptions::Delegate input_side_packet_delegate =
|
||||
kDelegate(cc).Get();
|
||||
CHECK(input_side_packet_delegate.has_gpu() ||
|
||||
input_side_packet_delegate.delegate_case() ==
|
||||
mediapipe::InferenceCalculatorOptions::Delegate::DELEGATE_NOT_SET)
|
||||
<< "inference_calculator_gl_advanced only supports delegate input side "
|
||||
"packet for Gpu";
|
||||
delegate.MergeFrom(input_side_packet_delegate);
|
||||
}
|
||||
allow_precision_loss_ = delegate.gpu().allow_precision_loss();
|
||||
tflite_gpu_runner_api_ = delegate.gpu().api();
|
||||
tflite_gpu_runner_usage_ = delegate.gpu().usage();
|
||||
use_kernel_caching_ = delegate.gpu().has_cached_kernel_path();
|
||||
use_serialized_model_ = delegate.gpu().has_serialized_model_dir() &&
|
||||
delegate.gpu().has_model_token();
|
||||
|
||||
if (use_kernel_caching_) {
|
||||
#ifdef MEDIAPIPE_ANDROID
|
||||
cached_kernel_filename_ = delegate.gpu().cached_kernel_path() +
|
||||
mediapipe::File::Basename(options.model_path()) +
|
||||
".ker";
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
}
|
||||
if (use_serialized_model_) {
|
||||
#ifdef MEDIAPIPE_ANDROID
|
||||
serialized_model_path_ = mediapipe::file::JoinPath(
|
||||
delegate.gpu().serialized_model_dir(), delegate.gpu().model_token());
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
|
||||
return gpu_helper_.RunInGlContext(
|
||||
[this, &cc]() -> absl::Status { return InitTFLiteGPURunner(cc); });
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlAdvancedImpl::Process(CalculatorContext* cc) {
|
||||
if (kInTensors(cc).IsEmpty()) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
const auto& input_tensors = *kInTensors(cc);
|
||||
RET_CHECK(!input_tensors.empty());
|
||||
auto output_tensors = absl::make_unique<std::vector<Tensor>>();
|
||||
|
||||
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
|
||||
[this, &input_tensors, &output_tensors]() -> absl::Status {
|
||||
for (int i = 0; i < input_tensors.size(); ++i) {
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToInputTensor(
|
||||
input_tensors[i].GetOpenGlBufferReadView().name(), i));
|
||||
}
|
||||
output_tensors->reserve(output_shapes_.size());
|
||||
for (int i = 0; i < output_shapes_.size(); ++i) {
|
||||
output_tensors->emplace_back(Tensor::ElementType::kFloat32,
|
||||
output_shapes_[i]);
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->BindSSBOToOutputTensor(
|
||||
output_tensors->back().GetOpenGlBufferWriteView().name(), i));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}));
|
||||
|
||||
// Run inference.
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->Invoke());
|
||||
kOutTensors(cc).Send(std::move(output_tensors));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlAdvancedImpl::SaveGpuCaches() {
|
||||
#ifdef MEDIAPIPE_ANDROID
|
||||
if (use_kernel_caching_) {
|
||||
// Save kernel file.
|
||||
auto kernel_cache = absl::make_unique<std::vector<uint8_t>>(
|
||||
tflite_gpu_runner_->GetSerializedBinaryCache());
|
||||
std::string cache_str(kernel_cache->begin(), kernel_cache->end());
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::file::SetContents(cached_kernel_filename_, cache_str));
|
||||
}
|
||||
if (use_serialized_model_) {
|
||||
// Save serialized model file.
|
||||
ASSIGN_OR_RETURN(std::vector<uint8_t> serialized_model_vec,
|
||||
tflite_gpu_runner_->GetSerializedModel());
|
||||
absl::string_view serialized_model(
|
||||
reinterpret_cast<char*>(serialized_model_vec.data()),
|
||||
serialized_model_vec.size());
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::file::SetContents(serialized_model_path_, serialized_model));
|
||||
}
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlAdvancedImpl::Close(CalculatorContext* cc) {
|
||||
MP_RETURN_IF_ERROR(SaveGpuCaches());
|
||||
return gpu_helper_.RunInGlContext([this]() -> absl::Status {
|
||||
tflite_gpu_runner_.reset();
|
||||
return absl::OkStatus();
|
||||
});
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlAdvancedImpl::ReadGpuCaches() {
|
||||
#ifdef MEDIAPIPE_ANDROID
|
||||
if (use_kernel_caching_ && File::Exists(cached_kernel_filename_)) {
|
||||
// Load pre-compiled kernel file.
|
||||
std::string cache_str;
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::file::GetContents(cached_kernel_filename_, &cache_str));
|
||||
std::vector<uint8_t> cache_vec(cache_str.begin(), cache_str.end());
|
||||
tflite_gpu_runner_->SetSerializedBinaryCache(std::move(cache_vec));
|
||||
}
|
||||
if (use_serialized_model_ && File::Exists(serialized_model_path_)) {
|
||||
// Load serialized model file.
|
||||
std::string serialized_model_str;
|
||||
MP_RETURN_IF_ERROR(
|
||||
file::GetContents(serialized_model_path_, &serialized_model_str));
|
||||
std::vector<uint8_t> serialized_model_vec(serialized_model_str.begin(),
|
||||
serialized_model_str.end());
|
||||
tflite_gpu_runner_->SetSerializedModel(std::move(serialized_model_vec));
|
||||
}
|
||||
#endif // MEDIAPIPE_ANDROID
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status InferenceCalculatorGlAdvancedImpl::InitTFLiteGPURunner(
|
||||
CalculatorContext* cc) {
|
||||
ASSIGN_OR_RETURN(model_packet_, GetModelAsPacket(cc));
|
||||
const auto& model = *model_packet_.Get();
|
||||
|
||||
// Create runner
|
||||
tflite::gpu::InferenceOptions options;
|
||||
options.priority1 = allow_precision_loss_
|
||||
? tflite::gpu::InferencePriority::MIN_LATENCY
|
||||
: tflite::gpu::InferencePriority::MAX_PRECISION;
|
||||
options.priority2 = tflite::gpu::InferencePriority::AUTO;
|
||||
options.priority3 = tflite::gpu::InferencePriority::AUTO;
|
||||
switch (tflite_gpu_runner_usage_) {
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
|
||||
FAST_SINGLE_ANSWER: {
|
||||
options.usage = tflite::gpu::InferenceUsage::FAST_SINGLE_ANSWER;
|
||||
break;
|
||||
}
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::
|
||||
SUSTAINED_SPEED: {
|
||||
options.usage = tflite::gpu::InferenceUsage::SUSTAINED_SPEED;
|
||||
break;
|
||||
}
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::UNSPECIFIED: {
|
||||
return absl::InternalError("inference usage need to be specified.");
|
||||
}
|
||||
}
|
||||
tflite_gpu_runner_ = std::make_unique<tflite::gpu::TFLiteGPURunner>(options);
|
||||
switch (tflite_gpu_runner_api_) {
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::ANY: {
|
||||
// Do not need to force any specific API.
|
||||
break;
|
||||
}
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENGL: {
|
||||
tflite_gpu_runner_->ForceOpenGL();
|
||||
break;
|
||||
}
|
||||
case mediapipe::InferenceCalculatorOptions::Delegate::Gpu::OPENCL: {
|
||||
tflite_gpu_runner_->ForceOpenCL();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (kSideInOpResolver(cc).IsConnected()) {
|
||||
const tflite::OpResolver& op_resolver = kSideInOpResolver(cc).Get();
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
|
||||
model, op_resolver, /*allow_quant_ops=*/true));
|
||||
} else {
|
||||
tflite::ops::builtin::BuiltinOpResolver op_resolver =
|
||||
kSideInCustomOpResolver(cc).GetOr(
|
||||
tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates());
|
||||
MP_RETURN_IF_ERROR(tflite_gpu_runner_->InitializeWithModel(
|
||||
model, op_resolver, /*allow_quant_ops=*/true));
|
||||
}
|
||||
|
||||
// Create and bind OpenGL buffers for outputs.
|
||||
// The buffers are created once and their ids are passed to calculator outputs
|
||||
output_shapes_.resize(tflite_gpu_runner_->outputs_size());
|
||||
for (int i = 0; i < tflite_gpu_runner_->outputs_size(); ++i) {
|
||||
output_shapes_[i] = {tflite_gpu_runner_->GetOutputShapes()[i].b,
|
||||
tflite_gpu_runner_->GetOutputShapes()[i].h,
|
||||
tflite_gpu_runner_->GetOutputShapes()[i].w,
|
||||
tflite_gpu_runner_->GetOutputShapes()[i].c};
|
||||
}
|
||||
|
||||
MP_RETURN_IF_ERROR(ReadGpuCaches());
|
||||
return tflite_gpu_runner_->Build();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
|
@ -38,61 +38,13 @@
|
|||
#endif // defined(__APPLE__)
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
void DoSmokeTest(const std::string& graph_proto) {
|
||||
const int width = 8;
|
||||
const int height = 8;
|
||||
const int channels = 3;
|
||||
// Prepare input tensor.
|
||||
auto input_vec = absl::make_unique<std::vector<Tensor>>();
|
||||
input_vec->emplace_back(Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape{1, height, width, channels});
|
||||
{
|
||||
auto view1 = input_vec->back().GetCpuWriteView();
|
||||
auto tensor_buffer = view1.buffer<float>();
|
||||
ASSERT_NE(tensor_buffer, nullptr);
|
||||
for (int i = 0; i < width * height * channels - 1; i++) {
|
||||
tensor_buffer[i] = 1;
|
||||
}
|
||||
}
|
||||
constexpr int kTensorWidth = 8;
|
||||
constexpr int kTensorHeight = 8;
|
||||
constexpr int kTensorChannels = 3;
|
||||
|
||||
// Prepare single calculator graph to and wait for packets.
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensor_out", &graph_config, &output_packets);
|
||||
CalculatorGraph graph(graph_config);
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
// Push the tensor into the graph.
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tensor_in", Adopt(input_vec.release()).At(Timestamp(0))));
|
||||
// Wait until the calculator done processing.
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
|
||||
// Get and process results.
|
||||
const std::vector<Tensor>& result_vec =
|
||||
output_packets[0].Get<std::vector<Tensor>>();
|
||||
ASSERT_EQ(1, result_vec.size());
|
||||
|
||||
const Tensor& result = result_vec[0];
|
||||
auto view = result.GetCpuReadView();
|
||||
auto result_buffer = view.buffer<float>();
|
||||
ASSERT_NE(result_buffer, nullptr);
|
||||
for (int i = 0; i < width * height * channels - 1; i++) {
|
||||
ASSERT_EQ(3, result_buffer[i]);
|
||||
}
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
MP_ASSERT_OK(graph.CloseInputStream("tensor_in"));
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// Tests a simple add model that adds an input tensor to itself.
|
||||
TEST(InferenceCalculatorTest, SmokeTest) {
|
||||
std::string graph_proto = R"(
|
||||
constexpr char kGraphWithModelPathInOption[] = R"(
|
||||
input_stream: "tensor_in"
|
||||
node {
|
||||
calculator: "InferenceCalculator"
|
||||
|
@ -106,18 +58,7 @@ TEST(InferenceCalculatorTest, SmokeTest) {
|
|||
}
|
||||
}
|
||||
)";
|
||||
// Test CPU inference only.
|
||||
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
|
||||
graph_proto, {{"$delegate", "delegate { tflite {} }"}}));
|
||||
DoSmokeTest(absl::StrReplaceAll(graph_proto,
|
||||
{{"$delegate", "delegate { xnnpack {} }"}}));
|
||||
DoSmokeTest(absl::StrReplaceAll(
|
||||
graph_proto,
|
||||
{{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}}));
|
||||
}
|
||||
|
||||
TEST(InferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) {
|
||||
std::string graph_proto = R"(
|
||||
constexpr char kGraphWithModelAsInputSidePacket[] = R"(
|
||||
input_stream: "tensor_in"
|
||||
|
||||
node {
|
||||
|
@ -154,7 +95,84 @@ TEST(InferenceCalculatorTest, SmokeTest_ModelAsInputSidePacket) {
|
|||
}
|
||||
}
|
||||
)";
|
||||
DoSmokeTest(graph_proto);
|
||||
|
||||
std::vector<Tensor> CreateInputs() {
|
||||
std::vector<Tensor> input_vec;
|
||||
// Prepare input tensor.
|
||||
input_vec.emplace_back(
|
||||
Tensor::ElementType::kFloat32,
|
||||
Tensor::Shape{1, kTensorHeight, kTensorWidth, kTensorChannels});
|
||||
{
|
||||
auto view = input_vec.back().GetCpuWriteView();
|
||||
auto num_elements = input_vec.back().shape().num_elements();
|
||||
auto tensor_buffer = view.buffer<float>();
|
||||
for (int i = 0; i < num_elements; i++) {
|
||||
tensor_buffer[i] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
return input_vec;
|
||||
}
|
||||
|
||||
void RunGraphThenClose(CalculatorGraph& graph, std::vector<Tensor> input_vec) {
|
||||
MP_ASSERT_OK(graph.StartRun({}));
|
||||
|
||||
// Push the tensor into the graph.
|
||||
MP_ASSERT_OK(graph.AddPacketToInputStream(
|
||||
"tensor_in",
|
||||
MakePacket<std::vector<Tensor>>(std::move(input_vec)).At(Timestamp(0))));
|
||||
// Wait until the calculator done processing.
|
||||
MP_ASSERT_OK(graph.WaitUntilIdle());
|
||||
|
||||
// Fully close graph at end, otherwise calculator+tensors are destroyed
|
||||
// after calling WaitUntilDone().
|
||||
MP_ASSERT_OK(graph.CloseInputStream("tensor_in"));
|
||||
MP_ASSERT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
void DoSmokeTest(const std::string& graph_proto) {
|
||||
auto input_vec = CreateInputs();
|
||||
|
||||
// Prepare single calculator graph to and wait for packets.
|
||||
CalculatorGraphConfig graph_config =
|
||||
ParseTextProtoOrDie<CalculatorGraphConfig>(graph_proto);
|
||||
std::vector<Packet> output_packets;
|
||||
tool::AddVectorSink("tensor_out", &graph_config, &output_packets);
|
||||
CalculatorGraph graph(graph_config);
|
||||
|
||||
RunGraphThenClose(graph, std::move(input_vec));
|
||||
|
||||
ASSERT_EQ(1, output_packets.size());
|
||||
|
||||
// Get and process results.
|
||||
const std::vector<Tensor>& result_vec =
|
||||
output_packets[0].Get<std::vector<Tensor>>();
|
||||
ASSERT_EQ(1, result_vec.size());
|
||||
|
||||
const Tensor& result = result_vec[0];
|
||||
auto view = result.GetCpuReadView();
|
||||
auto result_buffer = view.buffer<float>();
|
||||
ASSERT_NE(result_buffer, nullptr);
|
||||
for (int i = 0; i < result.shape().num_elements(); i++) {
|
||||
ASSERT_EQ(3, result_buffer[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Tests a simple add model that adds an input tensor to itself.
|
||||
TEST(InferenceCalculatorTest, SmokeTest) {
|
||||
// Test CPU inference only.
|
||||
DoSmokeTest(/*graph_proto=*/absl::StrReplaceAll(
|
||||
kGraphWithModelPathInOption, {{"$delegate", "delegate { tflite {} }"}}));
|
||||
DoSmokeTest(absl::StrReplaceAll(kGraphWithModelPathInOption,
|
||||
{{"$delegate", "delegate { xnnpack {} }"}}));
|
||||
DoSmokeTest(absl::StrReplaceAll(
|
||||
kGraphWithModelPathInOption,
|
||||
{{"$delegate", "delegate { xnnpack { num_threads: 10 } }"}}));
|
||||
}
|
||||
|
||||
TEST(InferenceCalculatorTest, ModelAsInputSidePacketSmokeTest) {
|
||||
DoSmokeTest(kGraphWithModelAsInputSidePacket);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mediapipe/calculators/tensor/tensors_to_classification_calculator.pb.h"
|
||||
|
@ -25,6 +24,7 @@
|
|||
#include "mediapipe/framework/formats/classification.pb.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
#if defined(MEDIAPIPE_MOBILE)
|
||||
#include "mediapipe/util/android/file/base/file.h"
|
||||
|
@ -35,6 +35,17 @@
|
|||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
namespace {
|
||||
|
||||
void SetClassificationLabel(const LabelMapItem label_map_item,
|
||||
Classification* classification) {
|
||||
classification->set_label(label_map_item.name());
|
||||
if (label_map_item.has_display_name()) {
|
||||
classification->set_display_name(label_map_item.display_name());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Convert result tensors from classification models into MediaPipe
|
||||
// classifications.
|
||||
|
@ -54,7 +65,6 @@ namespace api2 {
|
|||
// output_stream: "CLASSIFICATIONS:classifications"
|
||||
// options: {
|
||||
// [mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
||||
// num_classes: 1024
|
||||
// min_score_threshold: 0.1
|
||||
// label_map_path: "labelmap.txt"
|
||||
// }
|
||||
|
@ -72,22 +82,35 @@ class TensorsToClassificationCalculator : public Node {
|
|||
absl::Status Close(CalculatorContext* cc) override;
|
||||
|
||||
private:
|
||||
::mediapipe::TensorsToClassificationCalculatorOptions options_;
|
||||
int top_k_ = 0;
|
||||
absl::node_hash_map<int, std::string> label_map_;
|
||||
bool sort_by_descending_score_ = false;
|
||||
proto_ns::Map<int64, LabelMapItem> local_label_map_;
|
||||
bool label_map_loaded_ = false;
|
||||
bool is_binary_classification_ = false;
|
||||
float min_score_threshold_ = std::numeric_limits<float>::lowest();
|
||||
|
||||
// Set of allowed or ignored class indices.
|
||||
struct ClassIndexSet {
|
||||
absl::flat_hash_set<int> values;
|
||||
bool is_allowlist;
|
||||
};
|
||||
// Allowed or ignored class indices based on provided options.
|
||||
// These are used to filter out the output classification results.
|
||||
ClassIndexSet class_index_set_;
|
||||
bool IsClassIndexAllowed(int class_index);
|
||||
const proto_ns::Map<int64, LabelMapItem>& GetLabelMap(CalculatorContext* cc);
|
||||
};
|
||||
MEDIAPIPE_REGISTER_NODE(TensorsToClassificationCalculator);
|
||||
|
||||
absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
||||
options_ =
|
||||
cc->Options<::mediapipe::TensorsToClassificationCalculatorOptions>();
|
||||
const auto& options = cc->Options<TensorsToClassificationCalculatorOptions>();
|
||||
|
||||
top_k_ = options_.top_k();
|
||||
if (options_.has_label_map_path()) {
|
||||
top_k_ = options.top_k();
|
||||
sort_by_descending_score_ = options.sort_by_descending_score();
|
||||
if (options.has_label_map_path()) {
|
||||
std::string string_path;
|
||||
ASSIGN_OR_RETURN(string_path,
|
||||
PathToResourceAsFile(options_.label_map_path()));
|
||||
PathToResourceAsFile(options.label_map_path()));
|
||||
std::string label_map_string;
|
||||
MP_RETURN_IF_ERROR(
|
||||
mediapipe::GetResourceContents(string_path, &label_map_string));
|
||||
|
@ -96,18 +119,45 @@ absl::Status TensorsToClassificationCalculator::Open(CalculatorContext* cc) {
|
|||
std::string line;
|
||||
int i = 0;
|
||||
while (std::getline(stream, line)) {
|
||||
label_map_[i++] = line;
|
||||
LabelMapItem item;
|
||||
item.set_name(line);
|
||||
local_label_map_[i++] = item;
|
||||
}
|
||||
label_map_loaded_ = true;
|
||||
} else if (options_.has_label_map()) {
|
||||
for (int i = 0; i < options_.label_map().entries_size(); ++i) {
|
||||
const auto& entry = options_.label_map().entries(i);
|
||||
RET_CHECK(!label_map_.contains(entry.id()))
|
||||
} else if (!options.label_items().empty()) {
|
||||
label_map_loaded_ = true;
|
||||
} else if (options.has_label_map()) {
|
||||
for (int i = 0; i < options.label_map().entries_size(); ++i) {
|
||||
const auto& entry = options.label_map().entries(i);
|
||||
RET_CHECK(!local_label_map_.contains(entry.id()))
|
||||
<< "Duplicate id found: " << entry.id();
|
||||
label_map_[entry.id()] = entry.label();
|
||||
LabelMapItem item;
|
||||
item.set_name(entry.label());
|
||||
local_label_map_[entry.id()] = item;
|
||||
}
|
||||
label_map_loaded_ = true;
|
||||
}
|
||||
if (options.has_min_score_threshold()) {
|
||||
min_score_threshold_ = options.min_score_threshold();
|
||||
}
|
||||
is_binary_classification_ = options.binary_classification();
|
||||
|
||||
if (is_binary_classification_) {
|
||||
RET_CHECK(options.allow_classes().empty() &&
|
||||
options.ignore_classes().empty());
|
||||
}
|
||||
if (!options.allow_classes().empty()) {
|
||||
RET_CHECK(options.ignore_classes().empty());
|
||||
class_index_set_.is_allowlist = true;
|
||||
for (int i = 0; i < options.allow_classes_size(); ++i) {
|
||||
class_index_set_.values.insert(options.allow_classes(i));
|
||||
}
|
||||
} else {
|
||||
class_index_set_.is_allowlist = false;
|
||||
for (int i = 0; i < options.ignore_classes_size(); ++i) {
|
||||
class_index_set_.values.insert(options.ignore_classes(i));
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -118,19 +168,19 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
|||
|
||||
int num_classes = input_tensors[0].shape().num_elements();
|
||||
|
||||
if (options_.binary_classification()) {
|
||||
if (is_binary_classification_) {
|
||||
RET_CHECK_EQ(num_classes, 1);
|
||||
// Number of classes for binary classification.
|
||||
num_classes = 2;
|
||||
}
|
||||
if (label_map_loaded_) {
|
||||
RET_CHECK_EQ(num_classes, label_map_.size());
|
||||
RET_CHECK_EQ(num_classes, GetLabelMap(cc).size());
|
||||
}
|
||||
auto view = input_tensors[0].GetCpuReadView();
|
||||
auto raw_scores = view.buffer<float>();
|
||||
|
||||
auto classification_list = absl::make_unique<ClassificationList>();
|
||||
if (options_.binary_classification()) {
|
||||
if (is_binary_classification_) {
|
||||
Classification* class_first = classification_list->add_classification();
|
||||
Classification* class_second = classification_list->add_classification();
|
||||
class_first->set_index(0);
|
||||
|
@ -139,41 +189,48 @@ absl::Status TensorsToClassificationCalculator::Process(CalculatorContext* cc) {
|
|||
class_second->set_score(1. - raw_scores[0]);
|
||||
|
||||
if (label_map_loaded_) {
|
||||
class_first->set_label(label_map_[0]);
|
||||
class_second->set_label(label_map_[1]);
|
||||
SetClassificationLabel(GetLabelMap(cc).at(0), class_first);
|
||||
SetClassificationLabel(GetLabelMap(cc).at(1), class_second);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < num_classes; ++i) {
|
||||
if (options_.has_min_score_threshold() &&
|
||||
raw_scores[i] < options_.min_score_threshold()) {
|
||||
if (!IsClassIndexAllowed(i)) {
|
||||
continue;
|
||||
}
|
||||
if (raw_scores[i] < min_score_threshold_) {
|
||||
continue;
|
||||
}
|
||||
Classification* classification =
|
||||
classification_list->add_classification();
|
||||
classification->set_index(i);
|
||||
classification->set_score(raw_scores[i]);
|
||||
|
||||
if (label_map_loaded_) {
|
||||
classification->set_label(label_map_[i]);
|
||||
SetClassificationLabel(GetLabelMap(cc).at(i), classification);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note that partial_sort will raise error when top_k_ >
|
||||
// classification_list->classification_size().
|
||||
CHECK_GE(classification_list->classification_size(), top_k_);
|
||||
auto raw_classification_list = classification_list->mutable_classification();
|
||||
if (top_k_ > 0 && classification_list->classification_size() >= top_k_) {
|
||||
if (top_k_ > 0) {
|
||||
int desired_size =
|
||||
std::min(classification_list->classification_size(), top_k_);
|
||||
std::partial_sort(raw_classification_list->begin(),
|
||||
raw_classification_list->begin() + top_k_,
|
||||
raw_classification_list->begin() + desired_size,
|
||||
raw_classification_list->end(),
|
||||
[](const Classification a, const Classification b) {
|
||||
return a.score() > b.score();
|
||||
});
|
||||
|
||||
// Resizes the underlying list to have only top_k_ classifications.
|
||||
raw_classification_list->DeleteSubrange(
|
||||
top_k_, raw_classification_list->size() - top_k_);
|
||||
if (desired_size >= top_k_) {
|
||||
// Resizes the underlying list to have only top_k_ classifications.
|
||||
raw_classification_list->DeleteSubrange(
|
||||
top_k_, raw_classification_list->size() - top_k_);
|
||||
}
|
||||
} else if (sort_by_descending_score_) {
|
||||
std::sort(raw_classification_list->begin(), raw_classification_list->end(),
|
||||
[](const Classification a, const Classification b) {
|
||||
return a.score() > b.score();
|
||||
});
|
||||
}
|
||||
kOutClassificationList(cc).Send(std::move(classification_list));
|
||||
return absl::OkStatus();
|
||||
|
@ -183,5 +240,24 @@ absl::Status TensorsToClassificationCalculator::Close(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
bool TensorsToClassificationCalculator::IsClassIndexAllowed(int class_index) {
|
||||
if (class_index_set_.values.empty()) {
|
||||
return true;
|
||||
}
|
||||
if (class_index_set_.is_allowlist) {
|
||||
return class_index_set_.values.contains(class_index);
|
||||
} else {
|
||||
return !class_index_set_.values.contains(class_index);
|
||||
}
|
||||
}
|
||||
|
||||
const proto_ns::Map<int64, LabelMapItem>&
|
||||
TensorsToClassificationCalculator::GetLabelMap(CalculatorContext* cc) {
|
||||
return !local_label_map_.empty()
|
||||
? local_label_map_
|
||||
: cc->Options<TensorsToClassificationCalculatorOptions>()
|
||||
.label_items();
|
||||
}
|
||||
|
||||
} // namespace api2
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -19,6 +19,7 @@ syntax = "proto2";
|
|||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/util/label_map.proto";
|
||||
|
||||
message TensorsToClassificationCalculatorOptions {
|
||||
extend .mediapipe.CalculatorOptions {
|
||||
|
@ -38,16 +39,37 @@ message TensorsToClassificationCalculatorOptions {
|
|||
// Number of highest scoring labels to output. If top_k is not positive then
|
||||
// all labels are used.
|
||||
optional int32 top_k = 2;
|
||||
// Whether results should be sorted by descending score. By default, results
|
||||
// may or may not be sorted: setting this to true guarantees that the returned
|
||||
// results will be sorted by descending score.
|
||||
optional bool sort_by_descending_score = 9;
|
||||
// Path to a label map file for getting the actual name of class ids.
|
||||
optional string label_map_path = 3;
|
||||
// Label map. (Can be used instead of label_map_path.)
|
||||
// NOTE: "label_map_path", if specified, takes precedence over "label_map".
|
||||
// NOTE: either "label_map_path" or "label_items", if specified, takes
|
||||
// precedence over "label_map".
|
||||
// Deprecated: please use `label_items` instead.
|
||||
optional LabelMap label_map = 5;
|
||||
|
||||
// Label items. (Can be used instead of label_map_path.)
|
||||
// NOTE: "label_map_path", if specified, takes precedence over "label_items".
|
||||
map<int64, LabelMapItem> label_items = 6;
|
||||
|
||||
// Whether the input is a single float for binary classification.
|
||||
// When true, only a single float is expected in the input tensor and the
|
||||
// label map, if provided, is expected to have exactly two labels.
|
||||
// The single score(float) represent the probability of first label, and
|
||||
// 1 - score is the probabilility of the second label.
|
||||
optional bool binary_classification = 4;
|
||||
|
||||
// The ids of classes that should be ignored during decoding the score for
|
||||
// each classification. If `ignore_classes` is specified, all the other
|
||||
// classes that are not in the `ignore_class` field will be considered during
|
||||
// decoding. `ignore_classes` and `allow_classes` are mutually exclusive.
|
||||
repeated int32 ignore_classes = 7 [packed = true];
|
||||
// The ids of classes that will be allowed during decoding the score for
|
||||
// each classification. If `allow_classes` is specified, all the other classes
|
||||
// that are not in the `allow_classes` field will be completely ignored.
|
||||
// `ignore_classes` and `allow_classes` are mutually exclusive.
|
||||
repeated int32 allow_classes = 8 [packed = true];
|
||||
}
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
|
@ -206,4 +208,119 @@ TEST_F(TensorsToClassificationCalculatorTest, CorrectOutputWithTopK) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorsToClassificationCalculatorTest,
|
||||
CorrectOutputWithSortByDescendingScore) {
|
||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToClassificationCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "CLASSIFICATIONS:classifications"
|
||||
options {
|
||||
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
||||
sort_by_descending_score: true
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
BuildGraph(&runner, {0, 0.5, 1});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
|
||||
|
||||
EXPECT_EQ(1, output_packets_.size());
|
||||
|
||||
const auto& classification_list =
|
||||
output_packets_[0].Get<ClassificationList>();
|
||||
|
||||
// Verify results are sorted by descending score.
|
||||
EXPECT_EQ(3, classification_list.classification_size());
|
||||
float score = std::numeric_limits<float>::max();
|
||||
for (int i = 0; i < classification_list.classification_size(); ++i) {
|
||||
EXPECT_LE(classification_list.classification(i).score(), score);
|
||||
score = classification_list.classification(i).score();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorsToClassificationCalculatorTest,
|
||||
ClassNameAllowlistWithLabelItems) {
|
||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToClassificationCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "CLASSIFICATIONS:classifications"
|
||||
options {
|
||||
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
||||
label_items {
|
||||
key: 0
|
||||
value { name: "ClassA" }
|
||||
}
|
||||
label_items {
|
||||
key: 1
|
||||
value { name: "ClassB" }
|
||||
}
|
||||
label_items {
|
||||
key: 2
|
||||
value { name: "ClassC" }
|
||||
}
|
||||
allow_classes: 1
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
BuildGraph(&runner, {0, 0.5, 1});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
|
||||
|
||||
EXPECT_EQ(1, output_packets_.size());
|
||||
|
||||
const auto& classification_list =
|
||||
output_packets_[0].Get<ClassificationList>();
|
||||
EXPECT_EQ(1, classification_list.classification_size());
|
||||
EXPECT_EQ(1, classification_list.classification(0).index());
|
||||
EXPECT_EQ(0.5, classification_list.classification(0).score());
|
||||
ASSERT_TRUE(classification_list.classification(0).has_label());
|
||||
}
|
||||
|
||||
TEST_F(TensorsToClassificationCalculatorTest,
|
||||
ClassNameIgnorelistWithLabelItems) {
|
||||
mediapipe::CalculatorRunner runner(ParseTextProtoOrDie<Node>(R"pb(
|
||||
calculator: "TensorsToClassificationCalculator"
|
||||
input_stream: "TENSORS:tensors"
|
||||
output_stream: "CLASSIFICATIONS:classifications"
|
||||
options {
|
||||
[mediapipe.TensorsToClassificationCalculatorOptions.ext] {
|
||||
label_items {
|
||||
key: 0
|
||||
value { name: "ClassA" }
|
||||
}
|
||||
label_items {
|
||||
key: 1
|
||||
value { name: "ClassB" }
|
||||
}
|
||||
label_items {
|
||||
key: 2
|
||||
value { name: "ClassC" }
|
||||
}
|
||||
ignore_classes: 1
|
||||
}
|
||||
}
|
||||
)pb"));
|
||||
|
||||
BuildGraph(&runner, {0, 0.5, 1});
|
||||
MP_ASSERT_OK(runner.Run());
|
||||
|
||||
const auto& output_packets_ = runner.Outputs().Tag("CLASSIFICATIONS").packets;
|
||||
|
||||
EXPECT_EQ(1, output_packets_.size());
|
||||
|
||||
const auto& classification_list =
|
||||
output_packets_[0].Get<ClassificationList>();
|
||||
EXPECT_EQ(2, classification_list.classification_size());
|
||||
EXPECT_EQ(0, classification_list.classification(0).index());
|
||||
EXPECT_EQ(0, classification_list.classification(0).score());
|
||||
ASSERT_TRUE(classification_list.classification(0).has_label());
|
||||
EXPECT_EQ(2, classification_list.classification(1).index());
|
||||
EXPECT_EQ(1, classification_list.classification(1).score());
|
||||
ASSERT_TRUE(classification_list.classification(1).has_label());
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -20,10 +20,8 @@
|
|||
#include "mediapipe/framework/calculator_context.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/image.h"
|
||||
#include "mediapipe/framework/formats/image_opencv.h"
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
#include "mediapipe/framework/port.h"
|
||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "mediapipe/gpu/gpu_origin.pb.h"
|
||||
|
@ -37,6 +35,11 @@
|
|||
#include "mediapipe/gpu/shader_util.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
#include "mediapipe/framework/formats/image_opencv.h"
|
||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
#include "tensorflow/lite/delegates/gpu/gl/converters/util.h"
|
||||
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
|
||||
|
@ -159,9 +162,10 @@ class TensorsToSegmentationCalculator : public CalculatorBase {
|
|||
return options_.gpu_origin() != mediapipe::GpuOrigin_Mode_TOP_LEFT;
|
||||
}
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
template <class T>
|
||||
absl::Status ApplyActivation(cv::Mat& tensor_mat, cv::Mat* small_mask_mat);
|
||||
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
::mediapipe::TensorsToSegmentationCalculatorOptions options_;
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
|
@ -283,7 +287,11 @@ absl::Status TensorsToSegmentationCalculator::Process(CalculatorContext* cc) {
|
|||
RET_CHECK_FAIL() << "GPU processing disabled.";
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
} else {
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
MP_RETURN_IF_ERROR(ProcessCpu(cc));
|
||||
#else
|
||||
RET_CHECK_FAIL() << "OpenCV processing disabled.";
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -311,6 +319,7 @@ absl::Status TensorsToSegmentationCalculator::Close(CalculatorContext* cc) {
|
|||
|
||||
absl::Status TensorsToSegmentationCalculator::ProcessCpu(
|
||||
CalculatorContext* cc) {
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
// Get input streams, and dimensions.
|
||||
const auto& input_tensors =
|
||||
cc->Inputs().Tag(kTensorsTag).Get<std::vector<Tensor>>();
|
||||
|
@ -360,10 +369,12 @@ absl::Status TensorsToSegmentationCalculator::ProcessCpu(
|
|||
cv::resize(small_mask_mat, *output_mat,
|
||||
cv::Size(output_width, output_height));
|
||||
cc->Outputs().Tag(kMaskTag).Add(output_mask.release(), cc->InputTimestamp());
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_OPENCV
|
||||
template <class T>
|
||||
absl::Status TensorsToSegmentationCalculator::ApplyActivation(
|
||||
cv::Mat& tensor_mat, cv::Mat* small_mask_mat) {
|
||||
|
@ -411,6 +422,7 @@ absl::Status TensorsToSegmentationCalculator::ApplyActivation(
|
|||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
#endif // !MEDIAPIPE_DISABLE_OPENCV
|
||||
|
||||
// Steps:
|
||||
// 1. receive tensor
|
||||
|
|
|
@ -300,17 +300,26 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
|||
RET_CHECK(options_.batch_size() == 1 ||
|
||||
options_.recurrent_tag_pair().empty())
|
||||
<< "To use recurrent_tag_pairs, batch_size must be 1.";
|
||||
|
||||
// Helper for StrJoin. Prints key (tag) of map<string, string>.
|
||||
auto TagFormatter =
|
||||
absl::PairFormatter(absl::StreamFormatter(), "",
|
||||
[](std::string* out, const std::string& second) {});
|
||||
|
||||
for (const auto& tag_pair : options_.recurrent_tag_pair()) {
|
||||
const std::vector<std::string> tags = absl::StrSplit(tag_pair, ':');
|
||||
RET_CHECK_EQ(tags.size(), 2) << "recurrent_tag_pair must be a colon "
|
||||
"separated string with two components: "
|
||||
<< tag_pair;
|
||||
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[0]))
|
||||
<< "Can't find tag '" << tags[0] << "' in signature "
|
||||
<< options_.signature_name();
|
||||
<< options_.signature_name() << "; instead found tags "
|
||||
<< absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tags[1]))
|
||||
<< "Can't find tag '" << tags[1] << "' in signature "
|
||||
<< options_.signature_name();
|
||||
<< options_.signature_name() << " ; instead found tags "
|
||||
<< absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
|
||||
recurrent_feed_tags_.insert(tags[0]);
|
||||
recurrent_fetch_tags_to_feed_tags_[tags[1]] = tags[0];
|
||||
}
|
||||
|
@ -319,12 +328,14 @@ class TensorFlowInferenceCalculator : public CalculatorBase {
|
|||
for (const std::string& tag : cc->Inputs().GetTags()) {
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
|
||||
<< "Can't find tag '" << tag << "' in signature "
|
||||
<< options_.signature_name();
|
||||
<< options_.signature_name() << "; instead found tags "
|
||||
<< absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
|
||||
}
|
||||
for (const std::string& tag : cc->Outputs().GetTags()) {
|
||||
RET_CHECK(mediapipe::ContainsKey(tag_to_tensor_map_, tag))
|
||||
<< "Can't find tag '" << tag << "' in signature "
|
||||
<< options_.signature_name();
|
||||
<< options_.signature_name() << "; instead found tags "
|
||||
<< absl::StrJoin(tag_to_tensor_map_, ", ", TagFormatter);
|
||||
}
|
||||
|
||||
{
|
||||
|
|
|
@ -38,6 +38,9 @@
|
|||
|
||||
namespace mediapipe {
|
||||
|
||||
using ::testing::AllOf;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
namespace tf = ::tensorflow;
|
||||
|
||||
namespace {
|
||||
|
@ -199,8 +202,8 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed) {
|
|||
auto run_status = runner_->Run();
|
||||
ASSERT_FALSE(run_status.ok());
|
||||
EXPECT_THAT(run_status.ToString(),
|
||||
testing::HasSubstr("TensorFlowInferenceCalculator"));
|
||||
EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Tag B"));
|
||||
HasSubstr("TensorFlowInferenceCalculator"));
|
||||
EXPECT_THAT(run_status.ToString(), HasSubstr("Tag B"));
|
||||
}
|
||||
|
||||
TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) {
|
||||
|
@ -238,8 +241,8 @@ TEST_F(TensorflowInferenceCalculatorTest, GetComputed_MaxInFlight) {
|
|||
auto run_status = runner_->Run();
|
||||
ASSERT_FALSE(run_status.ok());
|
||||
EXPECT_THAT(run_status.ToString(),
|
||||
testing::HasSubstr("TensorFlowInferenceCalculator"));
|
||||
EXPECT_THAT(run_status.ToString(), testing::HasSubstr("Tag B"));
|
||||
HasSubstr("TensorFlowInferenceCalculator"));
|
||||
EXPECT_THAT(run_status.ToString(), HasSubstr("Tag B"));
|
||||
}
|
||||
|
||||
TEST_F(TensorflowInferenceCalculatorTest, BadTag) {
|
||||
|
@ -255,7 +258,12 @@ TEST_F(TensorflowInferenceCalculatorTest, BadTag) {
|
|||
|
||||
runner_ = absl::make_unique<CalculatorRunner>(config);
|
||||
AddSessionInputSidePacket();
|
||||
ASSERT_FALSE(runner_->Run().ok());
|
||||
absl::Status status = runner_->Run();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
AllOf(HasSubstr("Can't find tag 'BAD' in signature"),
|
||||
HasSubstr("instead found tags A, B, EXPENSIVE, MULTIPLIED")));
|
||||
}
|
||||
|
||||
TEST_F(TensorflowInferenceCalculatorTest, GetMultiBatchComputed) {
|
||||
|
@ -740,7 +748,7 @@ TEST_F(TensorflowInferenceCalculatorTest, BatchedInputTooBigBatch) {
|
|||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(
|
||||
status.message(),
|
||||
::testing::HasSubstr(
|
||||
HasSubstr(
|
||||
"has more packets than batch capacity. batch_size: 2 packets: 3"));
|
||||
}
|
||||
|
||||
|
|
|
@ -301,6 +301,8 @@ cc_library(
|
|||
":detection_label_id_to_text_calculator_cc_proto",
|
||||
"//mediapipe/framework/formats:detection_cc_proto",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework:calculator_framework",
|
||||
"//mediapipe/framework:packet",
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/formats/detection.pb.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/proto_ns.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/util/label_map.pb.h"
|
||||
#include "mediapipe/util/resource_util.h"
|
||||
|
@ -55,9 +57,9 @@ class DetectionLabelIdToTextCalculator : public CalculatorBase {
|
|||
private:
|
||||
// Local label map built from the calculator options' `label_map_path` or
|
||||
// `label` field.
|
||||
LabelMap local_label_map_;
|
||||
proto_ns::Map<int64, LabelMapItem> local_label_map_;
|
||||
bool keep_label_id_;
|
||||
const LabelMap& GetLabelMap(CalculatorContext* cc);
|
||||
const proto_ns::Map<int64, LabelMapItem>& GetLabelMap(CalculatorContext* cc);
|
||||
};
|
||||
REGISTER_CALCULATOR(DetectionLabelIdToTextCalculator);
|
||||
|
||||
|
@ -72,13 +74,12 @@ absl::Status DetectionLabelIdToTextCalculator::GetContract(
|
|||
absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
|
||||
cc->SetOffset(TimestampDiff(0));
|
||||
|
||||
const auto& options =
|
||||
cc->Options<::mediapipe::DetectionLabelIdToTextCalculatorOptions>();
|
||||
const auto& options = cc->Options<DetectionLabelIdToTextCalculatorOptions>();
|
||||
|
||||
if (options.has_label_map_path()) {
|
||||
RET_CHECK(!options.has_label_map() && options.label().empty())
|
||||
RET_CHECK(options.label_items().empty() && options.label().empty())
|
||||
<< "Only can set one of the following fields in the CalculatorOptions: "
|
||||
"label_map_path, label, and label_map.";
|
||||
"label_map_path, label, and label_items.";
|
||||
std::string string_path;
|
||||
ASSIGN_OR_RETURN(string_path,
|
||||
PathToResourceAsFile(options.label_map_path()));
|
||||
|
@ -91,16 +92,16 @@ absl::Status DetectionLabelIdToTextCalculator::Open(CalculatorContext* cc) {
|
|||
while (std::getline(stream, line)) {
|
||||
LabelMapItem item;
|
||||
item.set_name(line);
|
||||
(*local_label_map_.mutable_index_to_item())[i++] = item;
|
||||
local_label_map_[i++] = item;
|
||||
}
|
||||
} else if (!options.label().empty()) {
|
||||
RET_CHECK(!options.has_label_map())
|
||||
RET_CHECK(options.label_items().empty())
|
||||
<< "Only can set one of the following fields in the CalculatorOptions: "
|
||||
"label_map_path, label, and label_map.";
|
||||
"label_map_path, label, and label_items.";
|
||||
for (int i = 0; i < options.label_size(); ++i) {
|
||||
LabelMapItem item;
|
||||
item.set_name(options.label(i));
|
||||
(*local_label_map_.mutable_index_to_item())[i] = item;
|
||||
local_label_map_[i] = item;
|
||||
}
|
||||
}
|
||||
keep_label_id_ = options.keep_label_id();
|
||||
|
@ -115,9 +116,8 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
|
|||
Detection& output_detection = output_detections.back();
|
||||
bool has_text_label = false;
|
||||
for (const int32 label_id : output_detection.label_id()) {
|
||||
if (GetLabelMap(cc).index_to_item().find(label_id) !=
|
||||
GetLabelMap(cc).index_to_item().end()) {
|
||||
auto item = GetLabelMap(cc).index_to_item().at(label_id);
|
||||
if (GetLabelMap(cc).contains(label_id)) {
|
||||
auto item = GetLabelMap(cc).at(label_id);
|
||||
output_detection.add_label(item.name());
|
||||
if (item.has_display_name()) {
|
||||
output_detection.add_display_name(item.display_name());
|
||||
|
@ -136,13 +136,12 @@ absl::Status DetectionLabelIdToTextCalculator::Process(CalculatorContext* cc) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
const LabelMap& DetectionLabelIdToTextCalculator::GetLabelMap(
|
||||
CalculatorContext* cc) {
|
||||
return !local_label_map_.index_to_item().empty()
|
||||
const proto_ns::Map<int64, LabelMapItem>&
|
||||
DetectionLabelIdToTextCalculator::GetLabelMap(CalculatorContext* cc) {
|
||||
return !local_label_map_.empty()
|
||||
? local_label_map_
|
||||
: cc->Options<
|
||||
::mediapipe::DetectionLabelIdToTextCalculatorOptions>()
|
||||
.label_map();
|
||||
: cc->Options<DetectionLabelIdToTextCalculatorOptions>()
|
||||
.label_items();
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -38,6 +38,6 @@ message DetectionLabelIdToTextCalculatorOptions {
|
|||
// output detections.
|
||||
optional bool keep_label_id = 3;
|
||||
|
||||
// Label map.
|
||||
optional LabelMap label_map = 4;
|
||||
// Identifying information for each classification label.
|
||||
map<int64, LabelMapItem> label_items = 4;
|
||||
}
|
||||
|
|
|
@ -426,6 +426,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:opencv_core",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:test_util",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
@ -451,6 +452,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
"//mediapipe/framework/port:opencv_video",
|
||||
"//mediapipe/framework/port:parse_text_proto",
|
||||
"//mediapipe/framework/tool:test_util",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
],
|
||||
)
|
||||
|
@ -534,6 +536,7 @@ cc_test(
|
|||
"//mediapipe/framework/port:status",
|
||||
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
|
||||
"//mediapipe/framework/stream_handler:sync_set_input_stream_handler",
|
||||
"//mediapipe/framework/tool:test_util",
|
||||
"//mediapipe/util/tracking:box_tracker_cc_proto",
|
||||
"//mediapipe/util/tracking:tracking_cc_proto",
|
||||
"@com_google_absl//absl/flags:flag",
|
||||
|
|
|
@ -120,7 +120,7 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
|||
// back. To get correct image format, we read the first frame from the video
|
||||
// and get the number of channels.
|
||||
cv::Mat frame;
|
||||
cap_->read(frame);
|
||||
ReadFrame(frame);
|
||||
if (frame.empty()) {
|
||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Fail to read any frames from the video file at "
|
||||
|
@ -193,13 +193,13 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
|||
Timestamp timestamp(cap_->get(cv::CAP_PROP_POS_MSEC) * 1000);
|
||||
if (format_ == ImageFormat::GRAY8) {
|
||||
cv::Mat frame = formats::MatView(image_frame.get());
|
||||
cap_->read(frame);
|
||||
ReadFrame(frame);
|
||||
if (frame.empty()) {
|
||||
return tool::StatusStop();
|
||||
}
|
||||
} else {
|
||||
cv::Mat tmp_frame;
|
||||
cap_->read(tmp_frame);
|
||||
ReadFrame(tmp_frame);
|
||||
if (tmp_frame.empty()) {
|
||||
return tool::StatusStop();
|
||||
}
|
||||
|
@ -234,6 +234,14 @@ class OpenCvVideoDecoderCalculator : public CalculatorBase {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Sometimes an empty frame is returned even though there are more frames.
|
||||
void ReadFrame(cv::Mat& frame) {
|
||||
cap_->read(frame);
|
||||
if (frame.empty()) {
|
||||
cap_->read(frame); // Try again.
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<cv::VideoCapture> cap_;
|
||||
int width_;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "mediapipe/framework/port/opencv_core_inc.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -32,6 +33,7 @@ namespace {
|
|||
constexpr char kVideoTag[] = "VIDEO";
|
||||
constexpr char kVideoPrestreamTag[] = "VIDEO_PRESTREAM";
|
||||
constexpr char kInputFilePathTag[] = "INPUT_FILE_PATH";
|
||||
constexpr char kTestPackageRoot[] = "mediapipe/calculators/video";
|
||||
|
||||
TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
|
||||
CalculatorGraphConfig::Node node_config =
|
||||
|
@ -41,10 +43,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMp4Avc720pVideo) {
|
|||
output_stream: "VIDEO:video"
|
||||
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/video/"
|
||||
"testdata/format_MP4_AVC720P_AAC.video"));
|
||||
runner.MutableSidePackets()->Tag(kInputFilePathTag) =
|
||||
MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
|
||||
"format_MP4_AVC720P_AAC.video"));
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
|
||||
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
|
||||
|
@ -87,10 +88,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestFlvH264Video) {
|
|||
output_stream: "VIDEO:video"
|
||||
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/video/"
|
||||
"testdata/format_FLV_H264_AAC.video"));
|
||||
runner.MutableSidePackets()->Tag(kInputFilePathTag) =
|
||||
MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
|
||||
"format_FLV_H264_AAC.video"));
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
|
||||
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
|
||||
|
@ -131,10 +131,9 @@ TEST(OpenCvVideoDecoderCalculatorTest, TestMkvVp8Video) {
|
|||
output_stream: "VIDEO:video"
|
||||
output_stream: "VIDEO_PRESTREAM:video_prestream")pb");
|
||||
CalculatorRunner runner(node_config);
|
||||
runner.MutableSidePackets()->Tag(kInputFilePathTag) = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/video/"
|
||||
"testdata/format_MKV_VP8_VORBIS.video"));
|
||||
runner.MutableSidePackets()->Tag(kInputFilePathTag) =
|
||||
MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
|
||||
"format_MKV_VP8_VORBIS.video"));
|
||||
MP_EXPECT_OK(runner.Run());
|
||||
|
||||
EXPECT_EQ(runner.Outputs().Tag(kVideoPrestreamTag).packets.size(), 1);
|
||||
|
|
|
@ -28,10 +28,14 @@
|
|||
#include "mediapipe/framework/port/opencv_video_inc.h"
|
||||
#include "mediapipe/framework/port/parse_text_proto.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kTestPackageRoot[] = "mediapipe/calculators/video";
|
||||
|
||||
// Temporarily disable the test.
|
||||
// TODO: Investigate the “Could not open codec 'libx264'” error with
|
||||
// opencv2.
|
||||
|
@ -59,10 +63,9 @@ TEST(OpenCvVideoEncoderCalculatorTest, DISABLED_TestMp4Avc720pVideo) {
|
|||
}
|
||||
)pb");
|
||||
std::map<std::string, Packet> input_side_packets;
|
||||
input_side_packets["input_file_path"] = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/video/"
|
||||
"testdata/format_MP4_AVC720P_AAC.video"));
|
||||
input_side_packets["input_file_path"] =
|
||||
MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
|
||||
"format_MP4_AVC720P_AAC.video"));
|
||||
const std::string output_file_path = "/tmp/tmp_video.mp4";
|
||||
DeletingFile deleting_file(output_file_path, true);
|
||||
input_side_packets["output_file_path"] =
|
||||
|
@ -120,10 +123,9 @@ TEST(OpenCvVideoEncoderCalculatorTest, TestFlvH264Video) {
|
|||
}
|
||||
)pb");
|
||||
std::map<std::string, Packet> input_side_packets;
|
||||
input_side_packets["input_file_path"] = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/video/"
|
||||
"testdata/format_FLV_H264_AAC.video"));
|
||||
input_side_packets["input_file_path"] =
|
||||
MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
|
||||
"format_FLV_H264_AAC.video"));
|
||||
const std::string output_file_path = "/tmp/tmp_video.avi";
|
||||
DeletingFile deleting_file(output_file_path, true);
|
||||
input_side_packets["output_file_path"] =
|
||||
|
@ -183,10 +185,9 @@ TEST(OpenCvVideoEncoderCalculatorTest, TestMkvVp8Video) {
|
|||
}
|
||||
)pb");
|
||||
std::map<std::string, Packet> input_side_packets;
|
||||
input_side_packets["input_file_path"] = MakePacket<std::string>(
|
||||
file::JoinPath("./",
|
||||
"/mediapipe/calculators/video/"
|
||||
"testdata/format_MKV_VP8_VORBIS.video"));
|
||||
input_side_packets["input_file_path"] =
|
||||
MakePacket<std::string>(file::JoinPath(GetTestDataDir(kTestPackageRoot),
|
||||
"format_MKV_VP8_VORBIS.video"));
|
||||
const std::string output_file_path = "/tmp/tmp_video.mkv";
|
||||
DeletingFile deleting_file(output_file_path, true);
|
||||
input_side_packets["output_file_path"] =
|
||||
|
|
|
@ -33,39 +33,16 @@
|
|||
#include "mediapipe/framework/port/proto_ns.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/port/status_matchers.h"
|
||||
#include "mediapipe/framework/tool/test_util.h"
|
||||
#include "mediapipe/util/tracking/box_tracker.pb.h"
|
||||
#include "mediapipe/util/tracking/tracking.pb.h"
|
||||
|
||||
#ifdef __APPLE__
|
||||
#include <CoreFoundation/CoreFoundation.h>
|
||||
#endif // defined(__APPLE__)
|
||||
|
||||
namespace mediapipe {
|
||||
namespace {
|
||||
using ::testing::FloatNear;
|
||||
using ::testing::Test;
|
||||
|
||||
std::string GetTestDir() {
|
||||
#ifdef __APPLE__
|
||||
char path[1024];
|
||||
CFURLRef bundle_url = CFBundleCopyBundleURL(CFBundleGetMainBundle());
|
||||
CFURLGetFileSystemRepresentation(
|
||||
bundle_url, true, reinterpret_cast<UInt8*>(path), sizeof(path));
|
||||
CFRelease(bundle_url);
|
||||
return mediapipe::file::JoinPath(path, "testdata");
|
||||
#elif defined(__ANDROID__)
|
||||
char path[1024];
|
||||
getcwd(path, sizeof(path));
|
||||
return mediapipe::file::JoinPath(path,
|
||||
"mediapipe/calculators/video/testdata");
|
||||
#else
|
||||
return mediapipe::file::JoinPath(
|
||||
"./",
|
||||
// This should match the path of the output files
|
||||
// of the genrule() that generates test model files.
|
||||
"mediapipe/calculators/video/testdata");
|
||||
#endif // defined(__APPLE__)
|
||||
}
|
||||
constexpr char kTestPackageRoot[] = "mediapipe/calculators/video";
|
||||
|
||||
bool LoadBinaryTestGraph(const std::string& graph_path,
|
||||
CalculatorGraphConfig* config) {
|
||||
|
@ -85,7 +62,7 @@ class TrackingGraphTest : public Test {
|
|||
TrackingGraphTest() {}
|
||||
|
||||
void SetUp() override {
|
||||
test_dir_ = GetTestDir();
|
||||
test_dir_ = mediapipe::GetTestDataDir(kTestPackageRoot);
|
||||
const auto graph_path = file::JoinPath(test_dir_, "tracker.binarypb");
|
||||
ASSERT_TRUE(LoadBinaryTestGraph(graph_path, &config_));
|
||||
|
||||
|
|
|
@ -15,10 +15,10 @@
|
|||
package com.google.mediapipe.examples.facedetection;
|
||||
|
||||
import android.opengl.GLES20;
|
||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||
import com.google.mediapipe.solutioncore.ResultGlRenderer;
|
||||
import com.google.mediapipe.solutions.facedetection.FaceDetectionResult;
|
||||
import com.google.mediapipe.solutions.facedetection.FaceKeypoint;
|
||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
|
|
|
@ -23,9 +23,9 @@ import android.graphics.Color;
|
|||
import android.graphics.Matrix;
|
||||
import android.graphics.Paint;
|
||||
import androidx.appcompat.widget.AppCompatImageView;
|
||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||
import com.google.mediapipe.solutions.facedetection.FaceDetectionResult;
|
||||
import com.google.mediapipe.solutions.facedetection.FaceKeypoint;
|
||||
import com.google.mediapipe.formats.proto.DetectionProto.Detection;
|
||||
|
||||
/** An ImageView implementation for displaying {@link FaceDetectionResult}. */
|
||||
public class FaceDetectionResultImageView extends AppCompatImageView {
|
||||
|
|
|
@ -279,35 +279,45 @@ mediapipe::autoflip::RectF ShiftDetection(
|
|||
}
|
||||
absl::Status UpdateRanges(const SalientRegion& region,
|
||||
const float shift_vertical,
|
||||
const float shift_horizontal, float* xmin,
|
||||
float* xmax, float* ymin, float* ymax) {
|
||||
const float shift_horizontal,
|
||||
const float pad_vertical, const float pad_horizontal,
|
||||
float* xmin, float* xmax, float* ymin, float* ymax) {
|
||||
if (!region.has_location_normalized()) {
|
||||
return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "SalientRegion did not have location normalized set.";
|
||||
}
|
||||
auto location = ShiftDetection(region.location_normalized(), shift_vertical,
|
||||
shift_horizontal);
|
||||
*xmin = fmin(*xmin, location.x());
|
||||
*xmax = fmax(*xmax, location.x() + location.width());
|
||||
*ymin = fmin(*ymin, location.y());
|
||||
*ymax = fmax(*ymax, location.y() + location.height());
|
||||
|
||||
const float x_padding = pad_horizontal * location.width();
|
||||
const float y_padding = pad_vertical * location.height();
|
||||
|
||||
*xmin = fmin(*xmin, location.x() - x_padding);
|
||||
*xmax = fmax(*xmax, location.x() + location.width() + x_padding);
|
||||
*ymin = fmin(*ymin, location.y() - y_padding);
|
||||
*ymax = fmax(*ymax, location.y() + location.height() + y_padding);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
absl::Status UpdateRanges(const mediapipe::Detection& detection,
|
||||
const float shift_vertical,
|
||||
const float shift_horizontal, float* xmin,
|
||||
float* xmax, float* ymin, float* ymax) {
|
||||
const float shift_horizontal,
|
||||
const float pad_vertical, const float pad_horizontal,
|
||||
float* xmin, float* xmax, float* ymin, float* ymax) {
|
||||
RET_CHECK(detection.location_data().format() ==
|
||||
mediapipe::LocationData::RELATIVE_BOUNDING_BOX)
|
||||
<< "Face detection input is lacking required relative_bounding_box()";
|
||||
const auto& location =
|
||||
ShiftDetection(detection.location_data().relative_bounding_box(),
|
||||
shift_vertical, shift_horizontal);
|
||||
*xmin = fmin(*xmin, location.xmin());
|
||||
*xmax = fmax(*xmax, location.xmin() + location.width());
|
||||
*ymin = fmin(*ymin, location.ymin());
|
||||
*ymax = fmax(*ymax, location.ymin() + location.height());
|
||||
|
||||
const float x_padding = pad_horizontal * location.width();
|
||||
const float y_padding = pad_vertical * location.height();
|
||||
|
||||
*xmin = fmin(*xmin, location.xmin() - x_padding);
|
||||
*xmax = fmax(*xmax, location.xmin() + location.width() + x_padding);
|
||||
*ymin = fmin(*ymin, location.ymin() - y_padding);
|
||||
*ymax = fmax(*ymax, location.ymin() + location.height() + y_padding);
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -818,7 +828,9 @@ absl::Status ContentZoomingCalculator::GetDetectionsBox(
|
|||
*only_required_found = true;
|
||||
MP_RETURN_IF_ERROR(UpdateRanges(
|
||||
region, options_.detection_shift_vertical(),
|
||||
options_.detection_shift_horizontal(), xmin, xmax, ymin, ymax));
|
||||
options_.detection_shift_horizontal(),
|
||||
options_.extra_vertical_padding(),
|
||||
options_.extra_horizontal_padding(), xmin, xmax, ymin, ymax));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -864,7 +876,9 @@ absl::Status ContentZoomingCalculator::GetDetectionsBox(
|
|||
*only_required_found = true;
|
||||
MP_RETURN_IF_ERROR(UpdateRanges(
|
||||
detection, options_.detection_shift_vertical(),
|
||||
options_.detection_shift_horizontal(), xmin, xmax, ymin, ymax));
|
||||
options_.detection_shift_horizontal(),
|
||||
options_.extra_vertical_padding(),
|
||||
options_.extra_horizontal_padding(), xmin, xmax, ymin, ymax));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ package mediapipe.autoflip;
|
|||
import "mediapipe/examples/desktop/autoflip/quality/kinematic_path_solver.proto";
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
// NextTag: 19
|
||||
// NextTag: 21
|
||||
message ContentZoomingCalculatorOptions {
|
||||
extend mediapipe.CalculatorOptions {
|
||||
optional ContentZoomingCalculatorOptions ext = 313091992;
|
||||
|
@ -45,12 +45,17 @@ message ContentZoomingCalculatorOptions {
|
|||
optional int64 height = 2;
|
||||
}
|
||||
optional Size target_size = 8;
|
||||
// Amount to shift an input detection as a ratio of the size (positive:
|
||||
|
||||
// Amount to shift an input detection, as a ratio of its size (positive:
|
||||
// down/right, negative: up/left). Use a negative value to increase padding
|
||||
// above/left of an object, positive to increase padding below/right of an
|
||||
// object.
|
||||
// object. (Applies to one side only)
|
||||
optional float detection_shift_vertical = 11 [default = 0.0];
|
||||
optional float detection_shift_horizontal = 12 [default = 0.0];
|
||||
// Amount to pad around an input detection, as a ratio of its size.
|
||||
// (Applies to both sides)
|
||||
optional float extra_vertical_padding = 19 [default = 0.0];
|
||||
optional float extra_horizontal_padding = 20 [default = 0.0];
|
||||
|
||||
// Defines the smallest value in degrees the camera is permitted to zoom.
|
||||
optional float max_zoom_value_deg = 13 [default = 35];
|
||||
|
|
|
@ -35,7 +35,9 @@ objc_library(
|
|||
"CoreMedia",
|
||||
"UIKit",
|
||||
],
|
||||
visibility = ["//mediapipe:__subpackages__"],
|
||||
visibility = [
|
||||
"//mediapipe:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/objc:mediapipe_framework_ios",
|
||||
"//mediapipe/objc:mediapipe_input_sources_ios",
|
||||
|
|
|
@ -115,7 +115,10 @@ mediapipe_proto_library(
|
|||
name = "packet_test_proto",
|
||||
testonly = 1,
|
||||
srcs = ["packet_test.proto"],
|
||||
visibility = ["//mediapipe/framework:__subpackages__"],
|
||||
visibility = [
|
||||
":mediapipe_internal",
|
||||
"//mediapipe/framework:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
mediapipe_proto_library(
|
||||
|
@ -973,6 +976,7 @@ cc_library(
|
|||
],
|
||||
}),
|
||||
visibility = [
|
||||
"//fitbit/research/sensing/mobisense:__subpackages__",
|
||||
"//mediapipe/calculators:__subpackages__",
|
||||
"//mediapipe/framework:__subpackages__",
|
||||
"//mediapipe/framework/port:__pkg__",
|
||||
|
@ -1427,6 +1431,7 @@ cc_test(
|
|||
"//mediapipe/framework/stream_handler:timestamp_align_input_stream_handler",
|
||||
"//mediapipe/framework/tool:sink",
|
||||
"//mediapipe/framework/tool:status_util",
|
||||
"//mediapipe/gpu:graph_support",
|
||||
"@com_google_absl//absl/container:fixed_array",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
|
@ -149,6 +149,7 @@ cc_library(
|
|||
"//mediapipe/framework:calculator_contract",
|
||||
"//mediapipe/framework:output_side_packet",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/tool:type_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -27,41 +27,34 @@
|
|||
#include "mediapipe/framework/calculator_contract.h"
|
||||
#include "mediapipe/framework/output_side_packet.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/tool/type_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace api2 {
|
||||
|
||||
// typeid is not constexpr, but a pointer to this is.
|
||||
template <typename T>
|
||||
size_t get_type_hash() {
|
||||
return typeid(T).hash_code();
|
||||
}
|
||||
|
||||
using type_id_fptr = size_t (*)();
|
||||
|
||||
// This is a base class for various types of port. It is not meant to be used
|
||||
// directly by node code.
|
||||
class PortBase {
|
||||
public:
|
||||
constexpr PortBase(std::size_t tag_size, const char* tag,
|
||||
type_id_fptr get_type_id, bool optional, bool multiple)
|
||||
constexpr PortBase(std::size_t tag_size, const char* tag, TypeId type_id,
|
||||
bool optional, bool multiple)
|
||||
: tag_(tag_size, tag),
|
||||
optional_(optional),
|
||||
multiple_(multiple),
|
||||
type_id_getter_(get_type_id) {}
|
||||
type_id_(type_id) {}
|
||||
|
||||
bool IsOptional() const { return optional_; }
|
||||
bool IsMultiple() const { return multiple_; }
|
||||
const char* Tag() const { return tag_.data(); }
|
||||
|
||||
size_t type_id() const { return type_id_getter_(); }
|
||||
TypeId type_id() const { return type_id_; }
|
||||
|
||||
const const_str tag_;
|
||||
const bool optional_;
|
||||
const bool multiple_;
|
||||
|
||||
protected:
|
||||
type_id_fptr type_id_getter_;
|
||||
TypeId type_id_;
|
||||
};
|
||||
|
||||
// These four base classes are used to distinguish between ports of different
|
||||
|
@ -340,7 +333,7 @@ class PortCommon : public Base {
|
|||
|
||||
template <std::size_t N>
|
||||
explicit constexpr PortCommon(const char (&tag)[N])
|
||||
: Base(N, tag, &get_type_hash<ValueT>, IsOptionalV, IsMultipleV) {}
|
||||
: Base(N, tag, kTypeId<ValueT>, IsOptionalV, IsMultipleV) {}
|
||||
|
||||
using PayloadT = ActualPayloadT<ValueT>;
|
||||
|
||||
|
@ -428,7 +421,7 @@ class SideFallbackT : public Base {
|
|||
|
||||
template <std::size_t N>
|
||||
explicit constexpr SideFallbackT(const char (&tag)[N])
|
||||
: Base(N, tag, &get_type_hash<ValueT>, IsOptionalV, IsMultipleV),
|
||||
: Base(N, tag, kTypeId<ValueT>, IsOptionalV, IsMultipleV),
|
||||
stream_port(tag),
|
||||
side_port(tag) {}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ namespace {
|
|||
|
||||
TEST(PortTest, IntInput) {
|
||||
static constexpr auto port = Input<int>("FOO");
|
||||
EXPECT_EQ(port.type_id(), typeid(int).hash_code());
|
||||
EXPECT_EQ(port.type_id(), kTypeId<int>);
|
||||
}
|
||||
|
||||
TEST(PortTest, OptionalInput) {
|
||||
|
|
|
@ -59,7 +59,7 @@ class CalculatorContract {
|
|||
const CalculatorOptions& Options() const { return node_config_->options(); }
|
||||
|
||||
// Returns the name given to this node.
|
||||
const std::string& GetNodeName() { return node_name_; }
|
||||
const std::string& GetNodeName() const { return node_name_; }
|
||||
|
||||
// Returns the options given to this calculator. Template argument T must
|
||||
// be the type of the protobuf extension message or the protobuf::Any
|
||||
|
|
|
@ -120,10 +120,10 @@ CalculatorGraph::CalculatorGraph()
|
|||
counter_factory_ = absl::make_unique<BasicCounterFactory>();
|
||||
}
|
||||
|
||||
CalculatorGraph::CalculatorGraph(const CalculatorGraphConfig& config)
|
||||
CalculatorGraph::CalculatorGraph(CalculatorGraphConfig config)
|
||||
: CalculatorGraph() {
|
||||
counter_factory_ = absl::make_unique<BasicCounterFactory>();
|
||||
MEDIAPIPE_CHECK_OK(Initialize(config));
|
||||
MEDIAPIPE_CHECK_OK(Initialize(std::move(config)));
|
||||
}
|
||||
|
||||
// Defining the destructor here lets us use incomplete types in the header;
|
||||
|
@ -429,18 +429,17 @@ absl::Status CalculatorGraph::Initialize(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status CalculatorGraph::Initialize(
|
||||
const CalculatorGraphConfig& input_config) {
|
||||
return Initialize(input_config, {});
|
||||
absl::Status CalculatorGraph::Initialize(CalculatorGraphConfig input_config) {
|
||||
return Initialize(std::move(input_config), {});
|
||||
}
|
||||
|
||||
absl::Status CalculatorGraph::Initialize(
|
||||
const CalculatorGraphConfig& input_config,
|
||||
CalculatorGraphConfig input_config,
|
||||
const std::map<std::string, Packet>& side_packets) {
|
||||
auto validated_graph = absl::make_unique<ValidatedGraphConfig>();
|
||||
MP_RETURN_IF_ERROR(validated_graph->Initialize(
|
||||
input_config, /*graph_registry=*/nullptr, /*graph_options=*/nullptr,
|
||||
&service_manager_));
|
||||
std::move(input_config), /*graph_registry=*/nullptr,
|
||||
/*graph_options=*/nullptr, &service_manager_));
|
||||
return Initialize(std::move(validated_graph), side_packets);
|
||||
}
|
||||
|
||||
|
@ -675,6 +674,7 @@ absl::Status CalculatorGraph::PrepareForRun(
|
|||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
MP_RETURN_IF_ERROR(PrepareServices());
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
// TODO: should we do this on each run, or only once?
|
||||
MP_RETURN_IF_ERROR(PrepareGpu());
|
||||
additional_side_packets = MaybeCreateLegacyGpuSidePacket(legacy_sp);
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
@ -1251,7 +1251,9 @@ void CalculatorGraph::Resume() { scheduler_.Resume(); }
|
|||
|
||||
absl::Status CalculatorGraph::SetExecutorInternal(
|
||||
const std::string& name, std::shared_ptr<Executor> executor) {
|
||||
if (!executors_.emplace(name, executor).second) {
|
||||
auto [it, inserted] = executors_.emplace(name, executor);
|
||||
if (!inserted) {
|
||||
if (it->second == executor) return absl::OkStatus();
|
||||
return mediapipe::AlreadyExistsErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "SetExecutor must be called only once for the executor \"" << name
|
||||
<< "\"";
|
||||
|
|
|
@ -119,17 +119,17 @@ class CalculatorGraph {
|
|||
|
||||
// Initializes the graph from its proto description (using Initialize())
|
||||
// and crashes if something goes wrong.
|
||||
explicit CalculatorGraph(const CalculatorGraphConfig& config);
|
||||
explicit CalculatorGraph(CalculatorGraphConfig config);
|
||||
virtual ~CalculatorGraph();
|
||||
|
||||
// Initializes the graph from a its proto description.
|
||||
// side_packets that are provided at this stage are common across all Run()
|
||||
// invocations and could be used to execute PacketGenerators immediately.
|
||||
absl::Status Initialize(const CalculatorGraphConfig& config,
|
||||
absl::Status Initialize(CalculatorGraphConfig config,
|
||||
const std::map<std::string, Packet>& side_packets);
|
||||
|
||||
// Convenience version which does not take side packets.
|
||||
absl::Status Initialize(const CalculatorGraphConfig& config);
|
||||
absl::Status Initialize(CalculatorGraphConfig config);
|
||||
|
||||
// Initializes the CalculatorGraph from the specified graph and subgraph
|
||||
// configs. Template graph and subgraph configs can be specified through
|
||||
|
@ -272,7 +272,6 @@ class CalculatorGraph {
|
|||
absl::Status CloseInputStream(const std::string& stream_name);
|
||||
|
||||
// Closes all the graph input streams.
|
||||
// TODO: deprecate this function in favor of CloseAllPacketSources.
|
||||
absl::Status CloseAllInputStreams();
|
||||
|
||||
// Closes all the graph input streams and source calculator nodes.
|
||||
|
|
|
@ -60,6 +60,7 @@
|
|||
#include "mediapipe/framework/tool/sink.h"
|
||||
#include "mediapipe/framework/tool/status_util.h"
|
||||
#include "mediapipe/framework/type_map.h"
|
||||
#include "mediapipe/gpu/graph_support.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
|
@ -2059,6 +2060,26 @@ TEST(CalculatorGraph, HandlersRun) {
|
|||
input_side_packets.at("unavailable_input_counter2")));
|
||||
}
|
||||
|
||||
TEST(CalculatorGraph, CalculatorGraphConfigCopyElision) {
|
||||
CalculatorGraph graph;
|
||||
CalculatorGraphConfig config =
|
||||
mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"pb(
|
||||
input_stream: 'in'
|
||||
node {
|
||||
calculator: 'PassThroughCalculator'
|
||||
input_stream: 'in'
|
||||
output_stream: 'out'
|
||||
}
|
||||
)pb");
|
||||
// config is consumed and never copied, which avoid copying data.
|
||||
MP_ASSERT_OK(graph.Initialize(std::move(config)));
|
||||
MP_EXPECT_OK(graph.StartRun({}));
|
||||
MP_EXPECT_OK(
|
||||
graph.AddPacketToInputStream("in", MakePacket<int>(1).At(Timestamp(1))));
|
||||
MP_EXPECT_OK(graph.CloseInputStream("in"));
|
||||
MP_EXPECT_OK(graph.WaitUntilDone());
|
||||
}
|
||||
|
||||
// Test that calling SetOffset() in Calculator::Process() results in the
|
||||
// absl::StatusCode::kFailedPrecondition error.
|
||||
TEST(CalculatorGraph, SetOffsetInProcess) {
|
||||
|
|
|
@ -11,10 +11,6 @@
|
|||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Forked from mediapipe/framework/calculator_profile.proto.
|
||||
// The forked proto must remain identical to the original proto and should be
|
||||
// ONLY used by mediapipe open source project.
|
||||
|
||||
syntax = "proto2";
|
||||
|
||||
|
@ -24,6 +20,7 @@ import "mediapipe/framework/calculator.proto";
|
|||
|
||||
option java_package = "com.google.mediapipe.proto";
|
||||
option java_outer_classname = "CalculatorProfileProto";
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
// Stores the profiling information.
|
||||
//
|
||||
|
|
|
@ -88,7 +88,10 @@ cc_library(
|
|||
testonly = True,
|
||||
hdrs = ["message_matchers.h"],
|
||||
# Use this library through "mediapipe/framework/port:gtest_main".
|
||||
visibility = ["//mediapipe/framework/port:__pkg__"],
|
||||
visibility = [
|
||||
"//mediapipe/framework/port:__pkg__",
|
||||
"//third_party/visionai/algorithms/tracking:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:core_proto",
|
||||
"@com_google_googletest//:gtest",
|
||||
|
@ -137,7 +140,6 @@ cc_library(
|
|||
hdrs = ["image_resizer.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
"//mediapipe/framework/port:opencv_imgproc",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
#ifndef MEDIAPIPE_DEPS_IMAGE_RESIZER_H_
|
||||
#define MEDIAPIPE_DEPS_IMAGE_RESIZER_H_
|
||||
|
||||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
|
|
@ -140,9 +140,22 @@ _encode_binary_proto = rule(
|
|||
)
|
||||
|
||||
def encode_binary_proto(name, input, message_type, deps, **kwargs):
|
||||
if type(input) == type("string"):
|
||||
input_label = input
|
||||
textproto_srcs = [input]
|
||||
elif type(input) == type(dict()):
|
||||
# We cannot accept a select, as macros are unable to manipulate selects.
|
||||
input_label = select(input)
|
||||
srcs_dict = dict()
|
||||
for k, v in input.items():
|
||||
srcs_dict[k] = [v]
|
||||
textproto_srcs = select(srcs_dict)
|
||||
else:
|
||||
fail("input should be a string or a dict, got %s" % input)
|
||||
|
||||
_encode_binary_proto(
|
||||
name = name,
|
||||
input = input,
|
||||
input = input_label,
|
||||
message_type = message_type,
|
||||
deps = deps,
|
||||
**kwargs
|
||||
|
|
|
@ -448,6 +448,7 @@ cc_library(
|
|||
srcs =
|
||||
[
|
||||
"tensor.cc",
|
||||
"tensor_ahwb.cc",
|
||||
],
|
||||
hdrs = ["tensor.h"],
|
||||
copts = select({
|
||||
|
@ -463,6 +464,9 @@ cc_library(
|
|||
"-framework MetalKit",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
"//mediapipe:android": [
|
||||
"-landroid",
|
||||
],
|
||||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
|
|
|
@ -19,7 +19,9 @@ package mediapipe;
|
|||
// Joint of a 3D human model (e.g. elbow, knee, wrist). Contains 3D rotation of
|
||||
// the joint and its visibility.
|
||||
message Joint {
|
||||
// Joint rotation in 6D contineous representation.
|
||||
// Joint rotation in 6D contineous representation ordered as
|
||||
// [a1, b1, a2, b2, a3, b3].
|
||||
//
|
||||
// Such representation is more sutable for NN model training and can be
|
||||
// converted to quaternions and Euler angles if needed. Details can be found
|
||||
// in https://arxiv.org/abs/1812.07035.
|
||||
|
|
|
@ -20,6 +20,9 @@
|
|||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/port.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
#include <mach/mach_init.h>
|
||||
|
@ -319,28 +322,41 @@ void Tensor::AllocateOpenGlTexture2d() const {
|
|||
Tensor::OpenGlBufferView Tensor::GetOpenGlBufferReadView() const {
|
||||
LOG_IF(FATAL, valid_ == kValidNone)
|
||||
<< "Tensor must be written prior to read from.";
|
||||
LOG_IF(FATAL, !(valid_ & (kValidCpu | kValidOpenGlBuffer)))
|
||||
<< "Tensor conversion between different GPU resources is not supported "
|
||||
"yet.";
|
||||
LOG_IF(FATAL, !(valid_ & (kValidCpu |
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
kValidAHardwareBuffer |
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
kValidOpenGlBuffer)))
|
||||
<< "Tensor conversion between different GPU resources is not supported.";
|
||||
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
|
||||
AllocateOpenGlBuffer();
|
||||
if (!(valid_ & kValidOpenGlBuffer)) {
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
|
||||
void* ptr =
|
||||
glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(),
|
||||
GL_MAP_INVALIDATE_BUFFER_BIT | GL_MAP_WRITE_BIT);
|
||||
std::memcpy(ptr, cpu_buffer_, bytes());
|
||||
glUnmapBuffer(GL_SHADER_STORAGE_BUFFER);
|
||||
// If the call succeds then AHWB -> SSBO are synchronized so any usage of
|
||||
// the SSBO is correct after this call.
|
||||
if (!InsertAhwbToSsboFence()) {
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
|
||||
void* ptr =
|
||||
glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, bytes(),
|
||||
GL_MAP_INVALIDATE_BUFFER_BIT | GL_MAP_WRITE_BIT);
|
||||
std::memcpy(ptr, cpu_buffer_, bytes());
|
||||
glUnmapBuffer(GL_SHADER_STORAGE_BUFFER);
|
||||
}
|
||||
valid_ |= kValidOpenGlBuffer;
|
||||
}
|
||||
return {opengl_buffer_, std::move(lock)};
|
||||
return {opengl_buffer_, std::move(lock),
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
&ssbo_read_
|
||||
#else
|
||||
nullptr
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
};
|
||||
}
|
||||
|
||||
Tensor::OpenGlBufferView Tensor::GetOpenGlBufferWriteView() const {
|
||||
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
|
||||
AllocateOpenGlBuffer();
|
||||
valid_ = kValidOpenGlBuffer;
|
||||
return {opengl_buffer_, std::move(lock)};
|
||||
return {opengl_buffer_, std::move(lock), nullptr};
|
||||
}
|
||||
|
||||
void Tensor::AllocateOpenGlBuffer() const {
|
||||
|
@ -349,7 +365,10 @@ void Tensor::AllocateOpenGlBuffer() const {
|
|||
LOG_IF(FATAL, !gl_context_) << "GlContext is not bound to the thread.";
|
||||
glGenBuffers(1, &opengl_buffer_);
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, opengl_buffer_);
|
||||
glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY);
|
||||
if (!AllocateAhwbMapToSsbo()) {
|
||||
glBufferData(GL_SHADER_STORAGE_BUFFER, bytes(), NULL, GL_STREAM_COPY);
|
||||
}
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
|
||||
}
|
||||
}
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
|
@ -377,6 +396,8 @@ void Tensor::Move(Tensor* src) {
|
|||
src->metal_buffer_ = nil;
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
|
||||
MoveAhwbStuff(src);
|
||||
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
gl_context_ = std::move(src->gl_context_);
|
||||
frame_buffer_ = src->frame_buffer_;
|
||||
|
@ -395,27 +416,31 @@ void Tensor::Move(Tensor* src) {
|
|||
Tensor::Tensor(ElementType element_type, const Shape& shape)
|
||||
: element_type_(element_type), shape_(shape) {}
|
||||
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
void Tensor::Invalidate() {
|
||||
absl::MutexLock lock(&view_mutex_);
|
||||
// If memory is allocated and not owned by the metal buffer.
|
||||
// TODO: Re-design cpu buffer memory management.
|
||||
if (cpu_buffer_ && !metal_buffer_) {
|
||||
DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes()));
|
||||
}
|
||||
metal_buffer_ = nil;
|
||||
cpu_buffer_ = nullptr;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void Tensor::Invalidate() {
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
GLuint cleanup_gl_tex = GL_INVALID_INDEX;
|
||||
GLuint cleanup_gl_fb = GL_INVALID_INDEX;
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
GLuint cleanup_gl_buf = GL_INVALID_INDEX;
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
{
|
||||
absl::MutexLock lock(&view_mutex_);
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
// If memory is allocated and not owned by the metal buffer.
|
||||
// TODO: Re-design cpu buffer memory management.
|
||||
if (cpu_buffer_ && !metal_buffer_) {
|
||||
DeallocateVirtualMemory(cpu_buffer_, AlignToPageSize(bytes()));
|
||||
}
|
||||
metal_buffer_ = nil;
|
||||
#else
|
||||
if (cpu_buffer_) {
|
||||
free(cpu_buffer_);
|
||||
}
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
cpu_buffer_ = nullptr;
|
||||
ReleaseAhwbStuff();
|
||||
|
||||
// Don't need to wait for the resource to be deleted bacause if will be
|
||||
// released on last reference deletion inside the OpenGL driver.
|
||||
|
@ -429,28 +454,44 @@ void Tensor::Invalidate() {
|
|||
}
|
||||
// Do not hold the view mutex while invoking GlContext::RunWithoutWaiting,
|
||||
// since that method may acquire the context's own lock.
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX ||
|
||||
cleanup_gl_buf != GL_INVALID_INDEX)
|
||||
gl_context_->RunWithoutWaiting([cleanup_gl_tex, cleanup_gl_fb
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
,
|
||||
cleanup_gl_buf
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
]() {
|
||||
if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX ||
|
||||
cleanup_gl_buf != GL_INVALID_INDEX) {
|
||||
gl_context_->RunWithoutWaiting(
|
||||
[cleanup_gl_tex, cleanup_gl_fb, cleanup_gl_buf]() {
|
||||
glDeleteTextures(1, &cleanup_gl_tex);
|
||||
glDeleteFramebuffers(1, &cleanup_gl_fb);
|
||||
glDeleteBuffers(1, &cleanup_gl_buf);
|
||||
});
|
||||
}
|
||||
#elif MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
if (cleanup_gl_tex != GL_INVALID_INDEX || cleanup_gl_fb != GL_INVALID_INDEX) {
|
||||
gl_context_->RunWithoutWaiting([cleanup_gl_tex, cleanup_gl_fb]() {
|
||||
glDeleteTextures(1, &cleanup_gl_tex);
|
||||
glDeleteFramebuffers(1, &cleanup_gl_fb);
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
glDeleteBuffers(1, &cleanup_gl_buf);
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
});
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
}
|
||||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_31
|
||||
|
||||
if (cpu_buffer_) {
|
||||
free(cpu_buffer_);
|
||||
}
|
||||
cpu_buffer_ = nullptr;
|
||||
}
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
|
||||
Tensor::CpuReadView Tensor::GetCpuReadView() const {
|
||||
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
|
||||
LOG_IF(FATAL, valid_ == kValidNone)
|
||||
<< "Tensor must be written prior to read from.";
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
void* ptr = MapAhwbToCpuRead();
|
||||
if (ptr) {
|
||||
valid_ |= kValidCpu;
|
||||
return {ptr, ahwb_, nullptr, std::move(lock)};
|
||||
}
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
AllocateCpuBuffer();
|
||||
if (!(valid_ & kValidCpu)) {
|
||||
// GPU-to-CPU synchronization and read-back.
|
||||
|
@ -512,18 +553,33 @@ Tensor::CpuReadView Tensor::GetCpuReadView() const {
|
|||
#endif // MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
valid_ |= kValidCpu;
|
||||
}
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
return {cpu_buffer_, nullptr, nullptr, std::move(lock)};
|
||||
#else
|
||||
return {cpu_buffer_, std::move(lock)};
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
}
|
||||
|
||||
Tensor::CpuWriteView Tensor::GetCpuWriteView() const {
|
||||
auto lock = absl::make_unique<absl::MutexLock>(&view_mutex_);
|
||||
AllocateCpuBuffer();
|
||||
valid_ = kValidCpu;
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
void* ptr = MapAhwbToCpuWrite();
|
||||
if (ptr) {
|
||||
return {ptr, ahwb_, &fence_fd_, std::move(lock)};
|
||||
}
|
||||
return {cpu_buffer_, nullptr, nullptr, std::move(lock)};
|
||||
#else
|
||||
return {cpu_buffer_, std::move(lock)};
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
}
|
||||
|
||||
void Tensor::AllocateCpuBuffer() const {
|
||||
if (!cpu_buffer_) {
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
if (AllocateAHardwareBuffer()) return;
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#if MEDIAPIPE_METAL_ENABLED
|
||||
cpu_buffer_ = AllocateVirtualMemory(bytes());
|
||||
#else
|
||||
|
@ -532,4 +588,10 @@ void Tensor::AllocateCpuBuffer() const {
|
|||
}
|
||||
}
|
||||
|
||||
void Tensor::SetPreferredStorageType(StorageType type) {
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
use_ahwb_ = type == StorageType::kAhwb;
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -30,6 +30,16 @@
|
|||
#import <Metal/Metal.h>
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#if __ANDROID_API__ < 26
|
||||
#error MEDIAPIPE_TENSOR_USE_AHWB requires NDK version 26 or higher to be specified.
|
||||
#endif // __ANDROID_API__ < 26
|
||||
#include <android/hardware_buffer.h>
|
||||
|
||||
#include "third_party/GL/gl/include/EGL/egl.h"
|
||||
#include "third_party/GL/gl/include/EGL/eglext.h"
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
#include "mediapipe/gpu/gl_context.h"
|
||||
|
@ -108,14 +118,37 @@ class Tensor {
|
|||
return static_cast<typename std::tuple_element<
|
||||
std::is_const<T>::value, std::tuple<P*, const P*> >::type>(buffer_);
|
||||
}
|
||||
CpuView(CpuView&& src) : View(std::move(src)), buffer_(src.buffer_) {
|
||||
src.buffer_ = nullptr;
|
||||
CpuView(CpuView&& src) : View(std::move(src)) {
|
||||
buffer_ = std::exchange(src.buffer_, nullptr);
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
ahwb_ = std::exchange(src.ahwb_, nullptr);
|
||||
fence_fd_ = std::exchange(src.fence_fd_, nullptr);
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
}
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
~CpuView() {
|
||||
if (ahwb_) {
|
||||
auto error = AHardwareBuffer_unlock(ahwb_, fence_fd_);
|
||||
CHECK(error == 0) << "AHardwareBuffer_unlock " << error;
|
||||
}
|
||||
}
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
protected:
|
||||
friend class Tensor;
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
CpuView(T* buffer, AHardwareBuffer* ahwb, int* fence_fd,
|
||||
std::unique_ptr<absl::MutexLock>&& lock)
|
||||
: View(std::move(lock)),
|
||||
buffer_(buffer),
|
||||
fence_fd_(fence_fd),
|
||||
ahwb_(ahwb) {}
|
||||
AHardwareBuffer* ahwb_;
|
||||
int* fence_fd_;
|
||||
#else
|
||||
CpuView(T* buffer, std::unique_ptr<absl::MutexLock>&& lock)
|
||||
: View(std::move(lock)), buffer_(buffer) {}
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
T* buffer_;
|
||||
};
|
||||
using CpuReadView = CpuView<const void>;
|
||||
|
@ -150,6 +183,60 @@ class Tensor {
|
|||
MtlBufferView GetMtlBufferWriteView(id<MTLDevice> device) const;
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
class AHardwareBufferView : public View {
|
||||
public:
|
||||
AHardwareBuffer* handle() const { return handle_; }
|
||||
AHardwareBufferView(AHardwareBufferView&& src) : View(std::move(src)) {
|
||||
handle_ = std::exchange(src.handle_, nullptr);
|
||||
file_descriptor_ = src.file_descriptor_;
|
||||
fence_fd_ = std::exchange(src.fence_fd_, nullptr);
|
||||
ahwb_written_ = std::exchange(src.ahwb_written_, nullptr);
|
||||
release_callback_ = std::exchange(src.release_callback_, nullptr);
|
||||
}
|
||||
int file_descriptor() const { return file_descriptor_; }
|
||||
void SetReadingFinishedFunc(std::function<bool()>&& func) {
|
||||
CHECK(ahwb_written_)
|
||||
<< "AHWB write view can't accept 'reading finished callback'";
|
||||
*ahwb_written_ = std::move(func);
|
||||
}
|
||||
void SetWritingFinishedFD(int fd) {
|
||||
CHECK(fence_fd_)
|
||||
<< "AHWB read view can't accept 'writing finished file descriptor'";
|
||||
*fence_fd_ = fd;
|
||||
}
|
||||
// The function is called when the tensor is released.
|
||||
void SetReleaseCallback(std::function<void()> callback) {
|
||||
*release_callback_ = std::move(callback);
|
||||
}
|
||||
|
||||
protected:
|
||||
friend class Tensor;
|
||||
AHardwareBufferView(AHardwareBuffer* handle, int file_descriptor,
|
||||
int* fence_fd, std::function<bool()>* ahwb_written,
|
||||
std::function<void()>* release_callback,
|
||||
std::unique_ptr<absl::MutexLock>&& lock)
|
||||
: View(std::move(lock)),
|
||||
handle_(handle),
|
||||
file_descriptor_(file_descriptor),
|
||||
fence_fd_(fence_fd),
|
||||
ahwb_written_(ahwb_written),
|
||||
release_callback_(release_callback) {}
|
||||
AHardwareBuffer* handle_;
|
||||
int file_descriptor_;
|
||||
// The view sets some Tensor's fields. The view is released prior to tensor.
|
||||
int* fence_fd_;
|
||||
std::function<bool()>* ahwb_written_;
|
||||
std::function<void()>* release_callback_;
|
||||
};
|
||||
AHardwareBufferView GetAHardwareBufferReadView() const;
|
||||
// size_alignment is an optional argument to tell the API to allocate
|
||||
// a buffer that is padded to multiples of size_alignment bytes.
|
||||
// size_alignment must be power of 2, i.e. 2, 4, 8, 16, 64, etc.
|
||||
// If size_alignment is 0, then the buffer will not be padded.
|
||||
AHardwareBufferView GetAHardwareBufferWriteView(int size_alignment = 0) const;
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
// TODO: Use GlTextureView instead.
|
||||
// Only float32 textures are supported with 1/2/3/4 depths.
|
||||
|
@ -188,16 +275,23 @@ class Tensor {
|
|||
class OpenGlBufferView : public View {
|
||||
public:
|
||||
GLuint name() const { return name_; }
|
||||
OpenGlBufferView(OpenGlBufferView&& src)
|
||||
: View(std::move(src)), name_(src.name_) {
|
||||
src.name_ = GL_INVALID_INDEX;
|
||||
OpenGlBufferView(OpenGlBufferView&& src) : View(std::move(src)) {
|
||||
name_ = std::exchange(src.name_, GL_INVALID_INDEX);
|
||||
ssbo_read_ = std::exchange(src.ssbo_read_, nullptr);
|
||||
}
|
||||
~OpenGlBufferView() {
|
||||
if (ssbo_read_) {
|
||||
*ssbo_read_ = glFenceSync(GL_SYNC_GPU_COMMANDS_COMPLETE, 0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
friend class Tensor;
|
||||
OpenGlBufferView(GLuint name, std::unique_ptr<absl::MutexLock>&& lock)
|
||||
: View(std::move(lock)), name_(name) {}
|
||||
OpenGlBufferView(GLuint name, std::unique_ptr<absl::MutexLock>&& lock,
|
||||
GLsync* ssbo_read)
|
||||
: View(std::move(lock)), name_(name), ssbo_read_(ssbo_read) {}
|
||||
GLuint name_;
|
||||
GLsync* ssbo_read_;
|
||||
};
|
||||
// A valid OpenGL context must be bound to the calling thread due to possible
|
||||
// GPU resource allocation.
|
||||
|
@ -223,16 +317,26 @@ class Tensor {
|
|||
}
|
||||
int bytes() const { return shape_.num_elements() * element_size(); }
|
||||
|
||||
bool ready_on_cpu() const { return valid_ & kValidCpu; }
|
||||
bool ready_on_cpu() const {
|
||||
return valid_ & (kValidAHardwareBuffer | kValidCpu);
|
||||
}
|
||||
bool ready_on_gpu() const {
|
||||
return valid_ &
|
||||
(kValidMetalBuffer | kValidOpenGlBuffer | kValidOpenGlTexture2d);
|
||||
return valid_ & (kValidMetalBuffer | kValidOpenGlBuffer |
|
||||
kValidAHardwareBuffer | kValidOpenGlTexture2d);
|
||||
}
|
||||
bool ready_as_metal_buffer() const { return valid_ & kValidMetalBuffer; }
|
||||
bool ready_as_opengl_buffer() const { return valid_ & kValidOpenGlBuffer; }
|
||||
bool ready_as_opengl_buffer() const {
|
||||
return valid_ & (kValidAHardwareBuffer | kValidOpenGlBuffer);
|
||||
}
|
||||
bool ready_as_opengl_texture_2d() const {
|
||||
return valid_ & kValidOpenGlTexture2d;
|
||||
}
|
||||
// Sets the type of underlying resource that is going to be allocated.
|
||||
enum class StorageType {
|
||||
kDefault,
|
||||
kAhwb,
|
||||
};
|
||||
static void SetPreferredStorageType(StorageType type);
|
||||
|
||||
private:
|
||||
void Move(Tensor*);
|
||||
|
@ -248,6 +352,7 @@ class Tensor {
|
|||
kValidMetalBuffer = 1 << 1,
|
||||
kValidOpenGlBuffer = 1 << 2,
|
||||
kValidOpenGlTexture2d = 1 << 3,
|
||||
kValidAHardwareBuffer = 1 << 5,
|
||||
};
|
||||
// A list of resource which are currently allocated and synchronized between
|
||||
// each-other: valid_ = kValidCpu | kValidMetalBuffer;
|
||||
|
@ -264,6 +369,34 @@ class Tensor {
|
|||
void AllocateMtlBuffer(id<MTLDevice> device) const;
|
||||
#endif // MEDIAPIPE_METAL_ENABLED
|
||||
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
mutable AHardwareBuffer* ahwb_ = nullptr;
|
||||
// Signals when GPU finished writing into SSBO so AHWB can be used then. Or
|
||||
// signals when writing into AHWB has been finished so GPU can read from SSBO.
|
||||
// Sync and FD are bound together.
|
||||
mutable EGLSyncKHR fence_sync_ = EGL_NO_SYNC_KHR;
|
||||
// This FD signals when the writing into the SSBO has been finished.
|
||||
mutable int ssbo_written_ = -1;
|
||||
// An externally set FD that is wrapped with the EGL sync then to synchronize
|
||||
// AHWB -> OpenGL SSBO.
|
||||
mutable int fence_fd_ = -1;
|
||||
// Reading from SSBO has been finished so SSBO can be released.
|
||||
mutable GLsync ssbo_read_ = 0;
|
||||
// An externally set function that signals when it is safe to release AHWB.
|
||||
mutable std::function<bool()> ahwb_written_;
|
||||
mutable std::function<void()> release_callback_;
|
||||
bool AllocateAHardwareBuffer(int size_alignment = 0) const;
|
||||
void CreateEglSyncAndFd() const;
|
||||
// Use Ahwb for other views: OpenGL / CPU buffer.
|
||||
static inline bool use_ahwb_ = false;
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
bool AllocateAhwbMapToSsbo() const;
|
||||
bool InsertAhwbToSsboFence() const;
|
||||
void MoveAhwbStuff(Tensor* src);
|
||||
void ReleaseAhwbStuff();
|
||||
void* MapAhwbToCpuRead() const;
|
||||
void* MapAhwbToCpuWrite() const;
|
||||
|
||||
#if MEDIAPIPE_OPENGL_ES_VERSION >= MEDIAPIPE_OPENGL_ES_30
|
||||
mutable std::shared_ptr<mediapipe::GlContext> gl_context_;
|
||||
mutable GLuint opengl_texture2d_ = GL_INVALID_INDEX;
|
||||
|
|
382
mediapipe/framework/formats/tensor_ahwb.cc
Normal file
382
mediapipe/framework/formats/tensor_ahwb.cc
Normal file
|
@ -0,0 +1,382 @@
|
|||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/port.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
#include "third_party/GL/gl/include/EGL/egl.h"
|
||||
#include "third_party/GL/gl/include/EGL/eglext.h"
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
namespace mediapipe {
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
namespace {
|
||||
PFNGLBUFFERSTORAGEEXTERNALEXTPROC glBufferStorageExternalEXT;
|
||||
PFNEGLGETNATIVECLIENTBUFFERANDROIDPROC eglGetNativeClientBufferANDROID;
|
||||
PFNEGLDUPNATIVEFENCEFDANDROIDPROC eglDupNativeFenceFDANDROID;
|
||||
PFNEGLCREATESYNCKHRPROC eglCreateSyncKHR;
|
||||
PFNEGLWAITSYNCKHRPROC eglWaitSyncKHR;
|
||||
PFNEGLCLIENTWAITSYNCKHRPROC eglClientWaitSyncKHR;
|
||||
PFNEGLDESTROYSYNCKHRPROC eglDestroySyncKHR;
|
||||
|
||||
bool IsGlSupported() {
|
||||
static const bool extensions_allowed = [] {
|
||||
eglGetNativeClientBufferANDROID =
|
||||
reinterpret_cast<PFNEGLGETNATIVECLIENTBUFFERANDROIDPROC>(
|
||||
eglGetProcAddress("eglGetNativeClientBufferANDROID"));
|
||||
glBufferStorageExternalEXT =
|
||||
reinterpret_cast<PFNGLBUFFERSTORAGEEXTERNALEXTPROC>(
|
||||
eglGetProcAddress("glBufferStorageExternalEXT"));
|
||||
eglDupNativeFenceFDANDROID =
|
||||
reinterpret_cast<PFNEGLDUPNATIVEFENCEFDANDROIDPROC>(
|
||||
eglGetProcAddress("eglDupNativeFenceFDANDROID"));
|
||||
eglCreateSyncKHR = reinterpret_cast<PFNEGLCREATESYNCKHRPROC>(
|
||||
eglGetProcAddress("eglCreateSyncKHR"));
|
||||
eglWaitSyncKHR = reinterpret_cast<PFNEGLWAITSYNCKHRPROC>(
|
||||
eglGetProcAddress("eglWaitSyncKHR"));
|
||||
eglClientWaitSyncKHR = reinterpret_cast<PFNEGLCLIENTWAITSYNCKHRPROC>(
|
||||
eglGetProcAddress("eglClientWaitSyncKHR"));
|
||||
eglDestroySyncKHR = reinterpret_cast<PFNEGLDESTROYSYNCKHRPROC>(
|
||||
eglGetProcAddress("eglDestroySyncKHR"));
|
||||
return eglClientWaitSyncKHR && eglWaitSyncKHR &&
|
||||
eglGetNativeClientBufferANDROID && glBufferStorageExternalEXT &&
|
||||
eglCreateSyncKHR && eglDupNativeFenceFDANDROID && eglDestroySyncKHR;
|
||||
}();
|
||||
return extensions_allowed;
|
||||
}
|
||||
|
||||
absl::Status MapAHardwareBufferToGlBuffer(AHardwareBuffer* handle, size_t size,
|
||||
GLuint name) {
|
||||
if (!IsGlSupported()) {
|
||||
return absl::UnknownError(
|
||||
"No GL extension functions found to bind AHardwareBuffer and "
|
||||
"OpenGL buffer");
|
||||
}
|
||||
EGLClientBuffer native_buffer = eglGetNativeClientBufferANDROID(handle);
|
||||
if (!native_buffer) {
|
||||
return absl::UnknownError("Can't get native buffer");
|
||||
}
|
||||
glBufferStorageExternalEXT(GL_SHADER_STORAGE_BUFFER, 0, size, native_buffer,
|
||||
GL_MAP_READ_BIT | GL_MAP_WRITE_BIT |
|
||||
GL_MAP_COHERENT_BIT_EXT |
|
||||
GL_MAP_PERSISTENT_BIT_EXT);
|
||||
if (glGetError() == GL_NO_ERROR) {
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::InternalError("Error in glBufferStorageExternalEXT");
|
||||
}
|
||||
}
|
||||
|
||||
static inline int AlignedToPowerOf2(int value, int alignment) {
|
||||
// alignment must be a power of 2
|
||||
return ((value - 1) | (alignment - 1)) + 1;
|
||||
}
|
||||
|
||||
// This class keeps tensor's resources while the tensor is in use on GPU or TPU
|
||||
// but is already released on CPU. When a regular OpenGL buffer is bound to the
|
||||
// GPU queue for execution and released on client side then the buffer is still
|
||||
// not released because is being used by GPU. OpenGL driver keeps traking of
|
||||
// that. When OpenGL buffer is build on top of AHWB then the traking is done
|
||||
// with the DeleyedRelease which, actually, keeps record of all AHWBs allocated
|
||||
// and releases each of them if already used. EGL/GL fences are used to check
|
||||
// the status of a buffer.
|
||||
class DelayedReleaser {
|
||||
public:
|
||||
// Non-copyable
|
||||
DelayedReleaser(const DelayedReleaser&) = delete;
|
||||
DelayedReleaser& operator=(const DelayedReleaser&) = delete;
|
||||
// Non-movable
|
||||
DelayedReleaser(DelayedReleaser&&) = delete;
|
||||
DelayedReleaser& operator=(DelayedReleaser&&) = delete;
|
||||
|
||||
static void Add(AHardwareBuffer* ahwb, GLuint opengl_buffer,
|
||||
EGLSyncKHR ssbo_sync, GLsync ssbo_read,
|
||||
std::function<bool()>&& ahwb_written,
|
||||
std::shared_ptr<mediapipe::GlContext> gl_context,
|
||||
std::function<void()>&& callback) {
|
||||
static absl::Mutex mutex;
|
||||
absl::MutexLock lock(&mutex);
|
||||
// Using `new` to access a non-public constructor.
|
||||
to_release_.emplace_back(absl::WrapUnique(new DelayedReleaser(
|
||||
ahwb, opengl_buffer, ssbo_sync, ssbo_read, std::move(ahwb_written),
|
||||
gl_context, std::move(callback))));
|
||||
for (auto it = to_release_.begin(); it != to_release_.end();) {
|
||||
if ((*it)->IsSignaled()) {
|
||||
it = to_release_.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
~DelayedReleaser() {
|
||||
AHardwareBuffer_release(ahwb_);
|
||||
if (release_callback_) release_callback_();
|
||||
}
|
||||
|
||||
bool IsSignaled() {
|
||||
CHECK(!(ssbo_read_ && ahwb_written_))
|
||||
<< "ssbo_read_ and ahwb_written_ cannot both be set";
|
||||
if (ahwb_written_) {
|
||||
if (!ahwb_written_()) return false;
|
||||
gl_context_->Run([this]() {
|
||||
if (fence_sync_ != EGL_NO_SYNC_KHR && IsGlSupported()) {
|
||||
auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY);
|
||||
if (egl_display != EGL_NO_DISPLAY) {
|
||||
eglDestroySyncKHR(egl_display, fence_sync_);
|
||||
}
|
||||
fence_sync_ = EGL_NO_SYNC_KHR;
|
||||
}
|
||||
glDeleteBuffers(1, &opengl_buffer_);
|
||||
opengl_buffer_ = GL_INVALID_INDEX;
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
gl_context_->Run([this]() {
|
||||
if (ssbo_read_ != 0) {
|
||||
GLenum status = glClientWaitSync(ssbo_read_, 0,
|
||||
/* timeout ns = */ 0);
|
||||
if (status != GL_CONDITION_SATISFIED && status != GL_ALREADY_SIGNALED) {
|
||||
return;
|
||||
}
|
||||
glDeleteSync(ssbo_read_);
|
||||
ssbo_read_ = 0;
|
||||
|
||||
// Don't wait on ssbo_sync because it is ahead of ssbo_read_sync.
|
||||
if (fence_sync_ != EGL_NO_SYNC_KHR && IsGlSupported()) {
|
||||
auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY);
|
||||
if (egl_display != EGL_NO_DISPLAY) {
|
||||
eglDestroySyncKHR(egl_display, fence_sync_);
|
||||
}
|
||||
}
|
||||
fence_sync_ = EGL_NO_SYNC_KHR;
|
||||
|
||||
glDeleteBuffers(1, &opengl_buffer_);
|
||||
opengl_buffer_ = GL_INVALID_INDEX;
|
||||
}
|
||||
});
|
||||
return opengl_buffer_ == GL_INVALID_INDEX;
|
||||
}
|
||||
|
||||
protected:
|
||||
AHardwareBuffer* ahwb_;
|
||||
GLuint opengl_buffer_;
|
||||
// TODO: use wrapper instead.
|
||||
EGLSyncKHR fence_sync_;
|
||||
// TODO: use wrapper instead.
|
||||
GLsync ssbo_read_;
|
||||
std::function<bool()> ahwb_written_;
|
||||
std::shared_ptr<mediapipe::GlContext> gl_context_;
|
||||
std::function<void()> release_callback_;
|
||||
static inline std::deque<std::unique_ptr<DelayedReleaser>> to_release_;
|
||||
|
||||
DelayedReleaser(AHardwareBuffer* ahwb, GLuint opengl_buffer,
|
||||
EGLSyncKHR fence_sync, GLsync ssbo_read,
|
||||
std::function<bool()>&& ahwb_written,
|
||||
std::shared_ptr<mediapipe::GlContext> gl_context,
|
||||
std::function<void()>&& callback)
|
||||
: ahwb_(ahwb),
|
||||
opengl_buffer_(opengl_buffer),
|
||||
fence_sync_(fence_sync),
|
||||
ssbo_read_(ssbo_read),
|
||||
ahwb_written_(std::move(ahwb_written)),
|
||||
gl_context_(gl_context),
|
||||
release_callback_(std::move(callback)) {}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Tensor::AHardwareBufferView Tensor::GetAHardwareBufferReadView() const {
|
||||
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
|
||||
CHECK(valid_ != kValidNone) << "Tensor must be written prior to read from.";
|
||||
CHECK(!(valid_ & kValidOpenGlTexture2d))
|
||||
<< "Tensor conversion between OpenGL texture and AHardwareBuffer is not "
|
||||
"supported.";
|
||||
CHECK(ahwb_ || !(valid_ & kValidOpenGlBuffer))
|
||||
<< "Interoperability bettween OpenGL buffer and AHardwareBuffer is not "
|
||||
"supported on targe system.";
|
||||
CHECK(AllocateAHardwareBuffer())
|
||||
<< "AHardwareBuffer is not supported on the target system.";
|
||||
valid_ |= kValidAHardwareBuffer;
|
||||
if (valid_ & kValidOpenGlBuffer) CreateEglSyncAndFd();
|
||||
return {ahwb_,
|
||||
ssbo_written_,
|
||||
&fence_fd_, // The FD is created for SSBO -> AHWB synchronization.
|
||||
&ahwb_written_, // Filled by SetReadingFinishedFunc.
|
||||
&release_callback_,
|
||||
std::move(lock)};
|
||||
}
|
||||
|
||||
void Tensor::CreateEglSyncAndFd() const {
|
||||
gl_context_->Run([this]() {
|
||||
if (IsGlSupported()) {
|
||||
auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY);
|
||||
if (egl_display != EGL_NO_DISPLAY) {
|
||||
fence_sync_ = eglCreateSyncKHR(egl_display,
|
||||
EGL_SYNC_NATIVE_FENCE_ANDROID, nullptr);
|
||||
if (fence_sync_ != EGL_NO_SYNC_KHR) {
|
||||
ssbo_written_ = eglDupNativeFenceFDANDROID(egl_display, fence_sync_);
|
||||
if (ssbo_written_ == -1) {
|
||||
eglDestroySyncKHR(egl_display, fence_sync_);
|
||||
fence_sync_ = EGL_NO_SYNC_KHR;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Can't use Sync object.
|
||||
if (fence_sync_ == EGL_NO_SYNC_KHR) glFinish();
|
||||
});
|
||||
}
|
||||
|
||||
Tensor::AHardwareBufferView Tensor::GetAHardwareBufferWriteView(
|
||||
int size_alignment) const {
|
||||
auto lock(absl::make_unique<absl::MutexLock>(&view_mutex_));
|
||||
CHECK(AllocateAHardwareBuffer(size_alignment))
|
||||
<< "AHardwareBuffer is not supported on the target system.";
|
||||
valid_ = kValidAHardwareBuffer;
|
||||
return {ahwb_,
|
||||
/*ssbo_written=*/-1,
|
||||
&fence_fd_, // For SetWritingFinishedFD.
|
||||
/*ahwb_written=*/nullptr, // The lifetime is managed by SSBO.
|
||||
&release_callback_,
|
||||
std::move(lock)};
|
||||
}
|
||||
|
||||
bool Tensor::AllocateAHardwareBuffer(int size_alignment) const {
|
||||
if (!use_ahwb_) return false;
|
||||
if (ahwb_ == nullptr) {
|
||||
AHardwareBuffer_Desc desc = {};
|
||||
if (size_alignment == 0) {
|
||||
desc.width = bytes();
|
||||
} else {
|
||||
// We expect allocations to be page-aligned, implicitly satisfying any
|
||||
// requirements from Edge TPU. No need to add a check for this,
|
||||
// since Edge TPU will check for us.
|
||||
desc.width = AlignedToPowerOf2(bytes(), size_alignment);
|
||||
}
|
||||
desc.height = 1;
|
||||
desc.layers = 1;
|
||||
desc.format = AHARDWAREBUFFER_FORMAT_BLOB;
|
||||
desc.usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN |
|
||||
AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
|
||||
AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER;
|
||||
return AHardwareBuffer_allocate(&desc, &ahwb_) == 0;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Tensor::AllocateAhwbMapToSsbo() const {
|
||||
if (AllocateAHardwareBuffer()) {
|
||||
if (MapAHardwareBufferToGlBuffer(ahwb_, bytes(), opengl_buffer_).ok()) {
|
||||
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
|
||||
return true;
|
||||
}
|
||||
// Unable to make OpenGL <-> AHWB binding. Use regular SSBO instead.
|
||||
AHardwareBuffer_release(ahwb_);
|
||||
ahwb_ = nullptr;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// SSBO is created on top of AHWB. A fence is inserted into the GPU queue before
|
||||
// the GPU task that is going to read from the SSBO. When the writing into AHWB
|
||||
// is finished then the GPU reads from the SSBO.
|
||||
bool Tensor::InsertAhwbToSsboFence() const {
|
||||
if (!ahwb_) return false;
|
||||
if (fence_fd_ != -1) {
|
||||
// Can't wait for FD to be signaled on GPU.
|
||||
// TODO: wait on CPU instead.
|
||||
if (!IsGlSupported()) return true;
|
||||
|
||||
// Server-side fence.
|
||||
auto egl_display = eglGetDisplay(EGL_DEFAULT_DISPLAY);
|
||||
if (egl_display == EGL_NO_DISPLAY) return true;
|
||||
EGLint sync_attribs[] = {EGL_SYNC_NATIVE_FENCE_FD_ANDROID,
|
||||
(EGLint)fence_fd_, EGL_NONE};
|
||||
fence_sync_ = eglCreateSyncKHR(egl_display, EGL_SYNC_NATIVE_FENCE_ANDROID,
|
||||
sync_attribs);
|
||||
if (fence_sync_ != EGL_NO_SYNC_KHR) {
|
||||
eglWaitSyncKHR(egl_display, fence_sync_, 0);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Tensor::MoveAhwbStuff(Tensor* src) {
|
||||
ahwb_ = std::exchange(src->ahwb_, nullptr);
|
||||
fence_sync_ = std::exchange(src->fence_sync_, EGL_NO_SYNC_KHR);
|
||||
ssbo_read_ = std::exchange(src->ssbo_read_, static_cast<GLsync>(0));
|
||||
ssbo_written_ = std::exchange(src->ssbo_written_, -1);
|
||||
fence_fd_ = std::exchange(src->fence_fd_, -1);
|
||||
ahwb_written_ = std::move(src->ahwb_written_);
|
||||
release_callback_ = std::move(src->release_callback_);
|
||||
}
|
||||
|
||||
void Tensor::ReleaseAhwbStuff() {
|
||||
if (fence_fd_ != -1) {
|
||||
close(fence_fd_);
|
||||
fence_fd_ = -1;
|
||||
}
|
||||
if (ahwb_) {
|
||||
if (ssbo_read_ != 0 || fence_sync_ != EGL_NO_SYNC_KHR) {
|
||||
if (ssbo_written_ != -1) close(ssbo_written_);
|
||||
DelayedReleaser::Add(ahwb_, opengl_buffer_, fence_sync_, ssbo_read_,
|
||||
std::move(ahwb_written_), gl_context_,
|
||||
std::move(release_callback_));
|
||||
opengl_buffer_ = GL_INVALID_INDEX;
|
||||
} else {
|
||||
AHardwareBuffer_release(ahwb_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void* Tensor::MapAhwbToCpuRead() const {
|
||||
if (ahwb_) {
|
||||
if (!(valid_ & kValidCpu) && (valid_ & kValidOpenGlBuffer) &&
|
||||
ssbo_written_ == -1) {
|
||||
// EGLSync is failed. Use another synchronization method.
|
||||
// TODO: Use tflite::gpu::GlBufferSync and GlActiveSync.
|
||||
glFinish();
|
||||
}
|
||||
void* ptr;
|
||||
auto error =
|
||||
AHardwareBuffer_lock(ahwb_, AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN,
|
||||
ssbo_written_, nullptr, &ptr);
|
||||
CHECK(error == 0) << "AHardwareBuffer_lock " << error;
|
||||
close(ssbo_written_);
|
||||
ssbo_written_ = -1;
|
||||
return ptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void* Tensor::MapAhwbToCpuWrite() const {
|
||||
if (ahwb_) {
|
||||
// TODO: If previously acquired view is GPU write view then need to
|
||||
// be sure that writing is finished. That's a warning: two consequent write
|
||||
// views should be interleaved with read view.
|
||||
void* ptr;
|
||||
auto error = AHardwareBuffer_lock(
|
||||
ahwb_, AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN, -1, nullptr, &ptr);
|
||||
CHECK(error == 0) << "AHardwareBuffer_lock " << error;
|
||||
return ptr;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#else // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
bool Tensor::AllocateAhwbMapToSsbo() const { return false; }
|
||||
bool Tensor::InsertAhwbToSsboFence() const { return false; }
|
||||
void Tensor::MoveAhwbStuff(Tensor* src) {}
|
||||
void Tensor::ReleaseAhwbStuff() {}
|
||||
void* Tensor::MapAhwbToCpuRead() const { return nullptr; }
|
||||
void* Tensor::MapAhwbToCpuWrite() const { return nullptr; }
|
||||
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
||||
|
||||
} // namespace mediapipe
|
59
mediapipe/framework/formats/tensor_ahwb_test.cc
Normal file
59
mediapipe/framework/formats/tensor_ahwb_test.cc
Normal file
|
@ -0,0 +1,59 @@
|
|||
#include "mediapipe/gpu/gpu_test_base.h"
|
||||
#include "testing/base/public/gmock.h"
|
||||
#include "testing/base/public/gunit.h"
|
||||
|
||||
#ifdef MEDIAPIPE_TENSOR_USE_AHWB
|
||||
#include <android/hardware_buffer.h>
|
||||
|
||||
#include "mediapipe/framework/formats/tensor.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
#if !MEDIAPIPE_DISABLE_GPU
|
||||
class TensorAhwbTest : public mediapipe::GpuTestBase {
|
||||
public:
|
||||
};
|
||||
|
||||
TEST_F(TensorAhwbTest, TestCpuThenAHWB) {
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
||||
{
|
||||
auto ptr = tensor.GetCpuWriteView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
}
|
||||
{
|
||||
auto ahwb = tensor.GetAHardwareBufferReadView().handle();
|
||||
EXPECT_NE(ahwb, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorAhwbTest, TestAHWBThenCpu) {
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
||||
{
|
||||
auto ahwb = tensor.GetAHardwareBufferWriteView().handle();
|
||||
EXPECT_NE(ahwb, nullptr);
|
||||
}
|
||||
{
|
||||
auto ptr = tensor.GetCpuReadView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TensorAhwbTest, TestCpuThenGl) {
|
||||
RunInGlContext([] {
|
||||
Tensor tensor(Tensor::ElementType::kFloat32, Tensor::Shape{1});
|
||||
{
|
||||
auto ptr = tensor.GetCpuWriteView().buffer<float>();
|
||||
EXPECT_NE(ptr, nullptr);
|
||||
}
|
||||
{
|
||||
auto ssbo = tensor.GetOpenGlBufferReadView().name();
|
||||
EXPECT_GT(ssbo, 0);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // !MEDIAPIPE_DISABLE_GPU
|
||||
|
||||
#endif // MEDIAPIPE_TENSOR_USE_AHWB
|
|
@ -21,6 +21,8 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe;
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
// Header for a uniformly sampled time series stream. Each Packet in
|
||||
// the stream is a Matrix, and each column is a (vector-valued) sample of
|
||||
// the series, i.e. each column corresponds to a distinct sample in time.
|
||||
|
|
|
@ -204,6 +204,8 @@ absl::Status InputStreamManager::SetNextTimestampBound(const Timestamp bound,
|
|||
// untimed scheduling policies.
|
||||
if (bound > next_timestamp_bound_) {
|
||||
next_timestamp_bound_ = bound;
|
||||
VLOG(3) << "Next timestamp bound for input " << name_ << " is "
|
||||
<< next_timestamp_bound_;
|
||||
if (queue_.empty()) {
|
||||
// If the queue was not empty then a change to the next_timestamp_bound_
|
||||
// is not detectable by the consumer.
|
||||
|
|
|
@ -168,6 +168,8 @@ void OutputStreamManager::PropagateUpdatesToMirrors(
|
|||
if (next_timestamp_bound != Timestamp::Unset()) {
|
||||
absl::MutexLock lock(&stream_mutex_);
|
||||
next_timestamp_bound_ = next_timestamp_bound;
|
||||
VLOG(3) << "Next timestamp bound for output " << output_stream_spec_.name
|
||||
<< " is " << next_timestamp_bound_;
|
||||
}
|
||||
}
|
||||
std::list<Packet>* packets_to_propagate = output_stream_shard->OutputQueue();
|
||||
|
|
|
@ -106,19 +106,17 @@ std::string Packet::DebugString() const {
|
|||
return result;
|
||||
}
|
||||
|
||||
absl::Status Packet::ValidateAsType(const tool::TypeInfo& type_info) const {
|
||||
absl::Status Packet::ValidateAsType(TypeId type_id) const {
|
||||
if (ABSL_PREDICT_FALSE(IsEmpty())) {
|
||||
return absl::InternalError(
|
||||
absl::StrCat("Expected a Packet of type: ",
|
||||
MediaPipeTypeStringOrDemangled(type_info),
|
||||
", but received an empty Packet."));
|
||||
return absl::InternalError(absl::StrCat(
|
||||
"Expected a Packet of type: ", MediaPipeTypeStringOrDemangled(type_id),
|
||||
", but received an empty Packet."));
|
||||
}
|
||||
bool holder_is_right_type =
|
||||
holder_->GetTypeInfo().hash_code() == type_info.hash_code();
|
||||
bool holder_is_right_type = holder_->GetTypeId() == type_id;
|
||||
if (ABSL_PREDICT_FALSE(!holder_is_right_type)) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"The Packet stores \"", holder_->DebugTypeName(), "\", but \"",
|
||||
MediaPipeTypeStringOrDemangled(type_info), "\" was requested."));
|
||||
MediaPipeTypeStringOrDemangled(type_id), "\" was requested."));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <typeinfo>
|
||||
|
||||
#include "absl/base/macros.h"
|
||||
#include "absl/memory/memory.h"
|
||||
|
@ -69,7 +68,7 @@ absl::StatusOr<Packet> PacketFromDynamicProto(const std::string& type_name,
|
|||
// The preferred method of creating a Packet is with MakePacket<T>().
|
||||
// The Packet typically owns the object that it contains, but
|
||||
// PointToForeign allows a Packet to be constructed which does not
|
||||
// own it's data.
|
||||
// own its data.
|
||||
//
|
||||
// This class is thread compatible.
|
||||
class Packet {
|
||||
|
@ -180,7 +179,7 @@ class Packet {
|
|||
// Returns an error if the packet does not contain data of type T.
|
||||
template <typename T>
|
||||
absl::Status ValidateAsType() const {
|
||||
return ValidateAsType(tool::TypeInfo::Get<T>());
|
||||
return ValidateAsType(kTypeId<T>);
|
||||
}
|
||||
|
||||
// Returns an error if the packet is not an instance of
|
||||
|
@ -189,11 +188,7 @@ class Packet {
|
|||
|
||||
// Get the type id for the underlying type stored in the Packet.
|
||||
// Crashes if IsEmpty() == true.
|
||||
size_t GetTypeId() const { return GetTypeInfo().hash_code(); }
|
||||
|
||||
// Get the type info for the underlying type stored in the Packet.
|
||||
// Crashes if IsEmpty() == true.
|
||||
const tool::TypeInfo& GetTypeInfo() const;
|
||||
TypeId GetTypeId() const;
|
||||
|
||||
// Returns the timestamp.
|
||||
class Timestamp Timestamp() const;
|
||||
|
@ -225,7 +220,7 @@ class Packet {
|
|||
packet_internal::GetHolderShared(Packet&& packet);
|
||||
|
||||
friend class PacketType;
|
||||
absl::Status ValidateAsType(const tool::TypeInfo& type_info) const;
|
||||
absl::Status ValidateAsType(TypeId type_id) const;
|
||||
|
||||
std::shared_ptr<packet_internal::HolderBase> holder_;
|
||||
class Timestamp timestamp_;
|
||||
|
@ -369,7 +364,7 @@ class HolderBase {
|
|||
virtual ~HolderBase();
|
||||
template <typename T>
|
||||
bool PayloadIsOfType() const {
|
||||
return GetTypeInfo().hash_code() == tool::GetTypeHash<T>();
|
||||
return GetTypeId() == kTypeId<T>;
|
||||
}
|
||||
// Returns a printable string identifying the type stored in the holder.
|
||||
virtual const std::string DebugTypeName() const = 0;
|
||||
|
@ -377,7 +372,7 @@ class HolderBase {
|
|||
// empty string.
|
||||
virtual const std::string RegisteredTypeName() const = 0;
|
||||
// Get the type id of the underlying data type.
|
||||
virtual const tool::TypeInfo& GetTypeInfo() const = 0;
|
||||
virtual TypeId GetTypeId() const = 0;
|
||||
// Downcasts this to Holder<T>. Returns nullptr if deserialization
|
||||
// failed or if the requested type is not what is stored.
|
||||
template <typename T>
|
||||
|
@ -428,7 +423,7 @@ StatusOr<std::vector<const proto_ns::MessageLite*>>
|
|||
ConvertToVectorOfProtoMessageLitePtrs(const T* data,
|
||||
/*is_proto_vector=*/std::false_type) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"The Packet stores \"", tool::TypeInfo::Get<T>().name(), "\"",
|
||||
"The Packet stores \"", kTypeId<T>.name(), "\"",
|
||||
"which is not convertible to vector<proto_ns::MessageLite*>."));
|
||||
}
|
||||
|
||||
|
@ -510,9 +505,7 @@ class Holder : public HolderBase {
|
|||
HolderSupport<T>::EnsureStaticInit();
|
||||
return *ptr_;
|
||||
}
|
||||
const tool::TypeInfo& GetTypeInfo() const final {
|
||||
return tool::TypeInfo::Get<T>();
|
||||
}
|
||||
TypeId GetTypeId() const final { return kTypeId<T>; }
|
||||
// Releases the underlying data pointer and transfers the ownership to a
|
||||
// unique pointer.
|
||||
// This method is dangerous and is only used by Packet::Consume() if the
|
||||
|
@ -748,9 +741,9 @@ inline Packet& Packet::operator=(Packet&& packet) {
|
|||
|
||||
inline bool Packet::IsEmpty() const { return holder_ == nullptr; }
|
||||
|
||||
inline const tool::TypeInfo& Packet::GetTypeInfo() const {
|
||||
inline TypeId Packet::GetTypeId() const {
|
||||
CHECK(holder_);
|
||||
return holder_->GetTypeInfo();
|
||||
return holder_->GetTypeId();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -18,6 +18,8 @@ syntax = "proto2";
|
|||
|
||||
package mediapipe;
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message PacketTestProto {
|
||||
// Tests that the tags used to encode the timestamp do not interfere with
|
||||
// proto tags.
|
||||
|
|
|
@ -127,13 +127,13 @@ bool PacketType::IsOneOf() const {
|
|||
}
|
||||
|
||||
bool PacketType::IsExactType() const {
|
||||
return absl::holds_alternative<const tool::TypeInfo*>(type_spec_);
|
||||
return absl::holds_alternative<TypeId>(type_spec_);
|
||||
}
|
||||
|
||||
const std::string* PacketType::RegisteredTypeName() const {
|
||||
if (auto* same_as = SameAsPtr()) return same_as->RegisteredTypeName();
|
||||
if (auto* type_info = absl::get_if<const tool::TypeInfo*>(&type_spec_))
|
||||
return MediaPipeTypeStringFromTypeId((**type_info).hash_code());
|
||||
if (auto* type_id = absl::get_if<TypeId>(&type_spec_))
|
||||
return MediaPipeTypeStringFromTypeId(*type_id);
|
||||
if (auto* multi_type = absl::get_if<MultiType>(&type_spec_))
|
||||
return multi_type->registered_type_name;
|
||||
return nullptr;
|
||||
|
@ -141,8 +141,8 @@ const std::string* PacketType::RegisteredTypeName() const {
|
|||
|
||||
namespace internal {
|
||||
|
||||
struct TypeInfoFormatter {
|
||||
void operator()(std::string* out, const tool::TypeInfo& t) const {
|
||||
struct TypeIdFormatter {
|
||||
void operator()(std::string* out, TypeId t) const {
|
||||
absl::StrAppend(out, MediaPipeTypeStringOrDemangled(t));
|
||||
}
|
||||
};
|
||||
|
@ -167,12 +167,9 @@ explicit QuoteFormatter(Formatter f) -> QuoteFormatter<Formatter>;
|
|||
|
||||
} // namespace internal
|
||||
|
||||
std::string PacketType::TypeNameForOneOf(TypeInfoSpan types) {
|
||||
std::string PacketType::TypeNameForOneOf(TypeIdSpan types) {
|
||||
return absl::StrCat(
|
||||
"OneOf<",
|
||||
absl::StrJoin(types, ", ",
|
||||
absl::DereferenceFormatter(internal::TypeInfoFormatter())),
|
||||
">");
|
||||
"OneOf<", absl::StrJoin(types, ", ", internal::TypeIdFormatter()), ">");
|
||||
}
|
||||
|
||||
std::string PacketType::DebugTypeName() const {
|
||||
|
@ -185,8 +182,8 @@ std::string PacketType::DebugTypeName() const {
|
|||
if (auto* special = absl::get_if<SpecialType>(&type_spec_)) {
|
||||
return special->name_;
|
||||
}
|
||||
if (auto* type_info = absl::get_if<const tool::TypeInfo*>(&type_spec_)) {
|
||||
return MediaPipeTypeStringOrDemangled(**type_info);
|
||||
if (auto* type_id = absl::get_if<TypeId>(&type_spec_)) {
|
||||
return MediaPipeTypeStringOrDemangled(*type_id);
|
||||
}
|
||||
if (auto* multi_type = absl::get_if<MultiType>(&type_spec_)) {
|
||||
return TypeNameForOneOf(multi_type->types);
|
||||
|
@ -194,11 +191,11 @@ std::string PacketType::DebugTypeName() const {
|
|||
return "[Undefined Type]";
|
||||
}
|
||||
|
||||
static bool HaveCommonType(absl::Span<const tool::TypeInfo* const> types1,
|
||||
absl::Span<const tool::TypeInfo* const> types2) {
|
||||
static bool HaveCommonType(absl::Span<const TypeId> types1,
|
||||
absl::Span<const TypeId> types2) {
|
||||
for (const auto& first : types1) {
|
||||
for (const auto& second : types2) {
|
||||
if (first->hash_code() == second->hash_code()) {
|
||||
if (first == second) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -216,35 +213,34 @@ absl::Status PacketType::Validate(const Packet& packet) const {
|
|||
// in SetSameAs().
|
||||
return GetSameAs()->Validate(packet);
|
||||
}
|
||||
if (auto* type_info = absl::get_if<const tool::TypeInfo*>(&type_spec_)) {
|
||||
return packet.ValidateAsType(**type_info);
|
||||
if (auto* type_id = absl::get_if<TypeId>(&type_spec_)) {
|
||||
return packet.ValidateAsType(*type_id);
|
||||
}
|
||||
if (packet.IsEmpty()) {
|
||||
return mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Empty packets are not allowed for type: " << DebugTypeName();
|
||||
}
|
||||
if (auto* multi_type = absl::get_if<MultiType>(&type_spec_)) {
|
||||
auto* packet_type = &packet.GetTypeInfo();
|
||||
auto packet_type = packet.GetTypeId();
|
||||
if (HaveCommonType(multi_type->types, absl::MakeSpan(&packet_type, 1))) {
|
||||
return absl::OkStatus();
|
||||
} else {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"The Packet stores \"", packet.DebugTypeName(), "\", but one of ",
|
||||
absl::StrJoin(multi_type->types, ", ",
|
||||
absl::DereferenceFormatter(internal::QuoteFormatter(
|
||||
internal::TypeInfoFormatter()))),
|
||||
internal::QuoteFormatter(internal::TypeIdFormatter())),
|
||||
" was requested."));
|
||||
}
|
||||
}
|
||||
if (auto* special = absl::get_if<SpecialType>(&type_spec_)) {
|
||||
return special->accept_fn_(&packet.GetTypeInfo());
|
||||
return special->accept_fn_(packet.GetTypeId());
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
PacketType::TypeInfoSpan PacketType::GetTypeSpan(const TypeSpec& type_spec) {
|
||||
if (auto* type_info = absl::get_if<const tool::TypeInfo*>(&type_spec))
|
||||
return absl::MakeSpan(type_info, 1);
|
||||
PacketType::TypeIdSpan PacketType::GetTypeSpan(const TypeSpec& type_spec) {
|
||||
if (auto* type_id = absl::get_if<TypeId>(&type_spec))
|
||||
return absl::MakeSpan(type_id, 1);
|
||||
if (auto* multi_type = absl::get_if<MultiType>(&type_spec))
|
||||
return multi_type->types;
|
||||
return {};
|
||||
|
@ -254,8 +250,8 @@ bool PacketType::IsConsistentWith(const PacketType& other) const {
|
|||
const PacketType* type1 = GetSameAs();
|
||||
const PacketType* type2 = other.GetSameAs();
|
||||
|
||||
TypeInfoSpan types1 = GetTypeSpan(type1->type_spec_);
|
||||
TypeInfoSpan types2 = GetTypeSpan(type2->type_spec_);
|
||||
TypeIdSpan types1 = GetTypeSpan(type1->type_spec_);
|
||||
TypeIdSpan types2 = GetTypeSpan(type2->type_spec_);
|
||||
if (!types1.empty() && !types2.empty()) {
|
||||
return HaveCommonType(types1, types2);
|
||||
}
|
||||
|
|
|
@ -121,15 +121,15 @@ class PacketType {
|
|||
// We don't do union-find optimizations in order to avoid a mutex.
|
||||
const PacketType* other;
|
||||
};
|
||||
using TypeInfoSpan = absl::Span<const tool::TypeInfo* const>;
|
||||
using TypeIdSpan = absl::Span<const TypeId>;
|
||||
struct MultiType {
|
||||
TypeInfoSpan types;
|
||||
TypeIdSpan types;
|
||||
// TODO: refactor RegisteredTypeName, remove.
|
||||
const std::string* registered_type_name;
|
||||
};
|
||||
struct SpecialType;
|
||||
using TypeSpec = absl::variant<absl::monostate, const tool::TypeInfo*,
|
||||
MultiType, SameAs, SpecialType>;
|
||||
using TypeSpec =
|
||||
absl::variant<absl::monostate, TypeId, MultiType, SameAs, SpecialType>;
|
||||
typedef absl::Status (*AcceptsTypeFn)(const TypeSpec& type);
|
||||
struct SpecialType {
|
||||
std::string name_;
|
||||
|
@ -140,8 +140,8 @@ class PacketType {
|
|||
static absl::Status AcceptNone(const TypeSpec& type);
|
||||
|
||||
const PacketType* SameAsPtr() const;
|
||||
static TypeInfoSpan GetTypeSpan(const TypeSpec& type_spec);
|
||||
static std::string TypeNameForOneOf(TypeInfoSpan types);
|
||||
static TypeIdSpan GetTypeSpan(const TypeSpec& type_spec);
|
||||
static std::string TypeNameForOneOf(TypeIdSpan types);
|
||||
|
||||
TypeSpec type_spec_;
|
||||
|
||||
|
@ -259,14 +259,13 @@ absl::Status ValidatePacketTypeSet(const PacketTypeSet& packet_type_set);
|
|||
|
||||
template <typename T>
|
||||
PacketType& PacketType::Set() {
|
||||
type_spec_ = &tool::TypeInfo::Get<T>();
|
||||
type_spec_ = kTypeId<T>;
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename... T>
|
||||
PacketType& PacketType::SetOneOf() {
|
||||
static const NoDestructor<std::vector<const tool::TypeInfo*>> types{
|
||||
{&tool::TypeInfo::Get<T>()...}};
|
||||
static const NoDestructor<std::vector<TypeId>> types{{kTypeId<T>...}};
|
||||
static const NoDestructor<std::string> name{TypeNameForOneOf(*types)};
|
||||
type_spec_ = MultiType{*types, &*name};
|
||||
return *this;
|
||||
|
|
|
@ -43,7 +43,7 @@ const int kDefaultLogFileCount = 2;
|
|||
const char kDefaultLogFilePrefix[] = "mediapipe_trace_";
|
||||
|
||||
// The number of recent timestamps tracked for each input stream.
|
||||
const int kPacketInfoRecentCount = 100;
|
||||
const int kPacketInfoRecentCount = 400;
|
||||
|
||||
std::string PacketIdToString(const PacketId& packet_id) {
|
||||
return absl::Substitute("stream_name: $0, timestamp_usec: $1",
|
||||
|
@ -507,8 +507,8 @@ int64 GraphProfiler::AddInputStreamTimeSamples(
|
|||
// This is a condition rather than a failure CHECK because
|
||||
// under certain conditions the consumer calculator's Process()
|
||||
// can start before the producer calculator's Process() is finished.
|
||||
LOG_EVERY_N(WARNING, 100) << "Expected packet info is missing for: "
|
||||
<< PacketIdToString(packet_id);
|
||||
LOG_FIRST_N(WARNING, 10) << "Expected packet info is missing for: "
|
||||
<< PacketIdToString(packet_id);
|
||||
continue;
|
||||
}
|
||||
AddTimeSample(
|
||||
|
|
|
@ -36,7 +36,7 @@ class SubgraphContext {
|
|||
public:
|
||||
SubgraphContext() : SubgraphContext(nullptr, nullptr) {}
|
||||
// @node and/or @service_manager can be nullptr.
|
||||
SubgraphContext(const CalculatorGraphConfig::Node* node,
|
||||
SubgraphContext(CalculatorGraphConfig::Node* node,
|
||||
const GraphServiceManager* service_manager)
|
||||
: default_node_(node ? absl::nullopt
|
||||
: absl::optional<CalculatorGraphConfig::Node>(
|
||||
|
@ -48,14 +48,19 @@ class SubgraphContext {
|
|||
: absl::optional<GraphServiceManager>(GraphServiceManager())),
|
||||
service_manager_(service_manager ? *service_manager
|
||||
: default_service_manager_.value()),
|
||||
options_map_(std::move(tool::OptionsMap().Initialize(original_node_))) {
|
||||
}
|
||||
options_map_(
|
||||
std::move(tool::MutableOptionsMap().Initialize(original_node_))) {}
|
||||
|
||||
template <typename T>
|
||||
const T& Options() {
|
||||
return options_map_.Get<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* MutableOptions() {
|
||||
return options_map_.GetMutable<T>();
|
||||
}
|
||||
|
||||
const CalculatorGraphConfig::Node& OriginalNode() const {
|
||||
return original_node_;
|
||||
}
|
||||
|
@ -67,16 +72,16 @@ class SubgraphContext {
|
|||
|
||||
private:
|
||||
// Populated if node is not provided during construction.
|
||||
const absl::optional<CalculatorGraphConfig::Node> default_node_;
|
||||
absl::optional<CalculatorGraphConfig::Node> default_node_;
|
||||
|
||||
const CalculatorGraphConfig::Node& original_node_;
|
||||
CalculatorGraphConfig::Node& original_node_;
|
||||
|
||||
// Populated if service manager is not provided during construction.
|
||||
const absl::optional<GraphServiceManager> default_service_manager_;
|
||||
|
||||
const GraphServiceManager& service_manager_;
|
||||
|
||||
tool::OptionsMap options_map_;
|
||||
tool::MutableOptionsMap options_map_;
|
||||
};
|
||||
|
||||
// Instances of this class are responsible for providing a subgraph config.
|
||||
|
|
|
@ -22,6 +22,8 @@ package mediapipe;
|
|||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
|
||||
option objc_class_prefix = "MediaPipe";
|
||||
|
||||
message RandomMatrixCalculatorOptions {
|
||||
extend CalculatorOptions {
|
||||
optional RandomMatrixCalculatorOptions ext = 52056136;
|
||||
|
|
|
@ -198,6 +198,7 @@ cc_library(
|
|||
":name_util",
|
||||
":options_registry",
|
||||
":proto_util_lite",
|
||||
":type_util",
|
||||
"//mediapipe/framework:calculator_cc_proto",
|
||||
"//mediapipe/framework:packet",
|
||||
"//mediapipe/framework:packet_type",
|
||||
|
@ -277,9 +278,12 @@ cc_library(
|
|||
hdrs = ["options_registry.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":field_data_cc_proto",
|
||||
":proto_util_lite",
|
||||
"//mediapipe/framework/deps:registration",
|
||||
"//mediapipe/framework/port:advanced_proto",
|
||||
"//mediapipe/framework/port:logging",
|
||||
"//mediapipe/framework/port:ret_check",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
|
@ -334,6 +338,7 @@ cc_library(
|
|||
hdrs = ["proto_util_lite.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":field_data_cc_proto",
|
||||
"//mediapipe/framework:type_map",
|
||||
"//mediapipe/framework/port:advanced_proto_lite",
|
||||
"//mediapipe/framework/port:integral_types",
|
||||
|
@ -518,9 +523,11 @@ cc_library(
|
|||
cc_library(
|
||||
name = "type_util",
|
||||
hdrs = ["type_util.h"],
|
||||
visibility = ["//mediapipe/framework:mediapipe_internal"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//mediapipe/framework:demangle",
|
||||
"//mediapipe/framework:port",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ syntax = "proto2";
|
|||
package mediapipe;
|
||||
|
||||
import "mediapipe/framework/calculator.proto";
|
||||
import "mediapipe/framework/calculator_options.proto";
|
||||
import "mediapipe/framework/deps/proto_descriptor.proto";
|
||||
|
||||
option java_package = "com.google.mediapipe.proto";
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/packet.h"
|
||||
|
@ -18,6 +19,7 @@
|
|||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/tool/name_util.h"
|
||||
#include "mediapipe/framework/tool/proto_util_lite.h"
|
||||
#include "mediapipe/framework/tool/type_util.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tool {
|
||||
|
@ -41,165 +43,39 @@ FieldType AsFieldType(proto_ns::FieldDescriptorProto::Type type) {
|
|||
return static_cast<FieldType>(type);
|
||||
}
|
||||
|
||||
absl::Status WriteValue(const FieldData& value, FieldType field_type,
|
||||
std::string* field_bytes) {
|
||||
StringOutputStream sos(field_bytes);
|
||||
CodedOutputStream out(&sos);
|
||||
switch (field_type) {
|
||||
case WireFormatLite::TYPE_INT32:
|
||||
WireFormatLite::WriteInt32NoTag(value.int32_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_SINT32:
|
||||
WireFormatLite::WriteSInt32NoTag(value.int32_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_INT64:
|
||||
WireFormatLite::WriteInt64NoTag(value.int64_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_SINT64:
|
||||
WireFormatLite::WriteSInt64NoTag(value.int64_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_UINT32:
|
||||
WireFormatLite::WriteUInt32NoTag(value.uint32_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_UINT64:
|
||||
WireFormatLite::WriteUInt64NoTag(value.uint64_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_DOUBLE:
|
||||
WireFormatLite::WriteDoubleNoTag(value.uint64_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_FLOAT:
|
||||
WireFormatLite::WriteFloatNoTag(value.float_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_BOOL:
|
||||
WireFormatLite::WriteBoolNoTag(value.bool_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_ENUM:
|
||||
WireFormatLite::WriteEnumNoTag(value.enum_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_STRING:
|
||||
out.WriteString(value.string_value());
|
||||
break;
|
||||
case WireFormatLite::TYPE_MESSAGE:
|
||||
out.WriteString(value.message_value().value());
|
||||
break;
|
||||
default:
|
||||
return absl::UnimplementedError(
|
||||
absl::StrCat("Cannot write type: ", field_type));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Serializes a packet value.
|
||||
absl::Status WriteField(const FieldData& packet, const FieldDescriptor* field,
|
||||
std::string* result) {
|
||||
FieldType field_type = AsFieldType(field->type());
|
||||
return WriteValue(packet, field_type, result);
|
||||
}
|
||||
|
||||
template <typename ValueT, FieldType kFieldType>
|
||||
static ValueT ReadValue(absl::string_view field_bytes, absl::Status* status) {
|
||||
ArrayInputStream ais(field_bytes.data(), field_bytes.size());
|
||||
CodedInputStream input(&ais);
|
||||
ValueT result;
|
||||
if (!WireFormatLite::ReadPrimitive<ValueT, kFieldType>(&input, &result)) {
|
||||
status->Update(mediapipe::InvalidArgumentError(absl::StrCat(
|
||||
"Bad serialized value: ", MediaPipeTypeStringOrDemangled<ValueT>(),
|
||||
".")));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type,
|
||||
absl::string_view message_type, FieldData* result) {
|
||||
absl::Status status;
|
||||
result->Clear();
|
||||
switch (field_type) {
|
||||
case WireFormatLite::TYPE_INT32:
|
||||
result->set_int32_value(
|
||||
ReadValue<int32, WireFormatLite::TYPE_INT32>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_SINT32:
|
||||
result->set_int32_value(
|
||||
ReadValue<int32, WireFormatLite::TYPE_SINT32>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_INT64:
|
||||
result->set_int64_value(
|
||||
ReadValue<int64, WireFormatLite::TYPE_INT64>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_SINT64:
|
||||
result->set_int64_value(
|
||||
ReadValue<int64, WireFormatLite::TYPE_SINT64>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_UINT32:
|
||||
result->set_uint32_value(
|
||||
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_UINT64:
|
||||
result->set_uint64_value(
|
||||
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_DOUBLE:
|
||||
result->set_double_value(
|
||||
ReadValue<double, WireFormatLite::TYPE_DOUBLE>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_FLOAT:
|
||||
result->set_float_value(
|
||||
ReadValue<float, WireFormatLite::TYPE_FLOAT>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_BOOL:
|
||||
result->set_bool_value(
|
||||
ReadValue<bool, WireFormatLite::TYPE_BOOL>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_ENUM:
|
||||
result->set_enum_value(
|
||||
ReadValue<int32, WireFormatLite::TYPE_ENUM>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_STRING:
|
||||
result->set_string_value(std::string(field_bytes));
|
||||
break;
|
||||
case WireFormatLite::TYPE_MESSAGE:
|
||||
result->mutable_message_value()->set_value(std::string(field_bytes));
|
||||
result->mutable_message_value()->set_type_url(TypeUrl(message_type));
|
||||
break;
|
||||
default:
|
||||
status = absl::UnimplementedError(
|
||||
absl::StrCat("Cannot read type: ", field_type));
|
||||
break;
|
||||
}
|
||||
return status;
|
||||
return ProtoUtilLite::WriteValue(packet, field->type(), result);
|
||||
}
|
||||
|
||||
// Deserializes a packet from a protobuf field.
|
||||
absl::Status ReadField(absl::string_view bytes, const FieldDescriptor* field,
|
||||
absl::Status ReadField(absl::string_view bytes, const FieldDescriptor& field,
|
||||
FieldData* result) {
|
||||
RET_CHECK_NE(field, nullptr);
|
||||
FieldType field_type = AsFieldType(field->type());
|
||||
std::string message_type = (field_type == WireFormatLite::TYPE_MESSAGE)
|
||||
? field->message_type()->full_name()
|
||||
std::string message_type = (field.type() == WireFormatLite::TYPE_MESSAGE)
|
||||
? field.message_type()->full_name()
|
||||
: "";
|
||||
return ReadValue(bytes, field_type, message_type, result);
|
||||
return ProtoUtilLite::ReadValue(bytes, field.type(), message_type, result);
|
||||
}
|
||||
|
||||
// Reads all values from a repeated field.
|
||||
absl::Status GetFieldValues(const FieldData& message_data,
|
||||
const FieldDescriptor& field,
|
||||
std::vector<FieldData>* result) {
|
||||
absl::StatusOr<std::vector<FieldData>> GetFieldValues(
|
||||
const FieldData& message_data, const FieldDescriptor& field) {
|
||||
std::vector<FieldData> result;
|
||||
const std::string& message_bytes = message_data.message_value().value();
|
||||
FieldType field_type = AsFieldType(field.type());
|
||||
ProtoUtilLite proto_util;
|
||||
ProtoUtilLite::ProtoPath proto_path = {{field.number(), 0}};
|
||||
int count;
|
||||
MP_RETURN_IF_ERROR(
|
||||
proto_util.GetFieldCount(message_bytes, proto_path, field_type, &count));
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount(message_bytes, proto_path,
|
||||
field.type(), &count));
|
||||
std::vector<std::string> field_values;
|
||||
MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, count,
|
||||
field_type, &field_values));
|
||||
for (int i = 0; i < count; ++i) {
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(
|
||||
message_bytes, proto_path, count, field.type(), &field_values));
|
||||
for (int i = 0; i < field_values.size(); ++i) {
|
||||
FieldData r;
|
||||
MP_RETURN_IF_ERROR(ReadField(field_values[i], &field, &r));
|
||||
result->push_back(std::move(r));
|
||||
MP_RETURN_IF_ERROR(ReadField(field_values[i], field, &r));
|
||||
result.push_back(std::move(r));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
return result;
|
||||
}
|
||||
|
||||
// Reads one value from a field.
|
||||
|
@ -207,42 +83,70 @@ absl::Status GetFieldValue(const FieldData& message_data,
|
|||
const FieldPathEntry& entry, FieldData* result) {
|
||||
RET_CHECK_NE(entry.field, nullptr);
|
||||
const std::string& message_bytes = message_data.message_value().value();
|
||||
FieldType field_type = AsFieldType(entry.field->type());
|
||||
ProtoUtilLite proto_util;
|
||||
ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), entry.index}};
|
||||
FieldType field_type = entry.field->type();
|
||||
int index = std::max(0, entry.index);
|
||||
ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), index}};
|
||||
std::vector<std::string> field_values;
|
||||
MP_RETURN_IF_ERROR(proto_util.GetFieldRange(message_bytes, proto_path, 1,
|
||||
field_type, &field_values));
|
||||
MP_RETURN_IF_ERROR(ReadField(field_values[0], entry.field, result));
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(message_bytes, proto_path, 1,
|
||||
field_type, &field_values));
|
||||
MP_RETURN_IF_ERROR(ReadField(field_values[0], *entry.field, result));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Writes one value to a field.
|
||||
absl::Status SetFieldValue(const FieldPathEntry& entry, const FieldData& value,
|
||||
FieldData* result) {
|
||||
std::vector<FieldData> field_values;
|
||||
ProtoUtilLite proto_util;
|
||||
FieldType field_type = AsFieldType(entry.field->type());
|
||||
ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), entry.index}};
|
||||
std::string* message_bytes = result->mutable_message_value()->mutable_value();
|
||||
absl::Status SetFieldValue(FieldData& result, const FieldPathEntry& entry,
|
||||
const FieldData& value) {
|
||||
int index = std::max(0, entry.index);
|
||||
ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), index}};
|
||||
std::string* message_bytes = result.mutable_message_value()->mutable_value();
|
||||
int field_count;
|
||||
MP_RETURN_IF_ERROR(proto_util.GetFieldCount(*message_bytes, proto_path,
|
||||
field_type, &field_count));
|
||||
if (entry.index > field_count) {
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount(
|
||||
*message_bytes, proto_path, entry.field->type(), &field_count));
|
||||
if (index > field_count) {
|
||||
return absl::OutOfRangeError(
|
||||
absl::StrCat("Option field index out of range: ", entry.index));
|
||||
absl::StrCat("Option field index out of range: ", index));
|
||||
}
|
||||
int replace_length = entry.index < field_count ? 1 : 0;
|
||||
int replace_length = index < field_count ? 1 : 0;
|
||||
std::string field_value;
|
||||
MP_RETURN_IF_ERROR(WriteField(value, entry.field, &field_value));
|
||||
MP_RETURN_IF_ERROR(proto_util.ReplaceFieldRange(
|
||||
message_bytes, proto_path, replace_length, field_type, {field_value}));
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::ReplaceFieldRange(
|
||||
message_bytes, proto_path, replace_length, entry.field->type(),
|
||||
{field_value}));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Writes several values to a repeated field.
|
||||
// The specified |values| replace the specified |entry| index,
|
||||
// or if no index is specified all field values are replaced.
|
||||
absl::Status SetFieldValues(FieldData& result, const FieldPathEntry& entry,
|
||||
const std::vector<FieldData>& values) {
|
||||
if (entry.field == nullptr) {
|
||||
return absl::InvalidArgumentError("Field not found.");
|
||||
}
|
||||
FieldType field_type = entry.field->type();
|
||||
ProtoUtilLite::ProtoPath proto_path = {{entry.field->number(), 0}};
|
||||
std::string* message_bytes = result.mutable_message_value()->mutable_value();
|
||||
int field_count;
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount(*message_bytes, proto_path,
|
||||
field_type, &field_count));
|
||||
int replace_start = 0, replace_length = field_count;
|
||||
if (entry.index > -1) {
|
||||
replace_start = entry.index;
|
||||
replace_length = 1;
|
||||
}
|
||||
std::vector<std::string> field_values(values.size());
|
||||
for (int i = 0; i < values.size(); ++i) {
|
||||
MP_RETURN_IF_ERROR(WriteField(values[i], entry.field, &field_values[i]));
|
||||
}
|
||||
proto_path = {{entry.field->number(), replace_start}};
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::ReplaceFieldRange(
|
||||
message_bytes, proto_path, replace_length, field_type, field_values));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Returns true for a field of type "google.protobuf.Any".
|
||||
bool IsProtobufAny(const FieldDescriptor* field) {
|
||||
return AsFieldType(field->type()) == FieldType::TYPE_MESSAGE &&
|
||||
return field->type() == FieldType::TYPE_MESSAGE &&
|
||||
field->message_type()->full_name() == kGoogleProtobufAny;
|
||||
}
|
||||
|
||||
|
@ -275,9 +179,7 @@ StatusOr<int> FindExtensionIndex(const FieldData& message_data,
|
|||
}
|
||||
std::string& extension_type = entry->extension_type;
|
||||
std::vector<FieldData> field_values;
|
||||
RET_CHECK_NE(entry->field, nullptr);
|
||||
MP_RETURN_IF_ERROR(
|
||||
GetFieldValues(message_data, *entry->field, &field_values));
|
||||
ASSIGN_OR_RETURN(field_values, GetFieldValues(message_data, *entry->field));
|
||||
for (int i = 0; i < field_values.size(); ++i) {
|
||||
FieldData extension = ParseProtobufAny(field_values[i]);
|
||||
if (extension_type == "*" ||
|
||||
|
@ -290,9 +192,9 @@ StatusOr<int> FindExtensionIndex(const FieldData& message_data,
|
|||
|
||||
// Returns true if the value of a field is available.
|
||||
bool HasField(const FieldPath& field_path, const FieldData& message_data) {
|
||||
FieldData value;
|
||||
return GetField(field_path, message_data, &value).ok() &&
|
||||
value.value_case() != mediapipe::FieldData::VALUE_NOT_SET;
|
||||
auto value = GetField(message_data, field_path);
|
||||
return value.ok() &&
|
||||
value->value_case() != mediapipe::FieldData::VALUE_NOT_SET;
|
||||
}
|
||||
|
||||
// Returns the extension field containing the specified extension-type.
|
||||
|
@ -330,43 +232,24 @@ void SetOptionsMessage(
|
|||
*options_any->mutable_value() = node_options.message_value().value();
|
||||
}
|
||||
|
||||
// Returns the count of values in a repeated field.
|
||||
int FieldCount(const FieldData& message_data, const FieldDescriptor* field) {
|
||||
const std::string& message_bytes = message_data.message_value().value();
|
||||
FieldType field_type = AsFieldType(field->type());
|
||||
ProtoUtilLite proto_util;
|
||||
ProtoUtilLite::ProtoPath proto_path = {{field->number(), 0}};
|
||||
int count;
|
||||
if (proto_util.GetFieldCount(message_bytes, proto_path, field_type, &count)
|
||||
.ok()) {
|
||||
return count;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// Deserializes a packet containing a MessageLite value.
|
||||
absl::Status ReadMessage(const std::string& value, const std::string& type_name,
|
||||
Packet* result) {
|
||||
auto packet = packet_internal::PacketFromDynamicProto(type_name, value);
|
||||
if (packet.ok()) {
|
||||
*result = *packet;
|
||||
}
|
||||
return packet.status();
|
||||
absl::StatusOr<Packet> ReadMessage(const std::string& value,
|
||||
const std::string& type_name) {
|
||||
return packet_internal::PacketFromDynamicProto(type_name, value);
|
||||
}
|
||||
|
||||
// Merge two options FieldData values.
|
||||
absl::Status MergeMessages(const FieldData& base, const FieldData& over,
|
||||
FieldData* result) {
|
||||
absl::StatusOr<FieldData> MergeMessages(const FieldData& base,
|
||||
const FieldData& over) {
|
||||
FieldData result;
|
||||
absl::Status status;
|
||||
if (over.value_case() == FieldData::VALUE_NOT_SET) {
|
||||
*result = base;
|
||||
return status;
|
||||
return base;
|
||||
}
|
||||
if (base.value_case() == FieldData::VALUE_NOT_SET) {
|
||||
*result = over;
|
||||
return status;
|
||||
return over;
|
||||
}
|
||||
if (over.value_case() != base.value_case()) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
|
@ -382,10 +265,9 @@ absl::Status MergeMessages(const FieldData& base, const FieldData& over,
|
|||
absl::Cord merged_value;
|
||||
merged_value.Append(base.message_value().value());
|
||||
merged_value.Append(over.message_value().value());
|
||||
result->mutable_message_value()->set_type_url(
|
||||
base.message_value().type_url());
|
||||
result->mutable_message_value()->set_value(std::string(merged_value));
|
||||
return status;
|
||||
result.mutable_message_value()->set_type_url(base.message_value().type_url());
|
||||
result.mutable_message_value()->set_value(std::string(merged_value));
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns either the extension field or the repeated protobuf.Any field index
|
||||
|
@ -439,51 +321,48 @@ FieldPath GetExtensionPath(const std::string& parent_type,
|
|||
}
|
||||
|
||||
// Returns the requested options protobuf for a graph node.
|
||||
absl::Status GetNodeOptions(const FieldData& message_data,
|
||||
const std::string& extension_type,
|
||||
FieldData* result) {
|
||||
absl::StatusOr<FieldData> GetNodeOptions(const FieldData& message_data,
|
||||
const std::string& extension_type) {
|
||||
constexpr char kOptionsName[] = "options";
|
||||
constexpr char kNodeOptionsName[] = "node_options";
|
||||
std::string parent_type = options_field_util::ParseTypeUrl(
|
||||
std::string(message_data.message_value().type_url()));
|
||||
FieldPath path;
|
||||
Status status;
|
||||
absl::Status status;
|
||||
path = GetExtensionPath(parent_type, extension_type, kOptionsName, false);
|
||||
status = GetField(path, message_data, result);
|
||||
if (status.ok()) {
|
||||
return status;
|
||||
auto result = GetField(message_data, path);
|
||||
if (result.ok()) {
|
||||
return result;
|
||||
}
|
||||
path = GetExtensionPath(parent_type, extension_type, kNodeOptionsName, true);
|
||||
status = GetField(path, message_data, result);
|
||||
return status;
|
||||
return GetField(message_data, path);
|
||||
}
|
||||
|
||||
// Returns the requested options protobuf for a graph.
|
||||
absl::Status GetGraphOptions(const FieldData& message_data,
|
||||
const std::string& extension_type,
|
||||
FieldData* result) {
|
||||
absl::StatusOr<FieldData> GetGraphOptions(const FieldData& message_data,
|
||||
const std::string& extension_type) {
|
||||
constexpr char kOptionsName[] = "options";
|
||||
constexpr char kGraphOptionsName[] = "graph_options";
|
||||
std::string parent_type = options_field_util::ParseTypeUrl(
|
||||
std::string(message_data.message_value().type_url()));
|
||||
FieldPath path;
|
||||
Status status;
|
||||
absl::Status status;
|
||||
path = GetExtensionPath(parent_type, extension_type, kOptionsName, false);
|
||||
status = GetField(path, message_data, result);
|
||||
if (status.ok()) {
|
||||
return status;
|
||||
auto result = GetField(message_data, path);
|
||||
if (result.ok()) {
|
||||
return result;
|
||||
}
|
||||
path = GetExtensionPath(parent_type, extension_type, kGraphOptionsName, true);
|
||||
status = GetField(path, message_data, result);
|
||||
return status;
|
||||
return GetField(message_data, path);
|
||||
}
|
||||
|
||||
// Reads a FieldData value from a protobuf field.
|
||||
absl::Status GetField(const FieldPath& field_path,
|
||||
const FieldData& message_data, FieldData* result) {
|
||||
// Reads the FieldData values from a protobuf field.
|
||||
absl::StatusOr<std::vector<FieldData>> GetFieldValues(
|
||||
const FieldData& message_data, const FieldPath& field_path) {
|
||||
std::vector<FieldData> results;
|
||||
if (field_path.empty()) {
|
||||
*result->mutable_message_value() = message_data.message_value();
|
||||
return absl::OkStatus();
|
||||
results.push_back(message_data);
|
||||
return results;
|
||||
}
|
||||
FieldPathEntry head = field_path.front();
|
||||
FieldPath tail = field_path;
|
||||
|
@ -491,65 +370,101 @@ absl::Status GetField(const FieldPath& field_path,
|
|||
if (!head.extension_type.empty()) {
|
||||
MP_RETURN_IF_ERROR(FindExtension(message_data, &head));
|
||||
}
|
||||
if (tail.empty() && FieldCount(message_data, head.field) == 0) {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
MP_RETURN_IF_ERROR(GetFieldValue(message_data, head, result));
|
||||
RET_CHECK_NE(head.field, nullptr);
|
||||
ASSIGN_OR_RETURN(results, GetFieldValues(message_data, *head.field));
|
||||
if (IsProtobufAny(head.field)) {
|
||||
*result = ParseProtobufAny(*result);
|
||||
for (int i = 0; i < results.size(); ++i) {
|
||||
results[i] = ParseProtobufAny(results[i]);
|
||||
}
|
||||
}
|
||||
int index = tail.empty() ? head.index : std::max(0, head.index);
|
||||
if ((int)results.size() <= index) {
|
||||
return absl::OutOfRangeError(absl::StrCat(
|
||||
"Missing feild value: ", head.field ? head.field->name() : "#",
|
||||
" at index: ", index));
|
||||
}
|
||||
if (!tail.empty()) {
|
||||
FieldData child = *result;
|
||||
MP_RETURN_IF_ERROR(GetField(tail, child, result));
|
||||
FieldData child = results.at(index);
|
||||
ASSIGN_OR_RETURN(results, GetFieldValues(child, tail));
|
||||
} else if (index > -1) {
|
||||
FieldData child = results.at(index);
|
||||
results.clear();
|
||||
results.push_back(child);
|
||||
}
|
||||
return absl::OkStatus();
|
||||
return results;
|
||||
}
|
||||
|
||||
// Writes a FieldData value into protobuf field.
|
||||
absl::Status SetField(const FieldPath& field_path, const FieldData& value,
|
||||
FieldData* message_data) {
|
||||
// Reads a FieldData value from a protobuf field.
|
||||
absl::StatusOr<FieldData> GetField(const FieldData& message_data,
|
||||
const FieldPath& field_path) {
|
||||
std::vector<FieldData> results;
|
||||
ASSIGN_OR_RETURN(results, GetFieldValues(message_data, field_path));
|
||||
if (results.empty()) {
|
||||
FieldPathEntry tail = field_path.back();
|
||||
return absl::OutOfRangeError(absl::StrCat(
|
||||
"Missing feild value: ", tail.field ? tail.field->name() : "##",
|
||||
" at index: ", tail.index));
|
||||
}
|
||||
return results[0];
|
||||
}
|
||||
|
||||
// Writes FieldData values into protobuf field.
|
||||
absl::Status SetFieldValues(FieldData& message_data,
|
||||
const FieldPath& field_path,
|
||||
const std::vector<FieldData>& values) {
|
||||
if (field_path.empty()) {
|
||||
*message_data->mutable_message_value() = value.message_value();
|
||||
if (values.empty()) {
|
||||
return absl::InvalidArgumentError("Missing feild value.");
|
||||
}
|
||||
message_data = values[0];
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
FieldPathEntry head = field_path.front();
|
||||
FieldPath tail = field_path;
|
||||
tail.erase(tail.begin());
|
||||
if (!head.extension_type.empty()) {
|
||||
MP_RETURN_IF_ERROR(FindExtension(*message_data, &head));
|
||||
MP_RETURN_IF_ERROR(FindExtension(message_data, &head));
|
||||
}
|
||||
if (tail.empty()) {
|
||||
MP_RETURN_IF_ERROR(SetFieldValue(head, value, message_data));
|
||||
} else {
|
||||
FieldData child;
|
||||
MP_RETURN_IF_ERROR(GetFieldValue(*message_data, head, &child));
|
||||
MP_RETURN_IF_ERROR(SetField(tail, value, &child));
|
||||
if (IsProtobufAny(head.field)) {
|
||||
child = SerializeProtobufAny(child);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(SetFieldValue(head, child, message_data));
|
||||
MP_RETURN_IF_ERROR(SetFieldValues(message_data, head, values));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
FieldData child;
|
||||
MP_RETURN_IF_ERROR(GetFieldValue(message_data, head, &child));
|
||||
MP_RETURN_IF_ERROR(SetFieldValues(child, tail, values));
|
||||
if (IsProtobufAny(head.field)) {
|
||||
child = SerializeProtobufAny(child);
|
||||
}
|
||||
MP_RETURN_IF_ERROR(SetFieldValue(message_data, head, child));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Merges a packet value into nested protobuf Message.
|
||||
absl::Status MergeField(const FieldPath& field_path, const FieldData& value,
|
||||
FieldData* message_data) {
|
||||
// Writes a FieldData value into protobuf field.
|
||||
absl::Status SetField(FieldData& message_data, const FieldPath& field_path,
|
||||
const FieldData& value) {
|
||||
return SetFieldValues(message_data, field_path, {value});
|
||||
}
|
||||
|
||||
// Merges FieldData values into nested protobuf Message.
|
||||
// For each new field index, any previous value is merged with the new value.
|
||||
absl::Status MergeFieldValues(FieldData& message_data,
|
||||
const FieldPath& field_path,
|
||||
const std::vector<FieldData>& values) {
|
||||
absl::Status status;
|
||||
FieldType field_type = field_path.empty()
|
||||
? FieldType::TYPE_MESSAGE
|
||||
: AsFieldType(field_path.back().field->type());
|
||||
std::string message_type =
|
||||
(value.has_message_value())
|
||||
? ParseTypeUrl(std::string(value.message_value().type_url()))
|
||||
: "";
|
||||
FieldData v = value;
|
||||
FieldType field_type = field_path.empty() ? FieldType::TYPE_MESSAGE
|
||||
: field_path.back().field->type();
|
||||
std::vector<FieldData> results = values;
|
||||
std::vector<FieldData> prevs;
|
||||
ASSIGN_OR_RETURN(prevs, GetFieldValues(message_data, field_path));
|
||||
if (field_type == FieldType::TYPE_MESSAGE) {
|
||||
FieldData b;
|
||||
status.Update(GetField(field_path, *message_data, &b));
|
||||
status.Update(MergeMessages(b, v, &v));
|
||||
for (int i = 0; i < std::min(values.size(), prevs.size()); ++i) {
|
||||
FieldData& v = results[i];
|
||||
FieldData& b = prevs[i];
|
||||
ASSIGN_OR_RETURN(v, MergeMessages(b, v));
|
||||
}
|
||||
}
|
||||
status.Update(SetField(field_path, v, message_data));
|
||||
status.Update(SetFieldValues(message_data, field_path, results));
|
||||
return status;
|
||||
}
|
||||
|
||||
|
@ -576,34 +491,35 @@ struct ProtoEnum {
|
|||
int32 value;
|
||||
};
|
||||
|
||||
absl::Status AsPacket(const FieldData& data, Packet* result) {
|
||||
absl::StatusOr<Packet> AsPacket(const FieldData& data) {
|
||||
Packet result;
|
||||
switch (data.value_case()) {
|
||||
case FieldData::ValueCase::kInt32Value:
|
||||
*result = MakePacket<int32>(data.int32_value());
|
||||
result = MakePacket<int32>(data.int32_value());
|
||||
break;
|
||||
case FieldData::ValueCase::kInt64Value:
|
||||
*result = MakePacket<int64>(data.int64_value());
|
||||
result = MakePacket<int64>(data.int64_value());
|
||||
break;
|
||||
case FieldData::ValueCase::kUint32Value:
|
||||
*result = MakePacket<uint32>(data.uint32_value());
|
||||
result = MakePacket<uint32>(data.uint32_value());
|
||||
break;
|
||||
case FieldData::ValueCase::kUint64Value:
|
||||
*result = MakePacket<uint64>(data.uint64_value());
|
||||
result = MakePacket<uint64>(data.uint64_value());
|
||||
break;
|
||||
case FieldData::ValueCase::kDoubleValue:
|
||||
*result = MakePacket<double>(data.double_value());
|
||||
result = MakePacket<double>(data.double_value());
|
||||
break;
|
||||
case FieldData::ValueCase::kFloatValue:
|
||||
*result = MakePacket<float>(data.float_value());
|
||||
result = MakePacket<float>(data.float_value());
|
||||
break;
|
||||
case FieldData::ValueCase::kBoolValue:
|
||||
*result = MakePacket<bool>(data.bool_value());
|
||||
result = MakePacket<bool>(data.bool_value());
|
||||
break;
|
||||
case FieldData::ValueCase::kEnumValue:
|
||||
*result = MakePacket<ProtoEnum>(data.enum_value());
|
||||
result = MakePacket<ProtoEnum>(data.enum_value());
|
||||
break;
|
||||
case FieldData::ValueCase::kStringValue:
|
||||
*result = MakePacket<std::string>(data.string_value());
|
||||
result = MakePacket<std::string>(data.string_value());
|
||||
break;
|
||||
case FieldData::ValueCase::kMessageValue: {
|
||||
auto r = packet_internal::PacketFromDynamicProto(
|
||||
|
@ -612,32 +528,33 @@ absl::Status AsPacket(const FieldData& data, Packet* result) {
|
|||
if (!r.ok()) {
|
||||
return r.status();
|
||||
}
|
||||
*result = r.value();
|
||||
result = r.value();
|
||||
break;
|
||||
}
|
||||
case FieldData::VALUE_NOT_SET:
|
||||
*result = Packet();
|
||||
result = Packet();
|
||||
}
|
||||
return absl::OkStatus();
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::Status AsFieldData(Packet packet, FieldData* result) {
|
||||
static const auto* kTypeIds = new std::map<size_t, int32>{
|
||||
{tool::GetTypeHash<int32>(), WireFormatLite::CPPTYPE_INT32},
|
||||
{tool::GetTypeHash<int64>(), WireFormatLite::CPPTYPE_INT64},
|
||||
{tool::GetTypeHash<uint32>(), WireFormatLite::CPPTYPE_UINT32},
|
||||
{tool::GetTypeHash<uint64>(), WireFormatLite::CPPTYPE_UINT64},
|
||||
{tool::GetTypeHash<double>(), WireFormatLite::CPPTYPE_DOUBLE},
|
||||
{tool::GetTypeHash<float>(), WireFormatLite::CPPTYPE_FLOAT},
|
||||
{tool::GetTypeHash<bool>(), WireFormatLite::CPPTYPE_BOOL},
|
||||
{tool::GetTypeHash<ProtoEnum>(), WireFormatLite::CPPTYPE_ENUM},
|
||||
{tool::GetTypeHash<std::string>(), WireFormatLite::CPPTYPE_STRING},
|
||||
absl::StatusOr<FieldData> AsFieldData(Packet packet) {
|
||||
static const auto* kTypeIds = new std::map<TypeId, int32>{
|
||||
{kTypeId<int32>, WireFormatLite::CPPTYPE_INT32},
|
||||
{kTypeId<int64>, WireFormatLite::CPPTYPE_INT64},
|
||||
{kTypeId<uint32>, WireFormatLite::CPPTYPE_UINT32},
|
||||
{kTypeId<uint64>, WireFormatLite::CPPTYPE_UINT64},
|
||||
{kTypeId<double>, WireFormatLite::CPPTYPE_DOUBLE},
|
||||
{kTypeId<float>, WireFormatLite::CPPTYPE_FLOAT},
|
||||
{kTypeId<bool>, WireFormatLite::CPPTYPE_BOOL},
|
||||
{kTypeId<ProtoEnum>, WireFormatLite::CPPTYPE_ENUM},
|
||||
{kTypeId<std::string>, WireFormatLite::CPPTYPE_STRING},
|
||||
};
|
||||
|
||||
FieldData result;
|
||||
if (packet.ValidateAsProtoMessageLite().ok()) {
|
||||
result->mutable_message_value()->set_value(
|
||||
result.mutable_message_value()->set_value(
|
||||
packet.GetProtoMessageLite().SerializeAsString());
|
||||
result->mutable_message_value()->set_type_url(
|
||||
result.mutable_message_value()->set_type_url(
|
||||
TypeUrl(packet.GetProtoMessageLite().GetTypeName()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
@ -649,48 +566,42 @@ absl::Status AsFieldData(Packet packet, FieldData* result) {
|
|||
|
||||
switch (kTypeIds->at(packet.GetTypeId())) {
|
||||
case WireFormatLite::CPPTYPE_INT32:
|
||||
result->set_int32_value(packet.Get<int32>());
|
||||
result.set_int32_value(packet.Get<int32>());
|
||||
break;
|
||||
case WireFormatLite::CPPTYPE_INT64:
|
||||
result->set_int64_value(packet.Get<int64>());
|
||||
result.set_int64_value(packet.Get<int64>());
|
||||
break;
|
||||
case WireFormatLite::CPPTYPE_UINT32:
|
||||
result->set_uint32_value(packet.Get<uint32>());
|
||||
result.set_uint32_value(packet.Get<uint32>());
|
||||
break;
|
||||
case WireFormatLite::CPPTYPE_UINT64:
|
||||
result->set_uint64_value(packet.Get<uint64>());
|
||||
result.set_uint64_value(packet.Get<uint64>());
|
||||
break;
|
||||
case WireFormatLite::CPPTYPE_DOUBLE:
|
||||
result->set_double_value(packet.Get<double>());
|
||||
result.set_double_value(packet.Get<double>());
|
||||
break;
|
||||
case WireFormatLite::CPPTYPE_FLOAT:
|
||||
result->set_float_value(packet.Get<float>());
|
||||
result.set_float_value(packet.Get<float>());
|
||||
break;
|
||||
case WireFormatLite::CPPTYPE_BOOL:
|
||||
result->set_bool_value(packet.Get<bool>());
|
||||
result.set_bool_value(packet.Get<bool>());
|
||||
break;
|
||||
case WireFormatLite::CPPTYPE_ENUM:
|
||||
result->set_enum_value(packet.Get<ProtoEnum>().value);
|
||||
result.set_enum_value(packet.Get<ProtoEnum>().value);
|
||||
break;
|
||||
case WireFormatLite::CPPTYPE_STRING:
|
||||
result->set_string_value(packet.Get<std::string>());
|
||||
result.set_string_value(packet.Get<std::string>());
|
||||
break;
|
||||
}
|
||||
return absl::OkStatus();
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string TypeUrl(absl::string_view type_name) {
|
||||
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/";
|
||||
return absl::StrCat(std::string(kTypeUrlPrefix), std::string(type_name));
|
||||
return ProtoUtilLite::TypeUrl(type_name);
|
||||
}
|
||||
|
||||
std::string ParseTypeUrl(absl::string_view type_url) {
|
||||
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/";
|
||||
if (std::string(type_url).rfind(kTypeUrlPrefix, 0) == 0) {
|
||||
return std::string(
|
||||
type_url.substr(kTypeUrlPrefix.length(), std::string::npos));
|
||||
}
|
||||
return std::string(type_url);
|
||||
return ProtoUtilLite::ParseTypeUrl(type_url);
|
||||
}
|
||||
|
||||
} // namespace options_field_util
|
||||
|
|
|
@ -34,30 +34,38 @@ absl::Status SetField(const FieldPath& field_path, const FieldData& value,
|
|||
FieldData* message_data);
|
||||
|
||||
// Reads a field value from a protobuf field.
|
||||
absl::Status GetField(const FieldPath& field_path,
|
||||
const FieldData& message_data, FieldData* result);
|
||||
absl::StatusOr<FieldData> GetField(const FieldData& message_data,
|
||||
const FieldPath& field_path);
|
||||
|
||||
// Merges a field value into nested protobuf Message.
|
||||
absl::Status MergeField(const FieldPath& field_path, const FieldData& value,
|
||||
FieldData* message_data);
|
||||
// Reads one or all FieldData values from a protobuf field.
|
||||
absl::StatusOr<std::vector<FieldData>> GetFieldValues(
|
||||
const FieldData& message_data, const FieldPath& field_path);
|
||||
|
||||
// Writes FieldData values into a protobuf field.
|
||||
absl::Status SetFieldValues(FieldData& message_data,
|
||||
const FieldPath& field_path,
|
||||
const std::vector<FieldData>& values);
|
||||
|
||||
// Merges FieldData values into a protobuf field.
|
||||
absl::Status MergeFieldValues(FieldData& message_data,
|
||||
const FieldPath& field_path,
|
||||
const std::vector<FieldData>& values);
|
||||
|
||||
// Deserializes a packet containing a MessageLite value.
|
||||
absl::Status ReadMessage(const std::string& value, const std::string& type_name,
|
||||
Packet* result);
|
||||
absl::StatusOr<Packet> ReadMessage(const std::string& value,
|
||||
const std::string& type_name);
|
||||
|
||||
// Merge two options protobuf field values.
|
||||
absl::Status MergeMessages(const FieldData& base, const FieldData& over,
|
||||
FieldData* result);
|
||||
absl::StatusOr<FieldData> MergeMessages(const FieldData& base,
|
||||
const FieldData& over);
|
||||
|
||||
// Returns the requested options protobuf for a graph.
|
||||
absl::Status GetNodeOptions(const FieldData& message_data,
|
||||
const std::string& extension_type,
|
||||
FieldData* result);
|
||||
absl::StatusOr<FieldData> GetNodeOptions(const FieldData& message_data,
|
||||
const std::string& extension_type);
|
||||
|
||||
// Returns the requested options protobuf for a graph node.
|
||||
absl::Status GetGraphOptions(const FieldData& message_data,
|
||||
const std::string& extension_type,
|
||||
FieldData* result);
|
||||
absl::StatusOr<FieldData> GetGraphOptions(const FieldData& message_data,
|
||||
const std::string& extension_type);
|
||||
|
||||
// Sets the node_options field in a Node, and clears the options field.
|
||||
void SetOptionsMessage(const FieldData& node_options,
|
||||
|
@ -67,10 +75,10 @@ void SetOptionsMessage(const FieldData& node_options,
|
|||
FieldData AsFieldData(const proto_ns::MessageLite& message);
|
||||
|
||||
// Constructs a Packet for a FieldData proto.
|
||||
absl::Status AsPacket(const FieldData& data, Packet* result);
|
||||
absl::StatusOr<Packet> AsPacket(const FieldData& data);
|
||||
|
||||
// Constructs a FieldData proto for a Packet.
|
||||
absl::Status AsFieldData(Packet packet, FieldData* result);
|
||||
absl::StatusOr<FieldData> AsFieldData(Packet packet);
|
||||
|
||||
// Returns the protobuf type-url for a protobuf type-name.
|
||||
std::string TypeUrl(absl::string_view type_name);
|
||||
|
|
|
@ -25,11 +25,12 @@ constexpr char kDescriptorContents[] =
|
|||
#include "{{DESCRIPTOR_INC_FILE_PATH}}"
|
||||
; // NOLINT(whitespace/semicolon)
|
||||
|
||||
mediapipe::proto_ns::FileDescriptorSet ParseFileDescriptorSet(
|
||||
const std::string& pb) {
|
||||
mediapipe::proto_ns::FileDescriptorSet files;
|
||||
files.ParseFromString(pb);
|
||||
return files;
|
||||
mediapipe::FieldData ReadFileDescriptorSet(const std::string& pb) {
|
||||
mediapipe::FieldData result;
|
||||
*result.mutable_message_value()->mutable_type_url() =
|
||||
"proto2.FileDescriptorSet";
|
||||
*result.mutable_message_value()->mutable_value() = pb;
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -39,6 +40,6 @@ namespace mediapipe {
|
|||
template <>
|
||||
const RegistrationToken tool::OptionsRegistry::registration_token<
|
||||
MP_OPTION_TYPE_NS::MP_OPTION_TYPE_NAME> =
|
||||
tool::OptionsRegistry::Register(ParseFileDescriptorSet(
|
||||
tool::OptionsRegistry::Register(ReadFileDescriptorSet(
|
||||
std::string(kDescriptorContents, sizeof(kDescriptorContents) - 1)));
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -30,15 +30,26 @@ struct IsExtension {
|
|||
|
||||
template <class T,
|
||||
typename std::enable_if<IsExtension<T>::value, int>::type = 0>
|
||||
void GetExtension(const CalculatorOptions& options, T* result) {
|
||||
T* GetExtension(CalculatorOptions& options) {
|
||||
if (options.HasExtension(T::ext)) {
|
||||
*result = options.GetExtension(T::ext);
|
||||
return options.MutableExtension(T::ext);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class T,
|
||||
typename std::enable_if<!IsExtension<T>::value, int>::type = 0>
|
||||
void GetExtension(const CalculatorOptions& options, T* result) {}
|
||||
T* GetExtension(const CalculatorOptions& options) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void GetExtension(const CalculatorOptions& options, T* result) {
|
||||
T* r = GetExtension<T>(*const_cast<CalculatorOptions*>(&options));
|
||||
if (r) {
|
||||
*result = *r;
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void GetNodeOptions(const CalculatorGraphConfig::Node& node_config, T* result) {
|
||||
|
@ -53,23 +64,39 @@ void GetNodeOptions(const CalculatorGraphConfig::Node& node_config, T* result) {
|
|||
#endif
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void SetNodeOptions(CalculatorGraphConfig::Node& node_config, const T& value) {
|
||||
#if defined(MEDIAPIPE_PROTO_LITE) && defined(MEDIAPIPE_PROTO_THIRD_PARTY)
|
||||
// protobuf::Any is unavailable with third_party/protobuf:protobuf-lite.
|
||||
#else
|
||||
for (mediapipe::protobuf::Any& options :
|
||||
*node_config.mutable_node_options()) {
|
||||
if (options.Is<T>()) {
|
||||
options.PackFrom(value);
|
||||
return;
|
||||
}
|
||||
}
|
||||
node_config.add_node_options()->PackFrom(value);
|
||||
#endif
|
||||
}
|
||||
|
||||
// A map from object type to object.
|
||||
class TypeMap {
|
||||
public:
|
||||
template <class T>
|
||||
bool Has() const {
|
||||
return content_.count(TypeInfo::Get<T>()) > 0;
|
||||
return content_.count(kTypeId<T>) > 0;
|
||||
}
|
||||
template <class T>
|
||||
T* Get() const {
|
||||
if (!Has<T>()) {
|
||||
content_[TypeInfo::Get<T>()] = std::make_shared<T>();
|
||||
content_[kTypeId<T>] = std::make_shared<T>();
|
||||
}
|
||||
return static_cast<T*>(content_[TypeInfo::Get<T>()].get());
|
||||
return static_cast<T*>(content_[kTypeId<T>].get());
|
||||
}
|
||||
|
||||
private:
|
||||
mutable std::map<TypeIndex, std::shared_ptr<void>> content_;
|
||||
mutable std::map<TypeId, std::shared_ptr<void>> content_;
|
||||
};
|
||||
|
||||
// Extracts the options message of a specified type from a
|
||||
|
@ -77,7 +104,7 @@ class TypeMap {
|
|||
class OptionsMap {
|
||||
public:
|
||||
OptionsMap& Initialize(const CalculatorGraphConfig::Node& node_config) {
|
||||
node_config_ = &node_config;
|
||||
node_config_ = const_cast<CalculatorGraphConfig::Node*>(&node_config);
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
@ -97,10 +124,40 @@ class OptionsMap {
|
|||
return *result;
|
||||
}
|
||||
|
||||
const CalculatorGraphConfig::Node* node_config_;
|
||||
CalculatorGraphConfig::Node* node_config_;
|
||||
TypeMap options_;
|
||||
};
|
||||
|
||||
class MutableOptionsMap : public OptionsMap {
|
||||
public:
|
||||
MutableOptionsMap& Initialize(CalculatorGraphConfig::Node& node_config) {
|
||||
node_config_ = &node_config;
|
||||
return *this;
|
||||
}
|
||||
template <class T>
|
||||
void Set(const T& value) const {
|
||||
*options_.Get<T>() = value;
|
||||
if (node_config_->has_options()) {
|
||||
*GetExtension<T>(*node_config_->mutable_options()) = value;
|
||||
} else {
|
||||
SetNodeOptions(*node_config_, value);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T* GetMutable() const {
|
||||
if (options_.Has<T>()) {
|
||||
return options_.Get<T>();
|
||||
}
|
||||
if (node_config_->has_options()) {
|
||||
return GetExtension<T>(*node_config_->mutable_options());
|
||||
}
|
||||
T* result = options_.Get<T>();
|
||||
GetNodeOptions(*node_config_, result);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tool
|
||||
} // namespace mediapipe
|
||||
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
#include "mediapipe/framework/tool/options_registry.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/tool/proto_util_lite.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tool {
|
||||
|
@ -9,37 +14,135 @@ namespace {
|
|||
|
||||
// Returns a canonical message type name, with any leading "." removed.
|
||||
std::string CanonicalTypeName(const std::string& type_name) {
|
||||
return (type_name.rfind('.', 0) == 0) ? type_name.substr(1) : type_name;
|
||||
return (absl::StartsWith(type_name, ".")) ? type_name.substr(1) : type_name;
|
||||
}
|
||||
|
||||
// Returns the values from a protobuf field as typed FieldData.
|
||||
absl::StatusOr<std::vector<FieldData>> GetFieldValues(
|
||||
const FieldData& message_data, std::string field_name) {
|
||||
std::string type_name =
|
||||
ProtoUtilLite::ParseTypeUrl(message_data.message_value().type_url());
|
||||
const Descriptor* descriptor =
|
||||
OptionsRegistry::GetProtobufDescriptor(type_name);
|
||||
RET_CHECK_NE(descriptor, nullptr);
|
||||
const FieldDescriptor* field = descriptor->FindFieldByName(field_name);
|
||||
if (field == nullptr) {
|
||||
return std::vector<FieldData>();
|
||||
}
|
||||
ProtoUtilLite::ProtoPath proto_path = {{field->number(), 0}};
|
||||
ProtoUtilLite::FieldValue mesage_bytes = message_data.message_value().value();
|
||||
int count;
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldCount(mesage_bytes, proto_path,
|
||||
field->type(), &count));
|
||||
std::vector<std::string> field_values;
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::GetFieldRange(
|
||||
mesage_bytes, proto_path, count, field->type(), &field_values));
|
||||
std::vector<FieldData> result;
|
||||
for (int i = 0; i < field_values.size(); ++i) {
|
||||
FieldData r;
|
||||
std::string message_type =
|
||||
field->message_type() ? field->message_type()->full_name() : "";
|
||||
MP_RETURN_IF_ERROR(ProtoUtilLite::ReadValue(field_values[i], field->type(),
|
||||
message_type, &r));
|
||||
result.push_back(std::move(r));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns a single value from a protobuf string field.
|
||||
std::string GetFieldString(const FieldData& message_data,
|
||||
std::string field_name) {
|
||||
auto values = GetFieldValues(message_data, field_name);
|
||||
if (!values->empty()) {
|
||||
return values->front().string_value();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
// Registers the descriptors for the descriptor protobufs. These four
|
||||
// descriptors are required to deserialize descriptors for other protobufs.
|
||||
// This implementation avoids a code size problem introduced by
|
||||
// proto_ns::DescriptorProto.
|
||||
void RegisterDescriptorProtos(
|
||||
absl::flat_hash_map<std::string, Descriptor>& result) {
|
||||
std::vector<Descriptor> descriptors = {
|
||||
{"proto2.FileDescriptorSet",
|
||||
{
|
||||
{"file", 1, FieldType::TYPE_MESSAGE, "proto2.FileDescriptorProto"},
|
||||
}},
|
||||
{"proto2.FileDescriptorProto",
|
||||
{
|
||||
{"package", 2, FieldType::TYPE_STRING, ""},
|
||||
{"message_type", 4, FieldType::TYPE_MESSAGE,
|
||||
"proto2.DescriptorProto"},
|
||||
}},
|
||||
{"proto2.DescriptorProto",
|
||||
{
|
||||
{"name", 1, FieldType::TYPE_STRING, ""},
|
||||
{"field", 2, FieldType::TYPE_MESSAGE, "proto2.FieldDescriptorProto"},
|
||||
{"extension", 6, FieldType::TYPE_MESSAGE,
|
||||
"proto2.FieldDescriptorProto"},
|
||||
{"nested_type", 3, FieldType::TYPE_MESSAGE,
|
||||
"proto2.DescriptorProto"},
|
||||
}},
|
||||
{"proto2.FieldDescriptorProto",
|
||||
{
|
||||
{"name", 1, FieldType::TYPE_STRING, ""},
|
||||
{"number", 3, FieldType::TYPE_INT32, ""},
|
||||
{"type", 5, FieldType::TYPE_ENUM, ""},
|
||||
{"type_name", 6, FieldType::TYPE_STRING, ""},
|
||||
{"extendee", 2, FieldType::TYPE_STRING, ""},
|
||||
}},
|
||||
};
|
||||
for (const auto& descriptor : descriptors) {
|
||||
result[descriptor.full_name()] = descriptor;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
RegistrationToken OptionsRegistry::Register(
|
||||
const proto_ns::FileDescriptorSet& files) {
|
||||
absl::MutexLock lock(&mutex());
|
||||
for (auto& file : files.file()) {
|
||||
for (auto& message_type : file.message_type()) {
|
||||
Register(message_type, file.package());
|
||||
const FieldData& file_descriptor_set) {
|
||||
auto files = GetFieldValues(file_descriptor_set, "file");
|
||||
for (auto& file : *files) {
|
||||
std::string package_name = GetFieldString(file, "package");
|
||||
auto message_types = GetFieldValues(file, "message_type");
|
||||
for (auto& message_type : *message_types) {
|
||||
Register(message_type, package_name);
|
||||
}
|
||||
}
|
||||
return RegistrationToken([]() {});
|
||||
}
|
||||
|
||||
void OptionsRegistry::Register(const proto_ns::DescriptorProto& message_type,
|
||||
void OptionsRegistry::Register(const FieldData& message_type,
|
||||
const std::string& parent_name) {
|
||||
auto full_name = absl::StrCat(parent_name, ".", message_type.name());
|
||||
descriptors()[full_name] = Descriptor(message_type, full_name);
|
||||
for (auto& nested : message_type.nested_type()) {
|
||||
std::string name = GetFieldString(message_type, "name");
|
||||
std::string full_name = absl::StrCat(parent_name, ".", name);
|
||||
Descriptor descriptor(full_name, message_type);
|
||||
{
|
||||
absl::MutexLock lock(&mutex());
|
||||
descriptors()[full_name] = descriptor;
|
||||
}
|
||||
auto nested_types = GetFieldValues(message_type, "nested_type");
|
||||
for (auto& nested : *nested_types) {
|
||||
Register(nested, full_name);
|
||||
}
|
||||
for (auto& extension : message_type.extension()) {
|
||||
extensions()[CanonicalTypeName(extension.extendee())].push_back(
|
||||
FieldDescriptor(extension));
|
||||
auto exts = GetFieldValues(message_type, "extension");
|
||||
for (auto& extension : *exts) {
|
||||
FieldDescriptor field(extension);
|
||||
std::string extendee = GetFieldString(extension, "extendee");
|
||||
{
|
||||
absl::MutexLock lock(&mutex());
|
||||
extensions()[CanonicalTypeName(extendee)].push_back(field);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Descriptor* OptionsRegistry::GetProtobufDescriptor(
|
||||
const std::string& type_name) {
|
||||
if (descriptors().count("proto2.DescriptorProto") == 0) {
|
||||
RegisterDescriptorProtos(descriptors());
|
||||
}
|
||||
absl::ReaderMutexLock lock(&mutex());
|
||||
auto it = descriptors().find(CanonicalTypeName(type_name));
|
||||
return (it == descriptors().end()) ? nullptr : &it->second;
|
||||
|
@ -73,11 +176,21 @@ absl::Mutex& OptionsRegistry::mutex() {
|
|||
return *mutex;
|
||||
}
|
||||
|
||||
Descriptor::Descriptor(const proto_ns::DescriptorProto& proto,
|
||||
const std::string& full_name)
|
||||
Descriptor::Descriptor(const std::string& full_name,
|
||||
const FieldData& descriptor_proto)
|
||||
: full_name_(full_name) {
|
||||
for (auto& field : proto.field()) {
|
||||
fields_[field.name()] = FieldDescriptor(field);
|
||||
auto fields = GetFieldValues(descriptor_proto, "field");
|
||||
for (const auto& field : *fields) {
|
||||
FieldDescriptor f(field);
|
||||
fields_[f.name()] = f;
|
||||
}
|
||||
}
|
||||
|
||||
Descriptor::Descriptor(const std::string& full_name,
|
||||
const std::vector<FieldDescriptor>& fields)
|
||||
: full_name_(full_name) {
|
||||
for (const auto& field : fields) {
|
||||
fields_[field.name()] = field;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -89,20 +202,22 @@ const FieldDescriptor* Descriptor::FindFieldByName(
|
|||
return (it != fields_.end()) ? &it->second : nullptr;
|
||||
}
|
||||
|
||||
FieldDescriptor::FieldDescriptor(const proto_ns::FieldDescriptorProto& proto) {
|
||||
name_ = proto.name();
|
||||
message_type_ = CanonicalTypeName(proto.type_name());
|
||||
type_ = proto.type();
|
||||
number_ = proto.number();
|
||||
FieldDescriptor::FieldDescriptor(const FieldData& field_proto) {
|
||||
name_ = GetFieldString(field_proto, "name");
|
||||
number_ = GetFieldValues(field_proto, "number")->front().int32_value();
|
||||
type_ = (FieldType)GetFieldValues(field_proto, "type")->front().enum_value();
|
||||
message_type_ = CanonicalTypeName(GetFieldString(field_proto, "type_name"));
|
||||
}
|
||||
|
||||
FieldDescriptor::FieldDescriptor(std::string name, int number, FieldType type,
|
||||
std::string message_type)
|
||||
: name_(name), number_(number), type_(type), message_type_(message_type) {}
|
||||
|
||||
const std::string& FieldDescriptor::name() const { return name_; }
|
||||
|
||||
int FieldDescriptor::number() const { return number_; }
|
||||
|
||||
proto_ns::FieldDescriptorProto::Type FieldDescriptor::type() const {
|
||||
return type_;
|
||||
}
|
||||
FieldType FieldDescriptor::type() const { return type_; }
|
||||
|
||||
const Descriptor* FieldDescriptor::message_type() const {
|
||||
return OptionsRegistry::GetProtobufDescriptor(message_type_);
|
||||
|
|
|
@ -1,15 +1,20 @@
|
|||
#ifndef MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_
|
||||
#define MEDIAPIPE_FRAMEWORK_TOOL_OPTIONS_REGISTRY_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "mediapipe/framework/deps/registration.h"
|
||||
#include "mediapipe/framework/port/advanced_proto_inc.h"
|
||||
#include "mediapipe/framework/tool/field_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tool {
|
||||
|
||||
class Descriptor;
|
||||
class FieldDescriptor;
|
||||
using FieldType = mediapipe::proto_ns::internal::WireFormatLite::FieldType;
|
||||
using mediapipe::FieldData;
|
||||
|
||||
// A static registry that stores descriptors for protobufs used in MediaPipe
|
||||
// calculator options. Lite-proto builds do not normally include descriptors.
|
||||
|
@ -17,8 +22,8 @@ class FieldDescriptor;
|
|||
// referenced and specified separately within CalculatorGraphConfigs.
|
||||
class OptionsRegistry {
|
||||
public:
|
||||
// Registers the protobuf descriptors for a MessageLite.
|
||||
static RegistrationToken Register(const proto_ns::FileDescriptorSet& files);
|
||||
// Registers the protobuf descriptors for a FileDescriptorSet.
|
||||
static RegistrationToken Register(const FieldData& file_descriptor_set);
|
||||
|
||||
// Finds the descriptor for a protobuf.
|
||||
static const Descriptor* GetProtobufDescriptor(const std::string& type_name);
|
||||
|
@ -28,8 +33,8 @@ class OptionsRegistry {
|
|||
std::vector<const FieldDescriptor*>* result);
|
||||
|
||||
private:
|
||||
// Registers protobuf descriptors a MessageLite and nested types.
|
||||
static void Register(const proto_ns::DescriptorProto& message_type,
|
||||
// Registers protobuf descriptors for a message type and nested types.
|
||||
static void Register(const FieldData& message_type,
|
||||
const std::string& parent_name);
|
||||
|
||||
static absl::flat_hash_map<std::string, Descriptor>& descriptors();
|
||||
|
@ -46,9 +51,10 @@ class OptionsRegistry {
|
|||
// avoids a code size problem introduced by proto_ns::FieldDescriptor.
|
||||
class Descriptor {
|
||||
public:
|
||||
Descriptor() {}
|
||||
Descriptor(const proto_ns::DescriptorProto& proto,
|
||||
const std::string& full_name);
|
||||
Descriptor() = default;
|
||||
Descriptor(const std::string& full_name, const FieldData& descriptor_proto);
|
||||
Descriptor(const std::string& full_name,
|
||||
const std::vector<FieldDescriptor>& fields);
|
||||
const std::string& full_name() const;
|
||||
const FieldDescriptor* FindFieldByName(const std::string& name) const;
|
||||
|
||||
|
@ -61,18 +67,20 @@ class Descriptor {
|
|||
// avoids a code size problem introduced by proto_ns::FieldDescriptor.
|
||||
class FieldDescriptor {
|
||||
public:
|
||||
FieldDescriptor() {}
|
||||
FieldDescriptor(const proto_ns::FieldDescriptorProto& proto);
|
||||
FieldDescriptor() = default;
|
||||
FieldDescriptor(const FieldData& field_proto);
|
||||
FieldDescriptor(std::string name, int number, FieldType type,
|
||||
std::string message_type);
|
||||
const std::string& name() const;
|
||||
int number() const;
|
||||
proto_ns::FieldDescriptorProto::Type type() const;
|
||||
FieldType type() const;
|
||||
const Descriptor* message_type() const;
|
||||
|
||||
private:
|
||||
std::string name_;
|
||||
std::string message_type_;
|
||||
proto_ns::FieldDescriptorProto::Type type_;
|
||||
int number_;
|
||||
FieldType type_;
|
||||
std::string message_type_;
|
||||
};
|
||||
|
||||
} // namespace tool
|
||||
|
|
|
@ -91,8 +91,7 @@ class OptionsSyntaxUtil::OptionsSyntaxHelper {
|
|||
int index;
|
||||
if (absl::SimpleAtoi(option_name, &index)) {
|
||||
result.back().index = index;
|
||||
}
|
||||
if (!ExtensionType(option_name).empty()) {
|
||||
} else if (!ExtensionType(option_name).empty()) {
|
||||
std::string extension_type = std::string(ExtensionType(option_name));
|
||||
result.push_back({nullptr, 0, extension_type});
|
||||
descriptor = OptionsRegistry::GetProtobufDescriptor(extension_type);
|
||||
|
@ -102,7 +101,7 @@ class OptionsSyntaxUtil::OptionsSyntaxHelper {
|
|||
}
|
||||
auto field = descriptor->FindFieldByName(std::string(option_name));
|
||||
descriptor = field ? field->message_type() : nullptr;
|
||||
result.push_back({std::move(field), 0});
|
||||
result.push_back({std::move(field), -1});
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
|
|
@ -26,10 +26,9 @@ namespace mediapipe {
|
|||
namespace tool {
|
||||
|
||||
using options_field_util::FieldPath;
|
||||
using options_field_util::GetField;
|
||||
using options_field_util::GetGraphOptions;
|
||||
using options_field_util::GetNodeOptions;
|
||||
using options_field_util::MergeField;
|
||||
using options_field_util::MergeFieldValues;
|
||||
using options_field_util::MergeMessages;
|
||||
|
||||
// Returns the type for the root options message if specified.
|
||||
|
@ -56,10 +55,19 @@ std::string MessageType(FieldData message) {
|
|||
std::string(message.message_value().type_url()));
|
||||
}
|
||||
|
||||
// Assigns the value from a StatusOr if avialable.
|
||||
#define ASSIGN_IF_OK(lhs, rexpr) \
|
||||
{ \
|
||||
auto statusor = (rexpr); \
|
||||
if (statusor.ok()) { \
|
||||
lhs = statusor.value(); \
|
||||
} \
|
||||
}
|
||||
|
||||
// Copy literal options from graph_options to node_options.
|
||||
absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node,
|
||||
CalculatorGraphConfig* config) {
|
||||
Status status;
|
||||
absl::Status status;
|
||||
FieldData graph_data = options_field_util::AsFieldData(*config);
|
||||
FieldData parent_data = options_field_util::AsFieldData(parent_node);
|
||||
|
||||
|
@ -75,25 +83,26 @@ absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node,
|
|||
std::string node_tag = syntax_util.OptionFieldsTag(tag_and_name[0]);
|
||||
std::string node_extension_type = ExtensionType(node_tag);
|
||||
FieldData graph_options;
|
||||
GetGraphOptions(graph_data, graph_extension_type, &graph_options)
|
||||
.IgnoreError();
|
||||
ASSIGN_IF_OK(graph_options,
|
||||
GetGraphOptions(graph_data, graph_extension_type));
|
||||
FieldData parent_options;
|
||||
GetNodeOptions(parent_data, graph_extension_type, &parent_options)
|
||||
.IgnoreError();
|
||||
status.Update(
|
||||
MergeMessages(graph_options, parent_options, &graph_options));
|
||||
ASSIGN_IF_OK(parent_options,
|
||||
GetNodeOptions(parent_data, graph_extension_type));
|
||||
ASSIGN_OR_RETURN(graph_options,
|
||||
MergeMessages(graph_options, parent_options));
|
||||
FieldData node_options;
|
||||
status.Update(
|
||||
GetNodeOptions(node_data, node_extension_type, &node_options));
|
||||
ASSIGN_OR_RETURN(node_options,
|
||||
GetNodeOptions(node_data, node_extension_type));
|
||||
if (!node_options.has_message_value() ||
|
||||
!graph_options.has_message_value()) {
|
||||
continue;
|
||||
}
|
||||
FieldPath graph_path = GetPath(graph_tag, MessageType(graph_options));
|
||||
FieldPath node_path = GetPath(node_tag, MessageType(node_options));
|
||||
FieldData packet_data;
|
||||
status.Update(GetField(graph_path, graph_options, &packet_data));
|
||||
status.Update(MergeField(node_path, packet_data, &node_options));
|
||||
std::vector<FieldData> packet_data;
|
||||
ASSIGN_OR_RETURN(packet_data, GetFieldValues(graph_options, graph_path));
|
||||
MP_RETURN_IF_ERROR(
|
||||
MergeFieldValues(node_options, node_path, packet_data));
|
||||
options_field_util::SetOptionsMessage(node_options, &node);
|
||||
}
|
||||
node.clear_option_value();
|
||||
|
@ -105,7 +114,7 @@ absl::Status CopyLiteralOptions(CalculatorGraphConfig::Node parent_node,
|
|||
absl::Status DefineGraphOptions(const CalculatorGraphConfig::Node& parent_node,
|
||||
CalculatorGraphConfig* config) {
|
||||
MP_RETURN_IF_ERROR(CopyLiteralOptions(parent_node, config));
|
||||
return mediapipe::OkStatus();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace tool
|
||||
|
|
|
@ -13,8 +13,10 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mediapipe/framework/calculator_framework.h"
|
||||
#include "mediapipe/framework/port/gmock.h"
|
||||
|
@ -30,23 +32,27 @@
|
|||
namespace mediapipe {
|
||||
namespace {
|
||||
|
||||
using ::mediapipe::proto_ns::FieldDescriptorProto;
|
||||
using FieldType = ::mediapipe::proto_ns::FieldDescriptorProto::Type;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
// Assigns the value from a StatusOr if avialable.
|
||||
#define ASSERT_AND_ASSIGN(lhs, rexpr) \
|
||||
{ \
|
||||
auto statusor = (rexpr); \
|
||||
MP_ASSERT_OK(statusor); \
|
||||
lhs = statusor.value(); \
|
||||
}
|
||||
|
||||
// A test Calculator using DeclareOptions and DefineOptions.
|
||||
class NightLightCalculator : public CalculatorBase {
|
||||
public:
|
||||
static absl::Status GetContract(CalculatorContract* cc) {
|
||||
return mediapipe::OkStatus();
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Open(CalculatorContext* cc) final {
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
absl::Status Open(CalculatorContext* cc) final { return absl::OkStatus(); }
|
||||
|
||||
absl::Status Process(CalculatorContext* cc) final {
|
||||
return mediapipe::OkStatus();
|
||||
}
|
||||
absl::Status Process(CalculatorContext* cc) final { return absl::OkStatus(); }
|
||||
|
||||
private:
|
||||
NightLightCalculatorOptions options_;
|
||||
|
@ -124,7 +130,7 @@ TEST_F(OptionsUtilTest, CopyLiteralOptions) {
|
|||
|
||||
CalculatorGraph graph;
|
||||
graph_config.set_num_threads(4);
|
||||
MP_EXPECT_OK(graph.Initialize({subgraph_config, graph_config}, {}, {}));
|
||||
MP_ASSERT_OK(graph.Initialize({subgraph_config, graph_config}, {}, {}));
|
||||
|
||||
CalculatorGraphConfig expanded_config = graph.Config();
|
||||
expanded_config.clear_executor();
|
||||
|
@ -236,8 +242,8 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) {
|
|||
tool::options_field_util::FieldPath field_path =
|
||||
syntax_util.OptionFieldPath(split[1], descriptor);
|
||||
EXPECT_EQ(field_path.size(), 2);
|
||||
EXPECT_TRUE(Equals(field_path[0], "sub_options", 0, ""));
|
||||
EXPECT_TRUE(Equals(field_path[1], "num_lights", 0, ""));
|
||||
EXPECT_TRUE(Equals(field_path[0], "sub_options", -1, ""));
|
||||
EXPECT_TRUE(Equals(field_path[1], "num_lights", -1, ""));
|
||||
|
||||
{
|
||||
// NightLightCalculatorOptions in Node.options.
|
||||
|
@ -252,11 +258,11 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) {
|
|||
auto path = field_path;
|
||||
std::string node_extension_type = ExtensionType(std::string(split[1]));
|
||||
FieldData node_options;
|
||||
MP_EXPECT_OK(tool::options_field_util::GetNodeOptions(
|
||||
node_data, node_extension_type, &node_options));
|
||||
ASSERT_AND_ASSIGN(node_options, tool::options_field_util::GetNodeOptions(
|
||||
node_data, node_extension_type));
|
||||
FieldData packet_data;
|
||||
MP_EXPECT_OK(tool::options_field_util::GetField(field_path, node_options,
|
||||
&packet_data));
|
||||
ASSERT_AND_ASSIGN(packet_data, tool::options_field_util::GetField(
|
||||
node_options, field_path));
|
||||
EXPECT_EQ(packet_data.value_case(), FieldData::kInt32Value);
|
||||
EXPECT_EQ(packet_data.int32_value(), 33);
|
||||
}
|
||||
|
@ -273,11 +279,11 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) {
|
|||
auto path = field_path;
|
||||
std::string node_extension_type = ExtensionType(std::string(split[1]));
|
||||
FieldData node_options;
|
||||
MP_EXPECT_OK(tool::options_field_util::GetNodeOptions(
|
||||
node_data, node_extension_type, &node_options));
|
||||
ASSERT_AND_ASSIGN(node_options, tool::options_field_util::GetNodeOptions(
|
||||
node_data, node_extension_type));
|
||||
FieldData packet_data;
|
||||
MP_EXPECT_OK(tool::options_field_util::GetField(field_path, node_options,
|
||||
&packet_data));
|
||||
ASSERT_AND_ASSIGN(packet_data, tool::options_field_util::GetField(
|
||||
node_options, field_path));
|
||||
EXPECT_EQ(packet_data.value_case(), FieldData::kInt32Value);
|
||||
EXPECT_EQ(packet_data.int32_value(), 33);
|
||||
}
|
||||
|
@ -285,5 +291,333 @@ TEST_F(OptionsUtilTest, FindOptionsMessage) {
|
|||
// TODO: Test with specified extension_type.
|
||||
}
|
||||
|
||||
// Constructs the field path for a string of field names.
|
||||
FieldPath MakeFieldPath(std::string tag, FieldData message_data) {
|
||||
tool::OptionsSyntaxUtil syntax_util;
|
||||
const tool::Descriptor* descriptor =
|
||||
tool::OptionsRegistry::GetProtobufDescriptor(
|
||||
tool::options_field_util::ParseTypeUrl(
|
||||
message_data.message_value().type_url()));
|
||||
return syntax_util.OptionFieldPath(tag, descriptor);
|
||||
}
|
||||
|
||||
// Returns the field path addressing the entire specified field.
|
||||
FieldPath EntireField(FieldPath field_path) {
|
||||
field_path.back().index = -1;
|
||||
return field_path;
|
||||
}
|
||||
|
||||
// Converts an int to a FieldData record.
|
||||
FieldData AsFieldData(int v) {
|
||||
return tool::options_field_util::AsFieldData(MakePacket<int>(v)).value();
|
||||
}
|
||||
|
||||
// Equality comparison for field contents.
|
||||
template <typename T>
|
||||
absl::Status Equals(const T& v1, const T& v2) {
|
||||
RET_CHECK_EQ(v1, v2);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Equality comparison for protobuf field contents.
|
||||
// The generic Equals() fails because MessageLite lacks operator==().
|
||||
// The protobuf comparison is performed using testing::EqualsProto.
|
||||
using LightBundle = NightLightCalculatorOptions::LightBundle;
|
||||
template <>
|
||||
absl::Status Equals<LightBundle>(const LightBundle& v1, const LightBundle& v2) {
|
||||
std::string s_1, s_2;
|
||||
v1.SerializeToString(&s_1);
|
||||
v2.SerializeToString(&s_2);
|
||||
RET_CHECK(s_1 == s_2);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Equality comparison for FieldData vectors.
|
||||
template <typename FieldType>
|
||||
absl::Status Equals(std::vector<FieldData> b1, std::vector<FieldData> b2) {
|
||||
using tool::options_field_util::AsPacket;
|
||||
RET_CHECK_EQ(b1.size(), b2.size());
|
||||
for (int i = 0; i < b1.size(); ++i) {
|
||||
ASSIGN_OR_RETURN(Packet p1, AsPacket(b1.at(i)));
|
||||
ASSIGN_OR_RETURN(Packet p2, AsPacket(b2.at(i)));
|
||||
MP_RETURN_IF_ERROR(Equals(p1.Get<FieldType>(), p2.Get<FieldType>()));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
// Unit-tests for graph options feild accessors from options_field_util.
|
||||
class OptionsFieldUtilTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
// Tests empty FieldPaths applied to empty options.
|
||||
TEST_F(OptionsFieldUtilTest, EmptyFieldPaths) {
|
||||
FieldData graph_options;
|
||||
FieldData node_options;
|
||||
FieldPath graph_path;
|
||||
FieldPath node_path;
|
||||
std::vector<FieldData> packet_data;
|
||||
ASSERT_AND_ASSIGN(packet_data, GetFieldValues(graph_options, graph_path));
|
||||
MP_EXPECT_OK(MergeFieldValues(node_options, node_path, packet_data));
|
||||
}
|
||||
|
||||
// Tests GetFieldValues applied to an int field.
|
||||
TEST_F(OptionsFieldUtilTest, GetFieldValuesInt) {
|
||||
NightLightCalculatorOptions node_proto;
|
||||
node_proto.mutable_sub_options();
|
||||
node_proto.mutable_sub_options()->add_num_lights(33);
|
||||
node_proto.mutable_sub_options()->add_num_lights(44);
|
||||
FieldData node_data = tool::options_field_util::AsFieldData(node_proto);
|
||||
|
||||
// Read an entire populated repeated field.
|
||||
FieldPath path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data);
|
||||
MP_EXPECT_OK(Equals<int>(GetFieldValues(node_data, path).value(),
|
||||
{AsFieldData(33), AsFieldData(44)}));
|
||||
|
||||
// Read a specific populated repeated field index.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
|
||||
MP_EXPECT_OK(
|
||||
Equals<int>(GetFieldValues(node_data, path).value(), {AsFieldData(44)}));
|
||||
}
|
||||
|
||||
// Tests GetFieldValues applied to a protobuf field.
|
||||
TEST_F(OptionsFieldUtilTest, GetFieldValuesProtobuf) {
|
||||
using tool::options_field_util::AsFieldData;
|
||||
using LightBundle = NightLightCalculatorOptions::LightBundle;
|
||||
NightLightCalculatorOptions node_proto;
|
||||
node_proto.mutable_sub_options();
|
||||
node_proto.mutable_sub_options()->add_bundle();
|
||||
*node_proto.mutable_sub_options()->mutable_bundle(0)->mutable_room_id() =
|
||||
"111";
|
||||
node_proto.mutable_sub_options()
|
||||
->mutable_bundle(0)
|
||||
->add_room_lights()
|
||||
->set_frame_rate(11.1);
|
||||
node_proto.mutable_sub_options()
|
||||
->mutable_bundle(0)
|
||||
->add_room_lights()
|
||||
->set_frame_rate(22.1);
|
||||
FieldData node_data = AsFieldData(node_proto);
|
||||
|
||||
// Read all values from a repeated protobuf field.
|
||||
LightBundle expected_proto;
|
||||
*expected_proto.mutable_room_id() = "111";
|
||||
expected_proto.add_room_lights()->set_frame_rate(11.1);
|
||||
expected_proto.add_room_lights()->set_frame_rate(22.1);
|
||||
FieldData expected_data = AsFieldData(expected_proto);
|
||||
FieldPath path = MakeFieldPath("OPTIONS/sub_options/bundle", node_data);
|
||||
MP_EXPECT_OK(Equals<LightBundle>(GetFieldValues(node_data, path).value(),
|
||||
{expected_data}));
|
||||
|
||||
// Read a specific index from a repeated protobuf field.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/bundle/0", node_data);
|
||||
MP_EXPECT_OK(Equals<LightBundle>(GetFieldValues(node_data, path).value(),
|
||||
{expected_data}));
|
||||
}
|
||||
|
||||
// Tests SetFieldValues applied to an int field.
|
||||
TEST_F(OptionsFieldUtilTest, SetFieldValuesInt) {
|
||||
NightLightCalculatorOptions node_proto;
|
||||
node_proto.mutable_sub_options();
|
||||
FieldData node_data = tool::options_field_util::AsFieldData(node_proto);
|
||||
|
||||
// Replace an entire empty repeated field.
|
||||
FieldPath path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data);
|
||||
MP_ASSERT_OK(SetFieldValues(node_data, path, {AsFieldData(33)}));
|
||||
MP_EXPECT_OK(
|
||||
Equals<int>(GetFieldValues(node_data, path).value(), {AsFieldData(33)}));
|
||||
|
||||
// Replace an entire populated repeated field.
|
||||
MP_ASSERT_OK(SetFieldValues(node_data, path, {AsFieldData(44)}));
|
||||
MP_EXPECT_OK(
|
||||
Equals<int>(GetFieldValues(node_data, path).value(), {AsFieldData(44)}));
|
||||
|
||||
// Replace an entire repeated field with a new list of values.
|
||||
MP_ASSERT_OK(
|
||||
SetFieldValues(node_data, path, {AsFieldData(33), AsFieldData(44)}));
|
||||
MP_EXPECT_OK(Equals<int>(GetFieldValues(node_data, path).value(),
|
||||
{AsFieldData(33), AsFieldData(44)}));
|
||||
|
||||
// Replace a single field index with a new list of values.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
|
||||
MP_ASSERT_OK(
|
||||
SetFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)}));
|
||||
MP_EXPECT_OK(
|
||||
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(),
|
||||
{AsFieldData(33), AsFieldData(55), AsFieldData(66)}));
|
||||
|
||||
// Replace a single field middle index with a new list of values.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
|
||||
MP_ASSERT_OK(
|
||||
SetFieldValues(node_data, path, {AsFieldData(11), AsFieldData(12)}));
|
||||
MP_EXPECT_OK(Equals<int>(
|
||||
GetFieldValues(node_data, EntireField(path)).value(),
|
||||
{AsFieldData(33), AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
|
||||
|
||||
// Replace field index 0 with a new value.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/num_lights/0", node_data);
|
||||
MP_ASSERT_OK(SetFieldValues(node_data, path, {AsFieldData(77)}));
|
||||
MP_EXPECT_OK(Equals<int>(
|
||||
GetFieldValues(node_data, EntireField(path)).value(),
|
||||
{AsFieldData(77), AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
|
||||
|
||||
// Replace field index 0 with an empty list of values.
|
||||
MP_ASSERT_OK(SetFieldValues(node_data, path, {}));
|
||||
MP_EXPECT_OK(
|
||||
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(),
|
||||
{AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
|
||||
|
||||
// Replace an entire populated field with an empty list of values.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data);
|
||||
MP_ASSERT_OK(SetFieldValues(node_data, path, {}));
|
||||
MP_ASSERT_OK(
|
||||
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(), {}));
|
||||
|
||||
// Replace a missing field index with new values.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
|
||||
absl::Status status =
|
||||
SetFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)});
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kInternal);
|
||||
// TODO: status.message() appears empty on KokoroGCPDocker.
|
||||
// EXPECT_THAT(status.message(),
|
||||
// HasSubstr("index >= 0 && index <= v.size()"));
|
||||
}
|
||||
|
||||
// Tests SetFieldValues applied to a protobuf field.
|
||||
TEST_F(OptionsFieldUtilTest, SetFieldValuesProtobuf) {
|
||||
using tool::options_field_util::AsFieldData;
|
||||
using LightBundle = NightLightCalculatorOptions::LightBundle;
|
||||
NightLightCalculatorOptions node_proto;
|
||||
node_proto.mutable_sub_options();
|
||||
FieldData node_data = AsFieldData(node_proto);
|
||||
|
||||
// Replace an empty repeated protobuf field.
|
||||
LightBundle bundle_proto;
|
||||
*bundle_proto.mutable_room_id() = "222";
|
||||
bundle_proto.add_room_lights()->set_frame_rate(22.1);
|
||||
FieldData bundle_data = AsFieldData(bundle_proto);
|
||||
FieldData expected_data = bundle_data;
|
||||
FieldPath path = MakeFieldPath("OPTIONS/sub_options/bundle", node_data);
|
||||
MP_ASSERT_OK(SetFieldValues(node_data, path, {bundle_data}));
|
||||
MP_EXPECT_OK(Equals<LightBundle>(
|
||||
GetFieldValues(node_data, EntireField(path)).value(), {expected_data}));
|
||||
|
||||
// Replace a populated repeated protobuf field.
|
||||
*bundle_proto.mutable_room_id() = "333";
|
||||
bundle_proto.mutable_room_lights(0)->set_frame_rate(33.1);
|
||||
bundle_data = AsFieldData(bundle_proto);
|
||||
LightBundle expected_proto;
|
||||
*expected_proto.mutable_room_id() = "333";
|
||||
expected_proto.add_room_lights()->set_frame_rate(33.1);
|
||||
expected_data = AsFieldData(expected_proto);
|
||||
MP_ASSERT_OK(SetFieldValues(node_data, path, {bundle_data}));
|
||||
MP_EXPECT_OK(Equals<LightBundle>(
|
||||
GetFieldValues(node_data, EntireField(path)).value(), {expected_data}));
|
||||
}
|
||||
|
||||
// Tests MergeFieldValues applied to an int field.
|
||||
TEST_F(OptionsFieldUtilTest, MergeFieldValuesInt) {
|
||||
NightLightCalculatorOptions node_proto;
|
||||
node_proto.mutable_sub_options();
|
||||
FieldData node_data = tool::options_field_util::AsFieldData(node_proto);
|
||||
|
||||
// Replace an entire empty repeated field.
|
||||
FieldPath path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data);
|
||||
MP_ASSERT_OK(MergeFieldValues(node_data, path, {AsFieldData(33)}));
|
||||
MP_EXPECT_OK(
|
||||
Equals<int>(GetFieldValues(node_data, path).value(), {AsFieldData(33)}));
|
||||
|
||||
// Replace an entire populated repeated field.
|
||||
MP_ASSERT_OK(MergeFieldValues(node_data, path, {AsFieldData(44)}));
|
||||
MP_EXPECT_OK(
|
||||
Equals<int>(GetFieldValues(node_data, path).value(), {AsFieldData(44)}));
|
||||
|
||||
// Replace an entire repeated field with a new list of values.
|
||||
MP_ASSERT_OK(
|
||||
MergeFieldValues(node_data, path, {AsFieldData(33), AsFieldData(44)}));
|
||||
MP_EXPECT_OK(Equals<int>(GetFieldValues(node_data, path).value(),
|
||||
{AsFieldData(33), AsFieldData(44)}));
|
||||
|
||||
// Replace a singe field index with a new list of values.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
|
||||
MP_ASSERT_OK(
|
||||
MergeFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)}));
|
||||
MP_EXPECT_OK(
|
||||
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(),
|
||||
{AsFieldData(33), AsFieldData(55), AsFieldData(66)}));
|
||||
|
||||
// Replace a single field middle index with a new list of values.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
|
||||
MP_ASSERT_OK(
|
||||
MergeFieldValues(node_data, path, {AsFieldData(11), AsFieldData(12)}));
|
||||
MP_EXPECT_OK(Equals<int>(
|
||||
GetFieldValues(node_data, EntireField(path)).value(),
|
||||
{AsFieldData(33), AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
|
||||
|
||||
// Replace field index 0 with a new value.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/num_lights/0", node_data);
|
||||
MP_ASSERT_OK(MergeFieldValues(node_data, path, {AsFieldData(77)}));
|
||||
MP_EXPECT_OK(Equals<int>(
|
||||
GetFieldValues(node_data, EntireField(path)).value(),
|
||||
{AsFieldData(77), AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
|
||||
|
||||
// Replace field index 0 with an empty list of values.
|
||||
MP_ASSERT_OK(MergeFieldValues(node_data, path, {}));
|
||||
MP_EXPECT_OK(
|
||||
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(),
|
||||
{AsFieldData(11), AsFieldData(12), AsFieldData(66)}));
|
||||
|
||||
// Replace an entire populated field with an empty list of values.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/num_lights", node_data);
|
||||
MP_ASSERT_OK(MergeFieldValues(node_data, path, {}));
|
||||
MP_EXPECT_OK(
|
||||
Equals<int>(GetFieldValues(node_data, EntireField(path)).value(), {}));
|
||||
|
||||
// Replace a missing field index with new values.
|
||||
path = MakeFieldPath("OPTIONS/sub_options/num_lights/1", node_data);
|
||||
absl::Status status =
|
||||
MergeFieldValues(node_data, path, {AsFieldData(55), AsFieldData(66)});
|
||||
EXPECT_EQ(status.code(), absl::StatusCode::kOutOfRange);
|
||||
EXPECT_THAT(status.message(),
|
||||
HasSubstr("Missing feild value: num_lights at index: 1"));
|
||||
}
|
||||
|
||||
// Tests MergeFieldValues applied to a protobuf field.
|
||||
TEST_F(OptionsFieldUtilTest, MergeFieldValuesProtobuf) {
|
||||
using tool::options_field_util::AsFieldData;
|
||||
using LightBundle = NightLightCalculatorOptions::LightBundle;
|
||||
NightLightCalculatorOptions node_proto;
|
||||
node_proto.mutable_sub_options();
|
||||
FieldData node_data = AsFieldData(node_proto);
|
||||
|
||||
// Merge an empty repeated protobuf field.
|
||||
LightBundle bundle_proto;
|
||||
*bundle_proto.mutable_room_id() = "222";
|
||||
bundle_proto.add_room_lights()->set_frame_rate(22.1);
|
||||
FieldData bundle_data = AsFieldData(bundle_proto);
|
||||
FieldData expected_data = bundle_data;
|
||||
FieldPath path = MakeFieldPath("OPTIONS/sub_options/bundle", node_data);
|
||||
MP_ASSERT_OK(MergeFieldValues(node_data, path, {bundle_data}));
|
||||
MP_EXPECT_OK(Equals<LightBundle>(
|
||||
GetFieldValues(node_data, EntireField(path)).value(), {expected_data}));
|
||||
|
||||
// Merge a populated repeated protobuf field.
|
||||
// "LightBundle.room_id" merges to "333".
|
||||
// "LightBundle.room_lights" merges to {{22.1}, {33.1}}.
|
||||
*bundle_proto.mutable_room_id() = "333";
|
||||
bundle_proto.mutable_room_lights(0)->set_frame_rate(33.1);
|
||||
bundle_data = AsFieldData(bundle_proto);
|
||||
LightBundle expected_proto;
|
||||
*expected_proto.mutable_room_id() = "333";
|
||||
expected_proto.add_room_lights()->set_frame_rate(22.1);
|
||||
expected_proto.add_room_lights()->set_frame_rate(33.1);
|
||||
expected_data = AsFieldData(expected_proto);
|
||||
MP_ASSERT_OK(MergeFieldValues(node_data, path, {bundle_data}));
|
||||
MP_EXPECT_OK(Equals<LightBundle>(
|
||||
GetFieldValues(node_data, EntireField(path)).value(), {expected_data}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -16,11 +16,13 @@
|
|||
|
||||
#include <tuple>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "mediapipe/framework/port/canonical_errors.h"
|
||||
#include "mediapipe/framework/port/logging.h"
|
||||
#include "mediapipe/framework/port/ret_check.h"
|
||||
#include "mediapipe/framework/tool/field_data.pb.h"
|
||||
#include "mediapipe/framework/type_map.h"
|
||||
|
||||
#define RET_CHECK_NO_LOG(cond) RET_CHECK(cond).SetNoLogging()
|
||||
|
@ -37,6 +39,7 @@ using FieldAccess = ProtoUtilLite::FieldAccess;
|
|||
using FieldValue = ProtoUtilLite::FieldValue;
|
||||
using ProtoPath = ProtoUtilLite::ProtoPath;
|
||||
using FieldType = ProtoUtilLite::FieldType;
|
||||
using mediapipe::FieldData;
|
||||
|
||||
// Returns true if a wire type includes a length indicator.
|
||||
bool IsLengthDelimited(WireFormatLite::WireType wire_type) {
|
||||
|
@ -408,5 +411,149 @@ absl::Status ProtoUtilLite::Deserialize(
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ProtoUtilLite::WriteValue(const FieldData& value,
|
||||
FieldType field_type,
|
||||
std::string* field_bytes) {
|
||||
StringOutputStream sos(field_bytes);
|
||||
CodedOutputStream out(&sos);
|
||||
switch (field_type) {
|
||||
case WireFormatLite::TYPE_INT32:
|
||||
WireFormatLite::WriteInt32NoTag(value.int32_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_SINT32:
|
||||
WireFormatLite::WriteSInt32NoTag(value.int32_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_INT64:
|
||||
WireFormatLite::WriteInt64NoTag(value.int64_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_SINT64:
|
||||
WireFormatLite::WriteSInt64NoTag(value.int64_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_UINT32:
|
||||
WireFormatLite::WriteUInt32NoTag(value.uint32_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_UINT64:
|
||||
WireFormatLite::WriteUInt64NoTag(value.uint64_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_DOUBLE:
|
||||
WireFormatLite::WriteDoubleNoTag(value.uint64_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_FLOAT:
|
||||
WireFormatLite::WriteFloatNoTag(value.float_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_BOOL:
|
||||
WireFormatLite::WriteBoolNoTag(value.bool_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_ENUM:
|
||||
WireFormatLite::WriteEnumNoTag(value.enum_value(), &out);
|
||||
break;
|
||||
case WireFormatLite::TYPE_STRING:
|
||||
out.WriteString(value.string_value());
|
||||
break;
|
||||
case WireFormatLite::TYPE_MESSAGE:
|
||||
out.WriteString(value.message_value().value());
|
||||
break;
|
||||
default:
|
||||
return absl::UnimplementedError(
|
||||
absl::StrCat("Cannot write type: ", field_type));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <typename ValueT, FieldType kFieldType>
|
||||
static ValueT ReadValue(absl::string_view field_bytes, absl::Status* status) {
|
||||
ArrayInputStream ais(field_bytes.data(), field_bytes.size());
|
||||
CodedInputStream input(&ais);
|
||||
ValueT result;
|
||||
if (!WireFormatLite::ReadPrimitive<ValueT, kFieldType>(&input, &result)) {
|
||||
status->Update(absl::InvalidArgumentError(absl::StrCat(
|
||||
"Bad serialized value: ", MediaPipeTypeStringOrDemangled<ValueT>(),
|
||||
".")));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::Status ReadValue(absl::string_view field_bytes, FieldType field_type,
|
||||
absl::string_view message_type, FieldData* result) {
|
||||
absl::Status status;
|
||||
result->Clear();
|
||||
switch (field_type) {
|
||||
case WireFormatLite::TYPE_INT32:
|
||||
result->set_int32_value(
|
||||
ReadValue<int32, WireFormatLite::TYPE_INT32>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_SINT32:
|
||||
result->set_int32_value(
|
||||
ReadValue<int32, WireFormatLite::TYPE_SINT32>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_INT64:
|
||||
result->set_int64_value(
|
||||
ReadValue<int64, WireFormatLite::TYPE_INT64>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_SINT64:
|
||||
result->set_int64_value(
|
||||
ReadValue<int64, WireFormatLite::TYPE_SINT64>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_UINT32:
|
||||
result->set_uint32_value(
|
||||
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_UINT64:
|
||||
result->set_uint64_value(
|
||||
ReadValue<uint32, WireFormatLite::TYPE_UINT32>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_DOUBLE:
|
||||
result->set_double_value(
|
||||
ReadValue<double, WireFormatLite::TYPE_DOUBLE>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_FLOAT:
|
||||
result->set_float_value(
|
||||
ReadValue<float, WireFormatLite::TYPE_FLOAT>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_BOOL:
|
||||
result->set_bool_value(
|
||||
ReadValue<bool, WireFormatLite::TYPE_BOOL>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_ENUM:
|
||||
result->set_enum_value(
|
||||
ReadValue<int32, WireFormatLite::TYPE_ENUM>(field_bytes, &status));
|
||||
break;
|
||||
case WireFormatLite::TYPE_STRING:
|
||||
result->set_string_value(std::string(field_bytes));
|
||||
break;
|
||||
case WireFormatLite::TYPE_MESSAGE:
|
||||
result->mutable_message_value()->set_value(std::string(field_bytes));
|
||||
result->mutable_message_value()->set_type_url(
|
||||
ProtoUtilLite::TypeUrl(message_type));
|
||||
break;
|
||||
default:
|
||||
status = absl::UnimplementedError(
|
||||
absl::StrCat("Cannot read type: ", field_type));
|
||||
break;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
absl::Status ProtoUtilLite::ReadValue(absl::string_view field_bytes,
|
||||
FieldType field_type,
|
||||
absl::string_view message_type,
|
||||
FieldData* result) {
|
||||
return mediapipe::tool::ReadValue(field_bytes, field_type, message_type,
|
||||
result);
|
||||
}
|
||||
|
||||
std::string ProtoUtilLite::TypeUrl(absl::string_view type_name) {
|
||||
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/";
|
||||
return absl::StrCat(std::string(kTypeUrlPrefix), std::string(type_name));
|
||||
}
|
||||
|
||||
std::string ProtoUtilLite::ParseTypeUrl(absl::string_view type_url) {
|
||||
constexpr std::string_view kTypeUrlPrefix = "type.googleapis.com/";
|
||||
if (absl::StartsWith(std::string(type_url), std::string(kTypeUrlPrefix))) {
|
||||
return std::string(type_url.substr(kTypeUrlPrefix.length()));
|
||||
}
|
||||
return std::string(type_url);
|
||||
}
|
||||
|
||||
} // namespace tool
|
||||
} // namespace mediapipe
|
||||
|
|
|
@ -23,10 +23,12 @@
|
|||
#include "mediapipe/framework/port/integral_types.h"
|
||||
#include "mediapipe/framework/port/proto_ns.h"
|
||||
#include "mediapipe/framework/port/status.h"
|
||||
#include "mediapipe/framework/tool/field_data.pb.h"
|
||||
|
||||
namespace mediapipe {
|
||||
namespace tool {
|
||||
|
||||
// TODO: Replace this class with a namespace following Google style.
|
||||
class ProtoUtilLite {
|
||||
public:
|
||||
// Defines field types and tag formats.
|
||||
|
@ -89,6 +91,23 @@ class ProtoUtilLite {
|
|||
static absl::Status Deserialize(const std::vector<FieldValue>& field_values,
|
||||
FieldType field_type,
|
||||
std::vector<std::string>* result);
|
||||
|
||||
// Write a protobuf field value from a typed FieldData value.
|
||||
static absl::Status WriteValue(const mediapipe::FieldData& value,
|
||||
FieldType field_type,
|
||||
std::string* field_bytes);
|
||||
|
||||
// Read a protobuf field value into a typed FieldData value.
|
||||
static absl::Status ReadValue(absl::string_view field_bytes,
|
||||
FieldType field_type,
|
||||
absl::string_view message_type,
|
||||
mediapipe::FieldData* result);
|
||||
|
||||
// Returns the protobuf type-url for a protobuf type-name.
|
||||
static std::string TypeUrl(absl::string_view type_name);
|
||||
|
||||
// Returns the protobuf type-name for a protobuf type-url.
|
||||
static std::string ParseTypeUrl(absl::string_view type_url);
|
||||
};
|
||||
|
||||
} // namespace tool
|
||||
|
|
|
@ -59,7 +59,8 @@ absl::Status CombinedStatus(const std::string& general_comment,
|
|||
}
|
||||
}
|
||||
if (error_code == StatusCode::kOk) return OkStatus();
|
||||
Status combined = absl::Status(
|
||||
Status combined;
|
||||
combined = absl::Status(
|
||||
error_code,
|
||||
absl::StrCat(general_comment, "\n", absl::StrJoin(errors, "\n")));
|
||||
return combined;
|
||||
|
|
|
@ -28,8 +28,11 @@ namespace mediapipe {
|
|||
namespace {
|
||||
|
||||
using testing::ContainerEq;
|
||||
using testing::Eq;
|
||||
using testing::HasSubstr;
|
||||
using testing::IsEmpty;
|
||||
using testing::Matches;
|
||||
using testing::Pointwise;
|
||||
|
||||
TEST(StatusTest, StatusStopIsNotOk) { EXPECT_FALSE(tool::StatusStop().ok()); }
|
||||
|
||||
|
|
|
@ -293,7 +293,7 @@ absl::Status ExpandSubgraphs(CalculatorGraphConfig* config,
|
|||
if (subgraph_nodes_start == nodes->end()) break;
|
||||
std::vector<CalculatorGraphConfig> subgraphs;
|
||||
for (auto it = subgraph_nodes_start; it != nodes->end(); ++it) {
|
||||
const auto& node = *it;
|
||||
auto& node = *it;
|
||||
int node_id = it - nodes->begin();
|
||||
std::string node_name = CanonicalNodeName(*config, node_id);
|
||||
MP_RETURN_IF_ERROR(ValidateSubgraphFields(node));
|
||||
|
|
|
@ -16,79 +16,129 @@
|
|||
#define MEDIAPIPE_FRAMEWORK_TOOL_TYPE_UTIL_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <typeinfo>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/base/attributes.h"
|
||||
#include "mediapipe/framework/demangle.h"
|
||||
#include "mediapipe/framework/port.h"
|
||||
|
||||
namespace mediapipe {
|
||||
|
||||
// An identifier for a type. This class is lightweight and is meant to be passed
|
||||
// by value.
|
||||
// To get the TypeId for SomeType, write kTypeId<SomeType>.
|
||||
class TypeId {
|
||||
public:
|
||||
size_t hash_code() const { return impl_.hash_code(); }
|
||||
std::string name() const { return impl_.name(); }
|
||||
bool operator==(const TypeId& other) const { return impl_ == other.impl_; }
|
||||
bool operator<(const TypeId& other) const { return impl_ < other.impl_; }
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const TypeId& r) {
|
||||
return H::combine(std::move(h), r.hash_code());
|
||||
}
|
||||
|
||||
template <class T>
|
||||
static constexpr inline TypeId Of() {
|
||||
return TypeId{Impl::Get<T>()};
|
||||
}
|
||||
|
||||
private:
|
||||
// This implementation uses no RTTI. It distinguishes types, but does not
|
||||
// know their names.
|
||||
// TODO: record compile-time type string for (some or all) types.
|
||||
template <class T>
|
||||
struct TypeTag {
|
||||
static constexpr char dummy = 0;
|
||||
};
|
||||
struct NoRttiImpl {
|
||||
template <class T>
|
||||
static constexpr inline NoRttiImpl Get() {
|
||||
return {&TypeTag<T>::dummy};
|
||||
}
|
||||
size_t hash_code() const { return reinterpret_cast<uintptr_t>(tag_); }
|
||||
std::string name() const { return "<type name missing>"; }
|
||||
bool operator==(const NoRttiImpl& other) const {
|
||||
return tag_ == other.tag_;
|
||||
}
|
||||
bool operator<(const NoRttiImpl& other) const { return tag_ < other.tag_; }
|
||||
|
||||
const void* tag_;
|
||||
};
|
||||
|
||||
#if MEDIAPIPE_HAS_RTTI
|
||||
template <class T>
|
||||
static const std::type_info& GetTypeInfo() {
|
||||
return typeid(T);
|
||||
}
|
||||
// This implementation uses RTTI, and delegates all operations to
|
||||
// std::type_info. In order to support constexpr construction, we don't store
|
||||
// a type_info directly (which is not constexpr), but a pointer to a function
|
||||
// returning it (which is). This implementation is a bit slower than the
|
||||
// others. The only potential advantage would be the ability to match types
|
||||
// across multiple dynamic libraries, but we don't support that setup anyway.
|
||||
// This is provided for completeness.
|
||||
struct FullRttiImpl {
|
||||
template <class T>
|
||||
static constexpr inline FullRttiImpl Get() {
|
||||
return {GetTypeInfo<T>};
|
||||
}
|
||||
size_t hash_code() const { return get_().hash_code(); }
|
||||
std::string name() const { return Demangle(get_().name()); }
|
||||
bool operator==(const FullRttiImpl& other) const {
|
||||
return get_ == other.get_ || get_() == other.get_();
|
||||
}
|
||||
bool operator<(const FullRttiImpl& other) const {
|
||||
return get_().before(other.get_());
|
||||
}
|
||||
|
||||
decltype(&GetTypeInfo<void>) get_;
|
||||
};
|
||||
|
||||
// This implementation also stores a pointer to a std::type_info getter
|
||||
// function, but it only invokes it to get the type's name. It's equivalent to
|
||||
// NoRttiImpl for most operations, but it allows getting the type's name.
|
||||
struct FastRttiImpl {
|
||||
template <class T>
|
||||
static constexpr inline FastRttiImpl Get() {
|
||||
return {GetTypeInfo<T>};
|
||||
}
|
||||
size_t hash_code() const { return reinterpret_cast<uintptr_t>(get_); }
|
||||
std::string name() const { return Demangle(get_().name()); }
|
||||
bool operator==(const FastRttiImpl& other) const {
|
||||
return get_ == other.get_;
|
||||
}
|
||||
bool operator<(const FastRttiImpl& other) const {
|
||||
return reinterpret_cast<uintptr_t>(get_) <
|
||||
reinterpret_cast<uintptr_t>(other.get_);
|
||||
}
|
||||
|
||||
decltype(&GetTypeInfo<void>) get_;
|
||||
};
|
||||
|
||||
using Impl = FastRttiImpl;
|
||||
#else
|
||||
using Impl = NoRttiImpl;
|
||||
#endif // MEDIAPIPE_HAS_RTTI
|
||||
constexpr explicit TypeId(Impl impl) : impl_(impl) {}
|
||||
|
||||
Impl impl_;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
static constexpr TypeId kTypeId = TypeId::Of<T>();
|
||||
|
||||
namespace tool {
|
||||
|
||||
#if !MEDIAPIPE_HAS_RTTI
|
||||
// A unique identifier for type T.
|
||||
class TypeInfo {
|
||||
public:
|
||||
size_t hash_code() const { return reinterpret_cast<size_t>(this); }
|
||||
bool operator==(const TypeInfo& other) const { return &other == this; }
|
||||
bool operator<(const TypeInfo& other) const { return &other < this; }
|
||||
const char* name() const { return "<unknown>"; }
|
||||
template <typename T>
|
||||
static const TypeInfo& Get() {
|
||||
static TypeInfo* static_type_info = new TypeInfo;
|
||||
return *static_type_info;
|
||||
}
|
||||
|
||||
private:
|
||||
TypeInfo() {}
|
||||
TypeInfo(const TypeInfo&) = delete;
|
||||
};
|
||||
|
||||
#else // MEDIAPIPE_HAS_RTTI
|
||||
// The std unique identifier for type T.
|
||||
class TypeInfo {
|
||||
public:
|
||||
size_t hash_code() const { return info_.hash_code(); }
|
||||
bool operator==(const TypeInfo& o) const { return info_ == o.info_; }
|
||||
bool operator<(const TypeInfo& o) const { return info_.before(o.info_); }
|
||||
const char* name() const { return info_.name(); }
|
||||
template <typename T>
|
||||
static const TypeInfo& Get() {
|
||||
static TypeInfo* static_type_info = new TypeInfo(typeid(T));
|
||||
return *static_type_info;
|
||||
}
|
||||
|
||||
private:
|
||||
TypeInfo(const std::type_info& info) : info_(info) {}
|
||||
TypeInfo(const TypeInfo&) = delete;
|
||||
|
||||
private:
|
||||
const std::type_info& info_;
|
||||
friend class TypeIndex;
|
||||
};
|
||||
#endif
|
||||
|
||||
// An associative key for TypeInfo.
|
||||
class TypeIndex {
|
||||
public:
|
||||
TypeIndex(const TypeInfo& info) : info_(info) {}
|
||||
size_t hash_code() const { return info_.hash_code(); }
|
||||
bool operator==(const TypeIndex& other) const { return info_ == other.info_; }
|
||||
bool operator<(const TypeIndex& other) const { return info_ < other.info_; }
|
||||
|
||||
private:
|
||||
const TypeInfo& info_;
|
||||
};
|
||||
|
||||
// Helper method that returns a hash code of the given type. This allows for
|
||||
// typeid testing across multiple binaries, unlike FastTypeId which used a
|
||||
// memory location that only works within the same binary. Moreover, we use this
|
||||
// for supporting multiple .so binaries in a single Android app built using the
|
||||
// same compiler and C++ libraries.
|
||||
// Note that std::type_info may still generate the same hash code for different
|
||||
// types, although the c++ standard recommends that implementations avoid this
|
||||
// as much as possible.
|
||||
// Helper method that returns a hash code of the given type.
|
||||
// Superseded by TypeId.
|
||||
template <typename T>
|
||||
ABSL_DEPRECATED("Use TypeId directly instead.")
|
||||
size_t GetTypeHash() {
|
||||
return TypeInfo::Get<T>().hash_code();
|
||||
return kTypeId<T>.hash_code();
|
||||
}
|
||||
|
||||
} // namespace tool
|
||||
|
|
|
@ -361,32 +361,30 @@ DEFINE_MEDIAPIPE_TYPE_MAP(PacketTypeStringToMediaPipeTypeData, std::string);
|
|||
// End define MEDIAPIPE_REGISTER_TYPE_WITH_PROXY.
|
||||
|
||||
// Helper functions's to retrieve registration data.
|
||||
inline const std::string* MediaPipeTypeStringFromTypeId(const size_t type_id) {
|
||||
inline const std::string* MediaPipeTypeStringFromTypeId(TypeId type_id) {
|
||||
const MediaPipeTypeData* value =
|
||||
PacketTypeIdToMediaPipeTypeData::GetValue(type_id);
|
||||
PacketTypeIdToMediaPipeTypeData::GetValue(type_id.hash_code());
|
||||
return (value) ? &value->type_string : nullptr;
|
||||
}
|
||||
|
||||
// Returns string identifier of type or NULL if not registered.
|
||||
template <typename T>
|
||||
inline const std::string* MediaPipeTypeString() {
|
||||
return MediaPipeTypeStringFromTypeId(tool::GetTypeHash<T>());
|
||||
return MediaPipeTypeStringFromTypeId(kTypeId<T>);
|
||||
}
|
||||
|
||||
inline std::string MediaPipeTypeStringOrDemangled(
|
||||
const tool::TypeInfo& type_info) {
|
||||
const std::string* type_string =
|
||||
MediaPipeTypeStringFromTypeId(type_info.hash_code());
|
||||
inline std::string MediaPipeTypeStringOrDemangled(TypeId type_id) {
|
||||
const std::string* type_string = MediaPipeTypeStringFromTypeId(type_id);
|
||||
if (type_string) {
|
||||
return *type_string;
|
||||
} else {
|
||||
return mediapipe::Demangle(type_info.name());
|
||||
return type_id.name();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::string MediaPipeTypeStringOrDemangled() {
|
||||
return MediaPipeTypeStringOrDemangled(tool::TypeInfo::Get<T>());
|
||||
return MediaPipeTypeStringOrDemangled(kTypeId<T>);
|
||||
}
|
||||
|
||||
// Returns type hash id of type identified by type_string or NULL if not
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
#include "mediapipe/framework/validated_graph_config.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
|
@ -140,35 +142,6 @@ absl::Status AddPredefinedExecutorConfigs(CalculatorGraphConfig* graph_config) {
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status PerformBasicTransforms(
|
||||
const CalculatorGraphConfig& input_graph_config,
|
||||
const GraphRegistry* graph_registry,
|
||||
const Subgraph::SubgraphOptions* graph_options,
|
||||
const GraphServiceManager* service_manager,
|
||||
CalculatorGraphConfig* output_graph_config) {
|
||||
*output_graph_config = input_graph_config;
|
||||
MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(output_graph_config, graph_registry,
|
||||
graph_options, service_manager));
|
||||
|
||||
MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(output_graph_config));
|
||||
|
||||
// Populate each node with the graph level input stream handler if a
|
||||
// stream handler wasn't explicitly provided.
|
||||
// TODO Instead of pre-populating, handle the graph level
|
||||
// default appropriately within CalculatorGraph.
|
||||
if (output_graph_config->has_input_stream_handler()) {
|
||||
const auto& graph_level_input_stream_handler =
|
||||
output_graph_config->input_stream_handler();
|
||||
for (auto& node : *output_graph_config->mutable_node()) {
|
||||
if (!node.has_input_stream_handler()) {
|
||||
*node.mutable_input_stream_handler() = graph_level_input_stream_handler;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
|
@ -346,8 +319,7 @@ absl::Status NodeTypeInfo::Initialize(
|
|||
}
|
||||
|
||||
absl::Status ValidatedGraphConfig::Initialize(
|
||||
const CalculatorGraphConfig& input_config,
|
||||
const GraphRegistry* graph_registry,
|
||||
CalculatorGraphConfig input_config, const GraphRegistry* graph_registry,
|
||||
const Subgraph::SubgraphOptions* graph_options,
|
||||
const GraphServiceManager* service_manager) {
|
||||
RET_CHECK(!initialized_)
|
||||
|
@ -358,9 +330,9 @@ absl::Status ValidatedGraphConfig::Initialize(
|
|||
<< input_config.DebugString();
|
||||
#endif
|
||||
|
||||
MP_RETURN_IF_ERROR(PerformBasicTransforms(
|
||||
input_config, graph_registry, graph_options, service_manager, &config_));
|
||||
|
||||
config_ = std::move(input_config);
|
||||
MP_RETURN_IF_ERROR(
|
||||
PerformBasicTransforms(graph_registry, graph_options, service_manager));
|
||||
// Initialize the basic node information.
|
||||
MP_RETURN_IF_ERROR(InitializeGeneratorInfo());
|
||||
MP_RETURN_IF_ERROR(InitializeCalculatorInfo());
|
||||
|
@ -441,7 +413,12 @@ absl::Status ValidatedGraphConfig::Initialize(
|
|||
const GraphServiceManager* service_manager) {
|
||||
graph_registry =
|
||||
graph_registry ? graph_registry : &GraphRegistry::global_graph_registry;
|
||||
SubgraphContext subgraph_context(graph_options, service_manager);
|
||||
Subgraph::SubgraphOptions local_graph_options;
|
||||
if (graph_options) {
|
||||
local_graph_options = *graph_options;
|
||||
}
|
||||
SubgraphContext subgraph_context =
|
||||
SubgraphContext(&local_graph_options, service_manager);
|
||||
auto status_or_config =
|
||||
graph_registry->CreateByName("", graph_type, &subgraph_context);
|
||||
MP_RETURN_IF_ERROR(status_or_config.status());
|
||||
|
@ -466,6 +443,32 @@ absl::Status ValidatedGraphConfig::Initialize(
|
|||
service_manager);
|
||||
}
|
||||
|
||||
absl::Status ValidatedGraphConfig::PerformBasicTransforms(
|
||||
const GraphRegistry* graph_registry,
|
||||
const Subgraph::SubgraphOptions* graph_options,
|
||||
const GraphServiceManager* service_manager) {
|
||||
MP_RETURN_IF_ERROR(tool::ExpandSubgraphs(&config_, graph_registry,
|
||||
graph_options, service_manager));
|
||||
|
||||
MP_RETURN_IF_ERROR(AddPredefinedExecutorConfigs(&config_));
|
||||
|
||||
// Populate each node with the graph level input stream handler if a
|
||||
// stream handler wasn't explicitly provided.
|
||||
// TODO Instead of pre-populating, handle the graph level
|
||||
// default appropriately within CalculatorGraph.
|
||||
if (config_.has_input_stream_handler()) {
|
||||
const auto& graph_level_input_stream_handler =
|
||||
config_.input_stream_handler();
|
||||
for (auto& node : *config_.mutable_node()) {
|
||||
if (!node.has_input_stream_handler()) {
|
||||
*node.mutable_input_stream_handler() = graph_level_input_stream_handler;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status ValidatedGraphConfig::InitializeCalculatorInfo() {
|
||||
std::vector<absl::Status> statuses;
|
||||
calculators_.reserve(config_.node_size());
|
||||
|
@ -690,6 +693,7 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode(
|
|||
if (!need_sorting_ptr) {
|
||||
LOG(WARNING) << "Input Stream \"" << name
|
||||
<< "\" for node with sorted index " << node_index
|
||||
<< " name " << node_type_info->Contract().GetNodeName()
|
||||
<< " is marked as a back edge, but its output stream is "
|
||||
"already available. This means it was not necessary "
|
||||
"to mark it as a back edge.";
|
||||
|
@ -701,6 +705,7 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode(
|
|||
if (edge_info.back_edge) {
|
||||
VLOG(1) << "Encountered expected behavior: the back edge \"" << name
|
||||
<< "\" for node with (possibly sorted) index " << node_index
|
||||
<< " name " << node_type_info->Contract().GetNodeName()
|
||||
<< " has an output stream which we have not yet seen.";
|
||||
} else if (need_sorting_ptr) {
|
||||
*need_sorting_ptr = true;
|
||||
|
@ -709,7 +714,9 @@ absl::Status ValidatedGraphConfig::AddInputStreamsForNode(
|
|||
} else {
|
||||
return mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
|
||||
<< "Input Stream \"" << name << "\" for node with sorted index "
|
||||
<< node_index << " does not have a corresponding output stream.";
|
||||
<< node_index << " name "
|
||||
<< node_type_info->Contract().GetNodeName()
|
||||
<< " does not have a corresponding output stream.";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -195,7 +195,7 @@ class ValidatedGraphConfig {
|
|||
// before any other functions. Subgraphs are specified through the
|
||||
// global graph registry or an optional local graph registry.
|
||||
absl::Status Initialize(
|
||||
const CalculatorGraphConfig& input_config,
|
||||
CalculatorGraphConfig input_config,
|
||||
const GraphRegistry* graph_registry = nullptr,
|
||||
const Subgraph::SubgraphOptions* graph_options = nullptr,
|
||||
const GraphServiceManager* service_manager = nullptr);
|
||||
|
@ -302,6 +302,13 @@ class ValidatedGraphConfig {
|
|||
}
|
||||
|
||||
private:
|
||||
// Perform transforms such as converting legacy features, expanding
|
||||
// subgraphs, and popluting input stream handler.
|
||||
absl::Status PerformBasicTransforms(
|
||||
const GraphRegistry* graph_registry,
|
||||
const Subgraph::SubgraphOptions* graph_options,
|
||||
const GraphServiceManager* service_manager);
|
||||
|
||||
// Initialize the PacketGenerator information.
|
||||
absl::Status InitializeGeneratorInfo();
|
||||
// Initialize the Calculator information.
|
||||
|
|
|
@ -53,6 +53,12 @@ cc_library(
|
|||
deps = ["//mediapipe/framework:graph_service"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "attachments",
|
||||
hdrs = ["attachments.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
GL_BASE_LINK_OPTS = select({
|
||||
"//conditions:default": [],
|
||||
"//mediapipe:android": [
|
||||
|
@ -172,6 +178,7 @@ cc_library(
|
|||
}),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":attachments",
|
||||
":gl_base",
|
||||
":gl_thread_collector",
|
||||
":gpu_buffer_format",
|
||||
|
|
|
@ -12,8 +12,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_
|
||||
#define MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_
|
||||
#ifndef MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_
|
||||
#define MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_
|
||||
|
||||
#import <CoreVideo/CVMetalTextureCache.h>
|
||||
#import <CoreVideo/CoreVideo.h>
|
||||
|
@ -68,4 +68,4 @@ class GpuBufferMultiPool;
|
|||
|
||||
@end
|
||||
|
||||
#endif // MEDIAPIPE_GPU_DRISHTIGRAPHGPUDATA_H_
|
||||
#endif // MEDIAPIPE_GPU_MPPGRAPHGPUDATA_H_
|
||||
|
|
64
mediapipe/gpu/attachments.h
Normal file
64
mediapipe/gpu/attachments.h
Normal file
|
@ -0,0 +1,64 @@
|
|||
#ifndef MEDIAPIPE_GPU_ATTACHMENTS_H_
|
||||
#define MEDIAPIPE_GPU_ATTACHMENTS_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
namespace mediapipe {
|
||||
namespace internal {
|
||||
|
||||
// Unique pointer with a type-erased destructor.
|
||||
template <class T>
|
||||
using AttachmentPtr = std::unique_ptr<T, std::function<void(void*)>>;
|
||||
|
||||
// Like make_unique.
|
||||
template <class T, class... Args>
|
||||
static std::enable_if_t<!std::is_array<T>::value, AttachmentPtr<T>>
|
||||
MakeAttachmentPtr(Args&&... args) {
|
||||
return {new T(std::forward<Args>(args)...),
|
||||
[](void* ptr) { delete static_cast<T*>(ptr); }};
|
||||
}
|
||||
|
||||
template <class Context>
|
||||
class AttachmentBase {};
|
||||
|
||||
// An cacheable resource that can be associated with a context.
|
||||
// Attachments are defined as constants.
|
||||
// When access to an attachment is requested, it will be retrieved from the
|
||||
// context if already created, or the factory function will be invoked to create
|
||||
// it. The factory function for a given attachment is invoked at most once per
|
||||
// context. The lifetime of the object it returns is managed by the context.
|
||||
template <class Context, class T>
|
||||
class Attachment : public AttachmentBase<Context> {
|
||||
public:
|
||||
using FactoryT = std::function<AttachmentPtr<T>(Context&)>;
|
||||
Attachment(FactoryT factory) : factory_(factory) {}
|
||||
|
||||
Attachment(const Attachment&) = delete;
|
||||
Attachment(Attachment&&) = delete;
|
||||
Attachment& operator=(const Attachment&) = delete;
|
||||
Attachment& operator=(Attachment&&) = delete;
|
||||
|
||||
T& Get(Context& ctx) const { return ctx.GetCachedAttachment(*this); }
|
||||
|
||||
const FactoryT& factory() const { return factory_; }
|
||||
|
||||
// Ptr and MakePtr here make it more convenient to define new types of
|
||||
// attachment contexts, since you only need a using declaration for Attachment
|
||||
// and can refer to Ptr from it.
|
||||
using Ptr = AttachmentPtr<T>;
|
||||
|
||||
template <class... Args>
|
||||
inline static std::enable_if_t<!std::is_array<T>::value, AttachmentPtr<T>>
|
||||
MakePtr(Args&&... args) {
|
||||
return MakeAttachmentPtr<T>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
private:
|
||||
FactoryT factory_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace mediapipe
|
||||
|
||||
#endif // MEDIAPIPE_GPU_ATTACHMENTS_H_
|
|
@ -29,6 +29,7 @@
|
|||
#include "mediapipe/framework/port/statusor.h"
|
||||
#include "mediapipe/framework/port/threadpool.h"
|
||||
#include "mediapipe/framework/timestamp.h"
|
||||
#include "mediapipe/gpu/attachments.h"
|
||||
#include "mediapipe/gpu/gl_base.h"
|
||||
#include "mediapipe/gpu/gpu_buffer_format.h"
|
||||
|
||||
|
@ -286,42 +287,15 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
|||
// Sets default texture filtering parameters.
|
||||
void SetStandardTextureParams(GLenum target, GLint internal_format);
|
||||
|
||||
using AttachmentBase = internal::AttachmentBase<GlContext>;
|
||||
template <class T>
|
||||
using AttachmentPtr = std::unique_ptr<T, std::function<void(void*)>>;
|
||||
|
||||
template <class T, class... Args>
|
||||
static std::enable_if_t<!std::is_array<T>::value, AttachmentPtr<T>>
|
||||
MakeAttachmentPtr(Args&&... args) {
|
||||
return {new T(std::forward<Args>(args)...),
|
||||
[](void* ptr) { delete static_cast<T*>(ptr); }};
|
||||
}
|
||||
|
||||
class AttachmentBase {};
|
||||
|
||||
template <class T>
|
||||
class Attachment : public AttachmentBase {
|
||||
public:
|
||||
using FactoryT = std::function<AttachmentPtr<T>(GlContext&)>;
|
||||
Attachment(FactoryT factory) : factory_(factory) {}
|
||||
|
||||
Attachment(const Attachment&) = delete;
|
||||
Attachment(Attachment&&) = delete;
|
||||
Attachment& operator=(const Attachment&) = delete;
|
||||
Attachment& operator=(Attachment&&) = delete;
|
||||
|
||||
T& Get(GlContext& ctx) const { return ctx.GetCachedAttachment(*this); }
|
||||
|
||||
const FactoryT& factory() const { return factory_; }
|
||||
|
||||
private:
|
||||
FactoryT factory_;
|
||||
};
|
||||
using Attachment = internal::Attachment<GlContext, T>;
|
||||
|
||||
// TOOD: const result?
|
||||
template <class T>
|
||||
T& GetCachedAttachment(const Attachment<T>& attachment) {
|
||||
DCHECK(IsCurrent());
|
||||
AttachmentPtr<void>& entry = attachments_[&attachment];
|
||||
internal::AttachmentPtr<void>& entry = attachments_[&attachment];
|
||||
if (entry == nullptr) {
|
||||
entry = attachment.factory()(*this);
|
||||
}
|
||||
|
@ -454,7 +428,8 @@ class GlContext : public std::enable_shared_from_this<GlContext> {
|
|||
// better mechanism?
|
||||
bool can_linear_filter_float_textures_;
|
||||
|
||||
absl::flat_hash_map<const AttachmentBase*, AttachmentPtr<void>> attachments_;
|
||||
absl::flat_hash_map<const AttachmentBase*, internal::AttachmentPtr<void>>
|
||||
attachments_;
|
||||
|
||||
// Number of glFinish calls completed on the GL thread.
|
||||
// Changes should be guarded by mutex_. However, we use simple atomic
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user