Project import generated by Copybara.

GitOrigin-RevId: 796203faee20d7aae2876aac8ca5a1827dee4fe3
This commit is contained in:
MediaPipe Team 2019-09-30 10:18:09 -07:00 committed by jqtang
parent 412ab42d1f
commit a2a63e3876
122 changed files with 7330 additions and 2016 deletions

View File

@ -64,6 +64,12 @@ http_archive(
sha256 = "267103f8a1e9578978aa1dc256001e6529ef593e5aea38193d31c2872ee025e8", sha256 = "267103f8a1e9578978aa1dc256001e6529ef593e5aea38193d31c2872ee025e8",
strip_prefix = "glog-0.3.5", strip_prefix = "glog-0.3.5",
build_file = "@//third_party:glog.BUILD", build_file = "@//third_party:glog.BUILD",
patches = [
"@//third_party:com_github_glog_glog_9779e5ea6ef59562b030248947f787d1256132ae.diff"
],
patch_args = [
"-p1",
],
) )
# libyuv # libyuv

View File

@ -61,7 +61,9 @@ class AudioDecoderCalculator : public CalculatorBase {
::mediapipe::Status AudioDecoderCalculator::GetContract( ::mediapipe::Status AudioDecoderCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set<std::string>(); cc->InputSidePackets().Tag("INPUT_FILE_PATH").Set<std::string>();
if (cc->InputSidePackets().HasTag("OPTIONS")) {
cc->InputSidePackets().Tag("OPTIONS").Set<mediapipe::AudioDecoderOptions>();
}
cc->Outputs().Tag("AUDIO").Set<Matrix>(); cc->Outputs().Tag("AUDIO").Set<Matrix>();
if (cc->Outputs().HasTag("AUDIO_HEADER")) { if (cc->Outputs().HasTag("AUDIO_HEADER")) {
cc->Outputs().Tag("AUDIO_HEADER").SetNone(); cc->Outputs().Tag("AUDIO_HEADER").SetNone();
@ -72,7 +74,9 @@ class AudioDecoderCalculator : public CalculatorBase {
::mediapipe::Status AudioDecoderCalculator::Open(CalculatorContext* cc) { ::mediapipe::Status AudioDecoderCalculator::Open(CalculatorContext* cc) {
const std::string& input_file_path = const std::string& input_file_path =
cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get<std::string>(); cc->InputSidePackets().Tag("INPUT_FILE_PATH").Get<std::string>();
const auto& decoder_options = cc->Options<mediapipe::AudioDecoderOptions>(); const auto& decoder_options =
tool::RetrieveOptions(cc->Options<mediapipe::AudioDecoderOptions>(),
cc->InputSidePackets(), "OPTIONS");
decoder_ = absl::make_unique<AudioDecoder>(); decoder_ = absl::make_unique<AudioDecoder>();
MP_RETURN_IF_ERROR(decoder_->Initialize(input_file_path, decoder_options)); MP_RETURN_IF_ERROR(decoder_->Initialize(input_file_path, decoder_options));
std::unique_ptr<mediapipe::TimeSeriesHeader> header = std::unique_ptr<mediapipe::TimeSeriesHeader> header =

View File

@ -75,8 +75,13 @@ class StabilizedLogCalculator : public CalculatorBase {
::mediapipe::Status Process(CalculatorContext* cc) override { ::mediapipe::Status Process(CalculatorContext* cc) override {
auto input_matrix = cc->Inputs().Index(0).Get<Matrix>(); auto input_matrix = cc->Inputs().Index(0).Get<Matrix>();
if (input_matrix.array().isNaN().any()) {
return ::mediapipe::InvalidArgumentError("NaN input to log operation.");
}
if (check_nonnegativity_) { if (check_nonnegativity_) {
CHECK_GE(input_matrix.minCoeff(), 0); if (input_matrix.minCoeff() < 0.0) {
return ::mediapipe::OutOfRangeError("Negative input to log operation.");
}
} }
std::unique_ptr<Matrix> output_frame(new Matrix( std::unique_ptr<Matrix> output_frame(new Matrix(
output_scale_ * (input_matrix.array() + stabilizer_).log().matrix())); output_scale_ * (input_matrix.array() + stabilizer_).log().matrix()));

View File

@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cmath>
#include "Eigen/Core" #include "Eigen/Core"
#include "mediapipe/calculators/audio/stabilized_log_calculator.pb.h" #include "mediapipe/calculators/audio/stabilized_log_calculator.pb.h"
@ -108,13 +109,22 @@ TEST_F(StabilizedLogCalculatorTest, ZerosAreStabilized) {
runner_->Outputs().Index(0).packets[0].Get<Matrix>()); runner_->Outputs().Index(0).packets[0].Get<Matrix>());
} }
TEST_F(StabilizedLogCalculatorTest, NegativeValuesCheckFail) { TEST_F(StabilizedLogCalculatorTest, NanValuesReturnError) {
InitializeGraph();
FillInputHeader();
AppendInputPacket(
new Matrix(Matrix::Constant(kNumChannels, kNumSamples, std::nanf(""))),
0 /* timestamp */);
ASSERT_FALSE(RunGraph().ok());
}
TEST_F(StabilizedLogCalculatorTest, NegativeValuesReturnError) {
InitializeGraph(); InitializeGraph();
FillInputHeader(); FillInputHeader();
AppendInputPacket( AppendInputPacket(
new Matrix(Matrix::Constant(kNumChannels, kNumSamples, -1.0)), new Matrix(Matrix::Constant(kNumChannels, kNumSamples, -1.0)),
0 /* timestamp */); 0 /* timestamp */);
ASSERT_DEATH(RunGraphNoReturn(), ""); ASSERT_FALSE(RunGraph().ok());
} }
TEST_F(StabilizedLogCalculatorTest, NegativeValuesDoNotCheckFailIfCheckIsOff) { TEST_F(StabilizedLogCalculatorTest, NegativeValuesDoNotCheckFailIfCheckIsOff) {

View File

@ -56,6 +56,14 @@ namespace mediapipe {
// If pad_final_packet is true, all input samples will be emitted and the final // If pad_final_packet is true, all input samples will be emitted and the final
// packet will be zero padded as necessary. If pad_final_packet is false, some // packet will be zero padded as necessary. If pad_final_packet is false, some
// samples may be dropped at the end of the stream. // samples may be dropped at the end of the stream.
//
// If use_local_timestamp is true, the output packet's timestamp is based on the
// last sample of the packet. The timestamp of this sample is inferred by
// input_packet_timesamp + local_sample_index / sampling_rate_. If false, the
// output packet's timestamp is based on the cumulative timestamping, which is
// done by adopting the timestamp of the first sample of the packet and this
// sample's timestamp is inferred by initial_input_timestamp_ +
// cumulative_completed_samples / sample_rate_.
class TimeSeriesFramerCalculator : public CalculatorBase { class TimeSeriesFramerCalculator : public CalculatorBase {
public: public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) { static ::mediapipe::Status GetContract(CalculatorContract* cc) {
@ -86,11 +94,26 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
void FrameOutput(CalculatorContext* cc); void FrameOutput(CalculatorContext* cc);
Timestamp CurrentOutputTimestamp() { Timestamp CurrentOutputTimestamp() {
if (use_local_timestamp_) {
return current_timestamp_;
}
return CumulativeOutputTimestamp();
}
Timestamp CumulativeOutputTimestamp() {
return initial_input_timestamp_ + return initial_input_timestamp_ +
round(cumulative_completed_samples_ / sample_rate_ * round(cumulative_completed_samples_ / sample_rate_ *
Timestamp::kTimestampUnitsPerSecond); Timestamp::kTimestampUnitsPerSecond);
} }
// Returns the timestamp of a sample on a base, which is usually the time
// stamp of a packet.
Timestamp CurrentSampleTimestamp(const Timestamp& timestamp_base,
int64 number_of_samples) {
return timestamp_base + round(number_of_samples / sample_rate_ *
Timestamp::kTimestampUnitsPerSecond);
}
// The number of input samples to advance after the current output frame is // The number of input samples to advance after the current output frame is
// emitted. // emitted.
int next_frame_step_samples() const { int next_frame_step_samples() const {
@ -118,14 +141,18 @@ class TimeSeriesFramerCalculator : public CalculatorBase {
// any overlap). // any overlap).
int64 cumulative_completed_samples_; int64 cumulative_completed_samples_;
Timestamp initial_input_timestamp_; Timestamp initial_input_timestamp_;
// The current timestamp is updated along with the incoming packets.
Timestamp current_timestamp_;
int num_channels_; int num_channels_;
// Each entry in this deque consists of a single sample, i.e. a // Each entry in this deque consists of a single sample, i.e. a
// single column vector. // single column vector, and its timestamp.
std::deque<Matrix> sample_buffer_; std::deque<std::pair<Matrix, Timestamp>> sample_buffer_;
bool use_window_; bool use_window_;
Matrix window_; Matrix window_;
bool use_local_timestamp_;
}; };
REGISTER_CALCULATOR(TimeSeriesFramerCalculator); REGISTER_CALCULATOR(TimeSeriesFramerCalculator);
@ -133,7 +160,8 @@ void TimeSeriesFramerCalculator::EnqueueInput(CalculatorContext* cc) {
const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>(); const Matrix& input_frame = cc->Inputs().Index(0).Get<Matrix>();
for (int i = 0; i < input_frame.cols(); ++i) { for (int i = 0; i < input_frame.cols(); ++i) {
sample_buffer_.emplace_back(input_frame.col(i)); sample_buffer_.emplace_back(std::make_pair(
input_frame.col(i), CurrentSampleTimestamp(cc->InputTimestamp(), i)));
} }
cumulative_input_samples_ += input_frame.cols(); cumulative_input_samples_ += input_frame.cols();
@ -151,14 +179,16 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
new Matrix(num_channels_, frame_duration_samples_)); new Matrix(num_channels_, frame_duration_samples_));
for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_); for (int i = 0; i < std::min(frame_step_samples, frame_duration_samples_);
++i) { ++i) {
output_frame->col(i) = sample_buffer_.front(); output_frame->col(i) = sample_buffer_.front().first;
current_timestamp_ = sample_buffer_.front().second;
sample_buffer_.pop_front(); sample_buffer_.pop_front();
} }
const int frame_overlap_samples = const int frame_overlap_samples =
frame_duration_samples_ - frame_step_samples; frame_duration_samples_ - frame_step_samples;
if (frame_overlap_samples > 0) { if (frame_overlap_samples > 0) {
for (int i = 0; i < frame_overlap_samples; ++i) { for (int i = 0; i < frame_overlap_samples; ++i) {
output_frame->col(i + frame_step_samples) = sample_buffer_[i]; output_frame->col(i + frame_step_samples) = sample_buffer_[i].first;
current_timestamp_ = sample_buffer_[i].second;
} }
} else { } else {
samples_still_to_drop_ = -frame_overlap_samples; samples_still_to_drop_ = -frame_overlap_samples;
@ -178,6 +208,7 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
::mediapipe::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) { ::mediapipe::Status TimeSeriesFramerCalculator::Process(CalculatorContext* cc) {
if (initial_input_timestamp_ == Timestamp::Unstarted()) { if (initial_input_timestamp_ == Timestamp::Unstarted()) {
initial_input_timestamp_ = cc->InputTimestamp(); initial_input_timestamp_ = cc->InputTimestamp();
current_timestamp_ = initial_input_timestamp_;
} }
EnqueueInput(cc); EnqueueInput(cc);
@ -195,7 +226,8 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
std::unique_ptr<Matrix> output_frame(new Matrix); std::unique_ptr<Matrix> output_frame(new Matrix);
output_frame->setZero(num_channels_, frame_duration_samples_); output_frame->setZero(num_channels_, frame_duration_samples_);
for (int i = 0; i < sample_buffer_.size(); ++i) { for (int i = 0; i < sample_buffer_.size(); ++i) {
output_frame->col(i) = sample_buffer_[i]; output_frame->col(i) = sample_buffer_[i].first;
current_timestamp_ = sample_buffer_[i].second;
} }
cc->Outputs().Index(0).Add(output_frame.release(), cc->Outputs().Index(0).Add(output_frame.release(),
@ -258,6 +290,7 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
cumulative_output_frames_ = 0; cumulative_output_frames_ = 0;
samples_still_to_drop_ = 0; samples_still_to_drop_ = 0;
initial_input_timestamp_ = Timestamp::Unstarted(); initial_input_timestamp_ = Timestamp::Unstarted();
current_timestamp_ = Timestamp::Unstarted();
std::vector<double> window_vector; std::vector<double> window_vector;
use_window_ = false; use_window_ = false;
@ -282,6 +315,7 @@ void TimeSeriesFramerCalculator::FrameOutput(CalculatorContext* cc) {
frame_duration_samples_) frame_duration_samples_)
.cast<float>(); .cast<float>();
} }
use_local_timestamp_ = framer_options.use_local_timestamp();
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -62,4 +62,11 @@ message TimeSeriesFramerCalculatorOptions {
HANN = 2; HANN = 2;
} }
optional WindowFunction window_function = 4 [default = NONE]; optional WindowFunction window_function = 4 [default = NONE];
// If use_local_timestamp is true, the output packet's timestamp is based on
// the last sample of the packet and it's inferred from the latest input
// packet's timestamp. If false, the output packet's timestamp is based on
// the cumulative timestamping, which is inferred from the intial input
// timestamp and the cumulative number of samples.
optional bool use_local_timestamp = 6 [default = false];
} }

View File

@ -35,6 +35,8 @@ namespace mediapipe {
namespace { namespace {
const int kInitialTimestampOffsetMicroseconds = 4; const int kInitialTimestampOffsetMicroseconds = 4;
const int kGapBetweenPacketsInSeconds = 1;
const int kUniversalInputPacketSize = 50;
class TimeSeriesFramerCalculatorTest class TimeSeriesFramerCalculatorTest
: public TimeSeriesCalculatorTest<TimeSeriesFramerCalculatorOptions> { : public TimeSeriesCalculatorTest<TimeSeriesFramerCalculatorOptions> {
@ -391,5 +393,93 @@ TEST_F(TimeSeriesFramerCalculatorWindowingSanityTest, HannWindowSanityCheck) {
RunAndTestSinglePacketAverage(0.5f); RunAndTestSinglePacketAverage(0.5f);
} }
} // anonymous namespace // A simple test class that checks the local packet time stamp. This class
// generate a series of packets with and without gaps between packets and tests
// the behavior with cumulative timestamping and local packet timestamping.
class TimeSeriesFramerCalculatorTimestampingTest
: public TimeSeriesFramerCalculatorTest {
protected:
// Creates test input and saves a reference copy.
void InitializeInputForTimeStampingTest() {
concatenated_input_samples_.resize(0, num_input_channels_);
num_input_samples_ = 0;
for (int i = 0; i < 10; ++i) {
// This range of packet sizes was chosen such that some input
// packets will be smaller than the output packet size and other
// input packets will be larger.
int packet_size = kUniversalInputPacketSize;
double timestamp_seconds = kInitialTimestampOffsetMicroseconds * 1.0e-6 +
num_input_samples_ / input_sample_rate_;
if (options_.use_local_timestamp()) {
timestamp_seconds += kGapBetweenPacketsInSeconds * i;
}
Matrix* data_frame =
NewTestFrame(num_input_channels_, packet_size, timestamp_seconds);
AppendInputPacket(data_frame, round(timestamp_seconds *
Timestamp::kTimestampUnitsPerSecond));
num_input_samples_ += packet_size;
}
}
void CheckOutputTimestamps() {
int num_full_packets = output().packets.size();
if (options_.pad_final_packet()) {
num_full_packets -= 1;
}
int64 num_samples = 0;
for (int packet_num = 0; packet_num < num_full_packets; ++packet_num) {
const Packet& packet = output().packets[packet_num];
num_samples += FrameDurationSamples();
double expected_timestamp =
options_.use_local_timestamp()
? GetExpectedLocalTimestampForSample(num_samples - 1)
: GetExpectedCumulativeTimestamp(num_samples - 1);
ASSERT_NEAR(packet.Timestamp().Seconds(), expected_timestamp, 1e-10);
}
}
::mediapipe::Status RunTimestampTest() {
InitializeGraph();
InitializeInputForTimeStampingTest();
FillInputHeader();
return RunGraph();
}
private:
// Returns the timestamp in seconds based on local timestamping.
double GetExpectedLocalTimestampForSample(int sample_index) {
return kInitialTimestampOffsetMicroseconds * 1.0e-6 +
sample_index / input_sample_rate_ +
(sample_index / kUniversalInputPacketSize) *
kGapBetweenPacketsInSeconds;
}
// Returns the timestamp inseconds based on cumulative timestamping.
double GetExpectedCumulativeTimestamp(int sample_index) {
return kInitialTimestampOffsetMicroseconds * 1.0e-6 +
sample_index / FrameDurationSamples() * FrameDurationSamples() /
input_sample_rate_;
}
};
TEST_F(TimeSeriesFramerCalculatorTimestampingTest, UseLocalTimeStamp) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_use_local_timestamp(true);
MP_ASSERT_OK(RunTimestampTest());
CheckOutputTimestamps();
}
TEST_F(TimeSeriesFramerCalculatorTimestampingTest, UseCumulativeTimeStamp) {
options_.set_frame_duration_seconds(100.0 / input_sample_rate_);
options_.set_use_local_timestamp(false);
MP_ASSERT_OK(RunTimestampTest());
CheckOutputTimestamps();
}
} // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -166,7 +166,13 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
], ] + select({
"//mediapipe/gpu:disable_gpu": [],
"//mediapipe:ios": [],
"//conditions:default": [
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
],
}),
alwayslink = 1, alwayslink = 1,
) )

View File

@ -19,6 +19,10 @@
#include "mediapipe/framework/formats/landmark.pb.h" #include "mediapipe/framework/formats/landmark.pb.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#endif // !MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
// Example config: // Example config:
@ -45,4 +49,11 @@ REGISTER_CALCULATOR(ConcatenateTfLiteTensorVectorCalculator);
typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmark> typedef ConcatenateVectorCalculator<::mediapipe::NormalizedLandmark>
ConcatenateLandmarkVectorCalculator; ConcatenateLandmarkVectorCalculator;
REGISTER_CALCULATOR(ConcatenateLandmarkVectorCalculator); REGISTER_CALCULATOR(ConcatenateLandmarkVectorCalculator);
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
typedef ConcatenateVectorCalculator<::tflite::gpu::gl::GlBuffer>
ConcatenateGlBufferVectorCalculator;
REGISTER_CALCULATOR(ConcatenateGlBufferVectorCalculator);
#endif
} // namespace mediapipe } // namespace mediapipe

View File

@ -15,6 +15,7 @@
#ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_ #ifndef MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_ #define MEDIAPIPE_CALCULATORS_CORE_CONCATENATE_VECTOR_CALCULATOR_H_
#include <type_traits>
#include <vector> #include <vector>
#include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h" #include "mediapipe/calculators/core/concatenate_vector_calculator.pb.h"
@ -59,16 +60,58 @@ class ConcatenateVectorCalculator : public CalculatorBase {
if (cc->Inputs().Index(i).IsEmpty()) return ::mediapipe::OkStatus(); if (cc->Inputs().Index(i).IsEmpty()) return ::mediapipe::OkStatus();
} }
} }
auto output = absl::make_unique<std::vector<T>>();
return ConcatenateVectors<T>(std::is_copy_constructible<T>(), cc);
}
template <typename U>
::mediapipe::Status ConcatenateVectors(std::true_type,
CalculatorContext* cc) {
auto output = absl::make_unique<std::vector<U>>();
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) { for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
if (cc->Inputs().Index(i).IsEmpty()) continue; if (cc->Inputs().Index(i).IsEmpty()) continue;
const std::vector<T>& input = cc->Inputs().Index(i).Get<std::vector<T>>(); const std::vector<U>& input = cc->Inputs().Index(i).Get<std::vector<U>>();
output->insert(output->end(), input.begin(), input.end()); output->insert(output->end(), input.begin(), input.end());
} }
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp()); cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
template <typename U>
::mediapipe::Status ConcatenateVectors(std::false_type,
CalculatorContext* cc) {
return ConsumeAndConcatenateVectors<T>(std::is_move_constructible<U>(), cc);
}
template <typename U>
::mediapipe::Status ConsumeAndConcatenateVectors(std::true_type,
CalculatorContext* cc) {
auto output = absl::make_unique<std::vector<U>>();
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
if (cc->Inputs().Index(i).IsEmpty()) continue;
::mediapipe::StatusOr<std::unique_ptr<std::vector<U>>> input_status =
cc->Inputs().Index(i).Value().Consume<std::vector<U>>();
if (input_status.ok()) {
std::unique_ptr<std::vector<U>> input_vector =
std::move(input_status).ValueOrDie();
output->insert(output->end(),
std::make_move_iterator(input_vector->begin()),
std::make_move_iterator(input_vector->end()));
} else {
return input_status.status();
}
}
cc->Outputs().Index(0).Add(output.release(), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
template <typename U>
::mediapipe::Status ConsumeAndConcatenateVectors(std::false_type,
CalculatorContext* cc) {
return ::mediapipe::InternalError(
"Cannot copy or move input vectors to concatenate them");
}
private: private:
bool only_emit_if_all_present_; bool only_emit_if_all_present_;
}; };

View File

@ -235,4 +235,167 @@ TEST(ConcatenateFloatVectorCalculatorTest, OneEmptyStreamNoOutput) {
EXPECT_EQ(0, outputs.size()); EXPECT_EQ(0, outputs.size());
} }
typedef ConcatenateVectorCalculator<std::unique_ptr<int>>
TestConcatenateUniqueIntPtrCalculator;
REGISTER_CALCULATOR(TestConcatenateUniqueIntPtrCalculator);
TEST(TestConcatenateUniqueIntVectorCalculatorTest, ConsumeOneTimestamp) {
/* Note: We don't use CalculatorRunner for this test because it keeps copies
* of input packets, so packets sent to the graph don't have sole ownership.
* The test needs to send packets that own the data.
*/
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "in_1"
input_stream: "in_2"
input_stream: "in_3"
node {
calculator: "TestConcatenateUniqueIntPtrCalculator"
input_stream: "in_1"
input_stream: "in_2"
input_stream: "in_3"
output_stream: "out"
}
)");
std::vector<Packet> outputs;
tool::AddVectorSink("out", &graph_config, &outputs);
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config));
MP_EXPECT_OK(graph.StartRun({}));
// input1 : {0, 1, 2}
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_1 =
absl::make_unique<std::vector<std::unique_ptr<int>>>(3);
for (int i = 0; i < 3; ++i) {
input_1->at(i) = absl::make_unique<int>(i);
}
// input2: {3}
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_2 =
absl::make_unique<std::vector<std::unique_ptr<int>>>(1);
input_2->at(0) = absl::make_unique<int>(3);
// input3: {4, 5}
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_3 =
absl::make_unique<std::vector<std::unique_ptr<int>>>(2);
input_3->at(0) = absl::make_unique<int>(4);
input_3->at(1) = absl::make_unique<int>(5);
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_1", Adopt(input_1.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_2", Adopt(input_2.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_3", Adopt(input_3.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const std::vector<std::unique_ptr<int>>& result =
outputs[0].Get<std::vector<std::unique_ptr<int>>>();
EXPECT_EQ(6, result.size());
for (int i = 0; i < 6; ++i) {
const std::unique_ptr<int>& v = result[i];
EXPECT_EQ(i, *v);
}
}
TEST(TestConcatenateUniqueIntVectorCalculatorTest, OneEmptyStreamStillOutput) {
/* Note: We don't use CalculatorRunner for this test because it keeps copies
* of input packets, so packets sent to the graph don't have sole ownership.
* The test needs to send packets that own the data.
*/
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "in_1"
input_stream: "in_2"
node {
calculator: "TestConcatenateUniqueIntPtrCalculator"
input_stream: "in_1"
input_stream: "in_2"
output_stream: "out"
}
)");
std::vector<Packet> outputs;
tool::AddVectorSink("out", &graph_config, &outputs);
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config));
MP_EXPECT_OK(graph.StartRun({}));
// input1 : {0, 1, 2}
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_1 =
absl::make_unique<std::vector<std::unique_ptr<int>>>(3);
for (int i = 0; i < 3; ++i) {
input_1->at(i) = absl::make_unique<int>(i);
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_1", Adopt(input_1.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_EQ(1, outputs.size());
EXPECT_EQ(Timestamp(1), outputs[0].Timestamp());
const std::vector<std::unique_ptr<int>>& result =
outputs[0].Get<std::vector<std::unique_ptr<int>>>();
EXPECT_EQ(3, result.size());
for (int i = 0; i < 3; ++i) {
const std::unique_ptr<int>& v = result[i];
EXPECT_EQ(i, *v);
}
}
TEST(TestConcatenateUniqueIntVectorCalculatorTest, OneEmptyStreamNoOutput) {
/* Note: We don't use CalculatorRunner for this test because it keeps copies
* of input packets, so packets sent to the graph don't have sole ownership.
* The test needs to send packets that own the data.
*/
CalculatorGraphConfig graph_config =
ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "in_1"
input_stream: "in_2"
node {
calculator: "TestConcatenateUniqueIntPtrCalculator"
input_stream: "in_1"
input_stream: "in_2"
output_stream: "out"
options {
[mediapipe.ConcatenateVectorCalculatorOptions.ext] {
only_emit_if_all_present: true
}
}
}
)");
std::vector<Packet> outputs;
tool::AddVectorSink("out", &graph_config, &outputs);
CalculatorGraph graph;
MP_EXPECT_OK(graph.Initialize(graph_config));
MP_EXPECT_OK(graph.StartRun({}));
// input1 : {0, 1, 2}
std::unique_ptr<std::vector<std::unique_ptr<int>>> input_1 =
absl::make_unique<std::vector<std::unique_ptr<int>>>(3);
for (int i = 0; i < 3; ++i) {
input_1->at(i) = absl::make_unique<int>(i);
}
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in_1", Adopt(input_1.release()).At(Timestamp(1))));
MP_EXPECT_OK(graph.WaitUntilIdle());
MP_EXPECT_OK(graph.CloseAllPacketSources());
MP_EXPECT_OK(graph.WaitUntilDone());
EXPECT_EQ(0, outputs.size());
}
} // namespace mediapipe } // namespace mediapipe

View File

@ -19,7 +19,6 @@ package(default_visibility = ["//visibility:private"])
exports_files(["LICENSE"]) exports_files(["LICENSE"])
load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library") load("//mediapipe/framework/port:build_config.bzl", "mediapipe_cc_proto_library")
load("@bazel_skylib//lib:selects.bzl", "selects")
proto_library( proto_library(
name = "opencv_image_encoder_calculator_proto", name = "opencv_image_encoder_calculator_proto",
@ -227,19 +226,13 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",
] + select({ ] + select({
"//mediapipe:android": [ "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:shader_util", "//mediapipe/gpu:shader_util",
], ],
"//mediapipe:ios": [
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:shader_util",
],
"//conditions:default": [],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -263,13 +256,13 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",
] + select({ ] + select({
"//mediapipe:android": [ "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:shader_util", "//mediapipe/gpu:shader_util",
], ],
"//conditions:default": [],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -322,14 +315,14 @@ cc_library(
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
] + selects.with_or({ ] + select({
("//mediapipe:android", "//mediapipe:ios"): [ "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:shader_util", "//mediapipe/gpu:shader_util",
], ],
"//conditions:default": [],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -363,14 +356,15 @@ cc_library(
"//mediapipe/framework/port:opencv_imgproc", "//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
] + selects.with_or({ "//mediapipe/gpu:gpu_buffer",
("//mediapipe:android", "//mediapipe:ios"): [ ] + select({
"//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gl_quad_renderer", "//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:shader_util", "//mediapipe/gpu:shader_util",
], ],
"//conditions:default": [],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -415,19 +409,13 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/util:color_cc_proto", "//mediapipe/util:color_cc_proto",
] + select({ ] + select({
"//mediapipe:android": [ "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gl_quad_renderer",
"//mediapipe/gpu:shader_util", "//mediapipe/gpu:shader_util",
], ],
"//mediapipe:ios": [
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:shader_util",
],
"//conditions:default": [],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -486,11 +474,11 @@ cc_library(
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
] + selects.with_or({ ] + select({
("//mediapipe:android", "//mediapipe:ios"): [ "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
], ],
"//conditions:default": [],
}), }),
alwayslink = 1, alwayslink = 1,
) )

View File

@ -27,11 +27,11 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h" #include "mediapipe/framework/port/vector.h"
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_DISABLE_GPU)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h" #include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__ || __EMSCRIPTEN__ #endif // !MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
@ -101,11 +101,11 @@ class BilateralFilterCalculator : public CalculatorBase {
bool use_gpu_ = false; bool use_gpu_ = false;
bool gpu_initialized_ = false; bool gpu_initialized_ = false;
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_DISABLE_GPU)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
GLuint program_ = 0; GLuint program_ = 0;
GLuint program_joint_ = 0; GLuint program_joint_ = 0;
#endif // __ANDROID__ || __EMSCRIPTEN__ #endif // !MEDIAPIPE_DISABLE_GPU
}; };
REGISTER_CALCULATOR(BilateralFilterCalculator); REGISTER_CALCULATOR(BilateralFilterCalculator);
@ -122,39 +122,46 @@ REGISTER_CALCULATOR(BilateralFilterCalculator);
return ::mediapipe::InternalError("GPU output must have GPU input."); return ::mediapipe::InternalError("GPU output must have GPU input.");
} }
bool use_gpu = false;
// Input image to filter. // Input image to filter.
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag(kInputFrameTagGpu)) { if (cc->Inputs().HasTag(kInputFrameTagGpu)) {
cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ || __EMSCRIPTEN__ #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kInputFrameTag)) { if (cc->Inputs().HasTag(kInputFrameTag)) {
cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>(); cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>();
} }
// Input guide image mask (optional) // Input guide image mask (optional)
#if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag(kInputGuideTagGpu)) { if (cc->Inputs().HasTag(kInputGuideTagGpu)) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
cc->Inputs().Tag(kInputGuideTagGpu).Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kInputGuideTagGpu).Set<mediapipe::GpuBuffer>();
#endif // __ANDROID__ || __EMSCRIPTEN__ use_gpu |= true;
} }
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kInputGuideTag)) { if (cc->Inputs().HasTag(kInputGuideTag)) {
cc->Inputs().Tag(kInputGuideTag).Set<ImageFrame>(); cc->Inputs().Tag(kInputGuideTag).Set<ImageFrame>();
} }
// Output image. // Output image.
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { if (cc->Outputs().HasTag(kOutputFrameTagGpu)) {
cc->Outputs().Tag(kOutputFrameTagGpu).Set<mediapipe::GpuBuffer>(); cc->Outputs().Tag(kOutputFrameTagGpu).Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ || __EMSCRIPTEN__ #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag(kOutputFrameTag)) { if (cc->Outputs().HasTag(kOutputFrameTag)) {
cc->Outputs().Tag(kOutputFrameTag).Set<ImageFrame>(); cc->Outputs().Tag(kOutputFrameTag).Set<ImageFrame>();
} }
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) if (use_gpu) {
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #if !defined(MEDIAPIPE_DISABLE_GPU)
#endif // __ANDROID__ || __EMSCRIPTEN__ MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // !MEDIAPIPE_DISABLE_GPU
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -166,11 +173,11 @@ REGISTER_CALCULATOR(BilateralFilterCalculator);
if (cc->Inputs().HasTag(kInputFrameTagGpu) && if (cc->Inputs().HasTag(kInputFrameTagGpu) &&
cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().HasTag(kOutputFrameTagGpu)) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_DISABLE_GPU)
use_gpu_ = true; use_gpu_ = true;
#else #else
RET_CHECK_FAIL() << "GPU processing on non-Android not supported yet."; RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif // __ANDROID__ || __EMSCRIPTEN__ #endif
} }
sigma_color_ = options_.sigma_color(); sigma_color_ = options_.sigma_color();
@ -180,9 +187,9 @@ REGISTER_CALCULATOR(BilateralFilterCalculator);
if (!use_gpu_) sigma_color_ *= 255.0; if (!use_gpu_) sigma_color_ *= 255.0;
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_DISABLE_GPU)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif #endif // !MEDIAPIPE_DISABLE_GPU
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
@ -190,7 +197,7 @@ REGISTER_CALCULATOR(BilateralFilterCalculator);
::mediapipe::Status BilateralFilterCalculator::Process(CalculatorContext* cc) { ::mediapipe::Status BilateralFilterCalculator::Process(CalculatorContext* cc) {
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_DISABLE_GPU)
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
if (!gpu_initialized_) { if (!gpu_initialized_) {
@ -200,7 +207,7 @@ REGISTER_CALCULATOR(BilateralFilterCalculator);
MP_RETURN_IF_ERROR(RenderGpu(cc)); MP_RETURN_IF_ERROR(RenderGpu(cc));
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
#endif // __ANDROID__ || __EMSCRIPTEN__ #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else {
MP_RETURN_IF_ERROR(RenderCpu(cc)); MP_RETURN_IF_ERROR(RenderCpu(cc));
} }
@ -209,14 +216,14 @@ REGISTER_CALCULATOR(BilateralFilterCalculator);
} }
::mediapipe::Status BilateralFilterCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status BilateralFilterCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_DISABLE_GPU)
gpu_helper_.RunInGlContext([this] { gpu_helper_.RunInGlContext([this] {
if (program_) glDeleteProgram(program_); if (program_) glDeleteProgram(program_);
program_ = 0; program_ = 0;
if (program_joint_) glDeleteProgram(program_joint_); if (program_joint_) glDeleteProgram(program_joint_);
program_joint_ = 0; program_joint_ = 0;
}); });
#endif // __ANDROID__ || __EMSCRIPTEN__ #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -263,7 +270,7 @@ REGISTER_CALCULATOR(BilateralFilterCalculator);
if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) { if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_DISABLE_GPU)
const auto& input_frame = const auto& input_frame =
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
auto input_texture = gpu_helper_.CreateSourceTexture(input_frame); auto input_texture = gpu_helper_.CreateSourceTexture(input_frame);
@ -321,13 +328,13 @@ REGISTER_CALCULATOR(BilateralFilterCalculator);
// Cleanup // Cleanup
input_texture.Release(); input_texture.Release();
output_texture.Release(); output_texture.Release();
#endif // __ANDROID__ || __EMSCRIPTEN__ #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
void BilateralFilterCalculator::GlRender(CalculatorContext* cc) { void BilateralFilterCalculator::GlRender(CalculatorContext* cc) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_DISABLE_GPU)
static const GLfloat square_vertices[] = { static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left -1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right 1.0f, -1.0f, // bottom right
@ -373,11 +380,11 @@ void BilateralFilterCalculator::GlRender(CalculatorContext* cc) {
glDeleteVertexArrays(1, &vao); glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo); glDeleteBuffers(2, vbo);
#endif // __ANDROID__ || __EMSCRIPTEN__ #endif // !MEDIAPIPE_DISABLE_GPU
} }
::mediapipe::Status BilateralFilterCalculator::GlSetup(CalculatorContext* cc) { ::mediapipe::Status BilateralFilterCalculator::GlSetup(CalculatorContext* cc) {
#if defined(__ANDROID__) || defined(__EMSCRIPTEN__) #if !defined(MEDIAPIPE_DISABLE_GPU)
const GLint attr_location[NUM_ATTRIBUTES] = { const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX, ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION, ATTRIB_TEXTURE_POSITION,
@ -545,7 +552,7 @@ void BilateralFilterCalculator::GlRender(CalculatorContext* cc) {
glUniform1i(glGetUniformLocation(program_joint_, "input_frame"), 1); glUniform1i(glGetUniformLocation(program_joint_, "input_frame"), 1);
glUniform1i(glGetUniformLocation(program_joint_, "guide_frame"), 2); glUniform1i(glGetUniformLocation(program_joint_, "guide_frame"), 2);
#endif // __ANDROID__ || __EMSCRIPTEN__ #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -24,12 +24,12 @@
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/shader_util.h" #include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
namespace { namespace {
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
@ -37,9 +37,20 @@ enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
namespace mediapipe { namespace mediapipe {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) namespace {
#endif // __ANDROID__ or iOS #if !defined(MEDIAPIPE_DISABLE_GPU)
#endif // !MEDIAPIPE_DISABLE_GPU
constexpr char kRectTag[] = "RECT";
constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kHeightTag[] = "HEIGHT";
constexpr char kImageTag[] = "IMAGE";
constexpr char kImageGpuTag[] = "IMAGE_GPU";
constexpr char kWidthTag[] = "WIDTH";
} // namespace
// Crops the input texture to the given rectangle region. The rectangle can // Crops the input texture to the given rectangle region. The rectangle can
// be at arbitrary location on the image with rotation. If there's rotation, the // be at arbitrary location on the image with rotation. If there's rotation, the
@ -91,48 +102,55 @@ class ImageCroppingCalculator : public CalculatorBase {
bool use_gpu_ = false; bool use_gpu_ = false;
// Output texture corners (4) after transoformation in normalized coordinates. // Output texture corners (4) after transoformation in normalized coordinates.
float transformed_points_[8]; float transformed_points_[8];
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
bool gpu_initialized_ = false; bool gpu_initialized_ = false;
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
GLuint program_ = 0; GLuint program_ = 0;
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
}; };
REGISTER_CALCULATOR(ImageCroppingCalculator); REGISTER_CALCULATOR(ImageCroppingCalculator);
::mediapipe::Status ImageCroppingCalculator::GetContract( ::mediapipe::Status ImageCroppingCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU")); RET_CHECK(cc->Inputs().HasTag(kImageTag) ^ cc->Inputs().HasTag(kImageGpuTag));
RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU")); RET_CHECK(cc->Outputs().HasTag(kImageTag) ^
cc->Outputs().HasTag(kImageGpuTag));
if (cc->Inputs().HasTag("IMAGE")) { bool use_gpu = false;
RET_CHECK(cc->Outputs().HasTag("IMAGE"));
cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
}
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX)
if (cc->Inputs().HasTag("IMAGE_GPU")) {
RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU"));
cc->Inputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
cc->Outputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
}
#endif // __ANDROID__ or iOS
if (cc->Inputs().HasTag("RECT")) { if (cc->Inputs().HasTag(kImageTag)) {
cc->Inputs().Tag("RECT").Set<Rect>(); RET_CHECK(cc->Outputs().HasTag(kImageTag));
cc->Inputs().Tag(kImageTag).Set<ImageFrame>();
cc->Outputs().Tag(kImageTag).Set<ImageFrame>();
} }
if (cc->Inputs().HasTag("NORM_RECT")) { #if !defined(MEDIAPIPE_DISABLE_GPU)
cc->Inputs().Tag("NORM_RECT").Set<NormalizedRect>(); if (cc->Inputs().HasTag(kImageGpuTag)) {
RET_CHECK(cc->Outputs().HasTag(kImageGpuTag));
cc->Inputs().Tag(kImageGpuTag).Set<GpuBuffer>();
cc->Outputs().Tag(kImageGpuTag).Set<GpuBuffer>();
use_gpu |= true;
} }
if (cc->Inputs().HasTag("WIDTH")) { #endif // !MEDIAPIPE_DISABLE_GPU
cc->Inputs().Tag("WIDTH").Set<int>();
RET_CHECK(cc->Inputs().HasTag(kRectTag) ^ cc->Inputs().HasTag(kNormRectTag));
if (cc->Inputs().HasTag(kRectTag)) {
cc->Inputs().Tag(kRectTag).Set<Rect>();
} }
if (cc->Inputs().HasTag("HEIGHT")) { if (cc->Inputs().HasTag(kNormRectTag)) {
cc->Inputs().Tag("HEIGHT").Set<int>(); cc->Inputs().Tag(kNormRectTag).Set<NormalizedRect>();
}
if (cc->Inputs().HasTag(kWidthTag)) {
cc->Inputs().Tag(kWidthTag).Set<int>();
}
if (cc->Inputs().HasTag(kHeightTag)) {
cc->Inputs().Tag(kHeightTag).Set<int>();
} }
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (use_gpu) {
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #if !defined(MEDIAPIPE_DISABLE_GPU)
#endif // __ANDROID__ or iOS MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // !MEDIAPIPE_DISABLE_GPU
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -140,26 +158,35 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
::mediapipe::Status ImageCroppingCalculator::Open(CalculatorContext* cc) { ::mediapipe::Status ImageCroppingCalculator::Open(CalculatorContext* cc) {
cc->SetOffset(TimestampDiff(0)); cc->SetOffset(TimestampDiff(0));
if (cc->Inputs().HasTag("IMAGE_GPU")) { if (cc->Inputs().HasTag(kImageGpuTag)) {
use_gpu_ = true; use_gpu_ = true;
} }
options_ = cc->Options<mediapipe::ImageCroppingCalculatorOptions>(); options_ = cc->Options<mediapipe::ImageCroppingCalculatorOptions>();
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#else #else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; RET_CHECK_FAIL() << "GPU processing is for Android and iOS only.";
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status ImageCroppingCalculator::Process(CalculatorContext* cc) { ::mediapipe::Status ImageCroppingCalculator::Process(CalculatorContext* cc) {
if (cc->Inputs().HasTag(kRectTag) && cc->Inputs().Tag(kRectTag).IsEmpty()) {
VLOG(1) << "RECT is empty for timestamp: " << cc->InputTimestamp();
return ::mediapipe::OkStatus();
}
if (cc->Inputs().HasTag(kNormRectTag) &&
cc->Inputs().Tag(kNormRectTag).IsEmpty()) {
VLOG(1) << "NORM_RECT is empty for timestamp: " << cc->InputTimestamp();
return ::mediapipe::OkStatus();
}
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
if (!gpu_initialized_) { if (!gpu_initialized_) {
@ -169,7 +196,7 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
MP_RETURN_IF_ERROR(RenderGpu(cc)); MP_RETURN_IF_ERROR(RenderGpu(cc));
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else {
MP_RETURN_IF_ERROR(RenderCpu(cc)); MP_RETURN_IF_ERROR(RenderCpu(cc));
} }
@ -177,19 +204,22 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
} }
::mediapipe::Status ImageCroppingCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status ImageCroppingCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
gpu_helper_.RunInGlContext([this] { gpu_helper_.RunInGlContext([this] {
if (program_) glDeleteProgram(program_); if (program_) glDeleteProgram(program_);
program_ = 0; program_ = 0;
}); });
gpu_initialized_ = false; gpu_initialized_ = false;
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status ImageCroppingCalculator::RenderCpu(CalculatorContext* cc) { ::mediapipe::Status ImageCroppingCalculator::RenderCpu(CalculatorContext* cc) {
const auto& input_img = cc->Inputs().Tag("IMAGE").Get<ImageFrame>(); if (cc->Inputs().Tag(kImageTag).IsEmpty()) {
return ::mediapipe::OkStatus();
}
const auto& input_img = cc->Inputs().Tag(kImageTag).Get<ImageFrame>();
cv::Mat input_mat = formats::MatView(&input_img); cv::Mat input_mat = formats::MatView(&input_img);
float rect_center_x = input_img.Width() / 2.0f; float rect_center_x = input_img.Width() / 2.0f;
@ -197,8 +227,8 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
float rotation = 0.0f; float rotation = 0.0f;
int target_width = input_img.Width(); int target_width = input_img.Width();
int target_height = input_img.Height(); int target_height = input_img.Height();
if (cc->Inputs().HasTag("RECT")) { if (cc->Inputs().HasTag(kRectTag)) {
const auto& rect = cc->Inputs().Tag("RECT").Get<Rect>(); const auto& rect = cc->Inputs().Tag(kRectTag).Get<Rect>();
if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 && if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 &&
rect.y_center() >= 0) { rect.y_center() >= 0) {
rect_center_x = rect.x_center(); rect_center_x = rect.x_center();
@ -207,8 +237,8 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
target_height = rect.height(); target_height = rect.height();
rotation = rect.rotation(); rotation = rect.rotation();
} }
} else if (cc->Inputs().HasTag("NORM_RECT")) { } else if (cc->Inputs().HasTag(kNormRectTag)) {
const auto& rect = cc->Inputs().Tag("NORM_RECT").Get<NormalizedRect>(); const auto& rect = cc->Inputs().Tag(kNormRectTag).Get<NormalizedRect>();
if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 && if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 &&
rect.y_center() >= 0.0) { rect.y_center() >= 0.0) {
rect_center_x = std::round(rect.x_center() * input_img.Width()); rect_center_x = std::round(rect.x_center() * input_img.Width());
@ -218,9 +248,9 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
rotation = rect.rotation(); rotation = rect.rotation();
} }
} else { } else {
if (cc->Inputs().HasTag("WIDTH") && cc->Inputs().HasTag("HEIGHT")) { if (cc->Inputs().HasTag(kWidthTag) && cc->Inputs().HasTag(kHeightTag)) {
target_width = cc->Inputs().Tag("WIDTH").Get<int>(); target_width = cc->Inputs().Tag(kWidthTag).Get<int>();
target_height = cc->Inputs().Tag("HEIGHT").Get<int>(); target_height = cc->Inputs().Tag(kHeightTag).Get<int>();
} else if (options_.has_width() && options_.has_height()) { } else if (options_.has_width() && options_.has_height()) {
target_width = options_.width(); target_width = options_.width();
target_height = options_.height(); target_height = options_.height();
@ -253,16 +283,17 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
input_img.Format(), cropped_image.cols, cropped_image.rows)); input_img.Format(), cropped_image.cols, cropped_image.rows));
cv::Mat output_mat = formats::MatView(output_frame.get()); cv::Mat output_mat = formats::MatView(output_frame.get());
cropped_image.copyTo(output_mat); cropped_image.copyTo(output_mat);
cc->Outputs().Tag("IMAGE").Add(output_frame.release(), cc->InputTimestamp()); cc->Outputs().Tag(kImageTag).Add(output_frame.release(),
cc->InputTimestamp());
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) { ::mediapipe::Status ImageCroppingCalculator::RenderGpu(CalculatorContext* cc) {
if (cc->Inputs().Tag("IMAGE_GPU").IsEmpty()) { if (cc->Inputs().Tag(kImageGpuTag).IsEmpty()) {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value(); const Packet& input_packet = cc->Inputs().Tag(kImageGpuTag).Value();
const auto& input_buffer = input_packet.Get<mediapipe::GpuBuffer>(); const auto& input_buffer = input_packet.Get<mediapipe::GpuBuffer>();
auto src_tex = gpu_helper_.CreateSourceTexture(input_buffer); auto src_tex = gpu_helper_.CreateSourceTexture(input_buffer);
@ -287,18 +318,18 @@ REGISTER_CALCULATOR(ImageCroppingCalculator);
// Send result image in GPU packet. // Send result image in GPU packet.
auto output = dst_tex.GetFrame<mediapipe::GpuBuffer>(); auto output = dst_tex.GetFrame<mediapipe::GpuBuffer>();
cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp()); cc->Outputs().Tag(kImageGpuTag).Add(output.release(), cc->InputTimestamp());
// Cleanup // Cleanup
src_tex.Release(); src_tex.Release();
dst_tex.Release(); dst_tex.Release();
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
void ImageCroppingCalculator::GlRender() { void ImageCroppingCalculator::GlRender() {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
static const GLfloat square_vertices[] = { static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left -1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right 1.0f, -1.0f, // bottom right
@ -342,11 +373,11 @@ void ImageCroppingCalculator::GlRender() {
glDeleteVertexArrays(1, &vao); glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo); glDeleteBuffers(2, vbo);
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} }
::mediapipe::Status ImageCroppingCalculator::InitGpu(CalculatorContext* cc) { ::mediapipe::Status ImageCroppingCalculator::InitGpu(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
const GLint attr_location[NUM_ATTRIBUTES] = { const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX, ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION, ATTRIB_TEXTURE_POSITION,
@ -392,7 +423,7 @@ void ImageCroppingCalculator::GlRender() {
// Parameters // Parameters
glUseProgram(program_); glUseProgram(program_);
glUniform1i(glGetUniformLocation(program_, "input_frame"), 1); glUniform1i(glGetUniformLocation(program_, "input_frame"), 1);
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -410,8 +441,8 @@ void ImageCroppingCalculator::GetOutputDimensions(CalculatorContext* cc,
int y_center = src_height / 2; int y_center = src_height / 2;
// Get the rotation of the cropping box. // Get the rotation of the cropping box.
float rotation = 0.0f; float rotation = 0.0f;
if (cc->Inputs().HasTag("RECT")) { if (cc->Inputs().HasTag(kRectTag)) {
const auto& rect = cc->Inputs().Tag("RECT").Get<Rect>(); const auto& rect = cc->Inputs().Tag(kRectTag).Get<Rect>();
// Only use the rect if it is valid. // Only use the rect if it is valid.
if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 && if (rect.width() > 0 && rect.height() > 0 && rect.x_center() >= 0 &&
rect.y_center() >= 0) { rect.y_center() >= 0) {
@ -421,8 +452,8 @@ void ImageCroppingCalculator::GetOutputDimensions(CalculatorContext* cc,
crop_height = rect.height(); crop_height = rect.height();
rotation = rect.rotation(); rotation = rect.rotation();
} }
} else if (cc->Inputs().HasTag("NORM_RECT")) { } else if (cc->Inputs().HasTag(kNormRectTag)) {
const auto& rect = cc->Inputs().Tag("NORM_RECT").Get<NormalizedRect>(); const auto& rect = cc->Inputs().Tag(kNormRectTag).Get<NormalizedRect>();
// Only use the rect if it is valid. // Only use the rect if it is valid.
if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 && if (rect.width() > 0.0 && rect.height() > 0.0 && rect.x_center() >= 0.0 &&
rect.y_center() >= 0.0) { rect.y_center() >= 0.0) {
@ -433,9 +464,9 @@ void ImageCroppingCalculator::GetOutputDimensions(CalculatorContext* cc,
rotation = rect.rotation(); rotation = rect.rotation();
} }
} else { } else {
if (cc->Inputs().HasTag("WIDTH") && cc->Inputs().HasTag("HEIGHT")) { if (cc->Inputs().HasTag(kWidthTag) && cc->Inputs().HasTag(kHeightTag)) {
crop_width = cc->Inputs().Tag("WIDTH").Get<int>(); crop_width = cc->Inputs().Tag(kWidthTag).Get<int>();
crop_height = cc->Inputs().Tag("HEIGHT").Get<int>(); crop_height = cc->Inputs().Tag(kHeightTag).Get<int>();
} else if (options_.has_width() && options_.has_height()) { } else if (options_.has_width() && options_.has_height()) {
crop_width = options_.width(); crop_width = options_.width();
crop_height = options_.height(); crop_height = options_.height();

View File

@ -15,9 +15,9 @@
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
@ -44,11 +44,11 @@ class ImagePropertiesCalculator : public CalculatorBase {
if (cc->Inputs().HasTag("IMAGE")) { if (cc->Inputs().HasTag("IMAGE")) {
cc->Inputs().Tag("IMAGE").Set<ImageFrame>(); cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
} }
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag("IMAGE_GPU")) { if (cc->Inputs().HasTag("IMAGE_GPU")) {
cc->Inputs().Tag("IMAGE_GPU").Set<::mediapipe::GpuBuffer>(); cc->Inputs().Tag("IMAGE_GPU").Set<::mediapipe::GpuBuffer>();
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag("SIZE")) { if (cc->Outputs().HasTag("SIZE")) {
cc->Outputs().Tag("SIZE").Set<std::pair<int, int>>(); cc->Outputs().Tag("SIZE").Set<std::pair<int, int>>();
@ -71,7 +71,7 @@ class ImagePropertiesCalculator : public CalculatorBase {
width = image.Width(); width = image.Width();
height = image.Height(); height = image.Height();
} }
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag("IMAGE_GPU") && if (cc->Inputs().HasTag("IMAGE_GPU") &&
!cc->Inputs().Tag("IMAGE_GPU").IsEmpty()) { !cc->Inputs().Tag("IMAGE_GPU").IsEmpty()) {
const auto& image = const auto& image =
@ -79,7 +79,7 @@ class ImagePropertiesCalculator : public CalculatorBase {
width = image.width(); width = image.width();
height = image.height(); height = image.height();
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
cc->Outputs().Tag("SIZE").AddPacket( cc->Outputs().Tag("SIZE").AddPacket(
MakePacket<std::pair<int, int>>(width, height) MakePacket<std::pair<int, int>>(width, height)

View File

@ -22,12 +22,12 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/gpu/scale_mode.pb.h" #include "mediapipe/gpu/scale_mode.pb.h"
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX #if !defined(MEDIAPIPE_DISABLE_GPU)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_quad_renderer.h" #include "mediapipe/gpu/gl_quad_renderer.h"
#include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h" #include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__ || iOS #endif // !MEDIAPIPE_DISABLE_GPU
#if defined(__ANDROID__) #if defined(__ANDROID__)
// The size of Java arrays is dynamic, which makes it difficult to // The size of Java arrays is dynamic, which makes it difficult to
@ -42,9 +42,9 @@ typedef int DimensionsPacketType[2];
namespace mediapipe { namespace mediapipe {
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX #if !defined(MEDIAPIPE_DISABLE_GPU)
#endif // __ANDROID__ || iOS #endif // !MEDIAPIPE_DISABLE_GPU
namespace { namespace {
int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) { int RotationModeToDegrees(mediapipe::RotationMode_Mode rotation) {
@ -170,12 +170,12 @@ class ImageTransformationCalculator : public CalculatorBase {
mediapipe::ScaleMode_Mode scale_mode_; mediapipe::ScaleMode_Mode scale_mode_;
bool use_gpu_ = false; bool use_gpu_ = false;
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX #if !defined(MEDIAPIPE_DISABLE_GPU)
GlCalculatorHelper helper_; GlCalculatorHelper helper_;
std::unique_ptr<QuadRenderer> rgb_renderer_; std::unique_ptr<QuadRenderer> rgb_renderer_;
std::unique_ptr<QuadRenderer> yuv_renderer_; std::unique_ptr<QuadRenderer> yuv_renderer_;
std::unique_ptr<QuadRenderer> ext_rgb_renderer_; std::unique_ptr<QuadRenderer> ext_rgb_renderer_;
#endif // __ANDROID__ || iOS #endif // !MEDIAPIPE_DISABLE_GPU
}; };
REGISTER_CALCULATOR(ImageTransformationCalculator); REGISTER_CALCULATOR(ImageTransformationCalculator);
@ -185,18 +185,22 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU")); RET_CHECK(cc->Inputs().HasTag("IMAGE") ^ cc->Inputs().HasTag("IMAGE_GPU"));
RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU")); RET_CHECK(cc->Outputs().HasTag("IMAGE") ^ cc->Outputs().HasTag("IMAGE_GPU"));
bool use_gpu = false;
if (cc->Inputs().HasTag("IMAGE")) { if (cc->Inputs().HasTag("IMAGE")) {
RET_CHECK(cc->Outputs().HasTag("IMAGE")); RET_CHECK(cc->Outputs().HasTag("IMAGE"));
cc->Inputs().Tag("IMAGE").Set<ImageFrame>(); cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
cc->Outputs().Tag("IMAGE").Set<ImageFrame>(); cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
} }
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag("IMAGE_GPU")) { if (cc->Inputs().HasTag("IMAGE_GPU")) {
RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU")); RET_CHECK(cc->Outputs().HasTag("IMAGE_GPU"));
cc->Inputs().Tag("IMAGE_GPU").Set<GpuBuffer>(); cc->Inputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
cc->Outputs().Tag("IMAGE_GPU").Set<GpuBuffer>(); cc->Outputs().Tag("IMAGE_GPU").Set<GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ || iOS #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag("ROTATION_DEGREES")) { if (cc->Inputs().HasTag("ROTATION_DEGREES")) {
cc->Inputs().Tag("ROTATION_DEGREES").Set<int>(); cc->Inputs().Tag("ROTATION_DEGREES").Set<int>();
} }
@ -212,9 +216,11 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
cc->Outputs().Tag("LETTERBOX_PADDING").Set<std::array<float, 4>>(); cc->Outputs().Tag("LETTERBOX_PADDING").Set<std::array<float, 4>>();
} }
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX if (use_gpu) {
MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc)); #if !defined(MEDIAPIPE_DISABLE_GPU)
#endif // __ANDROID__ || iOS MP_RETURN_IF_ERROR(GlCalculatorHelper::UpdateContract(cc));
#endif // !MEDIAPIPE_DISABLE_GPU
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -250,12 +256,12 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
scale_mode_ = ParseScaleMode(options_.scale_mode(), DEFAULT_SCALE_MODE); scale_mode_ = ParseScaleMode(options_.scale_mode(), DEFAULT_SCALE_MODE);
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX #if !defined(MEDIAPIPE_DISABLE_GPU)
// Let the helper access the GL context information. // Let the helper access the GL context information.
MP_RETURN_IF_ERROR(helper_.Open(cc)); MP_RETURN_IF_ERROR(helper_.Open(cc));
#else #else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif // __ANDROID__ || iOS #endif // !MEDIAPIPE_DISABLE_GPU
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
@ -264,10 +270,10 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
::mediapipe::Status ImageTransformationCalculator::Process( ::mediapipe::Status ImageTransformationCalculator::Process(
CalculatorContext* cc) { CalculatorContext* cc) {
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX #if !defined(MEDIAPIPE_DISABLE_GPU)
return helper_.RunInGlContext( return helper_.RunInGlContext(
[this, cc]() -> ::mediapipe::Status { return RenderGpu(cc); }); [this, cc]() -> ::mediapipe::Status { return RenderGpu(cc); });
#endif // __ANDROID__ || iOS #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else {
return RenderCpu(cc); return RenderCpu(cc);
} }
@ -277,7 +283,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
::mediapipe::Status ImageTransformationCalculator::Close( ::mediapipe::Status ImageTransformationCalculator::Close(
CalculatorContext* cc) { CalculatorContext* cc) {
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX #if !defined(MEDIAPIPE_DISABLE_GPU)
QuadRenderer* rgb_renderer = rgb_renderer_.release(); QuadRenderer* rgb_renderer = rgb_renderer_.release();
QuadRenderer* yuv_renderer = yuv_renderer_.release(); QuadRenderer* yuv_renderer = yuv_renderer_.release();
QuadRenderer* ext_rgb_renderer = ext_rgb_renderer_.release(); QuadRenderer* ext_rgb_renderer = ext_rgb_renderer_.release();
@ -295,8 +301,9 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
delete yuv_renderer; delete yuv_renderer;
} }
}); });
#endif // __ANDROID__ || iOS #endif // !MEDIAPIPE_DISABLE_GPU
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -371,7 +378,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
::mediapipe::Status ImageTransformationCalculator::RenderGpu( ::mediapipe::Status ImageTransformationCalculator::RenderGpu(
CalculatorContext* cc) { CalculatorContext* cc) {
#if defined(__ANDROID__) || defined(__APPLE__) && !TARGET_OS_OSX #if !defined(MEDIAPIPE_DISABLE_GPU)
int input_width = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>().width(); int input_width = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>().width();
int input_height = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>().height(); int input_height = cc->Inputs().Tag("IMAGE_GPU").Get<GpuBuffer>().height();
@ -408,7 +415,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
#endif // iOS #endif // iOS
{ {
src1 = helper_.CreateSourceTexture(input); src1 = helper_.CreateSourceTexture(input);
#if defined(__ANDROID__) #if defined(TEXTURE_EXTERNAL_OES)
if (src1.target() == GL_TEXTURE_EXTERNAL_OES) { if (src1.target() == GL_TEXTURE_EXTERNAL_OES) {
if (!ext_rgb_renderer_) { if (!ext_rgb_renderer_) {
ext_rgb_renderer_ = absl::make_unique<QuadRenderer>(); ext_rgb_renderer_ = absl::make_unique<QuadRenderer>();
@ -417,7 +424,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
} }
renderer = ext_rgb_renderer_.get(); renderer = ext_rgb_renderer_.get();
} else // NOLINT(readability/braces) } else // NOLINT(readability/braces)
#endif // __ANDROID__ #endif // TEXTURE_EXTERNAL_OES
{ {
if (!rgb_renderer_) { if (!rgb_renderer_) {
rgb_renderer_ = absl::make_unique<QuadRenderer>(); rgb_renderer_ = absl::make_unique<QuadRenderer>();
@ -460,7 +467,7 @@ REGISTER_CALCULATOR(ImageTransformationCalculator);
auto output = dst.GetFrame<GpuBuffer>(); auto output = dst.GetFrame<GpuBuffer>();
cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp()); cc->Outputs().Tag("IMAGE_GPU").Add(output.release(), cc->InputTimestamp());
#endif // __ANDROID__ || iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -21,12 +21,11 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/util/color.pb.h" #include "mediapipe/util/color.pb.h"
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/shader_util.h" #include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
namespace { namespace {
enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES }; enum { ATTRIB_VERTEX, ATTRIB_TEXTURE_POSITION, NUM_ATTRIBUTES };
@ -95,10 +94,10 @@ class RecolorCalculator : public CalculatorBase {
mediapipe::RecolorCalculatorOptions::MaskChannel mask_channel_; mediapipe::RecolorCalculatorOptions::MaskChannel mask_channel_;
bool use_gpu_ = false; bool use_gpu_ = false;
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
GLuint program_ = 0; GLuint program_ = 0;
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
}; };
REGISTER_CALCULATOR(RecolorCalculator); REGISTER_CALCULATOR(RecolorCalculator);
@ -107,36 +106,43 @@ REGISTER_CALCULATOR(RecolorCalculator);
RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Inputs().GetTags().empty());
RET_CHECK(!cc->Outputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty());
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) bool use_gpu = false;
#if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag("IMAGE_GPU")) { if (cc->Inputs().HasTag("IMAGE_GPU")) {
cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag("IMAGE")) { if (cc->Inputs().HasTag("IMAGE")) {
cc->Inputs().Tag("IMAGE").Set<ImageFrame>(); cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
} }
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag("MASK_GPU")) { if (cc->Inputs().HasTag("MASK_GPU")) {
cc->Inputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag("MASK")) { if (cc->Inputs().HasTag("MASK")) {
cc->Inputs().Tag("MASK").Set<ImageFrame>(); cc->Inputs().Tag("MASK").Set<ImageFrame>();
} }
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Outputs().HasTag("IMAGE_GPU")) { if (cc->Outputs().HasTag("IMAGE_GPU")) {
cc->Outputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>(); cc->Outputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag("IMAGE")) { if (cc->Outputs().HasTag("IMAGE")) {
cc->Outputs().Tag("IMAGE").Set<ImageFrame>(); cc->Outputs().Tag("IMAGE").Set<ImageFrame>();
} }
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (use_gpu) {
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #if !defined(MEDIAPIPE_DISABLE_GPU)
#endif // __ANDROID__ or iOS MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // !MEDIAPIPE_DISABLE_GPU
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -146,9 +152,9 @@ REGISTER_CALCULATOR(RecolorCalculator);
if (cc->Inputs().HasTag("IMAGE_GPU")) { if (cc->Inputs().HasTag("IMAGE_GPU")) {
use_gpu_ = true; use_gpu_ = true;
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} }
MP_RETURN_IF_ERROR(LoadOptions(cc)); MP_RETURN_IF_ERROR(LoadOptions(cc));
@ -158,7 +164,7 @@ REGISTER_CALCULATOR(RecolorCalculator);
::mediapipe::Status RecolorCalculator::Process(CalculatorContext* cc) { ::mediapipe::Status RecolorCalculator::Process(CalculatorContext* cc) {
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, &cc]() -> ::mediapipe::Status {
if (!initialized_) { if (!initialized_) {
@ -168,7 +174,7 @@ REGISTER_CALCULATOR(RecolorCalculator);
MP_RETURN_IF_ERROR(RenderGpu(cc)); MP_RETURN_IF_ERROR(RenderGpu(cc));
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else {
MP_RETURN_IF_ERROR(RenderCpu(cc)); MP_RETURN_IF_ERROR(RenderCpu(cc));
} }
@ -176,12 +182,12 @@ REGISTER_CALCULATOR(RecolorCalculator);
} }
::mediapipe::Status RecolorCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status RecolorCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
gpu_helper_.RunInGlContext([this] { gpu_helper_.RunInGlContext([this] {
if (program_) glDeleteProgram(program_); if (program_) glDeleteProgram(program_);
program_ = 0; program_ = 0;
}); });
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -194,7 +200,7 @@ REGISTER_CALCULATOR(RecolorCalculator);
if (cc->Inputs().Tag("MASK_GPU").IsEmpty()) { if (cc->Inputs().Tag("MASK_GPU").IsEmpty()) {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
// Get inputs and setup output. // Get inputs and setup output.
const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value(); const Packet& input_packet = cc->Inputs().Tag("IMAGE_GPU").Value();
const Packet& mask_packet = cc->Inputs().Tag("MASK_GPU").Value(); const Packet& mask_packet = cc->Inputs().Tag("MASK_GPU").Value();
@ -233,13 +239,13 @@ REGISTER_CALCULATOR(RecolorCalculator);
img_tex.Release(); img_tex.Release();
mask_tex.Release(); mask_tex.Release();
dst_tex.Release(); dst_tex.Release();
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
void RecolorCalculator::GlRender() { void RecolorCalculator::GlRender() {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
static const GLfloat square_vertices[] = { static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left -1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right 1.0f, -1.0f, // bottom right
@ -287,7 +293,7 @@ void RecolorCalculator::GlRender() {
glBindVertexArray(0); glBindVertexArray(0);
glDeleteVertexArrays(1, &vao); glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo); glDeleteBuffers(2, vbo);
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} }
::mediapipe::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) { ::mediapipe::Status RecolorCalculator::LoadOptions(CalculatorContext* cc) {
@ -305,7 +311,7 @@ void RecolorCalculator::GlRender() {
} }
::mediapipe::Status RecolorCalculator::InitGpu(CalculatorContext* cc) { ::mediapipe::Status RecolorCalculator::InitGpu(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
const GLint attr_location[NUM_ATTRIBUTES] = { const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX, ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION, ATTRIB_TEXTURE_POSITION,
@ -374,7 +380,7 @@ void RecolorCalculator::GlRender() {
glUniform1i(glGetUniformLocation(program_, "mask"), 2); glUniform1i(glGetUniformLocation(program_, "mask"), 2);
glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0], color_[1], glUniform3f(glGetUniformLocation(program_, "recolor"), color_[0], color_[1],
color_[2]); color_[2]);
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -25,12 +25,11 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/vector.h" #include "mediapipe/framework/port/vector.h"
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/shader_util.h" #include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
@ -107,16 +106,18 @@ class SetAlphaCalculator : public CalculatorBase {
bool use_gpu_ = false; bool use_gpu_ = false;
bool gpu_initialized_ = false; bool gpu_initialized_ = false;
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
GLuint program_ = 0; GLuint program_ = 0;
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
}; };
REGISTER_CALCULATOR(SetAlphaCalculator); REGISTER_CALCULATOR(SetAlphaCalculator);
::mediapipe::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) { ::mediapipe::Status SetAlphaCalculator::GetContract(CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1); CHECK_GE(cc->Inputs().NumEntries(), 1);
bool use_gpu = false;
if (cc->Inputs().HasTag(kInputFrameTag) && if (cc->Inputs().HasTag(kInputFrameTag) &&
cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().HasTag(kInputFrameTagGpu)) {
return ::mediapipe::InternalError("Cannot have multiple input images."); return ::mediapipe::InternalError("Cannot have multiple input images.");
@ -127,38 +128,43 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
} }
// Input image to add/edit alpha channel. // Input image to add/edit alpha channel.
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag(kInputFrameTagGpu)) { if (cc->Inputs().HasTag(kInputFrameTagGpu)) {
cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kInputFrameTag)) { if (cc->Inputs().HasTag(kInputFrameTag)) {
cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>(); cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>();
} }
// Input alpha image mask (optional) // Input alpha image mask (optional)
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag(kInputAlphaTagGpu)) { if (cc->Inputs().HasTag(kInputAlphaTagGpu)) {
cc->Inputs().Tag(kInputAlphaTagGpu).Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kInputAlphaTagGpu).Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kInputAlphaTag)) { if (cc->Inputs().HasTag(kInputAlphaTag)) {
cc->Inputs().Tag(kInputAlphaTag).Set<ImageFrame>(); cc->Inputs().Tag(kInputAlphaTag).Set<ImageFrame>();
} }
// RGBA output image. // RGBA output image.
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { if (cc->Outputs().HasTag(kOutputFrameTagGpu)) {
cc->Outputs().Tag(kOutputFrameTagGpu).Set<mediapipe::GpuBuffer>(); cc->Outputs().Tag(kOutputFrameTagGpu).Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag(kOutputFrameTag)) { if (cc->Outputs().HasTag(kOutputFrameTag)) {
cc->Outputs().Tag(kOutputFrameTag).Set<ImageFrame>(); cc->Outputs().Tag(kOutputFrameTag).Set<ImageFrame>();
} }
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (use_gpu) {
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #if !defined(MEDIAPIPE_DISABLE_GPU)
#endif // __ANDROID__ or iOS MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // !MEDIAPIPE_DISABLE_GPU
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -170,11 +176,11 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
if (cc->Inputs().HasTag(kInputFrameTagGpu) && if (cc->Inputs().HasTag(kInputFrameTagGpu) &&
cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().HasTag(kOutputFrameTagGpu)) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
use_gpu_ = true; use_gpu_ = true;
#else #else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} }
// Get global value from options (-1 if not set). // Get global value from options (-1 if not set).
@ -187,17 +193,17 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
RET_CHECK_FAIL() << "Must use either image mask or options alpha value."; RET_CHECK_FAIL() << "Must use either image mask or options alpha value.";
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif #endif
} } // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status SetAlphaCalculator::Process(CalculatorContext* cc) { ::mediapipe::Status SetAlphaCalculator::Process(CalculatorContext* cc) {
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
if (!gpu_initialized_) { if (!gpu_initialized_) {
@ -207,7 +213,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
MP_RETURN_IF_ERROR(RenderGpu(cc)); MP_RETURN_IF_ERROR(RenderGpu(cc));
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else {
MP_RETURN_IF_ERROR(RenderCpu(cc)); MP_RETURN_IF_ERROR(RenderCpu(cc));
} }
@ -216,12 +222,12 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
} }
::mediapipe::Status SetAlphaCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status SetAlphaCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
gpu_helper_.RunInGlContext([this] { gpu_helper_.RunInGlContext([this] {
if (program_) glDeleteProgram(program_); if (program_) glDeleteProgram(program_);
program_ = 0; program_ = 0;
}); });
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -295,7 +301,7 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) { if (cc->Inputs().Tag(kInputFrameTagGpu).IsEmpty()) {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
// Setup source texture. // Setup source texture.
const auto& input_frame = const auto& input_frame =
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
@ -348,13 +354,13 @@ REGISTER_CALCULATOR(SetAlphaCalculator);
// Cleanup // Cleanup
input_texture.Release(); input_texture.Release();
output_texture.Release(); output_texture.Release();
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
void SetAlphaCalculator::GlRender(CalculatorContext* cc) { void SetAlphaCalculator::GlRender(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
static const GLfloat square_vertices[] = { static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left -1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right 1.0f, -1.0f, // bottom right
@ -403,11 +409,11 @@ void SetAlphaCalculator::GlRender(CalculatorContext* cc) {
glDeleteVertexArrays(1, &vao); glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo); glDeleteBuffers(2, vbo);
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} }
::mediapipe::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) { ::mediapipe::Status SetAlphaCalculator::GlSetup(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
const GLint attr_location[NUM_ATTRIBUTES] = { const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX, ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION, ATTRIB_TEXTURE_POSITION,
@ -460,7 +466,7 @@ void SetAlphaCalculator::GlRender(CalculatorContext* cc) {
glUniform1i(glGetUniformLocation(program_, "alpha_mask"), 2); glUniform1i(glGetUniformLocation(program_, "alpha_mask"), 2);
glUniform1f(glGetUniformLocation(program_, "alpha_value"), alpha_value_); glUniform1f(glGetUniformLocation(program_, "alpha_value"), alpha_value_);
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -255,6 +255,7 @@ mediapipe_cc_proto_library(
cc_deps = [ cc_deps = [
"//mediapipe/calculators/core:packet_resampler_calculator_cc_proto", "//mediapipe/calculators/core:packet_resampler_calculator_cc_proto",
"//mediapipe/framework:calculator_cc_proto", "//mediapipe/framework:calculator_cc_proto",
"//mediapipe/util:audio_decoder_cc_proto",
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [":unpack_media_sequence_calculator_proto"], deps = [":unpack_media_sequence_calculator_proto"],
@ -653,6 +654,7 @@ cc_library(
"//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/util:audio_decoder_cc_proto",
"//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/core:protos_all_cc", "@org_tensorflow//tensorflow/core:protos_all_cc",
@ -769,7 +771,6 @@ cc_test(
"//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:opencv_imgcodecs", "//mediapipe/framework/port:opencv_imgcodecs",
"//mediapipe/framework/port:status",
"//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
@ -971,6 +972,7 @@ cc_test(
"//mediapipe/framework/formats:location", "//mediapipe/framework/formats:location",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:rectangle", "//mediapipe/framework/port:rectangle",
"//mediapipe/util:audio_decoder_cc_proto",
"//mediapipe/util/sequence:media_sequence", "//mediapipe/util/sequence:media_sequence",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -285,6 +285,10 @@ class PackMediaSequenceCalculator : public CalculatorBase {
} }
::mediapipe::Status Process(CalculatorContext* cc) override { ::mediapipe::Status Process(CalculatorContext* cc) override {
int image_height = -1;
int image_width = -1;
// Because the tag order may vary, we need to loop through tags to get
// image information before processing other tag types.
for (const auto& tag : cc->Inputs().GetTags()) { for (const auto& tag : cc->Inputs().GetTags()) {
if (!cc->Inputs().Tag(tag).IsEmpty()) { if (!cc->Inputs().Tag(tag).IsEmpty()) {
features_present_[tag] = true; features_present_[tag] = true;
@ -306,14 +310,21 @@ class PackMediaSequenceCalculator : public CalculatorBase {
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "No encoded image"; << "No encoded image";
} }
image_height = image.height();
image_width = image.width();
mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(), mpms::AddImageTimestamp(key, cc->InputTimestamp().Value(),
sequence_.get()); sequence_.get());
mpms::AddImageEncoded(key, image.encoded_image(), sequence_.get()); mpms::AddImageEncoded(key, image.encoded_image(), sequence_.get());
} }
}
for (const auto& tag : cc->Inputs().GetTags()) {
if (!cc->Inputs().Tag(tag).IsEmpty()) {
features_present_[tag] = true;
}
if (absl::StartsWith(tag, kKeypointsTag) && if (absl::StartsWith(tag, kKeypointsTag) &&
!cc->Inputs().Tag(tag).IsEmpty()) { !cc->Inputs().Tag(tag).IsEmpty()) {
std::string key = ""; std::string key = "";
if (tag != kImageTag) { if (tag != kKeypointsTag) {
int tag_length = sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1; int tag_length = sizeof(kKeypointsTag) / sizeof(*kKeypointsTag) - 1;
if (tag[tag_length] == '_') { if (tag[tag_length] == '_') {
key = tag.substr(tag_length + 1); key = tag.substr(tag_length + 1);
@ -363,11 +374,20 @@ class PackMediaSequenceCalculator : public CalculatorBase {
LocationData::BOUNDING_BOX || LocationData::BOUNDING_BOX ||
detection.location_data().format() == detection.location_data().format() ==
LocationData::RELATIVE_BOUNDING_BOX) { LocationData::RELATIVE_BOUNDING_BOX) {
int height = mpms::GetImageHeight(*sequence_); if (mpms::HasImageHeight(*sequence_) &&
int width = mpms::GetImageWidth(*sequence_); mpms::HasImageWidth(*sequence_)) {
image_height = mpms::GetImageHeight(*sequence_);
image_width = mpms::GetImageWidth(*sequence_);
}
if (image_height == -1 || image_width == -1) {
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "Images must be provided with bounding boxes or the "
"image "
<< "height and width must already be in the example.";
}
Location relative_bbox = Location::CreateRelativeBBoxLocation( Location relative_bbox = Location::CreateRelativeBBoxLocation(
Location(detection.location_data()) Location(detection.location_data())
.ConvertToRelativeBBox(width, height)); .ConvertToRelativeBBox(image_width, image_height));
predicted_locations.push_back(relative_bbox); predicted_locations.push_back(relative_bbox);
if (detection.label_size() > 0) { if (detection.label_size() > 0) {
predicted_class_strings.push_back(detection.label(0)); predicted_class_strings.push_back(detection.label(0));

View File

@ -357,6 +357,148 @@ TEST_F(PackMediaSequenceCalculatorTest, PacksTwoBBoxDetections) {
} }
} }
TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithoutImageDims) {
SetUpCalculator({"BBOX_PREDICTED:detections"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
int height = 480;
int width = 640;
int num_vectors = 2;
for (int i = 0; i < num_vectors; ++i) {
auto detections = ::absl::make_unique<::std::vector<Detection>>();
Detection detection;
detection.add_label("absolute bbox");
detection.add_label_id(0);
detection.add_score(0.5);
Location::CreateBBoxLocation(0, height / 2, width / 2, height / 2)
.ConvertToProto(detection.mutable_location_data());
detections->push_back(detection);
detection = Detection();
detection.add_label("relative bbox");
detection.add_label_id(1);
detection.add_score(0.75);
Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5)
.ConvertToProto(detection.mutable_location_data());
detections->push_back(detection);
// The mask detection should be ignored in the output.
detection = Detection();
detection.add_label("mask");
detection.add_score(1.0);
cv::Mat image(2, 3, CV_8UC1, cv::Scalar(0));
Location::CreateCvMaskLocation<uint8>(image).ConvertToProto(
detection.mutable_location_data());
detections->push_back(detection);
runner_->MutableInputs()
->Tag("BBOX_PREDICTED")
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
auto status = runner_->Run();
EXPECT_EQ(::mediapipe::StatusCode::kInvalidArgument, status.code());
}
TEST_F(PackMediaSequenceCalculatorTest, PacksBBoxWithImages) {
SetUpCalculator({"BBOX_PREDICTED:detections", "IMAGE:images"}, {}, false,
true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
int height = 480;
int width = 640;
int num_vectors = 2;
for (int i = 0; i < num_vectors; ++i) {
auto detections = ::absl::make_unique<::std::vector<Detection>>();
Detection detection;
detection.add_label("absolute bbox");
detection.add_label_id(0);
detection.add_score(0.5);
Location::CreateBBoxLocation(0, height / 2, width / 2, height / 2)
.ConvertToProto(detection.mutable_location_data());
detections->push_back(detection);
detection = Detection();
detection.add_label("relative bbox");
detection.add_label_id(1);
detection.add_score(0.75);
Location::CreateRelativeBBoxLocation(0, 0.5, 0.5, 0.5)
.ConvertToProto(detection.mutable_location_data());
detections->push_back(detection);
// The mask detection should be ignored in the output.
detection = Detection();
detection.add_label("mask");
detection.add_score(1.0);
cv::Mat image(2, 3, CV_8UC1, cv::Scalar(0));
Location::CreateCvMaskLocation<uint8>(image).ConvertToProto(
detection.mutable_location_data());
detections->push_back(detection);
runner_->MutableInputs()
->Tag("BBOX_PREDICTED")
.packets.push_back(Adopt(detections.release()).At(Timestamp(i)));
}
cv::Mat image(height, width, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(cv::imencode(".jpg", image, bytes, {80}));
std::string test_image_string(bytes.begin(), bytes.end());
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(test_image_string);
encoded_image.set_width(width);
encoded_image.set_height(height);
int num_images = 2;
for (int i = 0; i < num_images; ++i) {
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag("IMAGE").packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(i)));
}
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(input_sequence.release());
MP_ASSERT_OK(runner_->Run());
const std::vector<Packet>& output_packets =
runner_->Outputs().Tag("SEQUENCE_EXAMPLE").packets;
ASSERT_EQ(1, output_packets.size());
const tf::SequenceExample& output_sequence =
output_packets[0].Get<tf::SequenceExample>();
ASSERT_EQ(test_video_id, mpms::GetClipMediaId(output_sequence));
ASSERT_EQ(height, mpms::GetImageHeight(output_sequence));
ASSERT_EQ(width, mpms::GetImageWidth(output_sequence));
ASSERT_EQ(num_vectors, mpms::GetPredictedBBoxSize(output_sequence));
ASSERT_EQ(num_vectors, mpms::GetPredictedBBoxTimestampSize(output_sequence));
ASSERT_EQ(0, mpms::GetClassSegmentationEncodedSize(output_sequence));
ASSERT_EQ(0, mpms::GetClassSegmentationTimestampSize(output_sequence));
for (int i = 0; i < num_vectors; ++i) {
ASSERT_EQ(i, mpms::GetPredictedBBoxTimestampAt(output_sequence, i));
auto bboxes = mpms::GetPredictedBBoxAt(output_sequence, i);
ASSERT_EQ(2, bboxes.size());
for (int j = 0; j < bboxes.size(); ++j) {
auto rect = bboxes[j].GetRelativeBBox();
ASSERT_NEAR(0, rect.xmin(), 0.001);
ASSERT_NEAR(0.5, rect.ymin(), 0.001);
ASSERT_NEAR(0.5, rect.xmax(), 0.001);
ASSERT_NEAR(1.0, rect.ymax(), 0.001);
}
auto class_strings =
mpms::GetPredictedBBoxLabelStringAt(output_sequence, i);
ASSERT_EQ("absolute bbox", class_strings[0]);
ASSERT_EQ("relative bbox", class_strings[1]);
auto class_indices = mpms::GetPredictedBBoxLabelIndexAt(output_sequence, i);
ASSERT_EQ(0, class_indices[0]);
ASSERT_EQ(1, class_indices[1]);
}
}
TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) { TEST_F(PackMediaSequenceCalculatorTest, PacksTwoKeypoints) {
SetUpCalculator({"KEYPOINTS_TEST:keypoints"}, {}, false, true); SetUpCalculator({"KEYPOINTS_TEST:keypoints"}, {}, false, true);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>(); auto input_sequence = ::absl::make_unique<tf::SequenceExample>();

View File

@ -19,6 +19,7 @@
#include "mediapipe/framework/formats/location.h" #include "mediapipe/framework/formats/location.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/util/audio_decoder.pb.h"
#include "mediapipe/util/sequence/media_sequence.h" #include "mediapipe/util/sequence/media_sequence.h"
#include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/example/feature.pb.h"
@ -37,6 +38,7 @@ const char kDatasetRootDirTag[] = "DATASET_ROOT";
const char kDataPath[] = "DATA_PATH"; const char kDataPath[] = "DATA_PATH";
const char kPacketResamplerOptions[] = "RESAMPLER_OPTIONS"; const char kPacketResamplerOptions[] = "RESAMPLER_OPTIONS";
const char kImagesFrameRateTag[] = "IMAGE_FRAME_RATE"; const char kImagesFrameRateTag[] = "IMAGE_FRAME_RATE";
const char kAudioDecoderOptions[] = "AUDIO_DECODER_OPTIONS";
namespace tf = ::tensorflow; namespace tf = ::tensorflow;
namespace mpms = ::mediapipe::mediasequence; namespace mpms = ::mediapipe::mediasequence;
@ -126,6 +128,11 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
if (cc->OutputSidePackets().HasTag(kDataPath)) { if (cc->OutputSidePackets().HasTag(kDataPath)) {
cc->OutputSidePackets().Tag(kDataPath).Set<std::string>(); cc->OutputSidePackets().Tag(kDataPath).Set<std::string>();
} }
if (cc->OutputSidePackets().HasTag(kAudioDecoderOptions)) {
cc->OutputSidePackets()
.Tag(kAudioDecoderOptions)
.Set<AudioDecoderOptions>();
}
if (cc->OutputSidePackets().HasTag(kImagesFrameRateTag)) { if (cc->OutputSidePackets().HasTag(kImagesFrameRateTag)) {
cc->OutputSidePackets().Tag(kImagesFrameRateTag).Set<double>(); cc->OutputSidePackets().Tag(kImagesFrameRateTag).Set<double>();
} }
@ -136,10 +143,11 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
} }
if ((options.has_padding_before_label() || if ((options.has_padding_before_label() ||
options.has_padding_after_label()) && options.has_padding_after_label()) &&
!(cc->OutputSidePackets().HasTag(kPacketResamplerOptions))) { !(cc->OutputSidePackets().HasTag(kAudioDecoderOptions) ||
cc->OutputSidePackets().HasTag(kPacketResamplerOptions))) {
return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC) return ::mediapipe::InvalidArgumentErrorBuilder(MEDIAPIPE_LOC)
<< "If specifying padding, must output " << "If specifying padding, must output " << kPacketResamplerOptions
<< kPacketResamplerOptions; << "or" << kAudioDecoderOptions;
} }
// Optional streams. // Optional streams.
@ -260,7 +268,8 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
// Set the start and end of the clip in the appropriate options protos. // Set the start and end of the clip in the appropriate options protos.
double start_time = 0; double start_time = 0;
double end_time = 0; double end_time = 0;
if (cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) { if (cc->OutputSidePackets().HasTag(kAudioDecoderOptions) ||
cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) {
if (mpms::HasClipStartTimestamp(sequence)) { if (mpms::HasClipStartTimestamp(sequence)) {
start_time = start_time =
Timestamp(mpms::GetClipStartTimestamp(sequence)).Seconds() - Timestamp(mpms::GetClipStartTimestamp(sequence)).Seconds() -
@ -271,6 +280,27 @@ class UnpackMediaSequenceCalculator : public CalculatorBase {
options.padding_after_label(); options.padding_after_label();
} }
} }
if (cc->OutputSidePackets().HasTag(kAudioDecoderOptions)) {
auto audio_decoder_options = absl::make_unique<AudioDecoderOptions>(
options.base_audio_decoder_options());
if (mpms::HasClipStartTimestamp(sequence)) {
if (options.force_decoding_from_start_of_media()) {
audio_decoder_options->set_start_time(0);
} else {
audio_decoder_options->set_start_time(
start_time - options.extra_padding_from_media_decoder());
}
}
if (mpms::HasClipEndTimestamp(sequence)) {
audio_decoder_options->set_end_time(
end_time + options.extra_padding_from_media_decoder());
}
LOG(INFO) << "Created AudioDecoderOptions:\n"
<< audio_decoder_options->DebugString();
cc->OutputSidePackets()
.Tag(kAudioDecoderOptions)
.Set(Adopt(audio_decoder_options.release()));
}
if (cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) { if (cc->OutputSidePackets().HasTag(kPacketResamplerOptions)) {
auto resampler_options = absl::make_unique<CalculatorOptions>(); auto resampler_options = absl::make_unique<CalculatorOptions>();
*(resampler_options->MutableExtension( *(resampler_options->MutableExtension(

View File

@ -18,6 +18,7 @@ package mediapipe;
import "mediapipe/calculators/core/packet_resampler_calculator.proto"; import "mediapipe/calculators/core/packet_resampler_calculator.proto";
import "mediapipe/framework/calculator.proto"; import "mediapipe/framework/calculator.proto";
import "mediapipe/util/audio_decoder.proto";
message UnpackMediaSequenceCalculatorOptions { message UnpackMediaSequenceCalculatorOptions {
extend mediapipe.CalculatorOptions { extend mediapipe.CalculatorOptions {
@ -49,4 +50,10 @@ message UnpackMediaSequenceCalculatorOptions {
// parameters for the MediaDecoderCalculator. End time parameters are still // parameters for the MediaDecoderCalculator. End time parameters are still
// respected. // respected.
optional bool force_decoding_from_start_of_media = 7; optional bool force_decoding_from_start_of_media = 7;
// Stores the audio decoder settings for the graph. (e.g. which audio
// stream to pull from the video.) The sequence's metadata overrides
// the clip start and end times and outputs these for the
// AudioDecoderCalculator to consume.
optional AudioDecoderOptions base_audio_decoder_options = 9;
} }

View File

@ -23,6 +23,7 @@
#include "mediapipe/framework/port/gtest.h" #include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/rectangle.h" #include "mediapipe/framework/port/rectangle.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/util/audio_decoder.pb.h"
#include "mediapipe/util/sequence/media_sequence.h" #include "mediapipe/util/sequence/media_sequence.h"
#include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/example.pb.h"
@ -459,6 +460,62 @@ TEST_F(UnpackMediaSequenceCalculatorTest, GetDatasetFromExample) {
data_path_); data_path_);
} }
TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptions) {
CalculatorOptions options;
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
->set_padding_before_label(1);
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
->set_padding_after_label(2);
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
&options);
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run());
MP_EXPECT_OK(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.ValidateAsType<AudioDecoderOptions>());
EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.Get<AudioDecoderOptions>()
.start_time(),
2.0, 1e-5);
EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.Get<AudioDecoderOptions>()
.end_time(),
7.0, 1e-5);
}
TEST_F(UnpackMediaSequenceCalculatorTest, GetAudioDecoderOptionsOverride) {
CalculatorOptions options;
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
->set_padding_before_label(1);
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
->set_padding_after_label(2);
options.MutableExtension(UnpackMediaSequenceCalculatorOptions::ext)
->set_force_decoding_from_start_of_media(true);
SetUpCalculator({}, {"AUDIO_DECODER_OPTIONS:audio_decoder_options"}, {},
&options);
runner_->MutableSidePackets()->Tag("SEQUENCE_EXAMPLE") =
Adopt(sequence_.release());
MP_ASSERT_OK(runner_->Run());
MP_EXPECT_OK(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.ValidateAsType<AudioDecoderOptions>());
EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.Get<AudioDecoderOptions>()
.start_time(),
0.0, 1e-5);
EXPECT_NEAR(runner_->OutputSidePackets()
.Tag("AUDIO_DECODER_OPTIONS")
.Get<AudioDecoderOptions>()
.end_time(),
7.0, 1e-5);
}
TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) { TEST_F(UnpackMediaSequenceCalculatorTest, GetPacketResamplingOptions) {
// TODO: Suport proto3 proto.Any in CalculatorOptions. // TODO: Suport proto3 proto.Any in CalculatorOptions.
// TODO: Avoid proto2 extensions in "RESAMPLER_OPTIONS". // TODO: Avoid proto2 extensions in "RESAMPLER_OPTIONS".

View File

@ -195,6 +195,12 @@ cc_test(
], ],
) )
cc_library(
name = "util",
hdrs = ["util.h"],
alwayslink = 1,
)
cc_library( cc_library(
name = "tflite_inference_calculator", name = "tflite_inference_calculator",
srcs = ["tflite_inference_calculator.cc"], srcs = ["tflite_inference_calculator.cc"],
@ -214,6 +220,7 @@ cc_library(
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":util",
":tflite_inference_calculator_cc_proto", ":tflite_inference_calculator_cc_proto",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
@ -222,20 +229,25 @@ cc_library(
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
] + select({ ] + select({
"//mediapipe:android": [ "//mediapipe/gpu:disable_gpu": [],
"//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/gpu:MPPMetalUtil",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/objc:mediapipe_framework_ios",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape",
"@org_tensorflow//tensorflow/lite/delegates/gpu/metal:buffer_convert",
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
],
"//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu/common:shape",
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
], ],
"//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/objc:mediapipe_framework_ios",
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
],
"//conditions:default": [],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -259,33 +271,33 @@ cc_library(
}), }),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":util",
":tflite_converter_calculator_cc_proto", ":tflite_converter_calculator_cc_proto",
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"//mediapipe/framework:calculator_framework", "//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame", "//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:matrix", "//mediapipe/framework/formats:matrix",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler", "//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/tool:status_util",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops", "@org_tensorflow//tensorflow/lite/kernels:builtin_ops",
] + select({ ] + select({
"//mediapipe:android": [ "//mediapipe/gpu:disable_gpu": [],
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
],
"//mediapipe:ios": [ "//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalUtil",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:MPPMetalHelper", "//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/objc:mediapipe_framework_ios", "//mediapipe/objc:mediapipe_framework_ios",
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
], ],
"//conditions:default": [], "//conditions:default": [
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gl_calculator_helper",
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -295,6 +307,7 @@ cc_library(
srcs = ["tflite_tensors_to_segmentation_calculator.cc"], srcs = ["tflite_tensors_to_segmentation_calculator.cc"],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":util",
":tflite_tensors_to_segmentation_calculator_cc_proto", ":tflite_tensors_to_segmentation_calculator_cc_proto",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
@ -308,7 +321,9 @@ cc_library(
"//mediapipe/util:resource_util", "//mediapipe/util:resource_util",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
] + select({ ] + select({
"//mediapipe:android": [ "//mediapipe/gpu:disable_gpu": [],
"//mediapipe:ios": [],
"//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
@ -319,7 +334,6 @@ cc_library(
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_texture",
], ],
"//conditions:default": [],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -346,8 +360,23 @@ cc_test(
cc_library( cc_library(
name = "tflite_tensors_to_detections_calculator", name = "tflite_tensors_to_detections_calculator",
srcs = ["tflite_tensors_to_detections_calculator.cc"], srcs = ["tflite_tensors_to_detections_calculator.cc"],
copts = select({
"//mediapipe:ios": [
"-x objective-c++",
"-fobjc-arc", # enable reference-counting
],
"//conditions:default": [],
}),
linkopts = select({
"//mediapipe:ios": [
"-framework CoreVideo",
"-framework MetalKit",
],
"//conditions:default": [],
}),
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
":util",
":tflite_tensors_to_detections_calculator_cc_proto", ":tflite_tensors_to_detections_calculator_cc_proto",
"//mediapipe/framework/formats:detection_cc_proto", "//mediapipe/framework/formats:detection_cc_proto",
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
@ -359,14 +388,21 @@ cc_library(
"//mediapipe/framework/port:ret_check", "//mediapipe/framework/port:ret_check",
"@org_tensorflow//tensorflow/lite:framework", "@org_tensorflow//tensorflow/lite:framework",
] + select({ ] + select({
"//mediapipe:android": [ "//mediapipe/gpu:disable_gpu": [],
"//mediapipe:ios": [
"//mediapipe/gpu:MPPMetalUtil",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:MPPMetalHelper",
"//mediapipe/objc:mediapipe_framework_ios",
"@org_tensorflow//tensorflow/lite/delegates/gpu:metal_delegate",
],
"//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate", "@org_tensorflow//tensorflow/lite/delegates/gpu:gl_delegate",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_buffer",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_program",
"@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader", "@org_tensorflow//tensorflow/lite/delegates/gpu/gl:gl_shader",
], ],
"//conditions:default": [],
}), }),
alwayslink = 1, alwayslink = 1,
) )

View File

@ -16,23 +16,23 @@
#include <vector> #include <vector>
#include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_converter_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/matrix.h" #include "mediapipe/framework/formats/matrix.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
#include "tensorflow/lite/error_reporter.h" #include "tensorflow/lite/error_reporter.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS #if defined(__APPLE__) && !TARGET_OS_OSX // iOS
#import <CoreVideo/CoreVideo.h> #import <CoreVideo/CoreVideo.h>
@ -40,11 +40,12 @@
#import <MetalKit/MetalKit.h> #import <MetalKit/MetalKit.h>
#import "mediapipe/gpu/MPPMetalHelper.h" #import "mediapipe/gpu/MPPMetalHelper.h"
#include "mediapipe/gpu/MPPMetalUtil.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // iOS #endif // iOS
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor; typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
typedef id<MTLBuffer> GpuTensor; typedef id<MTLBuffer> GpuTensor;
@ -66,26 +67,27 @@ typedef Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>
namespace mediapipe { namespace mediapipe {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
using ::tflite::gpu::gl::GlBuffer; using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlProgram; using ::tflite::gpu::gl::GlProgram;
using ::tflite::gpu::gl::GlShader; using ::tflite::gpu::gl::GlShader;
struct GPUData { struct GPUData {
int elements = 1; int elements = 1;
GlBuffer buffer; GpuTensor buffer;
GlShader shader; GlShader shader;
GlProgram program; GlProgram program;
}; };
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
struct GPUData { struct GPUData {
int elements = 1; int elements = 1;
id<MTLBuffer> buffer; GpuTensor buffer;
id<MTLComputePipelineState> pipeline_state; id<MTLComputePipelineState> pipeline_state;
}; };
#endif #endif
// Calculator for normalizing and converting an ImageFrame or Matrix // Calculator for normalizing and converting an ImageFrame or Matrix
// into a TfLiteTensor (float 32) or a GpuBuffer to a tflite::gpu::GlBuffer. // into a TfLiteTensor (float 32) or a GpuBuffer to a tflite::gpu::GlBuffer
// or MTLBuffer.
// //
// This calculator is designed to be used with the TfLiteInferenceCalcualtor, // This calculator is designed to be used with the TfLiteInferenceCalcualtor,
// as a pre-processing step for calculator inputs. // as a pre-processing step for calculator inputs.
@ -102,7 +104,7 @@ struct GPUData {
// Output: // Output:
// One of the following tags: // One of the following tags:
// TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8. // TENSORS - Vector of TfLiteTensor of type kTfLiteFloat32, or kTfLiteUint8.
// TENSORS_GPU - vector of GlBuffer. // TENSORS_GPU - vector of GlBuffer or MTLBuffer.
// //
// Example use: // Example use:
// node { // node {
@ -144,7 +146,7 @@ class TfLiteConverterCalculator : public CalculatorBase {
std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr; std::unique_ptr<tflite::Interpreter> interpreter_ = nullptr;
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_out_; std::unique_ptr<GPUData> gpu_data_out_;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
@ -175,25 +177,33 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
RET_CHECK(cc->Outputs().HasTag("TENSORS") ^ RET_CHECK(cc->Outputs().HasTag("TENSORS") ^
cc->Outputs().HasTag("TENSORS_GPU")); cc->Outputs().HasTag("TENSORS_GPU"));
bool use_gpu = false;
if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set<ImageFrame>(); if (cc->Inputs().HasTag("IMAGE")) cc->Inputs().Tag("IMAGE").Set<ImageFrame>();
if (cc->Inputs().HasTag("MATRIX")) cc->Inputs().Tag("MATRIX").Set<Matrix>(); if (cc->Inputs().HasTag("MATRIX")) cc->Inputs().Tag("MATRIX").Set<Matrix>();
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag("IMAGE_GPU")) if (cc->Inputs().HasTag("IMAGE_GPU")) {
cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag("IMAGE_GPU").Set<mediapipe::GpuBuffer>();
#endif use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag("TENSORS")) if (cc->Outputs().HasTag("TENSORS"))
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>(); cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Outputs().HasTag("TENSORS_GPU")) if (cc->Outputs().HasTag("TENSORS_GPU")) {
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>(); cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
#endif use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
#if defined(__ANDROID__) if (use_gpu) {
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif #endif
}
// Assign this calculator's default InputStreamHandler. // Assign this calculator's default InputStreamHandler.
cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); cc->SetInputStreamHandler("FixedSizeInputStreamHandler");
@ -208,10 +218,10 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
if (cc->Inputs().HasTag("IMAGE_GPU") || if (cc->Inputs().HasTag("IMAGE_GPU") ||
cc->Outputs().HasTag("IMAGE_OUT_GPU")) { cc->Outputs().HasTag("IMAGE_OUT_GPU")) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
use_gpu_ = true; use_gpu_ = true;
#else #else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif #endif
} }
@ -221,7 +231,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
cc->Outputs().HasTag("TENSORS_GPU")); cc->Outputs().HasTag("TENSORS_GPU"));
// Cannot use quantization. // Cannot use quantization.
use_quantized_tensors_ = false; use_quantized_tensors_ = false;
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
@ -238,6 +248,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
::mediapipe::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) { ::mediapipe::Status TfLiteConverterCalculator::Process(CalculatorContext* cc) {
if (use_gpu_) { if (use_gpu_) {
// GpuBuffer to tflite::gpu::GlBuffer conversion.
if (!initialized_) { if (!initialized_) {
MP_RETURN_IF_ERROR(InitGpu(cc)); MP_RETURN_IF_ERROR(InitGpu(cc));
initialized_ = true; initialized_ = true;
@ -253,7 +264,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
} }
::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status TfLiteConverterCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); }); gpu_helper_.RunInGlContext([this] { gpu_data_out_.reset(); });
#endif #endif
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS #if defined(__APPLE__) && !TARGET_OS_OSX // iOS
@ -372,7 +383,7 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
::mediapipe::Status TfLiteConverterCalculator::ProcessGPU( ::mediapipe::Status TfLiteConverterCalculator::ProcessGPU(
CalculatorContext* cc) { CalculatorContext* cc) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
// GpuBuffer to tflite::gpu::GlBuffer conversion. // GpuBuffer to tflite::gpu::GlBuffer conversion.
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>(); const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
@ -381,17 +392,11 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
auto src = gpu_helper_.CreateSourceTexture(input); auto src = gpu_helper_.CreateSourceTexture(input);
glActiveTexture(GL_TEXTURE0 + 0); glActiveTexture(GL_TEXTURE0 + 0);
glBindTexture(GL_TEXTURE_2D, src.name()); glBindTexture(GL_TEXTURE_2D, src.name());
auto status = gpu_data_out_->buffer.BindToIndex(1); RET_CHECK_CALL(gpu_data_out_->buffer.BindToIndex(1));
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
const tflite::gpu::uint3 workgroups = { const tflite::gpu::uint3 workgroups = {
NumGroups(input.width(), kWorkgroupSize), NumGroups(input.width(), kWorkgroupSize),
NumGroups(input.height(), kWorkgroupSize), 1}; NumGroups(input.height(), kWorkgroupSize), 1};
status = gpu_data_out_->program.Dispatch(workgroups); RET_CHECK_CALL(gpu_data_out_->program.Dispatch(workgroups));
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0); glBindBuffer(GL_SHADER_STORAGE_BUFFER, 0);
glBindTexture(GL_TEXTURE_2D, 0); glBindTexture(GL_TEXTURE_2D, 0);
src.Release(); src.Release();
@ -400,17 +405,17 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
// Copy into outputs. // Copy into outputs.
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>(); auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
output_tensors->resize(1); MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
{ [this, &output_tensors]() -> ::mediapipe::Status {
GlBuffer& tensor = output_tensors->at(0); output_tensors->resize(1);
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; {
auto status = CreateReadWriteShaderStorageBuffer<float>( GpuTensor& tensor = output_tensors->at(0);
gpu_data_out_->elements, &tensor); RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
if (!status.ok()) { gpu_data_out_->elements, &tensor));
return ::mediapipe::InternalError(status.error_message()); RET_CHECK_CALL(CopyBuffer(gpu_data_out_->buffer, tensor));
} }
tflite::gpu::gl::CopyBuffer(gpu_data_out_->buffer, tensor); return ::mediapipe::OkStatus();
} }));
cc->Outputs() cc->Outputs()
.Tag("TENSORS_GPU") .Tag("TENSORS_GPU")
.Add(output_tensors.release(), cc->InputTimestamp()); .Add(output_tensors.release(), cc->InputTimestamp());
@ -438,66 +443,60 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
} }
// Copy into outputs. // Copy into outputs.
// TODO Avoid this copy.
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>(); auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
output_tensors->resize(1);
{ {
id<MTLDevice> device = gpu_helper_.mtlDevice; id<MTLDevice> device = gpu_helper_.mtlDevice;
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer]; output_tensors->at(0) =
command_buffer.label = @"TfLiteConverterCalculatorCopy";
id<MTLBuffer> tensor =
[device newBufferWithLength:gpu_data_out_->elements * sizeof(float) [device newBufferWithLength:gpu_data_out_->elements * sizeof(float)
options:MTLResourceStorageModeShared]; options:MTLResourceStorageModeShared];
id<MTLBlitCommandEncoder> blit_command = [MPPMetalUtil blitMetalBufferTo:output_tensors->at(0)
[command_buffer blitCommandEncoder]; from:gpu_data_out_->buffer
[blit_command copyFromBuffer:gpu_data_out_->buffer blocking:true
sourceOffset:0 commandBuffer:[gpu_helper_ commandBuffer]];
toBuffer:tensor
destinationOffset:0
size:gpu_data_out_->elements * sizeof(float)];
[blit_command endEncoding];
[command_buffer commit];
[command_buffer waitUntilCompleted];
output_tensors->push_back(tensor);
} }
cc->Outputs() cc->Outputs()
.Tag("TENSORS_GPU") .Tag("TENSORS_GPU")
.Add(output_tensors.release(), cc->InputTimestamp()); .Add(output_tensors.release(), cc->InputTimestamp());
#else #else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; RET_CHECK_FAIL() << "GPU processing is not enabled.";
#endif #endif
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) { ::mediapipe::Status TfLiteConverterCalculator::InitGpu(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
// Configure inputs. // Get input image sizes.
const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>(); const auto& input = cc->Inputs().Tag("IMAGE_GPU").Get<mediapipe::GpuBuffer>();
mediapipe::ImageFormat::Format format = mediapipe::ImageFormat::Format format =
mediapipe::ImageFormatForGpuBufferFormat(input.format()); mediapipe::ImageFormatForGpuBufferFormat(input.format());
gpu_data_out_ = absl::make_unique<GPUData>(); gpu_data_out_ = absl::make_unique<GPUData>();
gpu_data_out_->elements = input.height() * input.width() * max_num_channels_; gpu_data_out_->elements = input.height() * input.width() * max_num_channels_;
const bool include_alpha = (max_num_channels_ == 4); const bool include_alpha = (max_num_channels_ == 4);
if (!(format == mediapipe::ImageFormat::SRGB || const bool single_channel = (max_num_channels_ == 1);
if (!(format == mediapipe::ImageFormat::GRAY8 ||
format == mediapipe::ImageFormat::SRGB ||
format == mediapipe::ImageFormat::SRGBA)) format == mediapipe::ImageFormat::SRGBA))
RET_CHECK_FAIL() << "Unsupported GPU input format."; RET_CHECK_FAIL() << "Unsupported GPU input format.";
if (include_alpha && (format != mediapipe::ImageFormat::SRGBA)) if (include_alpha && (format != mediapipe::ImageFormat::SRGBA))
RET_CHECK_FAIL() << "Num input channels is less than desired output."; RET_CHECK_FAIL() << "Num input channels is less than desired output.";
#endif #endif // !MEDIAPIPE_DISABLE_GPU
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
// Device memory. MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>( [this, &include_alpha, &input, &single_channel]() -> ::mediapipe::Status {
gpu_data_out_->elements, &gpu_data_out_->buffer); // Device memory.
if (!status.ok()) { RET_CHECK_CALL(
return ::mediapipe::InternalError(status.error_message()); ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
} gpu_data_out_->elements, &gpu_data_out_->buffer));
// Shader to convert GL Texture to Shader Storage Buffer Object (SSBO), // Shader to convert GL Texture to Shader Storage Buffer Object (SSBO),
// with normalization to either: [0,1] or [-1,1]. // with normalization to either: [0,1] or [-1,1].
const std::string shader_source = absl::Substitute( const std::string shader_source = absl::Substitute(
R"( #version 310 es R"( #version 310 es
layout(local_size_x = $0, local_size_y = $0) in; layout(local_size_x = $0, local_size_y = $0) in;
layout(binding = 0) uniform sampler2D input_texture; layout(binding = 0) uniform sampler2D input_texture;
layout(std430, binding = 1) buffer Output {float elements[];} output_data; layout(std430, binding = 1) buffer Output {float elements[];} output_data;
@ -505,33 +504,31 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
void main() { void main() {
ivec2 gid = ivec2(gl_GlobalInvocationID.xy); ivec2 gid = ivec2(gl_GlobalInvocationID.xy);
if (gid.x >= width_height.x || gid.y >= width_height.y) return; if (gid.x >= width_height.x || gid.y >= width_height.y) return;
$5 // pixel fetch vec4 pixel = texelFetch(input_texture, gid, 0);
$3 // normalize [-1,1] $3 // normalize [-1,1]
int linear_index = $7 * ($4 * width_height.x + gid.x); int linear_index = $7 * ($4 * width_height.x + gid.x);
output_data.elements[linear_index + 0] = pixel.x; output_data.elements[linear_index + 0] = pixel.x; // r channel
output_data.elements[linear_index + 1] = pixel.y; $5 // g & b channels
output_data.elements[linear_index + 2] = pixel.z;
$6 // alpha channel $6 // alpha channel
})", })",
/*$0=*/kWorkgroupSize, /*$1=*/input.width(), /*$2=*/input.height(), /*$0=*/kWorkgroupSize, /*$1=*/input.width(), /*$2=*/input.height(),
/*$3=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "", /*$3=*/zero_center_ ? "pixel = (pixel - 0.5) * 2.0;" : "",
/*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y", /*$4=*/flip_vertically_ ? "(width_height.y - 1 - gid.y)" : "gid.y",
/*$5=*/ /*$5=*/
include_alpha ? "vec4 pixel = texelFetch(input_texture, gid, 0);" single_channel
: "vec3 pixel = texelFetch(input_texture, gid, 0).xyz;", ? ""
/*$6=*/ : R"(output_data.elements[linear_index + 1] = pixel.y;
include_alpha ? "output_data.elements[linear_index + 3] = pixel.w;" : "", output_data.elements[linear_index + 2] = pixel.z;)",
/*$7=*/include_alpha ? 4 : 3); /*$6=*/
status = GlShader::CompileShader(GL_COMPUTE_SHADER, shader_source, include_alpha ? "output_data.elements[linear_index + 3] = pixel.w;"
&gpu_data_out_->shader); : "",
if (!status.ok()) { /*$7=*/max_num_channels_);
return ::mediapipe::InternalError(status.error_message()); RET_CHECK_CALL(GlShader::CompileShader(GL_COMPUTE_SHADER, shader_source,
} &gpu_data_out_->shader));
status = GlProgram::CreateWithShader(gpu_data_out_->shader, RET_CHECK_CALL(GlProgram::CreateWithShader(gpu_data_out_->shader,
&gpu_data_out_->program); &gpu_data_out_->program));
if (!status.ok()) { return ::mediapipe::OkStatus();
return ::mediapipe::InternalError(status.error_message()); }));
}
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
RET_CHECK(include_alpha) RET_CHECK(include_alpha)
<< "iOS GPU inference currently accepts only RGBA input."; << "iOS GPU inference currently accepts only RGBA input.";
@ -546,8 +543,6 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
// with normalization to either: [0,1] or [-1,1]. // with normalization to either: [0,1] or [-1,1].
const std::string shader_source = absl::Substitute( const std::string shader_source = absl::Substitute(
R"( R"(
#include <simd/simd.h>
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
@ -612,9 +607,9 @@ REGISTER_CALCULATOR(TfLiteConverterCalculator);
// Get desired way to handle input channels. // Get desired way to handle input channels.
max_num_channels_ = options.max_num_channels(); max_num_channels_ = options.max_num_channels();
// Currently only alpha channel toggling is suppored. CHECK_GE(max_num_channels_, 1);
CHECK_GE(max_num_channels_, 3);
CHECK_LE(max_num_channels_, 4); CHECK_LE(max_num_channels_, 4);
CHECK_NE(max_num_channels_, 2);
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS #if defined(__APPLE__) && !TARGET_OS_OSX // iOS
if (cc->Inputs().HasTag("IMAGE_GPU")) if (cc->Inputs().HasTag("IMAGE_GPU"))
// Currently on iOS, tflite gpu input tensor must be 4 channels, // Currently on iOS, tflite gpu input tensor must be 4 channels,

View File

@ -36,8 +36,7 @@ message TfLiteConverterCalculatorOptions {
optional bool flip_vertically = 2 [default = false]; optional bool flip_vertically = 2 [default = false];
// Controls how many channels of the input image get passed through to the // Controls how many channels of the input image get passed through to the
// tensor. Currently this only controls whether or not to ignore alpha // tensor. Valid values are 1,3,4 only. Ignored for iOS GPU.
// channel, so it must be 3 or 4.
optional int32 max_num_channels = 3 [default = 3]; optional int32 max_num_channels = 3 [default = 3];
// The calculator expects Matrix inputs to be in column-major order. Set // The calculator expects Matrix inputs to be in column-major order. Set

View File

@ -12,10 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cstring>
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_inference_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
@ -24,14 +27,15 @@
#include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h" #include "tensorflow/lite/model.h"
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS #if defined(__APPLE__) && !TARGET_OS_OSX // iOS
#import <CoreVideo/CoreVideo.h> #import <CoreVideo/CoreVideo.h>
@ -39,33 +43,42 @@
#import <MetalKit/MetalKit.h> #import <MetalKit/MetalKit.h>
#import "mediapipe/gpu/MPPMetalHelper.h" #import "mediapipe/gpu/MPPMetalHelper.h"
#include "mediapipe/gpu/MPPMetalUtil.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/common/shape.h"
#include "tensorflow/lite/delegates/gpu/metal/buffer_convert.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h" #include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // iOS #endif // iOS
#if defined(__ANDROID__) namespace {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor; typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
typedef id<MTLBuffer> GpuTensor; typedef id<MTLBuffer> GpuTensor;
#endif #endif
// Round up n to next multiple of m.
size_t RoundUp(size_t n, size_t m) { return ((n + m - 1) / m) * m; } // NOLINT
} // namespace
// TfLiteInferenceCalculator File Layout: // TfLiteInferenceCalculator File Layout:
// * Header // * Header
// * Core // * Core
// * Aux // * Aux
namespace mediapipe { namespace mediapipe {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
using ::tflite::gpu::gl::CopyBuffer;
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlBuffer; using ::tflite::gpu::gl::GlBuffer;
using ::tflite::gpu::gl::GlProgram; #endif
using ::tflite::gpu::gl::GlShader;
#if !defined(MEDIAPIPE_DISABLE_GPU)
struct GPUData { struct GPUData {
int elements = 1; int elements = 1;
GlBuffer buffer; GpuTensor buffer;
}; ::tflite::gpu::BHWC shape;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
struct GPUData {
int elements = 1;
id<MTLBuffer> buffer;
}; };
#endif #endif
@ -134,7 +147,7 @@ class TfLiteInferenceCalculator : public CalculatorBase {
std::unique_ptr<tflite::FlatBufferModel> model_; std::unique_ptr<tflite::FlatBufferModel> model_;
TfLiteDelegate* delegate_ = nullptr; TfLiteDelegate* delegate_ = nullptr;
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GPUData> gpu_data_in_; std::unique_ptr<GPUData> gpu_data_in_;
std::vector<std::unique_ptr<GPUData>> gpu_data_out_; std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
@ -142,6 +155,7 @@ class TfLiteInferenceCalculator : public CalculatorBase {
MPPMetalHelper* gpu_helper_ = nullptr; MPPMetalHelper* gpu_helper_ = nullptr;
std::unique_ptr<GPUData> gpu_data_in_; std::unique_ptr<GPUData> gpu_data_in_;
std::vector<std::unique_ptr<GPUData>> gpu_data_out_; std::vector<std::unique_ptr<GPUData>> gpu_data_out_;
TFLBufferConvert* converter_from_BPHWC4_ = nil;
#endif #endif
std::string model_path_ = ""; std::string model_path_ = "";
@ -161,19 +175,25 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
RET_CHECK(cc->Outputs().HasTag("TENSORS") ^ RET_CHECK(cc->Outputs().HasTag("TENSORS") ^
cc->Outputs().HasTag("TENSORS_GPU")); cc->Outputs().HasTag("TENSORS_GPU"));
bool use_gpu = false;
if (cc->Inputs().HasTag("TENSORS")) if (cc->Inputs().HasTag("TENSORS"))
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>(); cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag("TENSORS_GPU")) if (cc->Inputs().HasTag("TENSORS_GPU")) {
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>(); cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
#endif use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag("TENSORS")) if (cc->Outputs().HasTag("TENSORS"))
cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>(); cc->Outputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Outputs().HasTag("TENSORS_GPU")) if (cc->Outputs().HasTag("TENSORS_GPU")) {
cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>(); cc->Outputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
#endif use_gpu |= true;
}
#endif // !MEDIAPIPE_DISABLE_GPU
if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) { if (cc->InputSidePackets().HasTag("CUSTOM_OP_RESOLVER")) {
cc->InputSidePackets() cc->InputSidePackets()
@ -181,11 +201,17 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
.Set<tflite::ops::builtin::BuiltinOpResolver>(); .Set<tflite::ops::builtin::BuiltinOpResolver>();
} }
#if defined(__ANDROID__) const auto& options =
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); cc->Options<::mediapipe::TfLiteInferenceCalculatorOptions>();
use_gpu |= options.use_gpu();
if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]); MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif #endif
}
// Assign this calculator's default InputStreamHandler. // Assign this calculator's default InputStreamHandler.
cc->SetInputStreamHandler("FixedSizeInputStreamHandler"); cc->SetInputStreamHandler("FixedSizeInputStreamHandler");
@ -199,35 +225,41 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
MP_RETURN_IF_ERROR(LoadOptions(cc)); MP_RETURN_IF_ERROR(LoadOptions(cc));
if (cc->Inputs().HasTag("TENSORS_GPU")) { if (cc->Inputs().HasTag("TENSORS_GPU")) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
gpu_input_ = true; gpu_input_ = true;
gpu_inference_ = true; // Inference must be on GPU also. gpu_inference_ = true; // Inference must be on GPU also.
#else #else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU"))
#endif << "GPU processing not enabled.";
#endif // !MEDIAPIPE_DISABLE_GPU
} }
if (cc->Outputs().HasTag("TENSORS_GPU")) { if (cc->Outputs().HasTag("TENSORS_GPU")) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
gpu_output_ = true; gpu_output_ = true;
RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU")) RET_CHECK(cc->Inputs().HasTag("TENSORS_GPU"))
<< "GPU output must also have GPU Input."; << "GPU output must also have GPU Input.";
#else #else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; RET_CHECK(!cc->Inputs().HasTag("TENSORS_GPU"))
#endif << "GPU processing not enabled.";
#endif // !MEDIAPIPE_DISABLE_GPU
} }
MP_RETURN_IF_ERROR(LoadModel(cc)); MP_RETURN_IF_ERROR(LoadModel(cc));
if (gpu_inference_) { if (gpu_inference_) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc]; gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_); RET_CHECK(gpu_helper_);
#endif #endif
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &cc]() -> ::mediapipe::Status { return LoadDelegate(cc); }));
#else
MP_RETURN_IF_ERROR(LoadDelegate(cc)); MP_RETURN_IF_ERROR(LoadDelegate(cc));
#endif
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
@ -237,35 +269,27 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// 1. Receive pre-processed tensor inputs. // 1. Receive pre-processed tensor inputs.
if (gpu_input_) { if (gpu_input_) {
// Read GPU input into SSBO. // Read GPU input into SSBO.
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>(); cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
RET_CHECK_EQ(input_tensors.size(), 1); RET_CHECK_EQ(input_tensors.size(), 1);
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
[this, &input_tensors]() -> ::mediapipe::Status { [this, &input_tensors]() -> ::mediapipe::Status {
// Explicit copy input. // Explicit copy input.
tflite::gpu::gl::CopyBuffer(input_tensors[0], gpu_data_in_->buffer); RET_CHECK_CALL(CopyBuffer(input_tensors[0], gpu_data_in_->buffer));
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>(); cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
RET_CHECK_EQ(input_tensors.size(), 1); RET_CHECK_EQ(input_tensors.size(), 1);
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
command_buffer.label = @"TfLiteInferenceCalculatorInput";
id<MTLBlitCommandEncoder> blit_command =
[command_buffer blitCommandEncoder];
// Explicit copy input. // Explicit copy input.
[blit_command copyFromBuffer:input_tensors[0] [MPPMetalUtil blitMetalBufferTo:gpu_data_in_->buffer
sourceOffset:0 from:input_tensors[0]
toBuffer:gpu_data_in_->buffer blocking:true
destinationOffset:0 commandBuffer:[gpu_helper_ commandBuffer]];
size:gpu_data_in_->elements * sizeof(float)];
[blit_command endEncoding];
[command_buffer commit];
[command_buffer waitUntilCompleted];
#else #else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif #endif
} else { } else {
// Read CPU input into tensors. // Read CPU input into tensors.
@ -278,18 +302,20 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
if (use_quantized_tensors_) { if (use_quantized_tensors_) {
const uint8* input_tensor_buffer = input_tensor->data.uint8; const uint8* input_tensor_buffer = input_tensor->data.uint8;
uint8* local_tensor_buffer = interpreter_->typed_input_tensor<uint8>(i); uint8* local_tensor_buffer = interpreter_->typed_input_tensor<uint8>(i);
memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes); std::memcpy(local_tensor_buffer, input_tensor_buffer,
input_tensor->bytes);
} else { } else {
const float* input_tensor_buffer = input_tensor->data.f; const float* input_tensor_buffer = input_tensor->data.f;
float* local_tensor_buffer = interpreter_->typed_input_tensor<float>(i); float* local_tensor_buffer = interpreter_->typed_input_tensor<float>(i);
memcpy(local_tensor_buffer, input_tensor_buffer, input_tensor->bytes); std::memcpy(local_tensor_buffer, input_tensor_buffer,
input_tensor->bytes);
} }
} }
} }
// 2. Run inference. // 2. Run inference.
if (gpu_inference_) { if (gpu_inference_) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this]() -> ::mediapipe::Status {
RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk); RET_CHECK_EQ(interpreter_->Invoke(), kTfLiteOk);
@ -304,52 +330,51 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// 3. Output processed tensors. // 3. Output processed tensors.
if (gpu_output_) { if (gpu_output_) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
// Output result tensors (GPU). // Output result tensors (GPU).
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>(); auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
output_tensors->resize(gpu_data_out_.size()); MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
for (int i = 0; i < gpu_data_out_.size(); ++i) { [this, &output_tensors]() -> ::mediapipe::Status {
GlBuffer& tensor = output_tensors->at(i); output_tensors->resize(gpu_data_out_.size());
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; for (int i = 0; i < gpu_data_out_.size(); ++i) {
auto status = CreateReadWriteShaderStorageBuffer<float>( GpuTensor& tensor = output_tensors->at(i);
gpu_data_out_[i]->elements, &tensor); RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
if (!status.ok()) { gpu_data_out_[i]->elements, &tensor));
return ::mediapipe::InternalError(status.error_message()); RET_CHECK_CALL(CopyBuffer(gpu_data_out_[i]->buffer, tensor));
} }
tflite::gpu::gl::CopyBuffer(gpu_data_out_[i]->buffer, tensor); return ::mediapipe::OkStatus();
} }));
cc->Outputs() cc->Outputs()
.Tag("TENSORS_GPU") .Tag("TENSORS_GPU")
.Add(output_tensors.release(), cc->InputTimestamp()); .Add(output_tensors.release(), cc->InputTimestamp());
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
// Output result tensors (GPU). // Output result tensors (GPU).
auto output_tensors = absl::make_unique<std::vector<GpuTensor>>(); auto output_tensors = absl::make_unique<std::vector<GpuTensor>>();
output_tensors->resize(gpu_data_out_.size());
id<MTLDevice> device = gpu_helper_.mtlDevice; id<MTLDevice> device = gpu_helper_.mtlDevice;
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer]; id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
command_buffer.label = @"TfLiteInferenceCalculatorOutput"; command_buffer.label = @"TfLiteInferenceBPHWC4Convert";
id<MTLComputeCommandEncoder> convert_command =
[command_buffer computeCommandEncoder];
for (int i = 0; i < gpu_data_out_.size(); ++i) { for (int i = 0; i < gpu_data_out_.size(); ++i) {
id<MTLBuffer> tensor = output_tensors->at(i) =
[device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float) [device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float)
options:MTLResourceStorageModeShared]; options:MTLResourceStorageModeShared];
id<MTLBlitCommandEncoder> blit_command = // Reshape tensor.
[command_buffer blitCommandEncoder]; [converter_from_BPHWC4_ convertWithEncoder:convert_command
// Explicit copy input. shape:gpu_data_out_[i]->shape
[blit_command copyFromBuffer:gpu_data_out_[i]->buffer sourceBuffer:gpu_data_out_[i]->buffer
sourceOffset:0 convertedBuffer:output_tensors->at(i)];
toBuffer:tensor
destinationOffset:0
size:gpu_data_out_[i]->elements * sizeof(float)];
[blit_command endEncoding];
[command_buffer commit];
[command_buffer waitUntilCompleted];
output_tensors->push_back(tensor);
} }
[convert_command endEncoding];
[command_buffer commit];
[command_buffer waitUntilCompleted];
cc->Outputs() cc->Outputs()
.Tag("TENSORS_GPU") .Tag("TENSORS_GPU")
.Add(output_tensors.release(), cc->InputTimestamp()); .Add(output_tensors.release(), cc->InputTimestamp());
#else #else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else {
// Output result tensors (CPU). // Output result tensors (CPU).
const auto& tensor_indexes = interpreter_->outputs(); const auto& tensor_indexes = interpreter_->outputs();
@ -367,7 +392,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status TfLiteInferenceCalculator::Close(CalculatorContext* cc) {
if (delegate_) { if (delegate_) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status { MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]() -> Status {
TfLiteGpuDelegateDelete(delegate_); TfLiteGpuDelegateDelete(delegate_);
gpu_data_in_.reset(); gpu_data_in_.reset();
@ -446,7 +471,7 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate( ::mediapipe::Status TfLiteInferenceCalculator::LoadDelegate(
CalculatorContext* cc) { CalculatorContext* cc) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
// Configure and create the delegate. // Configure and create the delegate.
TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault(); TfLiteGpuDelegateOptions options = TfLiteGpuDelegateOptionsDefault();
options.compile_options.precision_loss_allowed = 1; options.compile_options.precision_loss_allowed = 1;
@ -466,15 +491,12 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
for (int d = 0; d < tensor->dims->size; ++d) { for (int d = 0; d < tensor->dims->size; ++d) {
gpu_data_in_->elements *= tensor->dims->data[d]; gpu_data_in_->elements *= tensor->dims->data[d];
} }
// Input to model can be either RGB/RGBA only. CHECK_GE(tensor->dims->data[3], 1);
RET_CHECK_GE(tensor->dims->data[3], 3); CHECK_LE(tensor->dims->data[3], 4);
RET_CHECK_LE(tensor->dims->data[3], 4); CHECK_NE(tensor->dims->data[3], 2);
// Create and bind input buffer. // Create and bind input buffer.
auto status = ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>( RET_CHECK_CALL(::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer<float>(
gpu_data_in_->elements, &gpu_data_in_->buffer); gpu_data_in_->elements, &gpu_data_in_->buffer));
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor( RET_CHECK_EQ(TfLiteGpuDelegateBindBufferToTensor(
delegate_, gpu_data_in_->buffer.id(), delegate_, gpu_data_in_->buffer.id(),
interpreter_->inputs()[0]), // First tensor only interpreter_->inputs()[0]), // First tensor only
@ -496,12 +518,8 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// Create and bind output buffers. // Create and bind output buffers.
interpreter_->SetAllowBufferHandleOutput(true); interpreter_->SetAllowBufferHandleOutput(true);
for (int i = 0; i < gpu_data_out_.size(); ++i) { for (int i = 0; i < gpu_data_out_.size(); ++i) {
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
auto status = CreateReadWriteShaderStorageBuffer<float>( gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer));
gpu_data_out_[i]->elements, &gpu_data_out_[i]->buffer);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
RET_CHECK_EQ( RET_CHECK_EQ(
TfLiteGpuDelegateBindBufferToTensor( TfLiteGpuDelegateBindBufferToTensor(
delegate_, gpu_data_out_[i]->buffer.id(), output_indices[i]), delegate_, gpu_data_out_[i]->buffer.id(), output_indices[i]),
@ -511,14 +529,15 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
// Must call this last. // Must call this last.
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk);
#endif // __ANDROID__ #endif // OpenGL
#if defined(__APPLE__) && !TARGET_OS_OSX // iOS #if defined(__APPLE__) && !TARGET_OS_OSX // iOS
// Configure and create the delegate. // Configure and create the delegate.
GpuDelegateOptions options; GpuDelegateOptions options;
options.allow_precision_loss = false; // Must match converter, F=float/T=half options.allow_precision_loss = false; // Must match converter, F=float/T=half
options.wait_type = GpuDelegateOptions::WaitType::kActive; options.wait_type = GpuDelegateOptions::WaitType::kPassive;
if (!delegate_) delegate_ = TFLGpuDelegateCreate(&options); if (!delegate_) delegate_ = TFLGpuDelegateCreate(&options);
id<MTLDevice> device = gpu_helper_.mtlDevice;
if (gpu_input_) { if (gpu_input_) {
// Get input image sizes. // Get input image sizes.
@ -539,11 +558,9 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
LOG(WARNING) << "Please ensure input GPU tensor is 4 channels."; LOG(WARNING) << "Please ensure input GPU tensor is 4 channels.";
} }
// Create and bind input buffer. // Create and bind input buffer.
id<MTLDevice> device = gpu_helper_.mtlDevice;
gpu_data_in_->buffer = gpu_data_in_->buffer =
[device newBufferWithLength:gpu_data_in_->elements * sizeof(float) [device newBufferWithLength:gpu_data_in_->elements * sizeof(float)
options:MTLResourceStorageModeShared]; options:MTLResourceStorageModeShared];
// Must call this before TFLGpuDelegateBindMetalBufferToTensor.
RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk); RET_CHECK_EQ(interpreter_->ModifyGraphWithDelegate(delegate_), kTfLiteOk);
RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor( RET_CHECK_EQ(TFLGpuDelegateBindMetalBufferToTensor(
delegate_, delegate_,
@ -561,12 +578,33 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
gpu_data_out_[i]->elements = 1; gpu_data_out_[i]->elements = 1;
// TODO handle *2 properly on some dialated models // TODO handle *2 properly on some dialated models
for (int d = 0; d < tensor->dims->size; ++d) { for (int d = 0; d < tensor->dims->size; ++d) {
gpu_data_out_[i]->elements *= tensor->dims->data[d]; // Pad each dim for BHWC4 conversion inside delegate.
gpu_data_out_[i]->elements *= RoundUp(tensor->dims->data[d], 4);
}
// Save dimensions for reshaping back later.
gpu_data_out_[i]->shape.b = tensor->dims->data[0];
switch (tensor->dims->size) {
case 2:
gpu_data_out_[i]->shape.h = 1;
gpu_data_out_[i]->shape.w = 1;
gpu_data_out_[i]->shape.c = tensor->dims->data[1];
break;
case 3:
gpu_data_out_[i]->shape.h = 1;
gpu_data_out_[i]->shape.w = tensor->dims->data[1];
gpu_data_out_[i]->shape.c = tensor->dims->data[2];
break;
case 4:
gpu_data_out_[i]->shape.h = tensor->dims->data[1];
gpu_data_out_[i]->shape.w = tensor->dims->data[2];
gpu_data_out_[i]->shape.c = tensor->dims->data[3];
break;
default:
return mediapipe::InternalError("Unsupported tensor shape.");
} }
} }
// Create and bind output buffers. // Create and bind output buffers.
interpreter_->SetAllowBufferHandleOutput(true); interpreter_->SetAllowBufferHandleOutput(true);
id<MTLDevice> device = gpu_helper_.mtlDevice;
for (int i = 0; i < gpu_data_out_.size(); ++i) { for (int i = 0; i < gpu_data_out_.size(); ++i) {
gpu_data_out_[i]->buffer = gpu_data_out_[i]->buffer =
[device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float) [device newBufferWithLength:gpu_data_out_[i]->elements * sizeof(float)
@ -575,6 +613,14 @@ REGISTER_CALCULATOR(TfLiteInferenceCalculator);
delegate_, output_indices[i], gpu_data_out_[i]->buffer), delegate_, output_indices[i], gpu_data_out_[i]->buffer),
true); true);
} }
// Create converter for GPU output.
converter_from_BPHWC4_ = [[TFLBufferConvert alloc] initWithDevice:device
isFloat16:false
convertToPBHWC4:false];
if (converter_from_BPHWC4_ == nil) {
return mediapipe::InternalError(
"Error initializating output buffer converter");
}
} }
#endif // iOS #endif // iOS

View File

@ -18,6 +18,7 @@
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_path.h" #include "mediapipe/framework/deps/file_path.h"
#include "mediapipe/framework/formats/detection.pb.h" #include "mediapipe/framework/formats/detection.pb.h"
@ -26,28 +27,61 @@
#include "mediapipe/framework/port/ret_check.h" #include "mediapipe/framework/port/ret_check.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_program.h" #include "tensorflow/lite/delegates/gpu/gl/gl_program.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // ANDROID #endif // !MEDIAPIPE_DISABLE_GPU
#if defined(__ANDROID__) #if defined(__APPLE__) && !TARGET_OS_OSX // iOS
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; #import <CoreVideo/CoreVideo.h>
using ::tflite::gpu::gl::GlBuffer; #import <Metal/Metal.h>
using ::tflite::gpu::gl::GlProgram; #import <MetalKit/MetalKit.h>
using ::tflite::gpu::gl::GlShader;
#endif // ANDROID
namespace mediapipe { #import "mediapipe/gpu/MPPMetalHelper.h"
#include "mediapipe/gpu/MPPMetalUtil.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "tensorflow/lite/delegates/gpu/metal_delegate.h"
#endif // iOS
namespace { namespace {
constexpr int kNumInputTensorsWithAnchors = 3; constexpr int kNumInputTensorsWithAnchors = 3;
constexpr int kNumCoordsPerBox = 4; constexpr int kNumCoordsPerBox = 4;
} // namespace
namespace mediapipe {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlShader;
#endif
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
typedef ::tflite::gpu::gl::GlBuffer GpuTensor;
typedef ::tflite::gpu::gl::GlProgram GpuProgram;
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
typedef id<MTLBuffer> GpuTensor;
typedef id<MTLComputePipelineState> GpuProgram;
#endif
namespace {
#if !defined(MEDIAPIPE_DISABLE_GPU)
struct GPUData {
GpuProgram decode_program;
GpuProgram score_program;
GpuTensor decoded_boxes_buffer;
GpuTensor raw_boxes_buffer;
GpuTensor raw_anchors_buffer;
GpuTensor scored_boxes_buffer;
GpuTensor raw_scores_buffer;
};
#endif
void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes, void ConvertRawValuesToAnchors(const float* raw_anchors, int num_boxes,
std::vector<Anchor>* anchors) { std::vector<Anchor>* anchors) {
anchors->clear(); anchors->clear();
@ -88,7 +122,7 @@ void ConvertAnchorsToRawValues(const std::vector<Anchor>& anchors,
// optional to pass in a third tensor for anchors (e.g. for SSD // optional to pass in a third tensor for anchors (e.g. for SSD
// models) depend on the outputs of the detection model. The size // models) depend on the outputs of the detection model. The size
// of anchor tensor must be (num_boxes * 4). // of anchor tensor must be (num_boxes * 4).
// TENSORS_GPU - vector of GlBuffer. // TENSORS_GPU - vector of GlBuffer of MTLBuffer.
// Output: // Output:
// DETECTIONS - Result MediaPipe detections. // DETECTIONS - Result MediaPipe detections.
// //
@ -126,7 +160,7 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
std::vector<Detection>* output_detections); std::vector<Detection>* output_detections);
::mediapipe::Status LoadOptions(CalculatorContext* cc); ::mediapipe::Status LoadOptions(CalculatorContext* cc);
::mediapipe::Status GlSetup(CalculatorContext* cc); ::mediapipe::Status GpuInit(CalculatorContext* cc);
::mediapipe::Status DecodeBoxes(const float* raw_boxes, ::mediapipe::Status DecodeBoxes(const float* raw_boxes,
const std::vector<Anchor>& anchors, const std::vector<Anchor>& anchors,
std::vector<float>* boxes); std::vector<float>* boxes);
@ -146,15 +180,12 @@ class TfLiteTensorsToDetectionsCalculator : public CalculatorBase {
std::vector<Anchor> anchors_; std::vector<Anchor> anchors_;
bool side_packet_anchors_{}; bool side_packet_anchors_{};
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GlProgram> decode_program_; std::unique_ptr<GPUData> gpu_data_;
std::unique_ptr<GlProgram> score_program_; #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
std::unique_ptr<GlBuffer> decoded_boxes_buffer_; MPPMetalHelper* gpu_helper_ = nullptr;
std::unique_ptr<GlBuffer> raw_boxes_buffer_; std::unique_ptr<GPUData> gpu_data_;
std::unique_ptr<GlBuffer> raw_anchors_buffer_;
std::unique_ptr<GlBuffer> scored_boxes_buffer_;
std::unique_ptr<GlBuffer> raw_scores_buffer_;
#endif #endif
bool gpu_input_ = false; bool gpu_input_ = false;
@ -167,15 +198,18 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Inputs().GetTags().empty());
RET_CHECK(!cc->Outputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty());
bool use_gpu = false;
if (cc->Inputs().HasTag("TENSORS")) { if (cc->Inputs().HasTag("TENSORS")) {
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>(); cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
} }
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag("TENSORS_GPU")) { if (cc->Inputs().HasTag("TENSORS_GPU")) {
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>(); cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GpuTensor>>();
use_gpu |= true;
} }
#endif #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag("DETECTIONS")) { if (cc->Outputs().HasTag("DETECTIONS")) {
cc->Outputs().Tag("DETECTIONS").Set<std::vector<Detection>>(); cc->Outputs().Tag("DETECTIONS").Set<std::vector<Detection>>();
@ -187,9 +221,13 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
} }
} }
#if defined(__ANDROID__) if (use_gpu) {
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
MP_RETURN_IF_ERROR([MPPMetalHelper updateContract:cc]);
#endif #endif
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -200,8 +238,11 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
if (cc->Inputs().HasTag("TENSORS_GPU")) { if (cc->Inputs().HasTag("TENSORS_GPU")) {
gpu_input_ = true; gpu_input_ = true;
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
gpu_helper_ = [[MPPMetalHelper alloc] initWithCalculatorContext:cc];
RET_CHECK(gpu_helper_);
#endif #endif
} }
@ -209,7 +250,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
side_packet_anchors_ = cc->InputSidePackets().HasTag("ANCHORS"); side_packet_anchors_ = cc->InputSidePackets().HasTag("ANCHORS");
if (gpu_input_) { if (gpu_input_) {
MP_RETURN_IF_ERROR(GlSetup(cc)); MP_RETURN_IF_ERROR(GpuInit(cc));
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
@ -228,7 +269,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get())); MP_RETURN_IF_ERROR(ProcessGPU(cc, output_detections.get()));
} else { } else {
MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get())); MP_RETURN_IF_ERROR(ProcessCPU(cc, output_detections.get()));
} // if gpu_input_ }
// Output // Output
if (cc->Outputs().HasTag("DETECTIONS")) { if (cc->Outputs().HasTag("DETECTIONS")) {
@ -245,7 +286,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>(); cc->Inputs().Tag("TENSORS").Get<std::vector<TfLiteTensor>>();
if (input_tensors.size() == 2) { if (input_tensors.size() == 2 ||
input_tensors.size() == kNumInputTensorsWithAnchors) {
// Postprocessing on CPU for model without postprocessing op. E.g. output // Postprocessing on CPU for model without postprocessing op. E.g. output
// raw score tensor and box tensor. Anchor decoding will be handled below. // raw score tensor and box tensor. Anchor decoding will be handled below.
const TfLiteTensor* raw_box_tensor = &input_tensors[0]; const TfLiteTensor* raw_box_tensor = &input_tensors[0];
@ -358,13 +400,84 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
} }
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::ProcessGPU(
CalculatorContext* cc, std::vector<Detection>* output_detections) { CalculatorContext* cc, std::vector<Detection>* output_detections) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>(); cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
RET_CHECK_GE(input_tensors.size(), 2);
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this, &input_tensors, &cc,
&output_detections]()
-> ::mediapipe::Status {
// Copy inputs.
RET_CHECK_CALL(CopyBuffer(input_tensors[0], gpu_data_->raw_boxes_buffer));
RET_CHECK_CALL(CopyBuffer(input_tensors[1], gpu_data_->raw_scores_buffer));
if (!anchors_init_) {
if (side_packet_anchors_) {
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
const auto& anchors =
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox);
ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data());
RET_CHECK_CALL(gpu_data_->raw_anchors_buffer.Write<float>(
absl::MakeSpan(raw_anchors)));
} else {
CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
RET_CHECK_CALL(
CopyBuffer(input_tensors[2], gpu_data_->raw_anchors_buffer));
}
anchors_init_ = true;
}
// Run shaders.
// Decode boxes.
RET_CHECK_CALL(gpu_data_->decoded_boxes_buffer.BindToIndex(0));
RET_CHECK_CALL(gpu_data_->raw_boxes_buffer.BindToIndex(1));
RET_CHECK_CALL(gpu_data_->raw_anchors_buffer.BindToIndex(2));
const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1};
RET_CHECK_CALL(gpu_data_->decode_program.Dispatch(decode_workgroups));
// Score boxes.
RET_CHECK_CALL(gpu_data_->scored_boxes_buffer.BindToIndex(0));
RET_CHECK_CALL(gpu_data_->raw_scores_buffer.BindToIndex(1));
const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1};
RET_CHECK_CALL(gpu_data_->score_program.Dispatch(score_workgroups));
// Copy decoded boxes from GPU to CPU.
std::vector<float> boxes(num_boxes_ * num_coords_);
RET_CHECK_CALL(gpu_data_->decoded_boxes_buffer.Read(absl::MakeSpan(boxes)));
std::vector<float> score_class_id_pairs(num_boxes_ * 2);
RET_CHECK_CALL(gpu_data_->scored_boxes_buffer.Read(
absl::MakeSpan(score_class_id_pairs)));
// TODO: b/138851969. Is it possible to output a float vector
// for score and an int vector for class so that we can avoid copying twice?
std::vector<float> detection_scores(num_boxes_);
std::vector<int> detection_classes(num_boxes_);
for (int i = 0; i < num_boxes_; ++i) {
detection_scores[i] = score_class_id_pairs[i * 2];
detection_classes[i] = static_cast<int>(score_class_id_pairs[i * 2 + 1]);
}
MP_RETURN_IF_ERROR(
ConvertToDetections(boxes.data(), detection_scores.data(),
detection_classes.data(), output_detections));
return ::mediapipe::OkStatus();
}));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GpuTensor>>();
RET_CHECK_GE(input_tensors.size(), 2);
// Copy inputs. // Copy inputs.
tflite::gpu::gl::CopyBuffer(input_tensors[0], *raw_boxes_buffer_.get()); [MPPMetalUtil blitMetalBufferTo:gpu_data_->raw_boxes_buffer
tflite::gpu::gl::CopyBuffer(input_tensors[1], *raw_scores_buffer_.get()); from:input_tensors[0]
blocking:true
commandBuffer:[gpu_helper_ commandBuffer]];
[MPPMetalUtil blitMetalBufferTo:gpu_data_->raw_scores_buffer
from:input_tensors[1]
blocking:true
commandBuffer:[gpu_helper_ commandBuffer]];
if (!anchors_init_) { if (!anchors_init_) {
if (side_packet_anchors_) { if (side_packet_anchors_) {
CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty()); CHECK(!cc->InputSidePackets().Tag("ANCHORS").IsEmpty());
@ -372,47 +485,65 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>(); cc->InputSidePackets().Tag("ANCHORS").Get<std::vector<Anchor>>();
std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox); std::vector<float> raw_anchors(num_boxes_ * kNumCoordsPerBox);
ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data()); ConvertAnchorsToRawValues(anchors, num_boxes_, raw_anchors.data());
raw_anchors_buffer_->Write<float>(absl::MakeSpan(raw_anchors)); memcpy([gpu_data_->raw_anchors_buffer contents], raw_anchors.data(),
raw_anchors.size() * sizeof(float));
} else { } else {
CHECK_EQ(input_tensors.size(), 3); RET_CHECK_EQ(input_tensors.size(), kNumInputTensorsWithAnchors);
tflite::gpu::gl::CopyBuffer(input_tensors[2], *raw_anchors_buffer_.get()); [MPPMetalUtil blitMetalBufferTo:gpu_data_->raw_anchors_buffer
from:input_tensors[2]
blocking:true
commandBuffer:[gpu_helper_ commandBuffer]];
} }
anchors_init_ = true; anchors_init_ = true;
} }
// Run shaders. // Run shaders.
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( {
[this, &input_tensors]() -> ::mediapipe::Status { id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
// Decode boxes. command_buffer.label = @"TfLiteDecodeBoxes";
decoded_boxes_buffer_->BindToIndex(0); id<MTLComputeCommandEncoder> decode_command =
raw_boxes_buffer_->BindToIndex(1); [command_buffer computeCommandEncoder];
raw_anchors_buffer_->BindToIndex(2); [decode_command setComputePipelineState:gpu_data_->decode_program];
const tflite::gpu::uint3 decode_workgroups = {num_boxes_, 1, 1}; [decode_command setBuffer:gpu_data_->decoded_boxes_buffer
decode_program_->Dispatch(decode_workgroups); offset:0
atIndex:0];
// Score boxes. [decode_command setBuffer:gpu_data_->raw_boxes_buffer offset:0 atIndex:1];
scored_boxes_buffer_->BindToIndex(0); [decode_command setBuffer:gpu_data_->raw_anchors_buffer offset:0 atIndex:2];
raw_scores_buffer_->BindToIndex(1); MTLSize decode_threads_per_group = MTLSizeMake(1, 1, 1);
const tflite::gpu::uint3 score_workgroups = {num_boxes_, 1, 1}; MTLSize decode_threadgroups = MTLSizeMake(num_boxes_, 1, 1);
score_program_->Dispatch(score_workgroups); [decode_command dispatchThreadgroups:decode_threadgroups
threadsPerThreadgroup:decode_threads_per_group];
return ::mediapipe::OkStatus(); [decode_command endEncoding];
})); [command_buffer commit];
[command_buffer waitUntilCompleted];
}
{
id<MTLCommandBuffer> command_buffer = [gpu_helper_ commandBuffer];
command_buffer.label = @"TfLiteScoreBoxes";
id<MTLComputeCommandEncoder> score_command =
[command_buffer computeCommandEncoder];
[score_command setComputePipelineState:gpu_data_->score_program];
[score_command setBuffer:gpu_data_->scored_boxes_buffer offset:0 atIndex:0];
[score_command setBuffer:gpu_data_->raw_scores_buffer offset:0 atIndex:1];
MTLSize score_threads_per_group = MTLSizeMake(1, num_classes_, 1);
MTLSize score_threadgroups = MTLSizeMake(num_boxes_, 1, 1);
[score_command dispatchThreadgroups:score_threadgroups
threadsPerThreadgroup:score_threads_per_group];
[score_command endEncoding];
[command_buffer commit];
[command_buffer waitUntilCompleted];
}
// Copy decoded boxes from GPU to CPU. // Copy decoded boxes from GPU to CPU.
std::vector<float> boxes(num_boxes_ * num_coords_); std::vector<float> boxes(num_boxes_ * num_coords_);
auto status = decoded_boxes_buffer_->Read(absl::MakeSpan(boxes)); memcpy(boxes.data(), [gpu_data_->decoded_boxes_buffer contents],
if (!status.ok()) { num_boxes_ * num_coords_ * sizeof(float));
return ::mediapipe::InternalError(status.error_message());
}
std::vector<float> score_class_id_pairs(num_boxes_ * 2); std::vector<float> score_class_id_pairs(num_boxes_ * 2);
status = scored_boxes_buffer_->Read(absl::MakeSpan(score_class_id_pairs)); memcpy(score_class_id_pairs.data(), [gpu_data_->scored_boxes_buffer contents],
if (!status.ok()) { num_boxes_ * 2 * sizeof(float));
return ::mediapipe::InternalError(status.error_message());
}
// TODO: b/138851969. Is it possible to output a float vector // Output detections.
// for score and an int vector for class so that we can avoid copying twice? // TODO Adjust shader to avoid copying shader output twice.
std::vector<float> detection_scores(num_boxes_); std::vector<float> detection_scores(num_boxes_);
std::vector<int> detection_classes(num_boxes_); std::vector<int> detection_classes(num_boxes_);
for (int i = 0; i < num_boxes_; ++i) { for (int i = 0; i < num_boxes_; ++i) {
@ -422,25 +553,20 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
MP_RETURN_IF_ERROR(ConvertToDetections(boxes.data(), detection_scores.data(), MP_RETURN_IF_ERROR(ConvertToDetections(boxes.data(), detection_scores.data(),
detection_classes.data(), detection_classes.data(),
output_detections)); output_detections));
#else #else
LOG(ERROR) << "GPU input on non-Android not supported yet."; LOG(ERROR) << "GPU input on non-Android not supported yet.";
#endif // defined(__ANDROID__) #endif
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::Close(
CalculatorContext* cc) { CalculatorContext* cc) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
gpu_helper_.RunInGlContext([this] { gpu_helper_.RunInGlContext([this] { gpu_data_.reset(); });
decode_program_.reset(); #elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
score_program_.reset(); gpu_data_.reset();
decoded_boxes_buffer_.reset(); #endif // !MEDIAPIPE_DISABLE_GPU
raw_boxes_buffer_.reset();
raw_anchors_buffer_.reset();
scored_boxes_buffer_.reset();
raw_scores_buffer_.reset();
});
#endif // __ANDROID__
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -530,6 +656,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToDetectionsCalculator);
} }
} }
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -586,12 +713,16 @@ Detection TfLiteTensorsToDetectionsCalculator::ConvertToDetection(
return detection; return detection;
} }
::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GlSetup( ::mediapipe::Status TfLiteTensorsToDetectionsCalculator::GpuInit(
CalculatorContext* cc) { CalculatorContext* cc) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
// A shader to decode detection boxes. MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
const std::string decode_src = absl::Substitute( -> ::mediapipe::Status {
R"( #version 310 es gpu_data_ = absl::make_unique<GPUData>();
// A shader to decode detection boxes.
const std::string decode_src = absl::Substitute(
R"( #version 310 es
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
@ -665,7 +796,7 @@ void main() {
if (num_keypoints > int(0)){ if (num_keypoints > int(0)){
for (int k = 0; k < num_keypoints; ++k) { for (int k = 0; k < num_keypoints; ++k) {
int kp_offset = int kp_offset =
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt; int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
float kp_y, kp_x; float kp_y, kp_x;
if (reverse_output_order == int(0)) { if (reverse_output_order == int(0)) {
kp_y = raw_boxes.data[kp_offset + int(0)]; kp_y = raw_boxes.data[kp_offset + int(0)];
@ -679,55 +810,37 @@ void main() {
} }
} }
})", })",
options_.num_coords(), // box xywh options_.num_coords(), // box xywh
options_.reverse_output_order() ? 1 : 0, options_.reverse_output_order() ? 1 : 0,
options_.apply_exponential_on_box_size() ? 1 : 0, options_.apply_exponential_on_box_size() ? 1 : 0,
options_.box_coord_offset(), options_.num_keypoints(), options_.box_coord_offset(), options_.num_keypoints(),
options_.keypoint_coord_offset(), options_.num_values_per_keypoint()); options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
// Shader program // Shader program
GlShader decode_shader; GlShader decode_shader;
auto status = RET_CHECK_CALL(
GlShader::CompileShader(GL_COMPUTE_SHADER, decode_src, &decode_shader); GlShader::CompileShader(GL_COMPUTE_SHADER, decode_src, &decode_shader));
if (!status.ok()) { RET_CHECK_CALL(GpuProgram::CreateWithShader(decode_shader,
return ::mediapipe::InternalError(status.error_message()); &gpu_data_->decode_program));
} // Outputs
decode_program_ = absl::make_unique<GlProgram>(); size_t decoded_boxes_length = num_boxes_ * num_coords_;
status = GlProgram::CreateWithShader(decode_shader, decode_program_.get()); RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
if (!status.ok()) { decoded_boxes_length, &gpu_data_->decoded_boxes_buffer));
return ::mediapipe::InternalError(status.error_message()); // Inputs
} size_t raw_boxes_length = num_boxes_ * num_coords_;
// Outputs RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
size_t decoded_boxes_length = num_boxes_ * num_coords_; raw_boxes_length, &gpu_data_->raw_boxes_buffer));
decoded_boxes_buffer_ = absl::make_unique<GlBuffer>(); size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox;
status = CreateReadWriteShaderStorageBuffer<float>( RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
decoded_boxes_length, decoded_boxes_buffer_.get()); raw_anchors_length, &gpu_data_->raw_anchors_buffer));
if (!status.ok()) { // Parameters
return ::mediapipe::InternalError(status.error_message()); glUseProgram(gpu_data_->decode_program.id());
} glUniform4f(0, options_.x_scale(), options_.y_scale(), options_.w_scale(),
// Inputs options_.h_scale());
size_t raw_boxes_length = num_boxes_ * num_coords_;
raw_boxes_buffer_ = absl::make_unique<GlBuffer>();
status = CreateReadWriteShaderStorageBuffer<float>(raw_boxes_length,
raw_boxes_buffer_.get());
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox;
raw_anchors_buffer_ = absl::make_unique<GlBuffer>();
status = CreateReadWriteShaderStorageBuffer<float>(raw_anchors_length,
raw_anchors_buffer_.get());
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
// Parameters
glUseProgram(decode_program_->id());
glUniform4f(0, options_.x_scale(), options_.y_scale(), options_.w_scale(),
options_.h_scale());
// A shader to score detection boxes. // A shader to score detection boxes.
const std::string score_src = absl::Substitute( const std::string score_src = absl::Substitute(
R"( #version 310 es R"( #version 310 es
layout(local_size_x = 1, local_size_y = $0, local_size_z = 1) in; layout(local_size_x = 1, local_size_y = $0, local_size_z = 1) in;
@ -781,6 +894,228 @@ void main() {
scored_boxes.data[g_idx * uint(2) + uint(0)] = max_score; scored_boxes.data[g_idx * uint(2) + uint(0)] = max_score;
scored_boxes.data[g_idx * uint(2) + uint(1)] = max_class; scored_boxes.data[g_idx * uint(2) + uint(1)] = max_class;
} }
})",
num_classes_, options_.sigmoid_score() ? 1 : 0,
options_.has_score_clipping_thresh() ? 1 : 0,
options_.has_score_clipping_thresh() ? options_.score_clipping_thresh()
: 0,
!ignore_classes_.empty() ? 1 : 0);
// # filter classes supported is hardware dependent.
int max_wg_size; // typically <= 1024
glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1,
&max_wg_size); // y-dim
CHECK_LT(num_classes_, max_wg_size)
<< "# classes must be < " << max_wg_size;
// TODO support better filtering.
CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed";
// Shader program
GlShader score_shader;
RET_CHECK_CALL(
GlShader::CompileShader(GL_COMPUTE_SHADER, score_src, &score_shader));
RET_CHECK_CALL(
GpuProgram::CreateWithShader(score_shader, &gpu_data_->score_program));
// Outputs
size_t scored_boxes_length = num_boxes_ * 2; // score, class
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
scored_boxes_length, &gpu_data_->scored_boxes_buffer));
// Inputs
size_t raw_scores_length = num_boxes_ * num_classes_;
RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
raw_scores_length, &gpu_data_->raw_scores_buffer));
return ::mediapipe::OkStatus();
}));
#elif defined(__APPLE__) && !TARGET_OS_OSX // iOS
// TODO consolidate Metal and OpenGL shaders via vulkan.
gpu_data_ = absl::make_unique<GPUData>();
id<MTLDevice> device = gpu_helper_.mtlDevice;
// A shader to decode detection boxes.
std::string decode_src = absl::Substitute(
R"(
#include <metal_stdlib>
using namespace metal;
kernel void decodeKernel(
device float* boxes [[ buffer(0) ]],
device float* raw_boxes [[ buffer(1) ]],
device float* raw_anchors [[ buffer(2) ]],
uint2 gid [[ thread_position_in_grid ]]) {
uint num_coords = uint($0);
int reverse_output_order = int($1);
int apply_exponential = int($2);
int box_coord_offset = int($3);
int num_keypoints = int($4);
int keypt_coord_offset = int($5);
int num_values_per_keypt = int($6);
)",
options_.num_coords(), // box xywh
options_.reverse_output_order() ? 1 : 0,
options_.apply_exponential_on_box_size() ? 1 : 0,
options_.box_coord_offset(), options_.num_keypoints(),
options_.keypoint_coord_offset(), options_.num_values_per_keypoint());
decode_src += absl::Substitute(
R"(
float4 scale = float4(($0),($1),($2),($3));
)",
options_.x_scale(), options_.y_scale(), options_.w_scale(),
options_.h_scale());
decode_src += R"(
uint g_idx = gid.x;
uint box_offset = g_idx * num_coords + uint(box_coord_offset);
uint anchor_offset = g_idx * uint(4); // check kNumCoordsPerBox
float y_center, x_center, h, w;
if (reverse_output_order == int(0)) {
y_center = raw_boxes[box_offset + uint(0)];
x_center = raw_boxes[box_offset + uint(1)];
h = raw_boxes[box_offset + uint(2)];
w = raw_boxes[box_offset + uint(3)];
} else {
x_center = raw_boxes[box_offset + uint(0)];
y_center = raw_boxes[box_offset + uint(1)];
w = raw_boxes[box_offset + uint(2)];
h = raw_boxes[box_offset + uint(3)];
}
float anchor_yc = raw_anchors[anchor_offset + uint(0)];
float anchor_xc = raw_anchors[anchor_offset + uint(1)];
float anchor_h = raw_anchors[anchor_offset + uint(2)];
float anchor_w = raw_anchors[anchor_offset + uint(3)];
x_center = x_center / scale.x * anchor_w + anchor_xc;
y_center = y_center / scale.y * anchor_h + anchor_yc;
if (apply_exponential == int(1)) {
h = exp(h / scale.w) * anchor_h;
w = exp(w / scale.z) * anchor_w;
} else {
h = (h / scale.w) * anchor_h;
w = (w / scale.z) * anchor_w;
}
float ymin = y_center - h / 2.0;
float xmin = x_center - w / 2.0;
float ymax = y_center + h / 2.0;
float xmax = x_center + w / 2.0;
boxes[box_offset + uint(0)] = ymin;
boxes[box_offset + uint(1)] = xmin;
boxes[box_offset + uint(2)] = ymax;
boxes[box_offset + uint(3)] = xmax;
if (num_keypoints > int(0)){
for (int k = 0; k < num_keypoints; ++k) {
int kp_offset =
int(g_idx * num_coords) + keypt_coord_offset + k * num_values_per_keypt;
float kp_y, kp_x;
if (reverse_output_order == int(0)) {
kp_y = raw_boxes[kp_offset + int(0)];
kp_x = raw_boxes[kp_offset + int(1)];
} else {
kp_x = raw_boxes[kp_offset + int(0)];
kp_y = raw_boxes[kp_offset + int(1)];
}
boxes[kp_offset + int(0)] = kp_x / scale.x * anchor_w + anchor_xc;
boxes[kp_offset + int(1)] = kp_y / scale.y * anchor_h + anchor_yc;
}
}
})";
{
// Shader program
NSString* library_source =
[NSString stringWithUTF8String:decode_src.c_str()];
NSError* error = nil;
id<MTLLibrary> library = [device newLibraryWithSource:library_source
options:nullptr
error:&error];
RET_CHECK(library != nil) << "Couldn't create shader library "
<< [[error localizedDescription] UTF8String];
id<MTLFunction> kernel_func = nil;
kernel_func = [library newFunctionWithName:@"decodeKernel"];
RET_CHECK(kernel_func != nil) << "Couldn't create kernel function.";
gpu_data_->decode_program =
[device newComputePipelineStateWithFunction:kernel_func error:&error];
RET_CHECK(gpu_data_->decode_program != nil)
<< "Couldn't create pipeline state "
<< [[error localizedDescription] UTF8String];
// Outputs
size_t decoded_boxes_length = num_boxes_ * num_coords_ * sizeof(float);
gpu_data_->decoded_boxes_buffer =
[device newBufferWithLength:decoded_boxes_length
options:MTLResourceStorageModeShared];
// Inputs
size_t raw_boxes_length = num_boxes_ * num_coords_ * sizeof(float);
gpu_data_->raw_boxes_buffer =
[device newBufferWithLength:raw_boxes_length
options:MTLResourceStorageModeShared];
size_t raw_anchors_length = num_boxes_ * kNumCoordsPerBox * sizeof(float);
gpu_data_->raw_anchors_buffer =
[device newBufferWithLength:raw_anchors_length
options:MTLResourceStorageModeShared];
}
// A shader to score detection boxes.
const std::string score_src = absl::Substitute(
R"(
#include <metal_stdlib>
using namespace metal;
float optional_sigmoid(float x) {
int apply_sigmoid = int($1);
int apply_clipping_thresh = int($2);
float clipping_thresh = float($3);
if (apply_sigmoid == int(0)) return x;
if (apply_clipping_thresh == int(1)) {
x = clamp(x, -clipping_thresh, clipping_thresh);
}
x = 1.0 / (1.0 + exp(-x));
return x;
}
kernel void scoreKernel(
device float* scored_boxes [[ buffer(0) ]],
device float* raw_scores [[ buffer(1) ]],
uint2 tid [[ thread_position_in_threadgroup ]],
uint2 gid [[ thread_position_in_grid ]]) {
uint num_classes = uint($0);
int apply_sigmoid = int($1);
int apply_clipping_thresh = int($2);
float clipping_thresh = float($3);
int ignore_class_0 = int($4);
uint g_idx = gid.x; // box idx
uint s_idx = tid.y; // score/class idx
// load all scores into shared memory
threadgroup float local_scores[$0];
float score = raw_scores[g_idx * num_classes + s_idx];
local_scores[s_idx] = optional_sigmoid(score);
threadgroup_barrier(mem_flags::mem_threadgroup);
// find max score in shared memory
if (s_idx == uint(0)) {
float max_score = -FLT_MAX;
float max_class = -1.0;
for (int i=ignore_class_0; i<int(num_classes); ++i) {
if (local_scores[i] > max_score) {
max_score = local_scores[i];
max_class = float(i);
}
}
scored_boxes[g_idx * uint(2) + uint(0)] = max_score;
scored_boxes[g_idx * uint(2) + uint(1)] = max_class;
}
})", })",
num_classes_, options_.sigmoid_score() ? 1 : 0, num_classes_, options_.sigmoid_score() ? 1 : 0,
options_.has_score_clipping_thresh() ? 1 : 0, options_.has_score_clipping_thresh() ? 1 : 0,
@ -788,42 +1123,44 @@ void main() {
: 0, : 0,
ignore_classes_.size() ? 1 : 0); ignore_classes_.size() ? 1 : 0);
// # filter classes supported is hardware dependent.
int max_wg_size; // typically <= 1024
glGetIntegeri_v(GL_MAX_COMPUTE_WORK_GROUP_SIZE, 1, &max_wg_size); // y-dim
CHECK_LT(num_classes_, max_wg_size) << "# classes must be < " << max_wg_size;
// TODO support better filtering. // TODO support better filtering.
CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed"; CHECK_LE(ignore_classes_.size(), 1) << "Only ignore class 0 is allowed";
// Shader program {
GlShader score_shader; // Shader program
status = GlShader::CompileShader(GL_COMPUTE_SHADER, score_src, &score_shader); NSString* library_source =
if (!status.ok()) { [NSString stringWithUTF8String:score_src.c_str()];
return ::mediapipe::InternalError(status.error_message()); NSError* error = nil;
} id<MTLLibrary> library = [device newLibraryWithSource:library_source
score_program_ = absl::make_unique<GlProgram>(); options:nullptr
status = GlProgram::CreateWithShader(score_shader, score_program_.get()); error:&error];
if (!status.ok()) { RET_CHECK(library != nil) << "Couldn't create shader library "
return ::mediapipe::InternalError(status.error_message()); << [[error localizedDescription] UTF8String];
} id<MTLFunction> kernel_func = nil;
// Outputs kernel_func = [library newFunctionWithName:@"scoreKernel"];
size_t scored_boxes_length = num_boxes_ * 2; // score, class RET_CHECK(kernel_func != nil) << "Couldn't create kernel function.";
scored_boxes_buffer_ = absl::make_unique<GlBuffer>(); gpu_data_->score_program =
status = CreateReadWriteShaderStorageBuffer<float>( [device newComputePipelineStateWithFunction:kernel_func error:&error];
scored_boxes_length, scored_boxes_buffer_.get()); RET_CHECK(gpu_data_->score_program != nil)
if (!status.ok()) { << "Couldn't create pipeline state "
return ::mediapipe::InternalError(status.error_message()); << [[error localizedDescription] UTF8String];
} // Outputs
// Inputs size_t scored_boxes_length = num_boxes_ * 2 * sizeof(float); // score,class
size_t raw_scores_length = num_boxes_ * num_classes_; gpu_data_->scored_boxes_buffer =
raw_scores_buffer_ = absl::make_unique<GlBuffer>(); [device newBufferWithLength:scored_boxes_length
status = CreateReadWriteShaderStorageBuffer<float>(raw_scores_length, options:MTLResourceStorageModeShared];
raw_scores_buffer_.get()); // Inputs
if (!status.ok()) { size_t raw_scores_length = num_boxes_ * num_classes_ * sizeof(float);
return ::mediapipe::InternalError(status.error_message()); gpu_data_->raw_scores_buffer =
[device newBufferWithLength:raw_scores_length
options:MTLResourceStorageModeShared];
// # filter classes supported is hardware dependent.
int max_wg_size = gpu_data_->score_program.maxTotalThreadsPerThreadgroup;
CHECK_LT(num_classes_, max_wg_size) << "# classes must be <" << max_wg_size;
} }
#endif // defined(__ANDROID__) #endif // __ANDROID__ or iOS
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -96,7 +96,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
options_.has_input_image_width()) options_.has_input_image_width())
<< "Must provide input with/height for getting normalized landmarks."; << "Must provide input with/height for getting normalized landmarks.";
} }
if (cc->Outputs().HasTag("LANDMARKS") && options_.flip_vertically()) { if (cc->Outputs().HasTag("LANDMARKS") &&
(options_.flip_vertically() || options_.flip_horizontally())) {
RET_CHECK(options_.has_input_image_height() && RET_CHECK(options_.has_input_image_height() &&
options_.has_input_image_width()) options_.has_input_image_width())
<< "Must provide input with/height for using flip_vertically option " << "Must provide input with/height for using flip_vertically option "
@ -133,7 +134,12 @@ REGISTER_CALCULATOR(TfLiteTensorsToLandmarksCalculator);
for (int ld = 0; ld < num_landmarks_; ++ld) { for (int ld = 0; ld < num_landmarks_; ++ld) {
const int offset = ld * num_dimensions; const int offset = ld * num_dimensions;
Landmark landmark; Landmark landmark;
landmark.set_x(raw_landmarks[offset]);
if (options_.flip_horizontally()) {
landmark.set_x(options_.input_image_width() - raw_landmarks[offset]);
} else {
landmark.set_x(raw_landmarks[offset]);
}
if (num_dimensions > 1) { if (num_dimensions > 1) {
if (options_.flip_vertically()) { if (options_.flip_vertically()) {
landmark.set_y(options_.input_image_height() - landmark.set_y(options_.input_image_height() -

View File

@ -40,6 +40,12 @@ message TfLiteTensorsToLandmarksCalculatorOptions {
// representation has a bottom-left origin (e.g., in OpenGL). // representation has a bottom-left origin (e.g., in OpenGL).
optional bool flip_vertically = 4 [default = false]; optional bool flip_vertically = 4 [default = false];
// Whether the detection coordinates from the input tensors should be flipped
// horizontally (along the x-direction). This is useful, for example, when the
// input image is horizontally flipped in ImageTransformationCalculator
// beforehand.
optional bool flip_horizontally = 6 [default = false];
// A value that z values should be divided by. // A value that z values should be divided by.
optional float normalize_z = 5 [default = 1.0]; optional float normalize_z = 5 [default = 1.0];
} }

View File

@ -17,6 +17,7 @@
#include "absl/strings/str_format.h" #include "absl/strings/str_format.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.pb.h" #include "mediapipe/calculators/tflite/tflite_tensors_to_segmentation_calculator.pb.h"
#include "mediapipe/calculators/tflite/util.h"
#include "mediapipe/framework/calculator_context.h" #include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h" #include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h" #include "mediapipe/framework/formats/image_frame.h"
@ -27,7 +28,7 @@
#include "mediapipe/util/resource_util.h" #include "mediapipe/util/resource_util.h"
#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/interpreter.h"
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/shader_util.h" #include "mediapipe/gpu/shader_util.h"
@ -36,7 +37,7 @@
#include "tensorflow/lite/delegates/gpu/gl/gl_shader.h" #include "tensorflow/lite/delegates/gpu/gl/gl_shader.h"
#include "tensorflow/lite/delegates/gpu/gl/gl_texture.h" #include "tensorflow/lite/delegates/gpu/gl/gl_texture.h"
#include "tensorflow/lite/delegates/gpu/gl_delegate.h" #include "tensorflow/lite/delegates/gpu/gl_delegate.h"
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
namespace { namespace {
constexpr int kWorkgroupSize = 8; // Block size for GPU shader. constexpr int kWorkgroupSize = 8; // Block size for GPU shader.
@ -52,12 +53,14 @@ float Clamp(float val, float min, float max) {
namespace mediapipe { namespace mediapipe {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
using ::tflite::gpu::gl::CopyBuffer;
using ::tflite::gpu::gl::CreateReadWriteRgbaImageTexture;
using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer; using ::tflite::gpu::gl::CreateReadWriteShaderStorageBuffer;
using ::tflite::gpu::gl::GlBuffer; using ::tflite::gpu::gl::GlBuffer;
using ::tflite::gpu::gl::GlProgram; using ::tflite::gpu::gl::GlProgram;
using ::tflite::gpu::gl::GlShader; using ::tflite::gpu::gl::GlShader;
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
// Converts TFLite tensors from a tflite segmentation model to an image mask. // Converts TFLite tensors from a tflite segmentation model to an image mask.
// //
@ -126,13 +129,13 @@ class TfLiteTensorsToSegmentationCalculator : public CalculatorBase {
int tensor_channels_ = 0; int tensor_channels_ = 0;
bool use_gpu_ = false; bool use_gpu_ = false;
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
std::unique_ptr<GlProgram> mask_program_with_prev_; std::unique_ptr<GlProgram> mask_program_with_prev_;
std::unique_ptr<GlProgram> mask_program_no_prev_; std::unique_ptr<GlProgram> mask_program_no_prev_;
std::unique_ptr<GlBuffer> tensor_buffer_; std::unique_ptr<GlBuffer> tensor_buffer_;
GLuint upsample_program_; GLuint upsample_program_;
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
}; };
REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator); REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
@ -142,6 +145,8 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Inputs().GetTags().empty());
RET_CHECK(!cc->Outputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty());
bool use_gpu = false;
// Inputs CPU. // Inputs CPU.
if (cc->Inputs().HasTag("TENSORS")) { if (cc->Inputs().HasTag("TENSORS")) {
cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>(); cc->Inputs().Tag("TENSORS").Set<std::vector<TfLiteTensor>>();
@ -154,32 +159,37 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
} }
// Inputs GPU. // Inputs GPU.
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
if (cc->Inputs().HasTag("TENSORS_GPU")) { if (cc->Inputs().HasTag("TENSORS_GPU")) {
cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>(); cc->Inputs().Tag("TENSORS_GPU").Set<std::vector<GlBuffer>>();
use_gpu |= true;
} }
if (cc->Inputs().HasTag("PREV_MASK_GPU")) { if (cc->Inputs().HasTag("PREV_MASK_GPU")) {
cc->Inputs().Tag("PREV_MASK_GPU").Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag("PREV_MASK_GPU").Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
if (cc->Inputs().HasTag("REFERENCE_IMAGE_GPU")) { if (cc->Inputs().HasTag("REFERENCE_IMAGE_GPU")) {
cc->Inputs().Tag("REFERENCE_IMAGE_GPU").Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag("REFERENCE_IMAGE_GPU").Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
// Outputs. // Outputs.
if (cc->Outputs().HasTag("MASK")) { if (cc->Outputs().HasTag("MASK")) {
cc->Outputs().Tag("MASK").Set<ImageFrame>(); cc->Outputs().Tag("MASK").Set<ImageFrame>();
} }
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
if (cc->Outputs().HasTag("MASK_GPU")) { if (cc->Outputs().HasTag("MASK_GPU")) {
cc->Outputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>(); cc->Outputs().Tag("MASK_GPU").Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
#if defined(__ANDROID__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // __ANDROID__
if (use_gpu) {
#if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // !MEDIAPIPE_DISABLE_GPU
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -189,24 +199,23 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
if (cc->Inputs().HasTag("TENSORS_GPU")) { if (cc->Inputs().HasTag("TENSORS_GPU")) {
use_gpu_ = true; use_gpu_ = true;
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
} }
MP_RETURN_IF_ERROR(LoadOptions(cc)); MP_RETURN_IF_ERROR(LoadOptions(cc));
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
MP_RETURN_IF_ERROR(InitGpu(cc)); MP_RETURN_IF_ERROR(InitGpu(cc));
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
#else #else
RET_CHECK_FAIL() RET_CHECK_FAIL() << "GPU processing not enabled.";
<< "GPU processing on non-Android devices is not supported yet."; #endif // !MEDIAPIPE_DISABLE_GPU
#endif // __ANDROID__
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
@ -215,13 +224,13 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Process( ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Process(
CalculatorContext* cc) { CalculatorContext* cc) {
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
MP_RETURN_IF_ERROR(ProcessGpu(cc)); MP_RETURN_IF_ERROR(ProcessGpu(cc));
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else {
MP_RETURN_IF_ERROR(ProcessCpu(cc)); MP_RETURN_IF_ERROR(ProcessCpu(cc));
} }
@ -231,7 +240,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close( ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::Close(
CalculatorContext* cc) { CalculatorContext* cc) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
gpu_helper_.RunInGlContext([this] { gpu_helper_.RunInGlContext([this] {
if (upsample_program_) glDeleteProgram(upsample_program_); if (upsample_program_) glDeleteProgram(upsample_program_);
upsample_program_ = 0; upsample_program_ = 0;
@ -239,7 +248,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
mask_program_no_prev_.reset(); mask_program_no_prev_.reset();
tensor_buffer_.reset(); tensor_buffer_.reset();
}); });
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -358,7 +367,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
if (cc->Inputs().Tag("TENSORS_GPU").IsEmpty()) { if (cc->Inputs().Tag("TENSORS_GPU").IsEmpty()) {
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
// Get input streams. // Get input streams.
const auto& input_tensors = const auto& input_tensors =
cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>(); cc->Inputs().Tag("TENSORS_GPU").Get<std::vector<GlBuffer>>();
@ -379,9 +388,9 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
// Create initial working mask texture. // Create initial working mask texture.
::tflite::gpu::gl::GlTexture small_mask_texture; ::tflite::gpu::gl::GlTexture small_mask_texture;
::tflite::gpu::gl::CreateReadWriteRgbaImageTexture( RET_CHECK_CALL(CreateReadWriteRgbaImageTexture(
tflite::gpu::DataType::UINT8, // GL_RGBA8 tflite::gpu::DataType::UINT8, // GL_RGBA8
{tensor_width_, tensor_height_}, &small_mask_texture); {tensor_width_, tensor_height_}, &small_mask_texture));
// Get input previous mask. // Get input previous mask.
auto input_mask_texture = has_prev_mask auto input_mask_texture = has_prev_mask
@ -389,7 +398,7 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
: mediapipe::GlTexture(); : mediapipe::GlTexture();
// Copy input tensor. // Copy input tensor.
tflite::gpu::gl::CopyBuffer(input_tensors[0], *tensor_buffer_); RET_CHECK_CALL(CopyBuffer(input_tensors[0], *tensor_buffer_));
// Run shader, process mask tensor. // Run shader, process mask tensor.
// Run softmax over tensor output and blend with previous mask. // Run softmax over tensor output and blend with previous mask.
@ -397,18 +406,18 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
const int output_index = 0; const int output_index = 0;
glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0, glBindImageTexture(output_index, small_mask_texture.id(), 0, GL_FALSE, 0,
GL_WRITE_ONLY, GL_RGBA8); GL_WRITE_ONLY, GL_RGBA8);
tensor_buffer_->BindToIndex(2); RET_CHECK_CALL(tensor_buffer_->BindToIndex(2));
const tflite::gpu::uint3 workgroups = { const tflite::gpu::uint3 workgroups = {
NumGroups(tensor_width_, kWorkgroupSize), NumGroups(tensor_width_, kWorkgroupSize),
NumGroups(tensor_height_, kWorkgroupSize), 1}; NumGroups(tensor_height_, kWorkgroupSize), 1};
if (!has_prev_mask) { if (!has_prev_mask) {
mask_program_no_prev_->Dispatch(workgroups); RET_CHECK_CALL(mask_program_no_prev_->Dispatch(workgroups));
} else { } else {
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, input_mask_texture.name()); glBindTexture(GL_TEXTURE_2D, input_mask_texture.name());
mask_program_with_prev_->Dispatch(workgroups); RET_CHECK_CALL(mask_program_with_prev_->Dispatch(workgroups));
glActiveTexture(GL_TEXTURE1); glActiveTexture(GL_TEXTURE1);
glBindTexture(GL_TEXTURE_2D, 0); glBindTexture(GL_TEXTURE_2D, 0);
} }
@ -438,13 +447,13 @@ REGISTER_CALCULATOR(TfLiteTensorsToSegmentationCalculator);
// Cleanup // Cleanup
input_mask_texture.Release(); input_mask_texture.Release();
output_texture.Release(); output_texture.Release();
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
void TfLiteTensorsToSegmentationCalculator::GlRender() { void TfLiteTensorsToSegmentationCalculator::GlRender() {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
static const GLfloat square_vertices[] = { static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left -1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right 1.0f, -1.0f, // bottom right
@ -492,7 +501,7 @@ void TfLiteTensorsToSegmentationCalculator::GlRender() {
glBindVertexArray(0); glBindVertexArray(0);
glDeleteVertexArrays(1, &vao); glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo); glDeleteBuffers(2, vbo);
#endif // __ANDROID__ #endif // !MEDIAPIPE_DISABLE_GPU
} }
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::LoadOptions( ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::LoadOptions(
@ -516,14 +525,15 @@ void TfLiteTensorsToSegmentationCalculator::GlRender() {
::mediapipe::Status TfLiteTensorsToSegmentationCalculator::InitGpu( ::mediapipe::Status TfLiteTensorsToSegmentationCalculator::InitGpu(
CalculatorContext* cc) { CalculatorContext* cc) {
#if defined(__ANDROID__) #if !defined(MEDIAPIPE_DISABLE_GPU) && !defined(__APPLE__)
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext([this]()
// A shader to process a segmentation tensor into an output mask, -> ::mediapipe::Status {
// and use an optional previous mask as input. // A shader to process a segmentation tensor into an output mask,
// Currently uses 4 channels for output, // and use an optional previous mask as input.
// and sets both R and A channels as mask value. // Currently uses 4 channels for output,
const std::string shader_src_template = // and sets both R and A channels as mask value.
R"( #version 310 es const std::string shader_src_template =
R"( #version 310 es
layout(local_size_x = $0, local_size_y = $0, local_size_z = 1) in; layout(local_size_x = $0, local_size_y = $0, local_size_z = 1) in;
@ -589,76 +599,60 @@ void main() {
imageStore(output_texture, output_coordinate, out_value); imageStore(output_texture, output_coordinate, out_value);
})"; })";
const std::string shader_src_no_previous = absl::Substitute( const std::string shader_src_no_previous = absl::Substitute(
shader_src_template, kWorkgroupSize, options_.output_layer_index(), shader_src_template, kWorkgroupSize, options_.output_layer_index(),
options_.combine_with_previous_ratio(), "", options_.combine_with_previous_ratio(), "",
options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y"); options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y");
const std::string shader_src_with_previous = absl::Substitute( const std::string shader_src_with_previous = absl::Substitute(
shader_src_template, kWorkgroupSize, options_.output_layer_index(), shader_src_template, kWorkgroupSize, options_.output_layer_index(),
options_.combine_with_previous_ratio(), "#define READ_PREVIOUS", options_.combine_with_previous_ratio(), "#define READ_PREVIOUS",
options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y"); options_.flip_vertically() ? "out_height - gid.y - 1" : "gid.y");
auto status = ::tflite::gpu::OkStatus(); // Shader programs.
GlShader shader_without_previous;
RET_CHECK_CALL(GlShader::CompileShader(
GL_COMPUTE_SHADER, shader_src_no_previous, &shader_without_previous));
mask_program_no_prev_ = absl::make_unique<GlProgram>();
RET_CHECK_CALL(GlProgram::CreateWithShader(shader_without_previous,
mask_program_no_prev_.get()));
GlShader shader_with_previous;
RET_CHECK_CALL(GlShader::CompileShader(
GL_COMPUTE_SHADER, shader_src_with_previous, &shader_with_previous));
mask_program_with_prev_ = absl::make_unique<GlProgram>();
RET_CHECK_CALL(GlProgram::CreateWithShader(shader_with_previous,
mask_program_with_prev_.get()));
// Shader programs. // Buffer storage for input tensor.
GlShader shader_without_previous; size_t tensor_length = tensor_width_ * tensor_height_ * tensor_channels_;
status = GlShader::CompileShader(GL_COMPUTE_SHADER, shader_src_no_previous, tensor_buffer_ = absl::make_unique<GlBuffer>();
&shader_without_previous); RET_CHECK_CALL(CreateReadWriteShaderStorageBuffer<float>(
if (!status.ok()) { tensor_length, tensor_buffer_.get()));
return ::mediapipe::InternalError(status.error_message());
}
mask_program_no_prev_ = absl::make_unique<GlProgram>();
status = GlProgram::CreateWithShader(shader_without_previous,
mask_program_no_prev_.get());
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
GlShader shader_with_previous;
status = GlShader::CompileShader(GL_COMPUTE_SHADER, shader_src_with_previous,
&shader_with_previous);
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
mask_program_with_prev_ = absl::make_unique<GlProgram>();
status = GlProgram::CreateWithShader(shader_with_previous,
mask_program_with_prev_.get());
if (!status.ok()) {
return ::mediapipe::InternalError(status.error_message());
}
// Buffer storage for input tensor. // Parameters.
size_t tensor_length = tensor_width_ * tensor_height_ * tensor_channels_; glUseProgram(mask_program_with_prev_->id());
tensor_buffer_ = absl::make_unique<GlBuffer>(); glUniform2i(glGetUniformLocation(mask_program_with_prev_->id(), "out_size"),
status = CreateReadWriteShaderStorageBuffer<float>(tensor_length, tensor_width_, tensor_height_);
tensor_buffer_.get()); glUniform1i(
if (!status.ok()) { glGetUniformLocation(mask_program_with_prev_->id(), "input_texture"),
return ::mediapipe::InternalError(status.error_message()); 1);
} glUseProgram(mask_program_no_prev_->id());
glUniform2i(glGetUniformLocation(mask_program_no_prev_->id(), "out_size"),
tensor_width_, tensor_height_);
glUniform1i(
glGetUniformLocation(mask_program_no_prev_->id(), "input_texture"), 1);
// Parameters. // Vertex shader attributes.
glUseProgram(mask_program_with_prev_->id()); const GLint attr_location[NUM_ATTRIBUTES] = {
glUniform2i(glGetUniformLocation(mask_program_with_prev_->id(), "out_size"), ATTRIB_VERTEX,
tensor_width_, tensor_height_); ATTRIB_TEXTURE_POSITION,
glUniform1i( };
glGetUniformLocation(mask_program_with_prev_->id(), "input_texture"), 1); const GLchar* attr_name[NUM_ATTRIBUTES] = {
glUseProgram(mask_program_no_prev_->id()); "position",
glUniform2i(glGetUniformLocation(mask_program_no_prev_->id(), "out_size"), "texture_coordinate",
tensor_width_, tensor_height_); };
glUniform1i(
glGetUniformLocation(mask_program_no_prev_->id(), "input_texture"), 1);
// Vertex shader attributes. // Simple pass-through shader, used for hardware upsampling.
const GLint attr_location[NUM_ATTRIBUTES] = { std::string upsample_shader_base = R"(
ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION,
};
const GLchar* attr_name[NUM_ATTRIBUTES] = {
"position",
"texture_coordinate",
};
// Simple pass-through shader, used for hardware upsampling.
std::string upsample_shader_base = R"(
#if __VERSION__ < 130 #if __VERSION__ < 130
#define in varying #define in varying
#endif // __VERSION__ < 130 #endif // __VERSION__ < 130
@ -683,16 +677,19 @@ void main() {
} }
)"; )";
// Program // Program
mediapipe::GlhCreateProgram(mediapipe::kBasicVertexShader, mediapipe::GlhCreateProgram(
upsample_shader_base.c_str(), NUM_ATTRIBUTES, mediapipe::kBasicVertexShader, upsample_shader_base.c_str(),
&attr_name[0], attr_location, &upsample_program_); NUM_ATTRIBUTES, &attr_name[0], attr_location, &upsample_program_);
RET_CHECK(upsample_program_) << "Problem initializing the program."; RET_CHECK(upsample_program_) << "Problem initializing the program.";
// Parameters // Parameters
glUseProgram(upsample_program_); glUseProgram(upsample_program_);
glUniform1i(glGetUniformLocation(upsample_program_, "input_data"), 1); glUniform1i(glGetUniformLocation(upsample_program_, "input_data"), 1);
#endif // __ANDROID__
return ::mediapipe::OkStatus();
}));
#endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -0,0 +1,25 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_
#define MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_
#define RET_CHECK_CALL(call) \
do { \
const auto status = (call); \
if (ABSL_PREDICT_FALSE(!status.ok())) \
return ::mediapipe::InternalError(status.error_message()); \
} while (0);
#endif // MEDIAPIPE_CALCULATORS_TFLITE_UTIL_H_

View File

@ -235,19 +235,13 @@ cc_library(
"//mediapipe/framework/port:vector", "//mediapipe/framework/port:vector",
"//mediapipe/util:annotation_renderer", "//mediapipe/util:annotation_renderer",
] + select({ ] + select({
"//mediapipe:android": [ "//mediapipe/gpu:disable_gpu": [],
"//conditions:default": [
"//mediapipe/gpu:gl_calculator_helper", "//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders", "//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer", "//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:shader_util", "//mediapipe/gpu:shader_util",
], ],
"//mediapipe:ios": [
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gl_simple_shaders",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:shader_util",
],
"//conditions:default": [],
}), }),
alwayslink = 1, alwayslink = 1,
) )
@ -694,3 +688,65 @@ cc_test(
"//mediapipe/framework/tool:validate_type", "//mediapipe/framework/tool:validate_type",
], ],
) )
proto_library(
name = "top_k_scores_calculator_proto",
srcs = ["top_k_scores_calculator.proto"],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/framework:calculator_proto",
],
)
mediapipe_cc_proto_library(
name = "top_k_scores_calculator_cc_proto",
srcs = ["top_k_scores_calculator.proto"],
cc_deps = [
"//mediapipe/framework:calculator_cc_proto",
],
visibility = ["//visibility:public"],
deps = [":top_k_scores_calculator_proto"],
)
cc_library(
name = "top_k_scores_calculator",
srcs = ["top_k_scores_calculator.cc"],
visibility = ["//visibility:public"],
deps = [
":top_k_scores_calculator_cc_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor",
"//mediapipe/framework:calculator_framework",
"//mediapipe/util:resource_util",
] + select({
"//mediapipe:android": [
"//mediapipe/util/android/file/base",
],
"//mediapipe:apple": [
"//mediapipe/util/android/file/base",
],
"//mediapipe:macos": [
"//mediapipe/framework/port:file_helpers",
],
"//conditions:default": [
"//mediapipe/framework/port:file_helpers",
],
}),
alwayslink = 1,
)
cc_test(
name = "top_k_scores_calculator_test",
srcs = ["top_k_scores_calculator_test.cc"],
deps = [
":top_k_scores_calculator",
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework:calculator_runner",
"//mediapipe/framework:packet",
"//mediapipe/framework/deps:message_matchers",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)

View File

@ -27,12 +27,12 @@
#include "mediapipe/util/annotation_renderer.h" #include "mediapipe/util/annotation_renderer.h"
#include "mediapipe/util/color.pb.h" #include "mediapipe/util/color.pb.h"
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
#include "mediapipe/gpu/gl_calculator_helper.h" #include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gl_simple_shaders.h" #include "mediapipe/gpu/gl_simple_shaders.h"
#include "mediapipe/gpu/gpu_buffer.h" #include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/shader_util.h" #include "mediapipe/gpu/shader_util.h"
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
namespace mediapipe { namespace mediapipe {
@ -146,13 +146,13 @@ class AnnotationOverlayCalculator : public CalculatorBase {
bool use_gpu_ = false; bool use_gpu_ = false;
bool gpu_initialized_ = false; bool gpu_initialized_ = false;
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
mediapipe::GlCalculatorHelper gpu_helper_; mediapipe::GlCalculatorHelper gpu_helper_;
GLuint program_ = 0; GLuint program_ = 0;
GLuint image_mat_tex_ = 0; // Overlay drawing image for GPU. GLuint image_mat_tex_ = 0; // Overlay drawing image for GPU.
int width_ = 0; int width_ = 0;
int height_ = 0; int height_ = 0;
#endif // __ANDROID__ or iOS #endif // MEDIAPIPE_DISABLE_GPU
}; };
REGISTER_CALCULATOR(AnnotationOverlayCalculator); REGISTER_CALCULATOR(AnnotationOverlayCalculator);
@ -160,6 +160,8 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
CalculatorContract* cc) { CalculatorContract* cc) {
CHECK_GE(cc->Inputs().NumEntries(), 1); CHECK_GE(cc->Inputs().NumEntries(), 1);
bool use_gpu = false;
if (cc->Inputs().HasTag(kInputFrameTag) && if (cc->Inputs().HasTag(kInputFrameTag) &&
cc->Inputs().HasTag(kInputFrameTagGpu)) { cc->Inputs().HasTag(kInputFrameTagGpu)) {
return ::mediapipe::InternalError("Cannot have multiple input images."); return ::mediapipe::InternalError("Cannot have multiple input images.");
@ -173,12 +175,13 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
int num_render_streams = cc->Inputs().NumEntries(); int num_render_streams = cc->Inputs().NumEntries();
// Input image to render onto copy of. // Input image to render onto copy of.
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Inputs().HasTag(kInputFrameTagGpu)) { if (cc->Inputs().HasTag(kInputFrameTagGpu)) {
cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kInputFrameTagGpu).Set<mediapipe::GpuBuffer>();
num_render_streams = cc->Inputs().NumEntries() - 1; num_render_streams = cc->Inputs().NumEntries() - 1;
use_gpu |= true;
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Inputs().HasTag(kInputFrameTag)) { if (cc->Inputs().HasTag(kInputFrameTag)) {
cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>(); cc->Inputs().Tag(kInputFrameTag).Set<ImageFrame>();
num_render_streams = cc->Inputs().NumEntries() - 1; num_render_streams = cc->Inputs().NumEntries() - 1;
@ -190,18 +193,21 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
} }
// Rendered image. // Rendered image.
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (cc->Outputs().HasTag(kOutputFrameTagGpu)) { if (cc->Outputs().HasTag(kOutputFrameTagGpu)) {
cc->Outputs().Tag(kOutputFrameTagGpu).Set<mediapipe::GpuBuffer>(); cc->Outputs().Tag(kOutputFrameTagGpu).Set<mediapipe::GpuBuffer>();
use_gpu |= true;
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
if (cc->Outputs().HasTag(kOutputFrameTag)) { if (cc->Outputs().HasTag(kOutputFrameTag)) {
cc->Outputs().Tag(kOutputFrameTag).Set<ImageFrame>(); cc->Outputs().Tag(kOutputFrameTag).Set<ImageFrame>();
} }
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) if (use_gpu) {
MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc)); #if !defined(MEDIAPIPE_DISABLE_GPU)
#endif // __ANDROID__ or iOS MP_RETURN_IF_ERROR(mediapipe::GlCalculatorHelper::UpdateContract(cc));
#endif // !MEDIAPIPE_DISABLE_GPU
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -212,11 +218,11 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
options_ = cc->Options<AnnotationOverlayCalculatorOptions>(); options_ = cc->Options<AnnotationOverlayCalculatorOptions>();
if (cc->Inputs().HasTag(kInputFrameTagGpu) && if (cc->Inputs().HasTag(kInputFrameTagGpu) &&
cc->Outputs().HasTag(kOutputFrameTagGpu)) { cc->Outputs().HasTag(kOutputFrameTagGpu)) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
use_gpu_ = true; use_gpu_ = true;
#else #else
RET_CHECK_FAIL() << "GPU processing is for Android and iOS only."; RET_CHECK_FAIL() << "GPU processing not enabled.";
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} }
if (cc->Inputs().HasTag(kInputFrameTagGpu) || if (cc->Inputs().HasTag(kInputFrameTagGpu) ||
@ -246,9 +252,9 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
} }
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
MP_RETURN_IF_ERROR(gpu_helper_.Open(cc)); MP_RETURN_IF_ERROR(gpu_helper_.Open(cc));
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} }
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
@ -260,7 +266,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
std::unique_ptr<cv::Mat> image_mat; std::unique_ptr<cv::Mat> image_mat;
ImageFormat::Format target_format; ImageFormat::Format target_format;
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (!gpu_initialized_) { if (!gpu_initialized_) {
MP_RETURN_IF_ERROR( MP_RETURN_IF_ERROR(
gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status { gpu_helper_.RunInGlContext([this, cc]() -> ::mediapipe::Status {
@ -269,7 +275,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
})); }));
gpu_initialized_ = true; gpu_initialized_ = true;
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
MP_RETURN_IF_ERROR(CreateRenderTargetGpu(cc, image_mat)); MP_RETURN_IF_ERROR(CreateRenderTargetGpu(cc, image_mat));
} else { } else {
MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format)); MP_RETURN_IF_ERROR(CreateRenderTargetCpu(cc, image_mat, &target_format));
@ -288,7 +294,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
} }
if (use_gpu_) { if (use_gpu_) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
// Overlay rendered image in OpenGL, onto a copy of input. // Overlay rendered image in OpenGL, onto a copy of input.
uchar* image_mat_ptr = image_mat->data; uchar* image_mat_ptr = image_mat->data;
MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext( MP_RETURN_IF_ERROR(gpu_helper_.RunInGlContext(
@ -296,7 +302,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
MP_RETURN_IF_ERROR(RenderToGpu(cc, image_mat_ptr)); MP_RETURN_IF_ERROR(RenderToGpu(cc, image_mat_ptr));
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
})); }));
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
} else { } else {
// Copy the rendered image to output. // Copy the rendered image to output.
uchar* image_mat_ptr = image_mat->data; uchar* image_mat_ptr = image_mat->data;
@ -307,14 +313,14 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
} }
::mediapipe::Status AnnotationOverlayCalculator::Close(CalculatorContext* cc) { ::mediapipe::Status AnnotationOverlayCalculator::Close(CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
gpu_helper_.RunInGlContext([this] { gpu_helper_.RunInGlContext([this] {
if (program_) glDeleteProgram(program_); if (program_) glDeleteProgram(program_);
program_ = 0; program_ = 0;
if (image_mat_tex_) glDeleteTextures(1, &image_mat_tex_); if (image_mat_tex_) glDeleteTextures(1, &image_mat_tex_);
image_mat_tex_ = 0; image_mat_tex_ = 0;
}); });
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -325,7 +331,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
auto output_frame = absl::make_unique<ImageFrame>( auto output_frame = absl::make_unique<ImageFrame>(
target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight()); target_format, renderer_->GetImageWidth(), renderer_->GetImageHeight());
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(),
renderer_->GetImageHeight(), data_image, renderer_->GetImageHeight(), data_image,
ImageFrame::kGlDefaultAlignmentBoundary); ImageFrame::kGlDefaultAlignmentBoundary);
@ -333,7 +339,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(), output_frame->CopyPixelData(target_format, renderer_->GetImageWidth(),
renderer_->GetImageHeight(), data_image, renderer_->GetImageHeight(), data_image,
ImageFrame::kDefaultAlignmentBoundary); ImageFrame::kDefaultAlignmentBoundary);
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
cc->Outputs() cc->Outputs()
.Tag(kOutputFrameTag) .Tag(kOutputFrameTag)
@ -344,7 +350,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
::mediapipe::Status AnnotationOverlayCalculator::RenderToGpu( ::mediapipe::Status AnnotationOverlayCalculator::RenderToGpu(
CalculatorContext* cc, uchar* overlay_image) { CalculatorContext* cc, uchar* overlay_image) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
// Source and destination textures. // Source and destination textures.
const auto& input_frame = const auto& input_frame =
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
@ -390,7 +396,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
// Cleanup // Cleanup
input_texture.Release(); input_texture.Release();
output_texture.Release(); output_texture.Release();
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -451,15 +457,16 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
::mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetGpu( ::mediapipe::Status AnnotationOverlayCalculator::CreateRenderTargetGpu(
CalculatorContext* cc, std::unique_ptr<cv::Mat>& image_mat) { CalculatorContext* cc, std::unique_ptr<cv::Mat>& image_mat) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
if (image_frame_available_) { if (image_frame_available_) {
const auto& input_frame = const auto& input_frame =
cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>(); cc->Inputs().Tag(kInputFrameTagGpu).Get<mediapipe::GpuBuffer>();
const mediapipe::ImageFormat::Format format = const mediapipe::ImageFormat::Format format =
mediapipe::ImageFormatForGpuBufferFormat(input_frame.format()); mediapipe::ImageFormatForGpuBufferFormat(input_frame.format());
if (format != mediapipe::ImageFormat::SRGBA) if (format != mediapipe::ImageFormat::SRGBA &&
RET_CHECK_FAIL() << "Unsupported GPU input format."; format != mediapipe::ImageFormat::SRGB)
RET_CHECK_FAIL() << "Unsupported GPU input format: " << format;
image_mat = absl::make_unique<cv::Mat>( image_mat = absl::make_unique<cv::Mat>(
height_, width_, CV_8UC3, height_, width_, CV_8UC3,
@ -471,14 +478,14 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(), cv::Scalar(options_.canvas_color().r(), options_.canvas_color().g(),
options_.canvas_color().b())); options_.canvas_color().b()));
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status AnnotationOverlayCalculator::GlRender( ::mediapipe::Status AnnotationOverlayCalculator::GlRender(
CalculatorContext* cc) { CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
static const GLfloat square_vertices[] = { static const GLfloat square_vertices[] = {
-1.0f, -1.0f, // bottom left -1.0f, -1.0f, // bottom left
1.0f, -1.0f, // bottom right 1.0f, -1.0f, // bottom right
@ -526,14 +533,14 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
glBindVertexArray(0); glBindVertexArray(0);
glDeleteVertexArrays(1, &vao); glDeleteVertexArrays(1, &vao);
glDeleteBuffers(2, vbo); glDeleteBuffers(2, vbo);
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status AnnotationOverlayCalculator::GlSetup( ::mediapipe::Status AnnotationOverlayCalculator::GlSetup(
CalculatorContext* cc) { CalculatorContext* cc) {
#if defined(__ANDROID__) || (defined(__APPLE__) && !TARGET_OS_OSX) #if !defined(MEDIAPIPE_DISABLE_GPU)
const GLint attr_location[NUM_ATTRIBUTES] = { const GLint attr_location[NUM_ATTRIBUTES] = {
ATTRIB_VERTEX, ATTRIB_VERTEX,
ATTRIB_TEXTURE_POSITION, ATTRIB_TEXTURE_POSITION,
@ -609,7 +616,7 @@ REGISTER_CALCULATOR(AnnotationOverlayCalculator);
glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); glTexParameterf(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
glBindTexture(GL_TEXTURE_2D, 0); glBindTexture(GL_TEXTURE_2D, 0);
} }
#endif // __ANDROID__ or iOS #endif // !MEDIAPIPE_DISABLE_GPU
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -23,10 +23,25 @@ namespace mediapipe {
namespace { namespace {
constexpr char kNormalizedRectTag[] = "NORM_RECT"; constexpr char kNormRectTag[] = "NORM_RECT";
constexpr char kRectTag[] = "RECT"; constexpr char kRectTag[] = "RECT";
constexpr char kNormRectsTag[] = "NORM_RECTS";
constexpr char kRectsTag[] = "RECTS";
constexpr char kRenderDataTag[] = "RENDER_DATA"; constexpr char kRenderDataTag[] = "RENDER_DATA";
RenderAnnotation::Rectangle* NewRect(
const RectToRenderDataCalculatorOptions& options, RenderData* render_data) {
auto* annotation = render_data->add_render_annotations();
annotation->mutable_color()->set_r(options.color().r());
annotation->mutable_color()->set_g(options.color().g());
annotation->mutable_color()->set_b(options.color().b());
annotation->set_thickness(options.thickness());
return options.filled()
? annotation->mutable_filled_rectangle()->mutable_rectangle()
: annotation->mutable_rectangle();
}
void SetRect(bool normalized, double xmin, double ymin, double width, void SetRect(bool normalized, double xmin, double ymin, double width,
double height, double rotation, double height, double rotation,
RenderAnnotation::Rectangle* rect) { RenderAnnotation::Rectangle* rect) {
@ -51,6 +66,8 @@ void SetRect(bool normalized, double xmin, double ymin, double width,
// One of the following: // One of the following:
// NORM_RECT: A NormalizedRect // NORM_RECT: A NormalizedRect
// RECT: A Rect // RECT: A Rect
// NORM_RECTS: An std::vector<NormalizedRect>
// RECTS: An std::vector<Rect>
// //
// Output: // Output:
// RENDER_DATA: A RenderData // RENDER_DATA: A RenderData
@ -83,16 +100,27 @@ REGISTER_CALCULATOR(RectToRenderDataCalculator);
::mediapipe::Status RectToRenderDataCalculator::GetContract( ::mediapipe::Status RectToRenderDataCalculator::GetContract(
CalculatorContract* cc) { CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag(kNormalizedRectTag) ^ RET_CHECK_EQ((cc->Inputs().HasTag(kNormRectTag) ? 1 : 0) +
cc->Inputs().HasTag(kRectTag)); (cc->Inputs().HasTag(kRectTag) ? 1 : 0) +
(cc->Inputs().HasTag(kNormRectsTag) ? 1 : 0) +
(cc->Inputs().HasTag(kRectsTag) ? 1 : 0),
1)
<< "Exactly one of NORM_RECT, RECT, NORM_RECTS or RECTS input stream "
"should be provided.";
RET_CHECK(cc->Outputs().HasTag(kRenderDataTag)); RET_CHECK(cc->Outputs().HasTag(kRenderDataTag));
if (cc->Inputs().HasTag(kNormalizedRectTag)) { if (cc->Inputs().HasTag(kNormRectTag)) {
cc->Inputs().Tag(kNormalizedRectTag).Set<NormalizedRect>(); cc->Inputs().Tag(kNormRectTag).Set<NormalizedRect>();
} }
if (cc->Inputs().HasTag(kRectTag)) { if (cc->Inputs().HasTag(kRectTag)) {
cc->Inputs().Tag(kRectTag).Set<Rect>(); cc->Inputs().Tag(kRectTag).Set<Rect>();
} }
if (cc->Inputs().HasTag(kNormRectsTag)) {
cc->Inputs().Tag(kNormRectsTag).Set<std::vector<NormalizedRect>>();
}
if (cc->Inputs().HasTag(kRectsTag)) {
cc->Inputs().Tag(kRectsTag).Set<std::vector<Rect>>();
}
cc->Outputs().Tag(kRenderDataTag).Set<RenderData>(); cc->Outputs().Tag(kRenderDataTag).Set<RenderData>();
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
@ -108,31 +136,43 @@ REGISTER_CALCULATOR(RectToRenderDataCalculator);
::mediapipe::Status RectToRenderDataCalculator::Process(CalculatorContext* cc) { ::mediapipe::Status RectToRenderDataCalculator::Process(CalculatorContext* cc) {
auto render_data = absl::make_unique<RenderData>(); auto render_data = absl::make_unique<RenderData>();
auto* annotation = render_data->add_render_annotations();
annotation->mutable_color()->set_r(options_.color().r());
annotation->mutable_color()->set_g(options_.color().g());
annotation->mutable_color()->set_b(options_.color().b());
annotation->set_thickness(options_.thickness());
auto* rectangle = if (cc->Inputs().HasTag(kNormRectTag) &&
options_.filled() !cc->Inputs().Tag(kNormRectTag).IsEmpty()) {
? annotation->mutable_filled_rectangle()->mutable_rectangle() const auto& rect = cc->Inputs().Tag(kNormRectTag).Get<NormalizedRect>();
: annotation->mutable_rectangle(); auto* rectangle = NewRect(options_, render_data.get());
if (cc->Inputs().HasTag(kNormalizedRectTag) &&
!cc->Inputs().Tag(kNormalizedRectTag).IsEmpty()) {
const auto& rect =
cc->Inputs().Tag(kNormalizedRectTag).Get<NormalizedRect>();
SetRect(/*normalized=*/true, rect.x_center() - rect.width() / 2.f, SetRect(/*normalized=*/true, rect.x_center() - rect.width() / 2.f,
rect.y_center() - rect.height() / 2.f, rect.width(), rect.height(), rect.y_center() - rect.height() / 2.f, rect.width(), rect.height(),
rect.rotation(), rectangle); rect.rotation(), rectangle);
} }
if (cc->Inputs().HasTag(kRectTag) && !cc->Inputs().Tag(kRectTag).IsEmpty()) { if (cc->Inputs().HasTag(kRectTag) && !cc->Inputs().Tag(kRectTag).IsEmpty()) {
const auto& rect = cc->Inputs().Tag(kRectTag).Get<Rect>(); const auto& rect = cc->Inputs().Tag(kRectTag).Get<Rect>();
auto* rectangle = NewRect(options_, render_data.get());
SetRect(/*normalized=*/false, rect.x_center() - rect.width() / 2.f, SetRect(/*normalized=*/false, rect.x_center() - rect.width() / 2.f,
rect.y_center() - rect.height() / 2.f, rect.width(), rect.height(), rect.y_center() - rect.height() / 2.f, rect.width(), rect.height(),
rect.rotation(), rectangle); rect.rotation(), rectangle);
} }
if (cc->Inputs().HasTag(kNormRectsTag) &&
!cc->Inputs().Tag(kNormRectsTag).IsEmpty()) {
const auto& rects =
cc->Inputs().Tag(kNormRectsTag).Get<std::vector<NormalizedRect>>();
for (auto& rect : rects) {
auto* rectangle = NewRect(options_, render_data.get());
SetRect(/*normalized=*/true, rect.x_center() - rect.width() / 2.f,
rect.y_center() - rect.height() / 2.f, rect.width(),
rect.height(), rect.rotation(), rectangle);
}
}
if (cc->Inputs().HasTag(kRectsTag) &&
!cc->Inputs().Tag(kRectsTag).IsEmpty()) {
const auto& rects = cc->Inputs().Tag(kRectsTag).Get<std::vector<Rect>>();
for (auto& rect : rects) {
auto* rectangle = NewRect(options_, render_data.get());
SetRect(/*normalized=*/false, rect.x_center() - rect.width() / 2.f,
rect.y_center() - rect.height() / 2.f, rect.width(),
rect.height(), rect.rotation(), rectangle);
}
}
cc->Outputs() cc->Outputs()
.Tag(kRenderDataTag) .Tag(kRenderDataTag)

View File

@ -0,0 +1,194 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <istream>
#include <iterator>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "mediapipe/calculators/util/top_k_scores_calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h"
#include "mediapipe/util/resource_util.h"
#if defined(MEDIAPIPE_LITE) || defined(__ANDROID__) || \
(defined(__APPLE__) && !TARGET_OS_OSX)
#include "mediapipe/util/android/file/base/file.h"
#include "mediapipe/util/android/file/base/helpers.h"
#else
#include "mediapipe/framework/port/file_helpers.h"
#endif
namespace mediapipe {
// A calculator that takes a vector of scores and returns the indexes, scores,
// labels of the top k elements.
//
// Usage example:
// node {
// calculator: "TopKScoresCalculator"
// input_stream: "SCORES:score_vector"
// output_stream: "TOP_K_INDEXES:top_k_indexes"
// output_stream: "TOP_K_SCORES:top_k_scores"
// output_stream: "TOP_K_LABELS:top_k_labels"
// options: {
// [mediapipe.TopKScoresCalculatorOptions.ext] {
// top_k: 5
// threshold: 0.1
// label_map_path: "/path/to/label/map"
// }
// }
// }
class TopKScoresCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc);
::mediapipe::Status Open(CalculatorContext* cc) override;
::mediapipe::Status Process(CalculatorContext* cc) override;
private:
::mediapipe::Status LoadLabelmap(std::string label_map_path);
int top_k_ = -1;
float threshold_ = 0.0;
std::unordered_map<int, std::string> label_map_;
};
REGISTER_CALCULATOR(TopKScoresCalculator);
::mediapipe::Status TopKScoresCalculator::GetContract(CalculatorContract* cc) {
RET_CHECK(cc->Inputs().HasTag("SCORES"));
cc->Inputs().Tag("SCORES").Set<std::vector<float>>();
if (cc->Outputs().HasTag("TOP_K_INDEXES")) {
cc->Outputs().Tag("TOP_K_INDEXES").Set<std::vector<int>>();
}
if (cc->Outputs().HasTag("TOP_K_SCORES")) {
cc->Outputs().Tag("TOP_K_SCORES").Set<std::vector<float>>();
}
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
cc->Outputs().Tag("TOP_K_LABELS").Set<std::vector<std::string>>();
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TopKScoresCalculator::Open(CalculatorContext* cc) {
const auto& options = cc->Options<::mediapipe::TopKScoresCalculatorOptions>();
RET_CHECK(options.has_top_k() || options.has_threshold())
<< "Must specify at least one of the top_k and threshold fields in "
"TopKScoresCalculatorOptions.";
if (options.has_top_k()) {
RET_CHECK(options.top_k() > 0) << "top_k must be greater than zero.";
top_k_ = options.top_k();
}
if (options.has_threshold()) {
threshold_ = options.threshold();
}
if (options.has_label_map_path()) {
MP_RETURN_IF_ERROR(LoadLabelmap(options.label_map_path()));
}
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
RET_CHECK(!label_map_.empty());
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TopKScoresCalculator::Process(CalculatorContext* cc) {
const std::vector<float>& input_vector =
cc->Inputs().Tag("SCORES").Get<std::vector<float>>();
std::vector<int> top_k_indexes;
std::vector<float> top_k_scores;
std::vector<std::string> top_k_labels;
if (top_k_ > 0) {
top_k_indexes.reserve(top_k_);
top_k_scores.reserve(top_k_);
top_k_labels.reserve(top_k_);
}
std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>,
std::greater<std::pair<float, int>>>
pq;
for (int i = 0; i < input_vector.size(); ++i) {
if (input_vector[i] < threshold_) {
continue;
}
if (top_k_ > 0) {
if (pq.size() < top_k_) {
pq.push(std::pair<float, int>(input_vector[i], i));
} else if (pq.top().first < input_vector[i]) {
pq.pop();
pq.push(std::pair<float, int>(input_vector[i], i));
}
} else {
pq.push(std::pair<float, int>(input_vector[i], i));
}
}
while (!pq.empty()) {
top_k_indexes.push_back(pq.top().second);
top_k_scores.push_back(pq.top().first);
pq.pop();
}
reverse(top_k_indexes.begin(), top_k_indexes.end());
reverse(top_k_scores.begin(), top_k_scores.end());
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
for (int index : top_k_indexes) {
top_k_labels.push_back(label_map_[index]);
}
}
if (cc->Outputs().HasTag("TOP_K_INDEXES")) {
cc->Outputs()
.Tag("TOP_K_INDEXES")
.AddPacket(MakePacket<std::vector<int>>(top_k_indexes)
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag("TOP_K_SCORES")) {
cc->Outputs()
.Tag("TOP_K_SCORES")
.AddPacket(MakePacket<std::vector<float>>(top_k_scores)
.At(cc->InputTimestamp()));
}
if (cc->Outputs().HasTag("TOP_K_LABELS")) {
cc->Outputs()
.Tag("TOP_K_LABELS")
.AddPacket(MakePacket<std::vector<std::string>>(top_k_labels)
.At(cc->InputTimestamp()));
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status TopKScoresCalculator::LoadLabelmap(
std::string label_map_path) {
std::string string_path;
ASSIGN_OR_RETURN(string_path, PathToResourceAsFile(label_map_path));
std::string label_map_string;
MP_RETURN_IF_ERROR(file::GetContents(string_path, &label_map_string));
std::istringstream stream(label_map_string);
std::string line;
int i = 0;
while (std::getline(stream, line)) {
label_map_[i++] = line;
}
return ::mediapipe::OkStatus();
}
} // namespace mediapipe

View File

@ -0,0 +1,33 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package mediapipe;
import "mediapipe/framework/calculator.proto";
message TopKScoresCalculatorOptions {
extend CalculatorOptions {
optional TopKScoresCalculatorOptions ext = 271211788;
}
// How many highest scoring packets to output.
optional int32 top_k = 1;
// If set, only keep the scores that are greater than the threshold.
optional float threshold = 2;
// Path to a label map file for getting the actual name of classes.
optional string label_map_path = 3;
}

View File

@ -0,0 +1,150 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mediapipe/framework/calculator_runner.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
TEST(TopKScoresCalculatorTest, TestNodeConfig) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "TopKScoresCalculator"
input_stream: "SCORES:score_vector"
output_stream: "TOP_K_INDEXES:top_k_indexes"
output_stream: "TOP_K_SCORES:top_k_scores"
options: {
[mediapipe.TopKScoresCalculatorOptions.ext] {}
}
)"));
auto status = runner.Run();
ASSERT_TRUE(!status.ok());
EXPECT_THAT(
status.ToString(),
testing::HasSubstr(
"Must specify at least one of the top_k and threshold fields"));
}
TEST(TopKScoresCalculatorTest, TestTopKOnly) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "TopKScoresCalculator"
input_stream: "SCORES:score_vector"
output_stream: "TOP_K_INDEXES:top_k_indexes"
output_stream: "TOP_K_SCORES:top_k_scores"
options: {
[mediapipe.TopKScoresCalculatorOptions.ext] { top_k: 2 }
}
)"));
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
runner.MutableInputs()->Tag("SCORES").packets.push_back(
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& indexes_outputs =
runner.Outputs().Tag("TOP_K_INDEXES").packets;
ASSERT_EQ(1, indexes_outputs.size());
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
EXPECT_EQ(2, indexes.size());
EXPECT_EQ(3, indexes[0]);
EXPECT_EQ(0, indexes[1]);
const std::vector<Packet>& scores_outputs =
runner.Outputs().Tag("TOP_K_SCORES").packets;
ASSERT_EQ(1, scores_outputs.size());
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
EXPECT_EQ(2, scores.size());
EXPECT_NEAR(1, scores[0], 1e-5);
EXPECT_NEAR(0.9, scores[1], 1e-5);
}
TEST(TopKScoresCalculatorTest, TestThresholdOnly) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "TopKScoresCalculator"
input_stream: "SCORES:score_vector"
output_stream: "TOP_K_INDEXES:top_k_indexes"
output_stream: "TOP_K_SCORES:top_k_scores"
options: {
[mediapipe.TopKScoresCalculatorOptions.ext] { threshold: 0.2 }
}
)"));
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
runner.MutableInputs()->Tag("SCORES").packets.push_back(
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& indexes_outputs =
runner.Outputs().Tag("TOP_K_INDEXES").packets;
ASSERT_EQ(1, indexes_outputs.size());
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
EXPECT_EQ(4, indexes.size());
EXPECT_EQ(3, indexes[0]);
EXPECT_EQ(0, indexes[1]);
EXPECT_EQ(2, indexes[2]);
EXPECT_EQ(1, indexes[3]);
const std::vector<Packet>& scores_outputs =
runner.Outputs().Tag("TOP_K_SCORES").packets;
ASSERT_EQ(1, scores_outputs.size());
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
EXPECT_EQ(4, scores.size());
EXPECT_NEAR(1.0, scores[0], 1e-5);
EXPECT_NEAR(0.9, scores[1], 1e-5);
EXPECT_NEAR(0.3, scores[2], 1e-5);
EXPECT_NEAR(0.2, scores[3], 1e-5);
}
TEST(TopKScoresCalculatorTest, TestBothTopKAndThreshold) {
CalculatorRunner runner(ParseTextProtoOrDie<CalculatorGraphConfig::Node>(R"(
calculator: "TopKScoresCalculator"
input_stream: "SCORES:score_vector"
output_stream: "TOP_K_INDEXES:top_k_indexes"
output_stream: "TOP_K_SCORES:top_k_scores"
options: {
[mediapipe.TopKScoresCalculatorOptions.ext] { top_k: 4 threshold: 0.3 }
}
)"));
std::vector<float> score_vector{0.9, 0.2, 0.3, 1.0, 0.1};
runner.MutableInputs()->Tag("SCORES").packets.push_back(
MakePacket<std::vector<float>>(score_vector).At(Timestamp(0)));
MP_ASSERT_OK(runner.Run());
const std::vector<Packet>& indexes_outputs =
runner.Outputs().Tag("TOP_K_INDEXES").packets;
ASSERT_EQ(1, indexes_outputs.size());
const auto& indexes = indexes_outputs[0].Get<std::vector<int>>();
EXPECT_EQ(3, indexes.size());
EXPECT_EQ(3, indexes[0]);
EXPECT_EQ(0, indexes[1]);
EXPECT_EQ(2, indexes[2]);
const std::vector<Packet>& scores_outputs =
runner.Outputs().Tag("TOP_K_SCORES").packets;
ASSERT_EQ(1, scores_outputs.size());
const auto& scores = scores_outputs[0].Get<std::vector<float>>();
EXPECT_EQ(3, scores.size());
EXPECT_NEAR(1.0, scores[0], 1e-5);
EXPECT_NEAR(0.9, scores[1], 1e-5);
EXPECT_NEAR(0.3, scores[2], 1e-5);
}
} // namespace mediapipe

View File

@ -51,6 +51,15 @@ and model details are described in the
* [Android](./face_detection_mobile_gpu.md) * [Android](./face_detection_mobile_gpu.md)
* [iOS](./face_detection_mobile_gpu.md) * [iOS](./face_detection_mobile_gpu.md)
### Face Detection with CPU
[Face Detection with CPU](./face_detection_mobile_cpu.md) illustrates using the
same TFLite model in a CPU-based pipeline. This example highlights how graphs
can be easily adapted to run on CPU v.s. GPU.
* [Android](./face_detection_mobile_cpu.md)
* [iOS](./face_detection_mobile_cpu.md)
### Hand Detection with GPU ### Hand Detection with GPU
[Hand Detection with GPU](./hand_detection_mobile_gpu.md) illustrates how to use [Hand Detection with GPU](./hand_detection_mobile_gpu.md) illustrates how to use
@ -103,3 +112,29 @@ object detection models (TensorFlow and TFLite) using the MediaPipe C++ APIs.
[Sobel edge detection]:https://en.wikipedia.org/wiki/Sobel_operator [Sobel edge detection]:https://en.wikipedia.org/wiki/Sobel_operator
[CameraX]:https://developer.android.com/training/camerax [CameraX]:https://developer.android.com/training/camerax
### Face Detection on Desktop with Webcam
[Face Detection on Desktop with Webcam](./face_detection_desktop.md) shows how
to use MediaPipe with a TFLite model for face detection on desktop using CPU or
GPU with live video from a webcam.
* [Desktop GPU](./face_detection_desktop.md)
* [Desktop CPU](./face_detection_desktop.md)
### Hand Tracking on Desktop with Webcam
[Hand Tracking on Desktop with Webcam](./hand_tracking_desktop.md) shows how to
use MediaPipe with a TFLite model for hand tracking on desktop using CPU or GPU
with live video from a webcam.
* [Desktop GPU](./hand_tracking_desktop.md)
* [Desktop CPU](./hand_tracking_desktop.md)
### Hair Segmentation on Desktop with Webcam
[Hair Segmentation on Desktop with Webcam](./hair_segmentation_desktop.md) shows
how to use MediaPipe with a TFLite model for hair segmentation on desktop using
GPU with live video from a webcam.
* [Desktop GPU](./hair_segmentation_desktop.md)

View File

@ -0,0 +1,265 @@
## Face Detection on Desktop
This is an example of using MediaPipe to run face detection models (TensorFlow
Lite) and render bounding boxes on the detected faces. To know more about the
face detection models, please refer to the model [`README file`]. Moreover, if
you are interested in running the same TensorfFlow Lite model on Android/iOS,
please see the
[Face Detection on GPU on Android/iOS](face_detection_mobile_gpu.md) and
[Face Detection on CPU on Android/iOS](face_detection_mobile_cpu.md) examples.
We show the face detection demos with TensorFlow Lite model using the Webcam:
- [TensorFlow Lite Face Detection Demo with Webcam (CPU)](#tensorflow-lite-face-detection-demo-with-webcam-cpu)
- [TensorFlow Lite Face Detection Demo with Webcam (GPU)](#tensorflow-lite-face-detection-demo-with-webcam-gpu)
Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please
see
[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md).
Note: If MediaPipe depends on OpenCV 2, please see the [known issues with OpenCV 2](#known-issues-with-opencv-2) section.
### TensorFlow Lite Face Detection Demo with Webcam (CPU)
To build and run the TensorFlow Lite example on desktop (CPU) with Webcam, run:
```bash
# Video from webcam running on desktop CPU
$ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
mediapipe/examples/desktop/face_detection:face_detection_cpu
# It should print:
# Target //mediapipe/examples/desktop/face_detection:face_detection_cpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu
# INFO: Elapsed time: 36.417s, Critical Path: 23.22s
# INFO: 711 processes: 710 linux-sandbox, 1 local.
# INFO: Build completed successfully, 734 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible
$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_cpu \
--calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_desktop_live.pbtxt
```
### TensorFlow Lite Face Detection Demo with Webcam (GPU)
To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run:
```bash
# Video from webcam running on desktop GPU
# This works only for linux currently
$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
mediapipe/examples/desktop/face_detection:face_detection_gpu
# It should print:
# Target //mediapipe/examples/desktop/face_detection:face_detection_gpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu
# INFO: Elapsed time: 36.417s, Critical Path: 23.22s
# INFO: 711 processes: 710 linux-sandbox, 1 local.
# INFO: Build completed successfully, 734 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible,
# or GPU drivers not setup properly.
$ bazel-bin/mediapipe/examples/desktop/face_detection/face_detection_gpu \
--calculator_graph_config_file=mediapipe/graphs/face_detection/face_detection_mobile_gpu.pbtxt
```
#### Graph
![graph visualization](images/face_detection_desktop.png)
To visualize the graph as shown above, copy the text specification of the graph
below and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev).
```bash
# MediaPipe graph that performs face detection with TensorFlow Lite on CPU & GPU.
# Used in the examples in
# mediapipie/examples/desktop/face_detection:face_detection_cpu.
# Images on CPU coming into and out of the graph.
input_stream: "input_video"
output_stream: "output_video"
# Throttles the images flowing downstream for flow control. It passes through
# the very first incoming image unaltered, and waits for
# TfLiteTensorsToDetectionsCalculator downstream in the graph to finish
# generating the corresponding detections before it passes through another
# image. All images that come in while waiting are dropped, limiting the number
# of in-flight images between this calculator and
# TfLiteTensorsToDetectionsCalculator to 1. This prevents the nodes in between
# from queuing up incoming images and data excessively, which leads to increased
# latency and memory usage, unwanted in real-time mobile applications. It also
# eliminates unnecessarily computation, e.g., a transformed image produced by
# ImageTransformationCalculator may get dropped downstream if the subsequent
# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy
# processing previous inputs.
node {
calculator: "FlowLimiterCalculator"
input_stream: "input_video"
input_stream: "FINISHED:detections"
input_stream_info: {
tag_index: "FINISHED"
back_edge: true
}
output_stream: "throttled_input_video"
}
# Transforms the input image on CPU to a 128x128 image. To scale the input
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
# resulting in potential letterboxing in the transformed image.
node: {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE:throttled_input_video"
output_stream: "IMAGE:transformed_input_video_cpu"
output_stream: "LETTERBOX_PADDING:letterbox_padding"
node_options: {
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
output_width: 128
output_height: 128
scale_mode: FIT
}
}
}
# Converts the transformed input image on CPU into an image tensor stored as a
# TfLiteTensor.
node {
calculator: "TfLiteConverterCalculator"
input_stream: "IMAGE:transformed_input_video_cpu"
output_stream: "TENSORS:image_tensor"
}
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and
# scores.
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS:image_tensor"
output_stream: "TENSORS:detection_tensors"
node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "mediapipe/models/face_detection_front.tflite"
}
}
}
# Generates a single side packet containing a vector of SSD anchors based on
# the specification in the options.
node {
calculator: "SsdAnchorsCalculator"
output_side_packet: "anchors"
node_options: {
[type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] {
num_layers: 4
min_scale: 0.1484375
max_scale: 0.75
input_size_height: 128
input_size_width: 128
anchor_offset_x: 0.5
anchor_offset_y: 0.5
strides: 8
strides: 16
strides: 16
strides: 16
aspect_ratios: 1.0
fixed_anchor_size: true
}
}
}
# Decodes the detection tensors generated by the TensorFlow Lite model, based on
# the SSD anchors and the specification in the options, into a vector of
# detections. Each detection describes a detected object.
node {
calculator: "TfLiteTensorsToDetectionsCalculator"
input_stream: "TENSORS:detection_tensors"
input_side_packet: "ANCHORS:anchors"
output_stream: "DETECTIONS:detections"
node_options: {
[type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] {
num_classes: 1
num_boxes: 896
num_coords: 16
box_coord_offset: 0
keypoint_coord_offset: 4
num_keypoints: 6
num_values_per_keypoint: 2
sigmoid_score: true
score_clipping_thresh: 100.0
reverse_output_order: true
x_scale: 128.0
y_scale: 128.0
h_scale: 128.0
w_scale: 128.0
min_score_thresh: 0.75
}
}
}
# Performs non-max suppression to remove excessive detections.
node {
calculator: "NonMaxSuppressionCalculator"
input_stream: "detections"
output_stream: "filtered_detections"
node_options: {
[type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] {
min_suppression_threshold: 0.3
overlap_type: INTERSECTION_OVER_UNION
algorithm: WEIGHTED
return_empty_detections: true
}
}
}
# Maps detection label IDs to the corresponding label text ("Face"). The label
# map is provided in the label_map_path option.
node {
calculator: "DetectionLabelIdToTextCalculator"
input_stream: "filtered_detections"
output_stream: "labeled_detections"
node_options: {
[type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] {
label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
}
}
}
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
# letterboxed image (after image transformation with the FIT scale mode) to the
# corresponding locations on the same image with the letterbox removed (the
# input image to the graph before image transformation).
node {
calculator: "DetectionLetterboxRemovalCalculator"
input_stream: "DETECTIONS:labeled_detections"
input_stream: "LETTERBOX_PADDING:letterbox_padding"
output_stream: "DETECTIONS:output_detections"
}
# Converts the detections to drawing primitives for annotation overlay.
node {
calculator: "DetectionsToRenderDataCalculator"
input_stream: "DETECTIONS:output_detections"
output_stream: "RENDER_DATA:render_data"
node_options: {
[type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] {
thickness: 4.0
color { r: 255 g: 0 b: 0 }
}
}
}
# Draws annotations and overlays them on top of the input images.
node {
calculator: "AnnotationOverlayCalculator"
input_stream: "INPUT_FRAME:throttled_input_video"
input_stream: "render_data"
output_stream: "OUTPUT_FRAME:output_video"
}
```
[`README file`]:https://github.com/google/mediapipe/tree/master/mediapipe/models/object_detection_saved_model/README.md

View File

@ -0,0 +1,244 @@
# Face Detection (CPU)
This doc focuses on the
[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_detection/face_detection_mobile_cpu.pbtxt)
that performs face detection with TensorFlow Lite on CPU.
## Android
[Source](https://github.com/google/mediapipe/tree/master/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu)
To build and install the app:
```bash
bazel build -c opt --config=android_arm64 mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu
adb install bazel-bin/mediapipe/examples/android/src/java/com/google/mediapipe/apps/facedetectioncpu/facedetectioncpu.apk
```
## iOS
[Source](https://github.com/google/mediapipe/tree/master/mediapipe/examples/ios/facedetectioncpu).
See the general [instructions](./mediapipe_ios_setup.md) for building iOS
examples and generating an Xcode project. This will be the FaceDetectionCpuApp
target.
To build on the command line:
```bash
bazel build -c opt --config=ios_arm64 mediapipe/examples/ios/facedetectioncpu:FaceDetectionCpuApp
```
## Graph
![face_detection_mobile_cpu_graph](images/mobile/face_detection_mobile_cpu.png)
To visualize the graph as shown above, copy the text specification of the graph
below and paste it into [MediaPipe Visualizer](https://viz.mediapipe.dev/).
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/face_detection/face_detection_mobile_cpu.pbtxt)
```bash
# MediaPipe graph that performs face detection with TensorFlow Lite on CPU.
# Used in the examples in
# mediapipie/examples/android/src/java/com/mediapipe/apps/facedetectioncpu and
# mediapipie/examples/ios/facedetectioncpu.
# Images on GPU coming into and out of the graph.
input_stream: "input_video"
output_stream: "output_video"
# Throttles the images flowing downstream for flow control. It passes through
# the very first incoming image unaltered, and waits for
# TfLiteTensorsToDetectionsCalculator downstream in the graph to finish
# generating the corresponding detections before it passes through another
# image. All images that come in while waiting are dropped, limiting the number
# of in-flight images between this calculator and
# TfLiteTensorsToDetectionsCalculator to 1. This prevents the nodes in between
# from queuing up incoming images and data excessively, which leads to increased
# latency and memory usage, unwanted in real-time mobile applications. It also
# eliminates unnecessarily computation, e.g., a transformed image produced by
# ImageTransformationCalculator may get dropped downstream if the subsequent
# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy
# processing previous inputs.
node {
calculator: "FlowLimiterCalculator"
input_stream: "input_video"
input_stream: "FINISHED:detections"
input_stream_info: {
tag_index: "FINISHED"
back_edge: true
}
output_stream: "throttled_input_video"
}
# Transfers the input image from GPU to CPU memory for the purpose of
# demonstrating a CPU-based pipeline. Note that the input image on GPU has the
# origin defined at the bottom-left corner (OpenGL convention). As a result,
# the transferred image on CPU also shares the same representation.
node: {
calculator: "GpuBufferToImageFrameCalculator"
input_stream: "throttled_input_video"
output_stream: "input_video_cpu"
}
# Transforms the input image on CPU to a 128x128 image. To scale the input
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
# resulting in potential letterboxing in the transformed image.
node: {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE:input_video_cpu"
output_stream: "IMAGE:transformed_input_video_cpu"
output_stream: "LETTERBOX_PADDING:letterbox_padding"
node_options: {
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
output_width: 128
output_height: 128
scale_mode: FIT
}
}
}
# Converts the transformed input image on CPU into an image tensor stored as a
# TfLiteTensor.
node {
calculator: "TfLiteConverterCalculator"
input_stream: "IMAGE:transformed_input_video_cpu"
output_stream: "TENSORS:image_tensor"
}
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and
# scores.
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS:image_tensor"
output_stream: "TENSORS:detection_tensors"
node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "mediapipe/models/face_detection_front.tflite"
}
}
}
# Generates a single side packet containing a vector of SSD anchors based on
# the specification in the options.
node {
calculator: "SsdAnchorsCalculator"
output_side_packet: "anchors"
node_options: {
[type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] {
num_layers: 4
min_scale: 0.1484375
max_scale: 0.75
input_size_height: 128
input_size_width: 128
anchor_offset_x: 0.5
anchor_offset_y: 0.5
strides: 8
strides: 16
strides: 16
strides: 16
aspect_ratios: 1.0
fixed_anchor_size: true
}
}
}
# Decodes the detection tensors generated by the TensorFlow Lite model, based on
# the SSD anchors and the specification in the options, into a vector of
# detections. Each detection describes a detected object.
node {
calculator: "TfLiteTensorsToDetectionsCalculator"
input_stream: "TENSORS:detection_tensors"
input_side_packet: "ANCHORS:anchors"
output_stream: "DETECTIONS:detections"
node_options: {
[type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] {
num_classes: 1
num_boxes: 896
num_coords: 16
box_coord_offset: 0
keypoint_coord_offset: 4
num_keypoints: 6
num_values_per_keypoint: 2
sigmoid_score: true
score_clipping_thresh: 100.0
reverse_output_order: true
x_scale: 128.0
y_scale: 128.0
h_scale: 128.0
w_scale: 128.0
min_score_thresh: 0.75
}
}
}
# Performs non-max suppression to remove excessive detections.
node {
calculator: "NonMaxSuppressionCalculator"
input_stream: "detections"
output_stream: "filtered_detections"
node_options: {
[type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] {
min_suppression_threshold: 0.3
overlap_type: INTERSECTION_OVER_UNION
algorithm: WEIGHTED
return_empty_detections: true
}
}
}
# Maps detection label IDs to the corresponding label text ("Face"). The label
# map is provided in the label_map_path option.
node {
calculator: "DetectionLabelIdToTextCalculator"
input_stream: "filtered_detections"
output_stream: "labeled_detections"
node_options: {
[type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] {
label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
}
}
}
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
# letterboxed image (after image transformation with the FIT scale mode) to the
# corresponding locations on the same image with the letterbox removed (the
# input image to the graph before image transformation).
node {
calculator: "DetectionLetterboxRemovalCalculator"
input_stream: "DETECTIONS:labeled_detections"
input_stream: "LETTERBOX_PADDING:letterbox_padding"
output_stream: "DETECTIONS:output_detections"
}
# Converts the detections to drawing primitives for annotation overlay.
node {
calculator: "DetectionsToRenderDataCalculator"
input_stream: "DETECTIONS:output_detections"
output_stream: "RENDER_DATA:render_data"
node_options: {
[type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] {
thickness: 4.0
color { r: 255 g: 0 b: 0 }
}
}
}
# Draws annotations and overlays them on top of the input images.
node {
calculator: "AnnotationOverlayCalculator"
input_stream: "INPUT_FRAME:input_video_cpu"
input_stream: "render_data"
output_stream: "OUTPUT_FRAME:output_video_cpu"
}
# Transfers the annotated image from CPU back to GPU memory, to be sent out of
# the graph.
node: {
calculator: "ImageFrameToGpuBufferCalculator"
input_stream: "output_video_cpu"
output_stream: "output_video"
}
```

View File

@ -0,0 +1,209 @@
## Hair Segmentation on Desktop
This is an example of using MediaPipe to run hair segmentation models
(TensorFlow Lite) and render a color to the detected hair. To know more about
the hair segmentation models, please refer to the model [`README file`].
Moreover, if you are interested in running the same TensorfFlow Lite model on
Android/iOS, please see the
[Hair Segmentation on GPU on Android/iOS](hair_segmentation_mobile_gpu.md) and
We show the hair segmentation demos with TensorFlow Lite model using the Webcam:
- [TensorFlow Lite Hair Segmentation Demo with Webcam (GPU)](#tensorflow-lite-hair-segmentation-demo-with-webcam-gpu)
Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please
see
[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md).
Note: If MediaPipe depends on OpenCV 2, please see the [known issues with OpenCV 2](#known-issues-with-opencv-2) section.
### TensorFlow Lite Hair Segmentation Demo with Webcam (GPU)
To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run:
```bash
# Video from webcam running on desktop GPU
# This works only for linux currently
$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
mediapipe/examples/desktop/hair_segmentation:hair_segmentation_gpu
# It should print:
#INFO: Found 1 target...
#Target //mediapipe/examples/desktop/hair_segmentation:hair_segmentation_gpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu
#INFO: Elapsed time: 18.209s, Forge stats: 13026/13057 actions cached, 20.8s CPU used, 0.0s queue time, 89.3 MB ObjFS output (novel bytes: 87.4 MB), 0.0 MB local output, Critical Path: 11.88s, Remote (86.01% of the time): [queue: 0.00%, network: 16.83%, setup: 4.59%, process: 38.92%]
#INFO: Streaming build results to: http://sponge2/37d5a184-293b-4e98-a43e-b22084db3142
#INFO: Build completed successfully, 12210 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible,
# or GPU drivers not setup properly.
$ bazel-bin/mediapipe/examples/desktop/hair_segmentation/hair_segmentation_gpu \
--calculator_graph_config_file=mediapipe/graphs/hair_segmentation/hair_segmentation_mobile_gpu.pbtxt
```
#### Graph
![hair_segmentation_mobile_gpu_graph](images/mobile/hair_segmentation_mobile_gpu.png)
To visualize the graph as shown above, copy the text specification of the graph
below and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev).
```bash
# MediaPipe graph that performs hair segmentation with TensorFlow Lite on GPU.
# Used in the example in
# mediapipie/examples/android/src/java/com/mediapipe/apps/hairsegmentationgpu.
# Images on GPU coming into and out of the graph.
input_stream: "input_video"
output_stream: "output_video"
# Throttles the images flowing downstream for flow control. It passes through
# the very first incoming image unaltered, and waits for
# TfLiteTensorsToSegmentationCalculator downstream in the graph to finish
# generating the corresponding hair mask before it passes through another
# image. All images that come in while waiting are dropped, limiting the number
# of in-flight images between this calculator and
# TfLiteTensorsToSegmentationCalculator to 1. This prevents the nodes in between
# from queuing up incoming images and data excessively, which leads to increased
# latency and memory usage, unwanted in real-time mobile applications. It also
# eliminates unnecessarily computation, e.g., a transformed image produced by
# ImageTransformationCalculator may get dropped downstream if the subsequent
# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy
# processing previous inputs.
node {
calculator: "FlowLimiterCalculator"
input_stream: "input_video"
input_stream: "FINISHED:hair_mask"
input_stream_info: {
tag_index: "FINISHED"
back_edge: true
}
output_stream: "throttled_input_video"
}
# Transforms the input image on GPU to a 512x512 image. To scale the image, by
# default it uses the STRETCH scale mode that maps the entire input image to the
# entire transformed image. As a result, image aspect ratio may be changed and
# objects in the image may be deformed (stretched or squeezed), but the hair
# segmentation model used in this graph is agnostic to that deformation.
node: {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE_GPU:throttled_input_video"
output_stream: "IMAGE_GPU:transformed_input_video"
node_options: {
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
output_width: 512
output_height: 512
}
}
}
# Caches a mask fed back from the previous round of hair segmentation, and upon
# the arrival of the next input image sends out the cached mask with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous mask. Note that upon the arrival of the very first
# input image, an empty packet is sent out to jump start the feedback loop.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:throttled_input_video"
input_stream: "LOOP:hair_mask"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:previous_hair_mask"
}
# Embeds the hair mask generated from the previous round of hair segmentation
# as the alpha channel of the current input image.
node {
calculator: "SetAlphaCalculator"
input_stream: "IMAGE_GPU:transformed_input_video"
input_stream: "ALPHA_GPU:previous_hair_mask"
output_stream: "IMAGE_GPU:mask_embedded_input_video"
}
# Converts the transformed input image on GPU into an image tensor stored in
# tflite::gpu::GlBuffer. The zero_center option is set to false to normalize the
# pixel values to [0.f, 1.f] as opposed to [-1.f, 1.f]. With the
# max_num_channels option set to 4, all 4 RGBA channels are contained in the
# image tensor.
node {
calculator: "TfLiteConverterCalculator"
input_stream: "IMAGE_GPU:mask_embedded_input_video"
output_stream: "TENSORS_GPU:image_tensor"
node_options: {
[type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] {
zero_center: false
max_num_channels: 4
}
}
}
# Generates a single side packet containing a TensorFlow Lite op resolver that
# supports custom ops needed by the model used in this graph.
node {
calculator: "TfLiteCustomOpResolverCalculator"
output_side_packet: "op_resolver"
node_options: {
[type.googleapis.com/mediapipe.TfLiteCustomOpResolverCalculatorOptions] {
use_gpu: true
}
}
}
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
# tensor representing the hair segmentation, which has the same width and height
# as the input image tensor.
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS_GPU:image_tensor"
output_stream: "TENSORS_GPU:segmentation_tensor"
input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver"
node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "mediapipe/models/hair_segmentation.tflite"
use_gpu: true
}
}
}
# Decodes the segmentation tensor generated by the TensorFlow Lite model into a
# mask of values in [0.f, 1.f], stored in the R channel of a GPU buffer. It also
# takes the mask generated previously as another input to improve the temporal
# consistency.
node {
calculator: "TfLiteTensorsToSegmentationCalculator"
input_stream: "TENSORS_GPU:segmentation_tensor"
input_stream: "PREV_MASK_GPU:previous_hair_mask"
output_stream: "MASK_GPU:hair_mask"
node_options: {
[type.googleapis.com/mediapipe.TfLiteTensorsToSegmentationCalculatorOptions] {
tensor_width: 512
tensor_height: 512
tensor_channels: 2
combine_with_previous_ratio: 0.9
output_layer_index: 1
}
}
}
# Colors the hair segmentation with the color specified in the option.
node {
calculator: "RecolorCalculator"
input_stream: "IMAGE_GPU:throttled_input_video"
input_stream: "MASK_GPU:hair_mask"
output_stream: "IMAGE_GPU:output_video"
node_options: {
[type.googleapis.com/mediapipe.RecolorCalculatorOptions] {
color { r: 0 g: 0 b: 255 }
mask_channel: RED
}
}
}
```
[`README file`]:https://github.com/google/mediapipe/tree/master/mediapipe/README.md

View File

@ -1,7 +1,7 @@
# Hand Detection (GPU) # Hand Detection (GPU)
This doc focuses on the This doc focuses on the
[example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt) [example graph](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_mobile.pbtxt)
that performs hand detection with TensorFlow Lite on GPU. It is related to the that performs hand detection with TensorFlow Lite on GPU. It is related to the
[hand tracking example](./hand_tracking_mobile_gpu.md). [hand tracking example](./hand_tracking_mobile_gpu.md).
@ -147,7 +147,7 @@ node {
![hand_detection_gpu_subgraph](images/mobile/hand_detection_gpu_subgraph.png) ![hand_detection_gpu_subgraph](images/mobile/hand_detection_gpu_subgraph.png)
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt) [Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_gpu.pbtxt)
```bash ```bash
# MediaPipe hand detection subgraph. # MediaPipe hand detection subgraph.

View File

@ -0,0 +1,184 @@
## Hand Tracking on Desktop
This is an example of using MediaPipe to run hand tracking models (TensorFlow
Lite) and render bounding boxes on the detected hand (one hand only). To know
more about the hand tracking models, please refer to the model [`README file`].
Moreover, if you are interested in running the same TensorfFlow Lite model on
Android/iOS, please see the
[Hand Tracking on GPU on Android/iOS](hand_tracking_mobile_gpu.md) and
We show the hand tracking demos with TensorFlow Lite model using the Webcam:
- [TensorFlow Lite Hand Tracking Demo with Webcam (CPU)](#tensorflow-lite-hand-tracking-demo-with-webcam-cpu)
- [TensorFlow Lite Hand Tracking Demo with Webcam (GPU)](#tensorflow-lite-hand-tracking-demo-with-webcam-gpu)
Note: Desktop GPU works only on Linux. Mesa drivers need to be installed. Please
see
[step 4 of "Installing on Debian and Ubuntu" in the installation guide](./install.md).
Note: If MediaPipe depends on OpenCV 2, please see the [known issues with OpenCV 2](#known-issues-with-opencv-2) section.
### TensorFlow Lite Hand Tracking Demo with Webcam (CPU)
To build and run the TensorFlow Lite example on desktop (CPU) with Webcam, run:
```bash
# Video from webcam running on desktop CPU
$ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
mediapipe/examples/desktop/hand_tracking:hand_tracking_cpu
# It should print:
#Target //mediapipe/examples/desktop/hand_tracking:hand_tracking_cpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu
#INFO: Elapsed time: 22.645s, Forge stats: 13356/13463 actions cached, 1.5m CPU used, 0.0s queue time, 819.8 MB ObjFS output (novel bytes: 85.6 MB), 0.0 MB local output, Critical Path: 14.43s, Remote (87.25% of the time): [queue: 0.00%, network: 14.88%, setup: 4.80%, process: 39.80%, fetch: 18.15%]
#INFO: Streaming build results to: http://sponge2/360196b9-33ab-44b1-84a7-1022b5043307
#INFO: Build completed successfully, 12517 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible
$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_cpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_desktop_live.pbtxt
```
### TensorFlow Lite Hand Tracking Demo with Webcam (GPU)
To build and run the TensorFlow Lite example on desktop (GPU) with Webcam, run:
```bash
# Video from webcam running on desktop GPU
# This works only for linux currently
$ bazel build -c opt --copt -DMESA_EGL_NO_X11_HEADERS \
mediapipe/examples/desktop/hand_tracking:hand_tracking_gpu
# It should print:
# Target //mediapipe/examples/desktop/hand_tracking:hand_tracking_gpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu
#INFO: Elapsed time: 84.055s, Forge stats: 6858/19343 actions cached, 1.6h CPU used, 0.9s queue time, 1.68 GB ObjFS output (novel bytes: 485.1 MB), 0.0 MB local output, Critical Path: 48.14s, Remote (99.40% of the time): [queue: 0.00%, setup: 5.59%, process: 74.44%]
#INFO: Streaming build results to: http://sponge2/00c7f95f-6fbc-432d-8978-f5d361efca3b
#INFO: Build completed successfully, 22455 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible,
# or GPU drivers not setup properly.
$ bazel-bin/mediapipe/examples/desktop/hand_tracking/hand_tracking_gpu \
--calculator_graph_config_file=mediapipe/graphs/hand_tracking/hand_tracking_mobile.pbtxt
```
#### Graph
![graph visualization](images/hand_tracking_desktop.png)
To visualize the graph as shown above, copy the text specification of the graph
below and paste it into
[MediaPipe Visualizer](https://viz.mediapipe.dev).
```bash
# MediaPipe graph that performs hand tracking on desktop with TensorFlow Lite
# on CPU & GPU.
# Used in the example in
# mediapipie/examples/desktop/hand_tracking:hand_tracking_cpu.
# Images coming into and out of the graph.
input_stream: "input_video"
output_stream: "output_video"
# Caches a hand-presence decision fed back from HandLandmarkSubgraph, and upon
# the arrival of the next input image sends out the cached decision with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous hand-presence decision. Note that upon the arrival
# of the very first input image, an empty packet is sent out to jump start the
# feedback loop.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:input_video"
input_stream: "LOOP:hand_presence"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:prev_hand_presence"
}
# Drops the incoming image if HandLandmarkSubgraph was able to identify hand
# presence in the previous image. Otherwise, passes the incoming image through
# to trigger a new round of hand detection in HandDetectionSubgraph.
node {
calculator: "GateCalculator"
input_stream: "input_video"
input_stream: "DISALLOW:prev_hand_presence"
output_stream: "hand_detection_input_video"
node_options: {
[type.googleapis.com/mediapipe.GateCalculatorOptions] {
empty_packets_as_allow: true
}
}
}
# Subgraph that detections hands (see hand_detection_cpu.pbtxt).
node {
calculator: "HandDetectionSubgraph"
input_stream: "hand_detection_input_video"
output_stream: "DETECTIONS:palm_detections"
output_stream: "NORM_RECT:hand_rect_from_palm_detections"
}
# Subgraph that localizes hand landmarks (see hand_landmark_cpu.pbtxt).
node {
calculator: "HandLandmarkSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "NORM_RECT:hand_rect"
output_stream: "LANDMARKS:hand_landmarks"
output_stream: "NORM_RECT:hand_rect_from_landmarks"
output_stream: "PRESENCE:hand_presence"
}
# Caches a hand rectangle fed back from HandLandmarkSubgraph, and upon the
# arrival of the next input image sends out the cached rectangle with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous hand rectangle. Note that upon the arrival of the
# very first input image, an empty packet is sent out to jump start the
# feedback loop.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:input_video"
input_stream: "LOOP:hand_rect_from_landmarks"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:prev_hand_rect_from_landmarks"
}
# Merges a stream of hand rectangles generated by HandDetectionSubgraph and that
# generated by HandLandmarkSubgraph into a single output stream by selecting
# between one of the two streams. The former is selected if the incoming packet
# is not empty, i.e., hand detection is performed on the current image by
# HandDetectionSubgraph (because HandLandmarkSubgraph could not identify hand
# presence in the previous image). Otherwise, the latter is selected, which is
# never empty because HandLandmarkSubgraphs processes all images (that went
# through FlowLimiterCaculator).
node {
calculator: "MergeCalculator"
input_stream: "hand_rect_from_palm_detections"
input_stream: "prev_hand_rect_from_landmarks"
output_stream: "hand_rect"
}
# Subgraph that renders annotations and overlays them on top of the input
# images (see renderer_cpu.pbtxt).
node {
calculator: "RendererSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "LANDMARKS:hand_landmarks"
input_stream: "NORM_RECT:hand_rect"
input_stream: "DETECTIONS:palm_detections"
output_stream: "IMAGE:output_video"
}
```
[`README file`]:https://github.com/google/mediapipe/tree/master/mediapipe/README.md

View File

@ -227,7 +227,7 @@ node {
![hand_detection_gpu_subgraph](images/mobile/hand_detection_gpu_subgraph.png) ![hand_detection_gpu_subgraph](images/mobile/hand_detection_gpu_subgraph.png)
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt) [Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_gpu.pbtxt)
```bash ```bash
# MediaPipe hand detection subgraph. # MediaPipe hand detection subgraph.
@ -433,7 +433,7 @@ node {
![hand_landmark_gpu_subgraph.pbtxt](images/mobile/hand_landmark_gpu_subgraph.png) ![hand_landmark_gpu_subgraph.pbtxt](images/mobile/hand_landmark_gpu_subgraph.png)
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_landmark_gpu.pbtxt) [Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/hand_landmark_gpu.pbtxt)
```bash ```bash
# MediaPipe hand landmark localization subgraph. # MediaPipe hand landmark localization subgraph.
@ -617,7 +617,7 @@ node {
![hand_renderer_gpu_subgraph.pbtxt](images/mobile/hand_renderer_gpu_subgraph.png) ![hand_renderer_gpu_subgraph.pbtxt](images/mobile/hand_renderer_gpu_subgraph.png)
[Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/renderer_gpu.pbtxt) [Source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/renderer_gpu.pbtxt)
```bash ```bash
# MediaPipe hand tracking rendering subgraph. # MediaPipe hand tracking rendering subgraph.

Binary file not shown.

After

Width:  |  Height:  |  Size: 118 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 96 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 121 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 49 KiB

After

Width:  |  Height:  |  Size: 52 KiB

View File

@ -107,14 +107,34 @@ To build and run iOS apps:
) )
``` ```
4. Run the [Hello World desktop example](./hello_world_desktop.md). 4. For running desktop examples on Linux only (not on OS X) with GPU
acceleration.
```bash
# Requires a GPU with EGL driver support.
# Can use mesa GPU libraries for desktop, (or Nvidia/AMD equivalent).
sudo apt-get install mesa-common-dev libegl1-mesa-dev libgles2-mesa-dev
# To compile with GPU support, replace
--define MEDIAPIPE_DISABLE_GPU=1
# with
--copt -DMESA_EGL_NO_X11_HEADERS
# when building GPU examples.
```
5. Run the [Hello World desktop example](./hello_world_desktop.md).
```bash ```bash
$ export GLOG_logtostderr=1 $ export GLOG_logtostderr=1
# Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is currently not supported
# if you are running on Linux desktop with CPU only
$ bazel run --define MEDIAPIPE_DISABLE_GPU=1 \ $ bazel run --define MEDIAPIPE_DISABLE_GPU=1 \
mediapipe/examples/desktop/hello_world:hello_world mediapipe/examples/desktop/hello_world:hello_world
# If you are running on Linux desktop with GPU support enabled (via mesa drivers)
$ bazel run --copt -DMESA_EGL_NO_X11_HEADERS \
mediapipe/examples/desktop/hello_world:hello_world
# Should print: # Should print:
# Hello World! # Hello World!
# Hello World! # Hello World!
@ -194,7 +214,7 @@ To build and run iOS apps:
```bash ```bash
$ export GLOG_logtostderr=1 $ export GLOG_logtostderr=1
# Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' as desktop GPU is currently not supported # Need bazel flag 'MEDIAPIPE_DISABLE_GPU=1' if you are running on Linux desktop with CPU only
$ bazel run --define MEDIAPIPE_DISABLE_GPU=1 \ $ bazel run --define MEDIAPIPE_DISABLE_GPU=1 \
mediapipe/examples/desktop/hello_world:hello_world mediapipe/examples/desktop/hello_world:hello_world

View File

@ -10,8 +10,9 @@ interested in running the same TensorfFlow Lite model on Android, please see the
We show the object detection demo with both TensorFlow model and TensorFlow Lite model: We show the object detection demo with both TensorFlow model and TensorFlow Lite model:
- [TensorFlow Object Detection Demo](#tensorflow-object-detection-demo) - [TensorFlow Object Detection Demo](#tensorflow-object-detection-demo)
- [TensorFlow Lite Object Detection Demo](#tensorflow-lite-object-detection-demo) - [TensorFlow Lite Object Detection Demo](#tensorflow-lite-object-detection-demo)
- [TensorFlow Lite Object Detection Demo with Webcam (CPU)](#tensorflow-lite-object-detection-demo)
Note: If MediaPipe depends on OpenCV 2, please see the [known issues with OpenCV 2](#known-issues-with-opencv-2) section. Note: If MediaPipe depends on OpenCV 2, please see the [known issues with OpenCV 2](#known-issues-with-opencv-2) section.
@ -207,6 +208,29 @@ $ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_tflite
--input_side_packets=input_video_path=<input video path>,output_video_path=<output video path> --input_side_packets=input_video_path=<input video path>,output_video_path=<output video path>
``` ```
### TensorFlow Lite Object Detection Demo with Webcam (CPU)
To build and run the TensorFlow Lite example on desktop (CPU) with Webcam, run:
```bash
# Video from webcam running on desktop CPU
$ bazel build -c opt --define MEDIAPIPE_DISABLE_GPU=1 \
mediapipe/examples/desktop/object_detection:object_detection_cpu
# It should print:
#Target //mediapipe/examples/desktop/object_detection:object_detection_cpu up-to-date:
# bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu
#INFO: Elapsed time: 16.020s, Forge stats: 13001/13003 actions cached, 2.1s CPU used, 0.0s queue time, 89.0 MB ObjFS output (novel bytes: 88.0 MB), 0.0 MB local output, Critical Path: 10.01s, Remote (41.42% of the time): [queue: 0.00%, setup: 4.21%, process: 12.48%]
#INFO: Streaming build results to: http://sponge2/1824d4cc-ba63-4350-bdc0-aacbd45b902b
#INFO: Build completed successfully, 12154 total actions
$ export GLOG_logtostderr=1
# This will open up your webcam as long as it is connected and on
# Any errors is likely due to your webcam being not accessible
$ bazel-bin/mediapipe/examples/desktop/object_detection/object_detection_cpu \
--calculator_graph_config_file=mediapipe/graphs/object_detection/object_detection_desktop_live.pbtxt
```
#### Graph #### Graph
![graph visualization](images/object_detection_desktop_tflite.png) ![graph visualization](images/object_detection_desktop_tflite.png)

View File

@ -77,7 +77,7 @@ For instance, there are two graphs involved in the
[hand detection example](./hand_detection_mobile_gpu.md): the main graph [hand detection example](./hand_detection_mobile_gpu.md): the main graph
([source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_mobile.pbtxt)) ([source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_mobile.pbtxt))
and its associated subgraph and its associated subgraph
([source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/hand_detection_gpu.pbtxt)). ([source pbtxt file](https://github.com/google/mediapipe/tree/master/mediapipe/graphs/hand_tracking/subgraphs/hand_detection_gpu.pbtxt)).
To visualize them: To visualize them:
* In the MediaPipe visualizer, click on the upload graph button and select the * In the MediaPipe visualizer, click on the upload graph button and select the

View File

@ -14,7 +14,9 @@
licenses(["notice"]) # Apache 2.0 licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//mediapipe/examples:__subpackages__"]) package(default_visibility = [
"//visibility:public",
])
cc_library( cc_library(
name = "simple_run_graph_main", name = "simple_run_graph_main",
@ -29,3 +31,44 @@ cc_library(
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
], ],
) )
cc_library(
name = "demo_run_graph_main",
srcs = ["demo_run_graph_main.cc"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:opencv_highgui",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:opencv_video",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
],
)
# Linux only.
# Must have a GPU with EGL support:
# ex: sudo aptitude install mesa-common-dev libegl1-mesa-dev libgles2-mesa-dev
# (or similar nvidia/amd equivalent)
cc_library(
name = "demo_run_graph_main_gpu",
srcs = ["demo_run_graph_main_gpu.cc"],
deps = [
"//mediapipe/framework:calculator_framework",
"//mediapipe/framework/formats:image_frame",
"//mediapipe/framework/formats:image_frame_opencv",
"//mediapipe/framework/port:commandlineflags",
"//mediapipe/framework/port:file_helpers",
"//mediapipe/framework/port:opencv_highgui",
"//mediapipe/framework/port:opencv_imgproc",
"//mediapipe/framework/port:opencv_video",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status",
"//mediapipe/gpu:gl_calculator_helper",
"//mediapipe/gpu:gpu_buffer",
"//mediapipe/gpu:gpu_shared_data_internal",
],
)

View File

@ -0,0 +1,146 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// An example of sending OpenCV webcam frames into a MediaPipe graph.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/opencv_highgui_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/opencv_video_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
constexpr char kInputStream[] = "input_video";
constexpr char kOutputStream[] = "output_video";
constexpr char kWindowName[] = "MediaPipe";
DEFINE_string(
calculator_graph_config_file, "",
"Name of file containing text format CalculatorGraphConfig proto.");
DEFINE_string(input_video_path, "",
"Full path of video to load. "
"If not provided, attempt to use a webcam.");
DEFINE_string(output_video_path, "",
"Full path of where to save result (.mp4 only). "
"If not provided, show result in a window.");
::mediapipe::Status RunMPPGraph() {
std::string calculator_graph_config_contents;
MP_RETURN_IF_ERROR(mediapipe::file::GetContents(
FLAGS_calculator_graph_config_file, &calculator_graph_config_contents));
LOG(INFO) << "Get calculator graph config contents: "
<< calculator_graph_config_contents;
mediapipe::CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig>(
calculator_graph_config_contents);
LOG(INFO) << "Initialize the calculator graph.";
mediapipe::CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(config));
LOG(INFO) << "Initialize the camera or load the video.";
cv::VideoCapture capture;
const bool load_video = !FLAGS_input_video_path.empty();
if (load_video) {
capture.open(FLAGS_input_video_path);
} else {
capture.open(0);
}
RET_CHECK(capture.isOpened());
cv::VideoWriter writer;
const bool save_video = !FLAGS_output_video_path.empty();
if (save_video) {
LOG(INFO) << "Prepare video writer.";
cv::Mat test_frame;
capture.read(test_frame); // Consume first frame.
capture.set(cv::CAP_PROP_POS_AVI_RATIO, 0); // Rewind to beginning.
writer.open(FLAGS_output_video_path,
mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4
capture.get(cv::CAP_PROP_FPS), test_frame.size());
RET_CHECK(writer.isOpened());
} else {
cv::namedWindow(kWindowName, /*flags=WINDOW_AUTOSIZE*/ 1);
}
LOG(INFO) << "Start running the calculator graph.";
ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller,
graph.AddOutputStreamPoller(kOutputStream));
MP_RETURN_IF_ERROR(graph.StartRun({}));
LOG(INFO) << "Start grabbing and processing frames.";
size_t frame_timestamp = 0;
bool grab_frames = true;
while (grab_frames) {
// Capture opencv camera or video frame.
cv::Mat camera_frame_raw;
capture >> camera_frame_raw;
if (camera_frame_raw.empty()) break; // End of video.
cv::Mat camera_frame;
cv::cvtColor(camera_frame_raw, camera_frame, cv::COLOR_BGR2RGB);
if (!load_video) {
cv::flip(camera_frame, camera_frame, /*flipcode=HORIZONTAL*/ 1);
}
// Wrap Mat into an ImageFrame.
auto input_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::SRGB, camera_frame.cols, camera_frame.rows,
mediapipe::ImageFrame::kDefaultAlignmentBoundary);
cv::Mat input_frame_mat = mediapipe::formats::MatView(input_frame.get());
camera_frame.copyTo(input_frame_mat);
// Send image packet into the graph.
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
kInputStream, mediapipe::Adopt(input_frame.release())
.At(mediapipe::Timestamp(frame_timestamp++))));
// Get the graph result packet, or stop if that fails.
mediapipe::Packet packet;
if (!poller.Next(&packet)) break;
auto& output_frame = packet.Get<mediapipe::ImageFrame>();
// Convert back to opencv for display or saving.
cv::Mat output_frame_mat = mediapipe::formats::MatView(&output_frame);
cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGB2BGR);
if (save_video) {
writer.write(output_frame_mat);
} else {
cv::imshow(kWindowName, output_frame_mat);
// Press any key to exit.
const int pressed_key = cv::waitKey(5);
if (pressed_key >= 0 && pressed_key != 255) grab_frames = false;
}
}
LOG(INFO) << "Shutting down.";
if (writer.isOpened()) writer.release();
MP_RETURN_IF_ERROR(graph.CloseInputStream(kInputStream));
return graph.WaitUntilDone();
}
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true);
::mediapipe::Status run_status = RunMPPGraph();
if (!run_status.ok()) {
LOG(ERROR) << "Failed to run the graph: " << run_status.message();
} else {
LOG(INFO) << "Success!";
}
return 0;
}

View File

@ -0,0 +1,186 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// An example of sending OpenCV webcam frames into a MediaPipe graph.
// This example requires a linux computer and a GPU with EGL support drivers.
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/image_frame_opencv.h"
#include "mediapipe/framework/port/commandlineflags.h"
#include "mediapipe/framework/port/file_helpers.h"
#include "mediapipe/framework/port/opencv_highgui_inc.h"
#include "mediapipe/framework/port/opencv_imgproc_inc.h"
#include "mediapipe/framework/port/opencv_video_inc.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/gpu/gpu_shared_data_internal.h"
constexpr char kInputStream[] = "input_video";
constexpr char kOutputStream[] = "output_video";
constexpr char kWindowName[] = "MediaPipe";
DEFINE_string(
calculator_graph_config_file, "",
"Name of file containing text format CalculatorGraphConfig proto.");
DEFINE_string(input_video_path, "",
"Full path of video to load. "
"If not provided, attempt to use a webcam.");
DEFINE_string(output_video_path, "",
"Full path of where to save result (.mp4 only). "
"If not provided, show result in a window.");
::mediapipe::Status RunMPPGraph() {
std::string calculator_graph_config_contents;
MP_RETURN_IF_ERROR(mediapipe::file::GetContents(
FLAGS_calculator_graph_config_file, &calculator_graph_config_contents));
LOG(INFO) << "Get calculator graph config contents: "
<< calculator_graph_config_contents;
mediapipe::CalculatorGraphConfig config =
mediapipe::ParseTextProtoOrDie<mediapipe::CalculatorGraphConfig>(
calculator_graph_config_contents);
LOG(INFO) << "Initialize the calculator graph.";
mediapipe::CalculatorGraph graph;
MP_RETURN_IF_ERROR(graph.Initialize(config));
LOG(INFO) << "Initialize the GPU.";
ASSIGN_OR_RETURN(auto gpu_resources, mediapipe::GpuResources::Create());
MP_RETURN_IF_ERROR(graph.SetGpuResources(std::move(gpu_resources)));
mediapipe::GlCalculatorHelper gpu_helper;
gpu_helper.InitializeForTest(graph.GetGpuResources().get());
LOG(INFO) << "Initialize the camera or load the video.";
cv::VideoCapture capture;
const bool load_video = !FLAGS_input_video_path.empty();
if (load_video) {
capture.open(FLAGS_input_video_path);
} else {
capture.open(0);
}
RET_CHECK(capture.isOpened());
cv::VideoWriter writer;
const bool save_video = !FLAGS_output_video_path.empty();
if (save_video) {
LOG(INFO) << "Prepare video writer.";
cv::Mat test_frame;
capture.read(test_frame); // Consume first frame.
capture.set(cv::CAP_PROP_POS_AVI_RATIO, 0); // Rewind to beginning.
writer.open(FLAGS_output_video_path,
mediapipe::fourcc('a', 'v', 'c', '1'), // .mp4
capture.get(cv::CAP_PROP_FPS), test_frame.size());
RET_CHECK(writer.isOpened());
} else {
cv::namedWindow(kWindowName, /*flags=WINDOW_AUTOSIZE*/ 1);
}
LOG(INFO) << "Start running the calculator graph.";
ASSIGN_OR_RETURN(mediapipe::OutputStreamPoller poller,
graph.AddOutputStreamPoller(kOutputStream));
MP_RETURN_IF_ERROR(graph.StartRun({}));
LOG(INFO) << "Start grabbing and processing frames.";
size_t frame_timestamp = 0;
bool grab_frames = true;
while (grab_frames) {
// Capture opencv camera or video frame.
cv::Mat camera_frame_raw;
capture >> camera_frame_raw;
if (camera_frame_raw.empty()) break; // End of video.
cv::Mat camera_frame;
cv::cvtColor(camera_frame_raw, camera_frame, cv::COLOR_BGR2RGB);
if (!load_video) {
cv::flip(camera_frame, camera_frame, /*flipcode=HORIZONTAL*/ 1);
}
// Wrap Mat into an ImageFrame.
auto input_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormat::SRGB, camera_frame.cols, camera_frame.rows,
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
cv::Mat input_frame_mat = mediapipe::formats::MatView(input_frame.get());
camera_frame.copyTo(input_frame_mat);
// Prepare and add graph input packet.
MP_RETURN_IF_ERROR(
gpu_helper.RunInGlContext([&input_frame, &frame_timestamp, &graph,
&gpu_helper]() -> ::mediapipe::Status {
// Convert ImageFrame to GpuBuffer.
auto texture = gpu_helper.CreateSourceTexture(*input_frame.get());
auto gpu_frame = texture.GetFrame<mediapipe::GpuBuffer>();
glFlush();
texture.Release();
// Send GPU image packet into the graph.
MP_RETURN_IF_ERROR(graph.AddPacketToInputStream(
kInputStream, mediapipe::Adopt(gpu_frame.release())
.At(mediapipe::Timestamp(frame_timestamp++))));
return ::mediapipe::OkStatus();
}));
// Get the graph result packet, or stop if that fails.
mediapipe::Packet packet;
if (!poller.Next(&packet)) break;
std::unique_ptr<mediapipe::ImageFrame> output_frame;
// Convert GpuBuffer to ImageFrame.
MP_RETURN_IF_ERROR(gpu_helper.RunInGlContext(
[&packet, &output_frame, &gpu_helper]() -> ::mediapipe::Status {
auto& gpu_frame = packet.Get<mediapipe::GpuBuffer>();
auto texture = gpu_helper.CreateSourceTexture(gpu_frame);
output_frame = absl::make_unique<mediapipe::ImageFrame>(
mediapipe::ImageFormatForGpuBufferFormat(gpu_frame.format()),
gpu_frame.width(), gpu_frame.height(),
mediapipe::ImageFrame::kGlDefaultAlignmentBoundary);
gpu_helper.BindFramebuffer(texture);
const auto info =
mediapipe::GlTextureInfoForGpuBufferFormat(gpu_frame.format(), 0);
glReadPixels(0, 0, texture.width(), texture.height(), info.gl_format,
info.gl_type, output_frame->MutablePixelData());
glFlush();
texture.Release();
return ::mediapipe::OkStatus();
}));
// Convert back to opencv for display or saving.
cv::Mat output_frame_mat = mediapipe::formats::MatView(output_frame.get());
cv::cvtColor(output_frame_mat, output_frame_mat, cv::COLOR_RGB2BGR);
if (save_video) {
writer.write(output_frame_mat);
} else {
cv::imshow(kWindowName, output_frame_mat);
// Press any key to exit.
const int pressed_key = cv::waitKey(5);
if (pressed_key >= 0 && pressed_key != 255) grab_frames = false;
}
}
LOG(INFO) << "Shutting down.";
if (writer.isOpened()) writer.release();
MP_RETURN_IF_ERROR(graph.CloseInputStream(kInputStream));
return graph.WaitUntilDone();
}
int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);
gflags::ParseCommandLineFlags(&argc, &argv, true);
::mediapipe::Status run_status = RunMPPGraph();
if (!run_status.ok()) {
LOG(ERROR) << "Failed to run the graph: " << run_status.message();
} else {
LOG(INFO) << "Success!";
}
return 0;
}

View File

@ -0,0 +1,34 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//mediapipe/examples:__subpackages__"])
cc_binary(
name = "face_detection_cpu",
deps = [
"//mediapipe/examples/desktop:demo_run_graph_main",
"//mediapipe/graphs/face_detection:desktop_tflite_calculators",
],
)
# Linux only
cc_binary(
name = "face_detection_gpu",
deps = [
"//mediapipe/examples/desktop:demo_run_graph_main_gpu",
"//mediapipe/graphs/face_detection:mobile_calculators",
],
)

View File

@ -0,0 +1,26 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//mediapipe/examples:__subpackages__"])
# Linux only
cc_binary(
name = "hair_segmentation_gpu",
deps = [
"//mediapipe/examples/desktop:demo_run_graph_main_gpu",
"//mediapipe/graphs/hair_segmentation:mobile_calculators",
],
)

View File

@ -0,0 +1,42 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//mediapipe/examples:__subpackages__"])
cc_binary(
name = "hand_tracking_tflite",
deps = [
"//mediapipe/examples/desktop:simple_run_graph_main",
"//mediapipe/graphs/hand_tracking:desktop_tflite_calculators",
],
)
cc_binary(
name = "hand_tracking_cpu",
deps = [
"//mediapipe/examples/desktop:demo_run_graph_main",
"//mediapipe/graphs/hand_tracking:desktop_tflite_calculators",
],
)
# Linux only
cc_binary(
name = "hand_tracking_gpu",
deps = [
"//mediapipe/examples/desktop:demo_run_graph_main_gpu",
"//mediapipe/graphs/hand_tracking:mobile_calculators",
],
)

View File

@ -72,3 +72,11 @@ cc_binary(
"//mediapipe/graphs/object_detection:desktop_tflite_calculators", "//mediapipe/graphs/object_detection:desktop_tflite_calculators",
], ],
) )
cc_binary(
name = "object_detection_cpu",
deps = [
"//mediapipe/examples/desktop:demo_run_graph_main",
"//mediapipe/graphs/object_detection:desktop_tflite_calculators",
],
)

View File

@ -27,7 +27,7 @@ cc_binary(
"//mediapipe/framework/port:map_util", "//mediapipe/framework/port:map_util",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/graphs/youtube8m:yt8m_calculators_deps", "//mediapipe/graphs/youtube8m:yt8m_feature_extraction_calculators",
# TODO: Figure out the minimum set of the kernels needed by this example. # TODO: Figure out the minimum set of the kernels needed by this example.
"@org_tensorflow//tensorflow/core:all_kernels", "@org_tensorflow//tensorflow/core:all_kernels",
"@org_tensorflow//tensorflow/core:direct_session", "@org_tensorflow//tensorflow/core:direct_session",

View File

@ -37,7 +37,9 @@
```bash ```bash
python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \ python -m mediapipe.examples.desktop.youtube8m.generate_input_sequence_example \
--path_to_input_video=/absolute/path/to/the/local/video/file --path_to_input_video=/absolute/path/to/the/local/video/file \
--clip_start_time_sec=0 \
--clip_end_time_sec=10
``` ```
5. Run the MediaPipe binary to extract the features 5. Run the MediaPipe binary to extract the features

View File

@ -37,20 +37,29 @@ def bytes23(string):
def main(argv): def main(argv):
if len(argv) > 1: if len(argv) > 3:
raise app.UsageError('Too many command-line arguments.') raise app.UsageError('Too many command-line arguments.')
if not flags.FLAGS.path_to_input_video: if not flags.FLAGS.path_to_input_video:
raise ValueError('You must specify the path to the input video.') raise ValueError('You must specify the path to the input video.')
if not flags.FLAGS.clip_end_time_sec:
raise ValueError('You must specify the clip end timestamp in seconds.')
if flags.FLAGS.clip_start_time_sec >= flags.FLAGS.clip_end_time_sec:
raise ValueError(
'The clip start time must be greater than the clip end time.')
metadata = tf.train.SequenceExample() metadata = tf.train.SequenceExample()
ms.set_clip_data_path(bytes23(flags.FLAGS.path_to_input_video), metadata) ms.set_clip_data_path(bytes23(flags.FLAGS.path_to_input_video), metadata)
ms.set_clip_start_timestamp(0, metadata) ms.set_clip_start_timestamp(
flags.FLAGS.clip_start_time_sec * SECONDS_TO_MICROSECONDS, metadata)
ms.set_clip_end_timestamp( ms.set_clip_end_timestamp(
int(float(300 * SECONDS_TO_MICROSECONDS)), metadata) flags.FLAGS.clip_end_time_sec * SECONDS_TO_MICROSECONDS, metadata)
with open('/tmp/mediapipe/metadata.tfrecord', 'wb') as writer: with open('/tmp/mediapipe/metadata.tfrecord', 'wb') as writer:
writer.write(metadata.SerializeToString()) writer.write(metadata.SerializeToString())
if __name__ == '__main__': if __name__ == '__main__':
flags.DEFINE_string('path_to_input_video', '', 'Path to the input video.') flags.DEFINE_string('path_to_input_video', '', 'Path to the input video.')
flags.DEFINE_integer('clip_start_time_sec', 0,
'Clip start timestamp in seconds')
flags.DEFINE_integer('clip_end_time_sec', 10, 'Clip end timestamp in seconds')
app.run(main) app.run(main)

View File

@ -1134,6 +1134,7 @@ cc_library(
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/port:statusor", "//mediapipe/framework/port:statusor",
"//mediapipe/framework/tool:calculator_graph_template_cc_proto", "//mediapipe/framework/tool:calculator_graph_template_cc_proto",
"//mediapipe/framework/tool:options_util",
"//mediapipe/framework/tool:template_expander", "//mediapipe/framework/tool:template_expander",
"@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
@ -1499,13 +1500,47 @@ cc_test(
deps = [ deps = [
":calculator_context", ":calculator_context",
":calculator_framework", ":calculator_framework",
":test_calculators",
":thread_pool_executor",
":timestamp", ":timestamp",
":type_map",
"//mediapipe/calculators/core:counting_source_calculator",
"//mediapipe/calculators/core:mux_calculator",
"//mediapipe/calculators/core:pass_through_calculator", "//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework/port:gtest_main", "//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging", "//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto", "//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:status", "//mediapipe/framework/port:status",
"//mediapipe/framework/stream_handler:barrier_input_stream_handler",
"//mediapipe/framework/stream_handler:early_close_input_stream_handler",
"//mediapipe/framework/stream_handler:fixed_size_input_stream_handler",
"//mediapipe/framework/stream_handler:immediate_input_stream_handler", "//mediapipe/framework/stream_handler:immediate_input_stream_handler",
"//mediapipe/framework/stream_handler:mux_input_stream_handler",
"//mediapipe/framework/tool:sink",
],
)
cc_test(
name = "calculator_graph_side_packet_test",
size = "small",
srcs = [
"calculator_graph_side_packet_test.cc",
],
visibility = ["//visibility:public"],
deps = [
":calculator_framework",
":test_calculators",
"//mediapipe/calculators/core:counting_source_calculator",
"//mediapipe/calculators/core:mux_calculator",
"//mediapipe/calculators/core:pass_through_calculator",
"//mediapipe/framework:calculator_cc_proto",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:logging",
"//mediapipe/framework/port:parse_text_proto",
"//mediapipe/framework/port:ret_check",
"//mediapipe/framework/port:status",
"//mediapipe/framework/tool:sink",
"@com_google_absl//absl/time",
], ],
) )

View File

@ -21,11 +21,258 @@
#include "mediapipe/framework/port/parse_text_proto.h" #include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h" #include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/thread_pool_executor.h"
#include "mediapipe/framework/timestamp.h" #include "mediapipe/framework/timestamp.h"
namespace mediapipe { namespace mediapipe {
namespace { namespace {
typedef std::function<::mediapipe::Status(CalculatorContext* cc)>
CalculatorContextFunction;
// A simple Semaphore for synchronizing test threads.
class AtomicSemaphore {
public:
AtomicSemaphore(int64_t supply) : supply_(supply) {}
void Acquire(int64_t amount) {
while (supply_.fetch_sub(amount) - amount < 0) {
Release(amount);
}
}
void Release(int64_t amount) { supply_ += amount; }
private:
std::atomic<int64_t> supply_;
};
// A mediapipe::Executor that signals the start and finish of each task.
class CountingExecutor : public Executor {
public:
CountingExecutor(int num_threads, std::function<void()> start_callback,
std::function<void()> finish_callback)
: thread_pool_(num_threads),
start_callback_(std::move(start_callback)),
finish_callback_(std::move(finish_callback)) {
thread_pool_.StartWorkers();
}
void Schedule(std::function<void()> task) override {
start_callback_();
thread_pool_.Schedule([this, task] {
task();
finish_callback_();
});
}
private:
ThreadPool thread_pool_;
std::function<void()> start_callback_;
std::function<void()> finish_callback_;
};
// A Calculator that adds the integer values in the packets in all the input
// streams and outputs the sum to the output stream.
class IntAdderCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
cc->Inputs().Index(i).Set<int>();
}
cc->Outputs().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
int sum = 0;
for (int i = 0; i < cc->Inputs().NumEntries(); ++i) {
sum += cc->Inputs().Index(i).Get<int>();
}
cc->Outputs().Index(0).Add(new int(sum), cc->InputTimestamp());
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(IntAdderCalculator);
template <typename InputType>
class TypedSinkCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<InputType>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
return ::mediapipe::OkStatus();
}
};
typedef TypedSinkCalculator<std::string> StringSinkCalculator;
typedef TypedSinkCalculator<int> IntSinkCalculator;
REGISTER_CALCULATOR(StringSinkCalculator);
REGISTER_CALCULATOR(IntSinkCalculator);
// A Calculator that passes an input packet through if it contains an even
// integer.
class EvenIntFilterCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
int value = cc->Inputs().Index(0).Get<int>();
if (value % 2 == 0) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
} else {
cc->Outputs().Index(0).SetNextTimestampBound(
cc->InputTimestamp().NextAllowedInStream());
}
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(EvenIntFilterCalculator);
// A Calculator that passes packets through or not, depending on a second
// input. The first input stream's packets are only propagated if the second
// input stream carries the value true.
class ValveCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Inputs().Index(1).Set<bool>();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header());
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
if (cc->Inputs().Index(1).Get<bool>()) {
cc->GetCounter("PassThrough")->Increment();
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
} else {
cc->GetCounter("Block")->Increment();
// The next timestamp bound is the minimum timestamp that the next packet
// can have, so, if we want to inform the downstream that no packet at
// InputTimestamp() is coming, we need to set it to the next value.
// We could also just call SetOffset(TimestampDiff(0)) in Open, and then
// we would not have to call this manually.
cc->Outputs().Index(0).SetNextTimestampBound(
cc->InputTimestamp().NextAllowedInStream());
}
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(ValveCalculator);
// A Calculator that simply passes its input Packets and header through,
// but shifts the timestamp.
class TimeShiftCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
cc->InputSidePackets().Index(0).Set<TimestampDiff>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
// Input: arbitrary Packets.
// Output: copy of the input.
cc->Outputs().Index(0).SetHeader(cc->Inputs().Index(0).Header());
shift_ = cc->InputSidePackets().Index(0).Get<TimestampDiff>();
cc->SetOffset(shift_);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
cc->GetCounter("PassThrough")->Increment();
cc->Outputs().Index(0).AddPacket(
cc->Inputs().Index(0).Value().At(cc->InputTimestamp() + shift_));
return ::mediapipe::OkStatus();
}
private:
TimestampDiff shift_;
};
REGISTER_CALCULATOR(TimeShiftCalculator);
// A source calculator that alternates between outputting an integer (0, 1, 2,
// ..., 100) and setting the next timestamp bound. The timestamps of the output
// packets and next timestamp bounds are 0, 10, 20, 30, ...
//
// T=0 Output 0
// T=10 Set timestamp bound
// T=20 Output 1
// T=30 Set timestamp bound
// ...
// T=2000 Output 100
class OutputAndBoundSourceCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Outputs().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) override {
counter_ = 0;
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
Timestamp timestamp(counter_);
if (counter_ % 20 == 0) {
cc->Outputs().Index(0).AddPacket(
MakePacket<int>(counter_ / 20).At(timestamp));
} else {
cc->Outputs().Index(0).SetNextTimestampBound(timestamp);
}
if (counter_ == 2000) {
return tool::StatusStop();
}
counter_ += 10;
return ::mediapipe::OkStatus();
}
private:
int counter_;
};
REGISTER_CALCULATOR(OutputAndBoundSourceCalculator);
// A calculator that outputs an initial packet of value 0 at time 0 in the
// Open() method, and then delays each input packet by 20 time units in the
// Process() method. The input stream and output stream have the integer type.
class Delay20Calculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(TimestampDiff(20));
cc->Outputs().Index(0).AddPacket(MakePacket<int>(0).At(Timestamp(0)));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
const Packet& packet = cc->Inputs().Index(0).Value();
Timestamp timestamp = packet.Timestamp() + 20;
cc->Outputs().Index(0).AddPacket(packet.At(timestamp));
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(Delay20Calculator);
class CustomBoundCalculator : public CalculatorBase { class CustomBoundCalculator : public CalculatorBase {
public: public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) { static ::mediapipe::Status GetContract(CalculatorContract* cc) {
@ -45,8 +292,280 @@ class CustomBoundCalculator : public CalculatorBase {
}; };
REGISTER_CALCULATOR(CustomBoundCalculator); REGISTER_CALCULATOR(CustomBoundCalculator);
// Test that SetNextTimestampBound propagates.
TEST(CalculatorGraph, SetNextTimestampBoundPropagation) {
CalculatorGraph graph;
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'in'
input_stream: 'gate'
node {
calculator: 'ValveCalculator'
input_stream: 'in'
input_stream: 'gate'
output_stream: 'gated'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'gated'
output_stream: 'passed'
}
node {
calculator: 'TimeShiftCalculator'
input_stream: 'passed'
output_stream: 'shifted'
input_side_packet: 'shift'
}
node {
calculator: 'MergeCalculator'
input_stream: 'in'
input_stream: 'shifted'
output_stream: 'merged'
}
node {
name: 'merged_output'
calculator: 'PassThroughCalculator'
input_stream: 'merged'
output_stream: 'out'
}
)");
Timestamp timestamp = Timestamp(0);
auto send_inputs = [&graph, &timestamp](int input, bool pass) {
++timestamp;
MP_EXPECT_OK(graph.AddPacketToInputStream(
"in", MakePacket<int>(input).At(timestamp)));
MP_EXPECT_OK(graph.AddPacketToInputStream(
"gate", MakePacket<bool>(pass).At(timestamp)));
};
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({{"shift", MakePacket<TimestampDiff>(0)}}));
auto pass_counter =
graph.GetCounterFactory()->GetCounter("ValveCalculator-PassThrough");
auto block_counter =
graph.GetCounterFactory()->GetCounter("ValveCalculator-Block");
auto merged_counter =
graph.GetCounterFactory()->GetCounter("merged_output-PassThrough");
send_inputs(1, true);
send_inputs(2, true);
send_inputs(3, false);
send_inputs(4, false);
MP_ASSERT_OK(graph.WaitUntilIdle());
// Verify that MergeCalculator was able to run even when the gated branch
// was blocked.
EXPECT_EQ(2, pass_counter->Get());
EXPECT_EQ(2, block_counter->Get());
EXPECT_EQ(4, merged_counter->Get());
send_inputs(5, true);
send_inputs(6, false);
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(3, pass_counter->Get());
EXPECT_EQ(3, block_counter->Get());
EXPECT_EQ(6, merged_counter->Get());
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
// Now test with time shift
MP_ASSERT_OK(graph.StartRun({{"shift", MakePacket<TimestampDiff>(-1)}}));
send_inputs(7, true);
MP_ASSERT_OK(graph.WaitUntilIdle());
// The merger should have run only once now, at timestamp 6, with inputs
// <null, 7>. If we do not respect the offset and unblock the merger for
// timestamp 7 too, then it will have run twice, with 6: <null,7> and
// 7: <7, null>.
EXPECT_EQ(4, pass_counter->Get());
EXPECT_EQ(3, block_counter->Get());
EXPECT_EQ(7, merged_counter->Get());
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_EQ(4, pass_counter->Get());
EXPECT_EQ(3, block_counter->Get());
EXPECT_EQ(8, merged_counter->Get());
}
// Both input streams of the calculator node have the same next timestamp
// bound. One input stream has a packet at that timestamp. The other input
// stream is empty. We should not run the Process() method of the node in this
// case.
TEST(CalculatorGraph, NotAllInputPacketsAtNextTimestampBoundAvailable) {
//
// in0_unfiltered in1_to_be_filtered
// | |
// | V
// | +-----------------------+
// | |EvenIntFilterCalculator|
// | +-----------------------+
// | |
// \ /
// \ / in1_filtered
// \ /
// | |
// V V
// +------------------+
// |IntAdderCalculator|
// +------------------+
// |
// V
// out
//
CalculatorGraph graph;
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'in0_unfiltered'
input_stream: 'in1_to_be_filtered'
node {
calculator: 'EvenIntFilterCalculator'
input_stream: 'in1_to_be_filtered'
output_stream: 'in1_filtered'
}
node {
calculator: 'IntAdderCalculator'
input_stream: 'in0_unfiltered'
input_stream: 'in1_filtered'
output_stream: 'out'
}
)");
std::vector<Packet> packet_dump;
tool::AddVectorSink("out", &config, &packet_dump);
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
Timestamp timestamp = Timestamp(0);
// We send an integer with timestamp 1 to the in0_unfiltered input stream of
// the IntAdderCalculator. We then send an even integer with timestamp 1 to
// the EvenIntFilterCalculator. This packet will go through and
// the IntAdderCalculator will run. The next timestamp bounds of both the
// input streams of the IntAdderCalculator will become 2.
++timestamp; // Timestamp 1.
MP_EXPECT_OK(graph.AddPacketToInputStream("in0_unfiltered",
MakePacket<int>(1).At(timestamp)));
MP_EXPECT_OK(graph.AddPacketToInputStream("in1_to_be_filtered",
MakePacket<int>(2).At(timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
EXPECT_EQ(3, packet_dump[0].Get<int>());
// We send an odd integer with timestamp 2 to the EvenIntFilterCalculator.
// This packet will be filtered out and the next timestamp bound of the
// in1_filtered input stream of the IntAdderCalculator will become 3.
++timestamp; // Timestamp 2.
MP_EXPECT_OK(graph.AddPacketToInputStream("in1_to_be_filtered",
MakePacket<int>(3).At(timestamp)));
// We send an integer with timestamp 3 to the in0_unfiltered input stream of
// the IntAdderCalculator. MediaPipe should propagate the next timestamp bound
// across the IntAdderCalculator but should not run its Process() method.
++timestamp; // Timestamp 3.
MP_EXPECT_OK(graph.AddPacketToInputStream("in0_unfiltered",
MakePacket<int>(3).At(timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(1, packet_dump.size());
// We send an even integer with timestamp 3 to the IntAdderCalculator. This
// packet will go through and the IntAdderCalculator will run.
MP_EXPECT_OK(graph.AddPacketToInputStream("in1_to_be_filtered",
MakePacket<int>(4).At(timestamp)));
MP_ASSERT_OK(graph.WaitUntilIdle());
ASSERT_EQ(2, packet_dump.size());
EXPECT_EQ(7, packet_dump[1].Get<int>());
MP_ASSERT_OK(graph.CloseAllInputStreams());
MP_ASSERT_OK(graph.WaitUntilDone());
EXPECT_EQ(2, packet_dump.size());
}
TEST(CalculatorGraph, PropagateBoundLoop) {
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
node {
calculator: 'OutputAndBoundSourceCalculator'
output_stream: 'integers'
}
node {
calculator: 'IntAdderCalculator'
input_stream: 'integers'
input_stream: 'old_sum'
input_stream_info: {
tag_index: ':1' # 'old_sum'
back_edge: true
}
output_stream: 'sum'
input_stream_handler {
input_stream_handler: 'EarlyCloseInputStreamHandler'
}
}
node {
calculator: 'Delay20Calculator'
input_stream: 'sum'
output_stream: 'old_sum'
}
)");
std::vector<Packet> packet_dump;
tool::AddVectorSink("sum", &config, &packet_dump);
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run());
ASSERT_EQ(101, packet_dump.size());
int sum = 0;
for (int i = 0; i < 101; ++i) {
sum += i;
EXPECT_EQ(sum, packet_dump[i].Get<int>());
EXPECT_EQ(Timestamp(i * 20), packet_dump[i].Timestamp());
}
}
TEST(CalculatorGraph, CheckBatchProcessingBoundPropagation) {
// The timestamp bound sent by OutputAndBoundSourceCalculator shouldn't be
// directly propagated to the output stream when PassThroughCalculator has
// anything in its default calculator context for batch processing. Otherwise,
// the sink calculator's input stream should report packet timestamp
// mismatches.
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
node {
calculator: 'OutputAndBoundSourceCalculator'
output_stream: 'integers'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'integers'
output_stream: 'output'
input_stream_handler {
input_stream_handler: "DefaultInputStreamHandler"
options: {
[mediapipe.DefaultInputStreamHandlerOptions.ext]: {
batch_size: 10
}
}
}
}
node { calculator: 'IntSinkCalculator' input_stream: 'output' }
)");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.Run());
}
// Shows that ImmediateInputStreamHandler allows bounds propagation. // Shows that ImmediateInputStreamHandler allows bounds propagation.
TEST(CalculatorGraphBounds, ImmediateHandlerBounds) { TEST(CalculatorGraphBoundsTest, ImmediateHandlerBounds) {
// CustomBoundCalculator produces only timestamp bounds. // CustomBoundCalculator produces only timestamp bounds.
// The first PassThroughCalculator propagates bounds using SetOffset(0). // The first PassThroughCalculator propagates bounds using SetOffset(0).
// The second PassthroughCalculator delivers an output packet whenever the // The second PassthroughCalculator delivers an output packet whenever the
@ -101,5 +620,261 @@ TEST(CalculatorGraphBounds, ImmediateHandlerBounds) {
EXPECT_EQ(output_packets.size(), 4); EXPECT_EQ(output_packets.size(), 4);
} }
// A Calculator that only sets timestamp bound by SetOffset().
class OffsetBoundCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->SetOffset(0);
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(OffsetBoundCalculator);
// A Calculator that produces a packet for each call to Process.
class BoundToPacketCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<int>();
cc->Outputs().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
cc->Outputs().Index(0).AddPacket(Adopt(new int(33)));
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(BoundToPacketCalculator);
// Verifies that SetOffset still propagates when Process is called and
// produces no output packets.
TEST(CalculatorGraphBoundsTest, OffsetBoundPropagation) {
// OffsetBoundCalculator produces only timestamp bounds.
// The PassthroughCalculator delivers an output packet whenever the
// OffsetBoundCalculator delivers a timestamp bound.
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'input'
node {
calculator: 'OffsetBoundCalculator'
input_stream: 'input'
output_stream: 'bounds'
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'bounds'
input_stream: 'input'
output_stream: 'bounds_output'
output_stream: 'output'
}
)");
CalculatorGraph graph;
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) {
output_packets.push_back(p);
return ::mediapipe::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Add four packets into the graph.
constexpr int kNumInputs = 4;
for (int i = 0; i < kNumInputs; ++i) {
Packet p = MakePacket<int>(33).At(Timestamp(i));
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
}
// Four packets arrive at the output only if timestamp bounds are propagated.
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(output_packets.size(), kNumInputs);
// Shutdown the graph.
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows that bounds changes alone do not invoke Process.
// Note: Bounds changes alone will invoke Process eventually
// when SetOffset is cleared, see: go/mediapipe-realtime-graph.
TEST(CalculatorGraphBoundsTest, BoundWithoutInputPackets) {
// OffsetBoundCalculator produces only timestamp bounds.
// The BoundToPacketCalculator delivers an output packet whenever the
// OffsetBoundCalculator delivers a timestamp bound.
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'input'
node {
calculator: 'OffsetBoundCalculator'
input_stream: 'input'
output_stream: 'bounds'
}
node {
calculator: 'BoundToPacketCalculator'
input_stream: 'bounds'
output_stream: 'output'
}
)");
CalculatorGraph graph;
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) {
output_packets.push_back(p);
return ::mediapipe::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Add four packets into the graph.
constexpr int kNumInputs = 4;
for (int i = 0; i < kNumInputs; ++i) {
Packet p = MakePacket<int>(33).At(Timestamp(i));
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
}
// No packets arrive, because updated timestamp bounds do not invoke
// BoundToPacketCalculator::Process.
MP_ASSERT_OK(graph.WaitUntilIdle());
EXPECT_EQ(output_packets.size(), 0);
// Shutdown the graph.
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
// Shows that when fixed-size-input-stream-hanlder drops packets,
// no timetamp bounds are announced.
TEST(CalculatorGraphBoundsTest, FixedSizeHandlerBounds) {
// LambdaCalculator with FixedSizeInputStreamHandler will drop packets
// while it is busy. Timetamps for the dropped packets are only relevant
// when SetOffset is active on the LambdaCalculator.
// The PassthroughCalculator delivers an output packet whenever the
// LambdaCalculator delivers a timestamp bound.
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: 'input'
input_side_packet: 'open_function'
input_side_packet: 'process_function'
node {
calculator: 'LambdaCalculator'
input_stream: 'input'
output_stream: 'thinned'
input_side_packet: 'OPEN:open_fn'
input_side_packet: 'PROCESS:process_fn'
input_stream_handler {
input_stream_handler: "FixedSizeInputStreamHandler"
}
}
node {
calculator: 'PassThroughCalculator'
input_stream: 'thinned'
input_stream: 'input'
output_stream: 'thinned_output'
output_stream: 'output'
}
)");
CalculatorGraph graph;
// The task_semaphore counts the number of running tasks.
constexpr int kTaskSupply = 10;
AtomicSemaphore task_semaphore(/*supply=*/kTaskSupply);
// This executor invokes a callback at the start and finish of each task.
auto executor = std::make_shared<CountingExecutor>(
4, /*start_callback=*/[&]() { task_semaphore.Acquire(1); },
/*finish_callback=*/[&]() { task_semaphore.Release(1); });
MP_ASSERT_OK(graph.SetExecutor(/*name=*/"", executor));
// Monitor output from the graph.
MP_ASSERT_OK(graph.Initialize(config));
std::vector<Packet> outputs;
MP_ASSERT_OK(graph.ObserveOutputStream("output", [&](const Packet& p) {
outputs.push_back(p);
return ::mediapipe::OkStatus();
}));
std::vector<Packet> thinned_outputs;
MP_ASSERT_OK(
graph.ObserveOutputStream("thinned_output", [&](const Packet& p) {
thinned_outputs.push_back(p);
return ::mediapipe::OkStatus();
}));
// The enter_semaphore is used to wait for LambdaCalculator::Process.
// The exit_semaphore blocks and unblocks LambdaCalculator::Process.
AtomicSemaphore enter_semaphore(0);
AtomicSemaphore exit_semaphore(0);
CalculatorContextFunction open_fn = [&](CalculatorContext* cc) {
cc->SetOffset(0);
return ::mediapipe::OkStatus();
};
CalculatorContextFunction process_fn = [&](CalculatorContext* cc) {
enter_semaphore.Release(1);
exit_semaphore.Acquire(1);
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
return ::mediapipe::OkStatus();
};
MP_ASSERT_OK(graph.StartRun({
{"open_fn", Adopt(new auto(open_fn))},
{"process_fn", Adopt(new auto(process_fn))},
}));
MP_ASSERT_OK(graph.WaitUntilIdle());
// Add four packets into the graph.
constexpr int kNumInputs = 4;
for (int i = 0; i < kNumInputs; ++i) {
Packet p = MakePacket<int>(33).At(Timestamp(i));
MP_ASSERT_OK(graph.AddPacketToInputStream("input", p));
}
// Wait until only the LambdaCalculator is running,
// by wating until the task_semaphore has only one token occupied.
// At this point 2 packets were dropped by the FixedSizeInputStreamHandler.
task_semaphore.Acquire(kTaskSupply - 1);
task_semaphore.Release(kTaskSupply - 1);
// No timestamp bounds and no packets are emitted yet.
EXPECT_EQ(outputs.size(), 0);
EXPECT_EQ(thinned_outputs.size(), 0);
// Allow the first LambdaCalculator::Process call to complete.
// Wait for the second LambdaCalculator::Process call to begin.
// Wait until only the LambdaCalculator is running.
enter_semaphore.Acquire(1);
exit_semaphore.Release(1);
enter_semaphore.Acquire(1);
task_semaphore.Acquire(kTaskSupply - 1);
task_semaphore.Release(kTaskSupply - 1);
// Only one timestamp bound and one packet are emitted.
EXPECT_EQ(outputs.size(), 1);
EXPECT_EQ(thinned_outputs.size(), 1);
// Allow the second LambdaCalculator::Process call to complete.
exit_semaphore.Release(1);
MP_ASSERT_OK(graph.WaitUntilIdle());
// Packets 1 and 2 were dropped by the FixedSizeInputStreamHandler.
EXPECT_EQ(thinned_outputs.size(), 2);
EXPECT_EQ(thinned_outputs[0].Timestamp(), Timestamp(0));
EXPECT_EQ(thinned_outputs[1].Timestamp(), Timestamp(kNumInputs - 1));
EXPECT_EQ(outputs.size(), kNumInputs);
MP_ASSERT_OK(graph.CloseAllPacketSources());
MP_ASSERT_OK(graph.WaitUntilDone());
}
} // namespace } // namespace
} // namespace mediapipe } // namespace mediapipe

View File

@ -0,0 +1,747 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "mediapipe/framework/calculator.pb.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/canonical_errors.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/framework/port/logging.h"
#include "mediapipe/framework/port/parse_text_proto.h"
#include "mediapipe/framework/port/ret_check.h"
#include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/status_matchers.h"
namespace mediapipe {
namespace {
// Takes an input stream packet and passes it (with timestamp removed) as an
// output side packet.
class OutputSidePacketInProcessCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
cc->OutputSidePackets().Index(0).Set(
cc->Inputs().Index(0).Value().At(Timestamp::Unset()));
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(OutputSidePacketInProcessCalculator);
// Takes an input stream packet and counts the number of the packets it
// receives. Outputs the total number of packets as a side packet in Close.
class CountAndOutputSummarySidePacketInCloseCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->OutputSidePackets().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
++count_;
return ::mediapipe::OkStatus();
}
::mediapipe::Status Close(CalculatorContext* cc) final {
cc->OutputSidePackets().Index(0).Set(
MakePacket<int>(count_).At(Timestamp::Unset()));
return ::mediapipe::OkStatus();
}
int count_ = 0;
};
REGISTER_CALCULATOR(CountAndOutputSummarySidePacketInCloseCalculator);
// Takes an input stream packet and passes it (with timestamp intact) as an
// output side packet. This triggers an error in the graph.
class OutputSidePacketWithTimestampCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->OutputSidePackets().Index(0).SetSameAs(&cc->Inputs().Index(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
cc->OutputSidePackets().Index(0).Set(cc->Inputs().Index(0).Value());
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(OutputSidePacketWithTimestampCalculator);
// Generates an output side packet containing the integer 1.
class IntegerOutputSidePacketCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->OutputSidePackets().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->OutputSidePackets().Index(0).Set(MakePacket<int>(1));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
LOG(FATAL) << "Not reached.";
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(IntegerOutputSidePacketCalculator);
// Generates an output side packet containing the sum of the two integer input
// side packets.
class SidePacketAdderCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Index(0).Set<int>();
cc->InputSidePackets().Index(1).Set<int>();
cc->OutputSidePackets().Index(0).Set<int>();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->OutputSidePackets().Index(0).Set(
MakePacket<int>(cc->InputSidePackets().Index(1).Get<int>() +
cc->InputSidePackets().Index(0).Get<int>()));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
LOG(FATAL) << "Not reached.";
return ::mediapipe::OkStatus();
}
};
REGISTER_CALCULATOR(SidePacketAdderCalculator);
// Produces an output packet with the PostStream timestamp containing the
// input side packet.
class SidePacketToStreamPacketCalculator : public CalculatorBase {
public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->InputSidePackets().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->InputSidePackets().Index(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Open(CalculatorContext* cc) final {
cc->Outputs().Index(0).AddPacket(
cc->InputSidePackets().Index(0).At(Timestamp::PostStream()));
cc->Outputs().Index(0).Close();
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) final {
return ::mediapipe::tool::StatusStop();
}
};
REGISTER_CALCULATOR(SidePacketToStreamPacketCalculator);
// Packet generator for an arbitrary unit64 packet.
class Uint64PacketGenerator : public PacketGenerator {
public:
static ::mediapipe::Status FillExpectations(
const PacketGeneratorOptions& extendable_options,
PacketTypeSet* input_side_packets, PacketTypeSet* output_side_packets) {
output_side_packets->Index(0).Set<uint64>();
return ::mediapipe::OkStatus();
}
static ::mediapipe::Status Generate(
const PacketGeneratorOptions& extendable_options,
const PacketSet& input_side_packets, PacketSet* output_side_packets) {
output_side_packets->Index(0) = Adopt(new uint64(15LL << 32 | 5));
return ::mediapipe::OkStatus();
}
};
REGISTER_PACKET_GENERATOR(Uint64PacketGenerator);
TEST(CalculatorGraph, OutputSidePacketInProcess) {
const int64 offset = 100;
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "offset"
node {
calculator: "OutputSidePacketInProcessCalculator"
input_stream: "offset"
output_side_packet: "offset"
}
node {
calculator: "SidePacketToStreamPacketCalculator"
output_stream: "output"
input_side_packet: "offset"
}
)");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.ObserveOutputStream(
"output", [&output_packets](const Packet& packet) {
output_packets.push_back(packet);
return ::mediapipe::OkStatus();
}));
// Run the graph twice.
for (int run = 0; run < 2; ++run) {
output_packets.clear();
MP_ASSERT_OK(graph.StartRun({}));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"offset", MakePacket<TimestampDiff>(offset).At(Timestamp(0))));
MP_ASSERT_OK(graph.CloseInputStream("offset"));
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(1, output_packets.size());
EXPECT_EQ(offset, output_packets[0].Get<TimestampDiff>().Value());
}
}
// A PacketGenerator that simply passes its input Packets through
// unchanged. The inputs may be specified by tag or index. The outputs
// must match the inputs exactly. Any options may be specified and will
// also be ignored.
class PassThroughGenerator : public PacketGenerator {
public:
static ::mediapipe::Status FillExpectations(
const PacketGeneratorOptions& extendable_options, PacketTypeSet* inputs,
PacketTypeSet* outputs) {
if (!inputs->TagMap()->SameAs(*outputs->TagMap())) {
return ::mediapipe::InvalidArgumentError(
"Input and outputs to PassThroughGenerator must use the same tags "
"and indexes.");
}
for (CollectionItemId id = inputs->BeginId(); id < inputs->EndId(); ++id) {
inputs->Get(id).SetAny();
outputs->Get(id).SetSameAs(&inputs->Get(id));
}
return ::mediapipe::OkStatus();
}
static ::mediapipe::Status Generate(
const PacketGeneratorOptions& extendable_options,
const PacketSet& input_side_packets, PacketSet* output_side_packets) {
for (CollectionItemId id = input_side_packets.BeginId();
id < input_side_packets.EndId(); ++id) {
output_side_packets->Get(id) = input_side_packets.Get(id);
}
return ::mediapipe::OkStatus();
}
};
REGISTER_PACKET_GENERATOR(PassThroughGenerator);
TEST(CalculatorGraph, SharePacketGeneratorGraph) {
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
node {
calculator: 'CountingSourceCalculator'
output_stream: 'count1'
input_side_packet: 'MAX_COUNT:max_count1'
}
node {
calculator: 'CountingSourceCalculator'
output_stream: 'count2'
input_side_packet: 'MAX_COUNT:max_count2'
}
node {
calculator: 'CountingSourceCalculator'
output_stream: 'count3'
input_side_packet: 'MAX_COUNT:max_count3'
}
node {
calculator: 'CountingSourceCalculator'
output_stream: 'count4'
input_side_packet: 'MAX_COUNT:max_count4'
}
node {
calculator: 'PassThroughCalculator'
input_side_packet: 'MAX_COUNT:max_count5'
output_side_packet: 'MAX_COUNT:max_count6'
}
node {
calculator: 'CountingSourceCalculator'
output_stream: 'count5'
input_side_packet: 'MAX_COUNT:max_count6'
}
packet_generator {
packet_generator: 'PassThroughGenerator'
input_side_packet: 'max_count1'
output_side_packet: 'max_count2'
}
packet_generator {
packet_generator: 'PassThroughGenerator'
input_side_packet: 'max_count4'
output_side_packet: 'max_count5'
}
)");
// At this point config is a standard config which specifies both
// calculators and packet_factories/packet_genators. The following
// code is an example of reusing side packets across a number of
// CalculatorGraphs. It is particularly informative to note how each
// side packet is created.
//
// max_count1 is set for all graphs by a PacketFactory in the config.
// The side packet is created by generator_graph.InitializeGraph().
//
// max_count2 is set for all graphs by a PacketGenerator in the config.
// The side packet is created by generator_graph.InitializeGraph()
// because max_count1 is available at that time.
//
// max_count3 is set for all graphs by directly being specified as an
// argument to generator_graph.InitializeGraph().
//
// max_count4 is set per graph because it is directly specified as an
// argument to generator_graph.ProcessGraph().
//
// max_count5 is set per graph by a PacketGenerator which is run when
// generator_graph.ProcessGraph() is run (because max_count4 isn't
// available until then).
// Before anything else, split the graph config into two parts, one
// with the PacketFactory and PacketGenerator config and the other
// with the Calculator config.
CalculatorGraphConfig calculator_config = config;
calculator_config.clear_packet_factory();
calculator_config.clear_packet_generator();
CalculatorGraphConfig generator_config = config;
generator_config.clear_node();
// Next, create a ValidatedGraphConfig for both configs.
ValidatedGraphConfig validated_calculator_config;
MP_ASSERT_OK(validated_calculator_config.Initialize(calculator_config));
ValidatedGraphConfig validated_generator_config;
MP_ASSERT_OK(validated_generator_config.Initialize(generator_config));
// Create a PacketGeneratorGraph. Side packets max_count1, max_count2,
// and max_count3 are created upon initialization.
// Note that validated_generator_config must outlive generator_graph.
PacketGeneratorGraph generator_graph;
MP_ASSERT_OK(
generator_graph.Initialize(&validated_generator_config, nullptr,
{{"max_count1", MakePacket<int>(10)},
{"max_count3", MakePacket<int>(20)}}));
ASSERT_THAT(generator_graph.BasePackets(),
testing::ElementsAre(testing::Key("max_count1"),
testing::Key("max_count2"),
testing::Key("max_count3")));
// Create a bunch of graphs.
std::vector<std::unique_ptr<CalculatorGraph>> graphs;
for (int i = 0; i < 100; ++i) {
graphs.emplace_back(absl::make_unique<CalculatorGraph>());
// Do not pass extra side packets here.
// Note that validated_calculator_config must outlive the graph.
MP_ASSERT_OK(graphs.back()->Initialize(calculator_config, {}));
}
// Run a bunch of graphs, reusing side packets max_count1, max_count2,
// and max_count3. The side packet max_count4 is added per run,
// and triggers the execution of a packet generator which generates
// max_count5.
for (int i = 0; i < 100; ++i) {
std::map<std::string, Packet> all_side_packets;
// Creates max_count4 and max_count5.
MP_ASSERT_OK(generator_graph.RunGraphSetup(
{{"max_count4", MakePacket<int>(30 + i)}}, &all_side_packets));
ASSERT_THAT(all_side_packets,
testing::ElementsAre(
testing::Key("max_count1"), testing::Key("max_count2"),
testing::Key("max_count3"), testing::Key("max_count4"),
testing::Key("max_count5")));
// Pass all the side packets prepared by generator_graph here.
MP_ASSERT_OK(graphs[i]->Run(all_side_packets));
// TODO Verify the actual output.
}
// Destroy all the graphs.
graphs.clear();
}
TEST(CalculatorGraph, OutputSidePacketAlreadySet) {
const int64 offset = 100;
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "offset"
node {
calculator: "OutputSidePacketInProcessCalculator"
input_stream: "offset"
output_side_packet: "offset"
}
)");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
// Send two input packets to cause OutputSidePacketInProcessCalculator to
// set the output side packet twice.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"offset", MakePacket<TimestampDiff>(offset).At(Timestamp(0))));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"offset", MakePacket<TimestampDiff>(offset).At(Timestamp(1))));
MP_ASSERT_OK(graph.CloseInputStream("offset"));
::mediapipe::Status status = graph.WaitUntilDone();
EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kAlreadyExists);
EXPECT_THAT(status.message(), testing::HasSubstr("was already set."));
}
TEST(CalculatorGraph, OutputSidePacketWithTimestamp) {
const int64 offset = 100;
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "offset"
node {
calculator: "OutputSidePacketWithTimestampCalculator"
input_stream: "offset"
output_side_packet: "offset"
}
)");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
MP_ASSERT_OK(graph.StartRun({}));
// The OutputSidePacketWithTimestampCalculator neglects to clear the
// timestamp in the input packet when it copies the input packet to the
// output side packet. The timestamp value should appear in the error
// message.
MP_ASSERT_OK(graph.AddPacketToInputStream(
"offset", MakePacket<TimestampDiff>(offset).At(Timestamp(237))));
MP_ASSERT_OK(graph.CloseInputStream("offset"));
::mediapipe::Status status = graph.WaitUntilDone();
EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kInvalidArgument);
EXPECT_THAT(status.message(), testing::HasSubstr("has a timestamp 237."));
}
TEST(CalculatorGraph, OutputSidePacketConsumedBySourceNode) {
const int max_count = 10;
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "max_count"
node {
calculator: "OutputSidePacketInProcessCalculator"
input_stream: "max_count"
output_side_packet: "max_count"
}
node {
calculator: "CountingSourceCalculator"
output_stream: "count"
input_side_packet: "MAX_COUNT:max_count"
}
node {
calculator: "PassThroughCalculator"
input_stream: "count"
output_stream: "output"
}
)");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.ObserveOutputStream(
"output", [&output_packets](const Packet& packet) {
output_packets.push_back(packet);
return ::mediapipe::OkStatus();
}));
MP_ASSERT_OK(graph.StartRun({}));
// Wait until the graph is idle so that
// Scheduler::TryToScheduleNextSourceLayer() gets called.
// Scheduler::TryToScheduleNextSourceLayer() should not activate source
// nodes that haven't been opened. We can't call graph.WaitUntilIdle()
// because the graph has a source node.
absl::SleepFor(absl::Milliseconds(10));
MP_ASSERT_OK(graph.AddPacketToInputStream(
"max_count", MakePacket<int>(max_count).At(Timestamp(0))));
MP_ASSERT_OK(graph.CloseInputStream("max_count"));
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(max_count, output_packets.size());
for (int i = 0; i < output_packets.size(); ++i) {
EXPECT_EQ(i, output_packets[i].Get<int>());
EXPECT_EQ(Timestamp(i), output_packets[i].Timestamp());
}
}
// Returns the first packet of the input stream.
class FirstPacketFilterCalculator : public CalculatorBase {
public:
FirstPacketFilterCalculator() {}
~FirstPacketFilterCalculator() override {}
static ::mediapipe::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).SetAny();
cc->Outputs().Index(0).SetSameAs(&cc->Inputs().Index(0));
return ::mediapipe::OkStatus();
}
::mediapipe::Status Process(CalculatorContext* cc) override {
if (!seen_first_packet_) {
cc->Outputs().Index(0).AddPacket(cc->Inputs().Index(0).Value());
cc->Outputs().Index(0).Close();
seen_first_packet_ = true;
}
return ::mediapipe::OkStatus();
}
private:
bool seen_first_packet_ = false;
};
REGISTER_CALCULATOR(FirstPacketFilterCalculator);
TEST(CalculatorGraph, SourceLayerInversion) {
// There are three CountingSourceCalculators, indexed 0, 1, and 2. Each of
// them outputs 10 packets.
//
// CountingSourceCalculator 0 should output 0, 1, 2, 3, ..., 9.
// CountingSourceCalculator 1 should output 100, 101, 102, 103, ..., 109.
// CountingSourceCalculator 2 should output 0, 100, 200, 300, ..., 900.
// However, there is a source layer inversion.
// CountingSourceCalculator 0 is in source layer 0.
// CountingSourceCalculator 1 is in source layer 1.
// CountingSourceCalculator 2 is in source layer 0, but consumes an output
// side packet generated by a downstream calculator of
// CountingSourceCalculator 1.
//
// This graph will deadlock when CountingSourceCalculator 0 runs to
// completion and CountingSourceCalculator 1 cannot be activated because
// CountingSourceCalculator 2 cannot be opened.
const int max_count = 10;
const int initial_value1 = 100;
// Set num_threads to 1 to force sequential execution for deterministic
// outputs.
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
num_threads: 1
node {
calculator: "CountingSourceCalculator"
output_stream: "count0"
input_side_packet: "MAX_COUNT:max_count"
source_layer: 0
}
node {
calculator: "CountingSourceCalculator"
output_stream: "count1"
input_side_packet: "MAX_COUNT:max_count"
input_side_packet: "INITIAL_VALUE:initial_value1"
source_layer: 1
}
node {
calculator: "FirstPacketFilterCalculator"
input_stream: "count1"
output_stream: "first_count1"
}
node {
calculator: "OutputSidePacketInProcessCalculator"
input_stream: "first_count1"
output_side_packet: "increment2"
}
node {
calculator: "CountingSourceCalculator"
output_stream: "count2"
input_side_packet: "MAX_COUNT:max_count"
input_side_packet: "INCREMENT:increment2"
source_layer: 0
}
)");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(
config, {{"max_count", MakePacket<int>(max_count)},
{"initial_value1", MakePacket<int>(initial_value1)}}));
::mediapipe::Status status = graph.Run();
EXPECT_EQ(status.code(), ::mediapipe::StatusCode::kUnknown);
EXPECT_THAT(status.message(), testing::HasSubstr("deadlock"));
}
// Tests a graph of packet-generator-like calculators, which have no input
// streams and no output streams.
TEST(CalculatorGraph, PacketGeneratorLikeCalculators) {
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
node {
calculator: "IntegerOutputSidePacketCalculator"
output_side_packet: "one"
}
node {
calculator: "IntegerOutputSidePacketCalculator"
output_side_packet: "another_one"
}
node {
calculator: "SidePacketAdderCalculator"
input_side_packet: "one"
input_side_packet: "another_one"
output_side_packet: "two"
}
node {
calculator: "IntegerOutputSidePacketCalculator"
output_side_packet: "yet_another_one"
}
node {
calculator: "SidePacketAdderCalculator"
input_side_packet: "two"
input_side_packet: "yet_another_one"
output_side_packet: "three"
}
node {
calculator: "SidePacketToStreamPacketCalculator"
input_side_packet: "three"
output_stream: "output"
}
)");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.ObserveOutputStream(
"output", [&output_packets](const Packet& packet) {
output_packets.push_back(packet);
return ::mediapipe::OkStatus();
}));
MP_ASSERT_OK(graph.Run());
ASSERT_EQ(1, output_packets.size());
EXPECT_EQ(3, output_packets[0].Get<int>());
EXPECT_EQ(Timestamp::PostStream(), output_packets[0].Timestamp());
}
TEST(CalculatorGraph, OutputSummarySidePacketInClose) {
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "input_packets"
node {
calculator: "CountAndOutputSummarySidePacketInCloseCalculator"
input_stream: "input_packets"
output_side_packet: "num_of_packets"
}
node {
calculator: "SidePacketToStreamPacketCalculator"
input_side_packet: "num_of_packets"
output_stream: "output"
}
)");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
std::vector<Packet> output_packets;
MP_ASSERT_OK(graph.ObserveOutputStream(
"output", [&output_packets](const Packet& packet) {
output_packets.push_back(packet);
return ::mediapipe::OkStatus();
}));
// Run the graph twice.
int max_count = 100;
for (int run = 0; run < 1; ++run) {
output_packets.clear();
MP_ASSERT_OK(graph.StartRun({}));
for (int i = 0; i < max_count; ++i) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_packets", MakePacket<int>(i).At(Timestamp(i))));
}
MP_ASSERT_OK(graph.CloseInputStream("input_packets"));
MP_ASSERT_OK(graph.WaitUntilDone());
ASSERT_EQ(1, output_packets.size());
EXPECT_EQ(max_count, output_packets[0].Get<int>());
EXPECT_EQ(Timestamp::PostStream(), output_packets[0].Timestamp());
}
}
TEST(CalculatorGraph, GetOutputSidePacket) {
CalculatorGraphConfig config =
::mediapipe::ParseTextProtoOrDie<CalculatorGraphConfig>(R"(
input_stream: "input_packets"
node {
calculator: "CountAndOutputSummarySidePacketInCloseCalculator"
input_stream: "input_packets"
output_side_packet: "num_of_packets"
}
packet_generator {
packet_generator: "Uint64PacketGenerator"
output_side_packet: "output_uint64"
}
packet_generator {
packet_generator: "IntSplitterPacketGenerator"
input_side_packet: "input_uint64"
output_side_packet: "output_uint32_pair"
}
)");
CalculatorGraph graph;
MP_ASSERT_OK(graph.Initialize(config));
// Check a packet generated by the PacketGenerator, which is available after
// graph initialization, can be fetched before graph starts.
::mediapipe::StatusOr<Packet> status_or_packet =
graph.GetOutputSidePacket("output_uint64");
MP_ASSERT_OK(status_or_packet);
EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp());
// IntSplitterPacketGenerator is missing its input side packet and we
// won't be able to get its output side packet now.
status_or_packet = graph.GetOutputSidePacket("output_uint32_pair");
EXPECT_EQ(::mediapipe::StatusCode::kUnavailable,
status_or_packet.status().code());
// Run the graph twice.
int max_count = 100;
std::map<std::string, Packet> extra_side_packets;
extra_side_packets.insert({"input_uint64", MakePacket<uint64>(1123)});
for (int run = 0; run < 1; ++run) {
MP_ASSERT_OK(graph.StartRun(extra_side_packets));
status_or_packet = graph.GetOutputSidePacket("output_uint32_pair");
MP_ASSERT_OK(status_or_packet);
EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp());
for (int i = 0; i < max_count; ++i) {
MP_ASSERT_OK(graph.AddPacketToInputStream(
"input_packets", MakePacket<int>(i).At(Timestamp(i))));
}
MP_ASSERT_OK(graph.CloseInputStream("input_packets"));
// Should return NOT_FOUND for invalid side packets.
status_or_packet = graph.GetOutputSidePacket("unknown");
EXPECT_FALSE(status_or_packet.ok());
EXPECT_EQ(::mediapipe::StatusCode::kNotFound,
status_or_packet.status().code());
// Should return UNAVAILABLE before graph is done for valid non-base
// packets.
status_or_packet = graph.GetOutputSidePacket("num_of_packets");
EXPECT_FALSE(status_or_packet.ok());
EXPECT_EQ(::mediapipe::StatusCode::kUnavailable,
status_or_packet.status().code());
// Should stil return a base even before graph is done.
status_or_packet = graph.GetOutputSidePacket("output_uint64");
MP_ASSERT_OK(status_or_packet);
EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp());
MP_ASSERT_OK(graph.WaitUntilDone());
// Check packets are available after graph is done.
status_or_packet = graph.GetOutputSidePacket("num_of_packets");
MP_ASSERT_OK(status_or_packet);
EXPECT_EQ(max_count, status_or_packet.ValueOrDie().Get<int>());
EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp());
// Should still return a base packet after graph is done.
status_or_packet = graph.GetOutputSidePacket("output_uint64");
MP_ASSERT_OK(status_or_packet);
EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp());
// Should still return a non-base packet after graph is done.
status_or_packet = graph.GetOutputSidePacket("output_uint32_pair");
MP_ASSERT_OK(status_or_packet);
EXPECT_EQ(Timestamp::Unset(), status_or_packet.ValueOrDie().Timestamp());
}
}
} // namespace
} // namespace mediapipe

File diff suppressed because it is too large Load Diff

View File

@ -34,7 +34,7 @@ message MatrixData {
ROW_MAJOR = 1; ROW_MAJOR = 1;
} }
// Order in which the data are stored. Implicitly defaults to COLUMN_MAJOR, // Order in which the data are stored. Defaults to COLUMN_MAJOR, which matches
// which matches the default for mediapipe::Matrix and Eigen::Matrix*. // the default for mediapipe::Matrix and Eigen::Matrix*.
optional Layout layout = 4; optional Layout layout = 4 [default = COLUMN_MAJOR];
} }

View File

@ -154,11 +154,13 @@ TEST(ValidatedGraphConfigTest, InitializeTemplateFromProtos) {
} }
)"); )");
auto options = ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"( auto options = ParseTextProtoOrDie<Subgraph::SubgraphOptions>(R"(
[mediapipe.TemplateSubgraphOptions.ext]: { options: {
dict: { [mediapipe.TemplateSubgraphOptions.ext]: {
arg: { dict: {
key: "in_name" arg: {
value: { str: "stream_9" } key: "in_name"
value: { str: "stream_9" }
}
} }
} }
})"); })");

View File

@ -44,8 +44,8 @@ TemplateSubgraph::~TemplateSubgraph() {}
::mediapipe::StatusOr<CalculatorGraphConfig> TemplateSubgraph::GetConfig( ::mediapipe::StatusOr<CalculatorGraphConfig> TemplateSubgraph::GetConfig(
const Subgraph::SubgraphOptions& options) { const Subgraph::SubgraphOptions& options) {
const TemplateDict& arguments = TemplateDict arguments =
options.GetExtension(TemplateSubgraphOptions::ext).dict(); Subgraph::GetOptions<mediapipe::TemplateSubgraphOptions>(options).dict();
tool::TemplateExpander expander; tool::TemplateExpander expander;
CalculatorGraphConfig config; CalculatorGraphConfig config;
MP_RETURN_IF_ERROR(expander.ExpandTemplates(arguments, templ_, &config)); MP_RETURN_IF_ERROR(expander.ExpandTemplates(arguments, templ_, &config));

View File

@ -24,6 +24,7 @@
#include "mediapipe/framework/port/status.h" #include "mediapipe/framework/port/status.h"
#include "mediapipe/framework/port/statusor.h" #include "mediapipe/framework/port/statusor.h"
#include "mediapipe/framework/tool/calculator_graph_template.pb.h" #include "mediapipe/framework/tool/calculator_graph_template.pb.h"
#include "mediapipe/framework/tool/options_util.h"
namespace mediapipe { namespace mediapipe {
@ -32,7 +33,7 @@ namespace mediapipe {
// the graph is running. // the graph is running.
class Subgraph { class Subgraph {
public: public:
using SubgraphOptions = CalculatorOptions; using SubgraphOptions = CalculatorGraphConfig::Node;
Subgraph(); Subgraph();
virtual ~Subgraph(); virtual ~Subgraph();
// Returns the config to use for one instantiation of the subgraph. The // Returns the config to use for one instantiation of the subgraph. The
@ -42,6 +43,12 @@ class Subgraph {
// TODO: make this static? // TODO: make this static?
virtual ::mediapipe::StatusOr<CalculatorGraphConfig> GetConfig( virtual ::mediapipe::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& options) = 0; const SubgraphOptions& options) = 0;
// Returns options of a specific type.
template <typename T>
static T GetOptions(Subgraph::SubgraphOptions supgraph_options) {
return tool::OptionsMap().Initialize(supgraph_options).Get<T>();
}
}; };
using SubgraphRegistry = GlobalFactoryRegistry<std::unique_ptr<Subgraph>>; using SubgraphRegistry = GlobalFactoryRegistry<std::unique_ptr<Subgraph>>;

View File

@ -548,8 +548,12 @@ typedef std::function<::mediapipe::Status(const InputStreamShardSet&,
OutputStreamShardSet*)> OutputStreamShardSet*)>
ProcessFunction; ProcessFunction;
// A callback function for Calculator::Open, Process, or Close.
typedef std::function<::mediapipe::Status(CalculatorContext* cc)>
CalculatorContextFunction;
// A Calculator that runs a testing callback function in Process, // A Calculator that runs a testing callback function in Process,
// which is specified as an input side packet. // Open, or Close, which is specified as an input side packet.
class LambdaCalculator : public CalculatorBase { class LambdaCalculator : public CalculatorBase {
public: public:
static ::mediapipe::Status GetContract(CalculatorContract* cc) { static ::mediapipe::Status GetContract(CalculatorContract* cc) {
@ -561,21 +565,49 @@ class LambdaCalculator : public CalculatorBase {
id < cc->Outputs().EndId(); ++id) { id < cc->Outputs().EndId(); ++id) {
cc->Outputs().Get(id).SetAny(); cc->Outputs().Get(id).SetAny();
} }
cc->InputSidePackets().Index(0).Set<ProcessFunction>(); if (cc->InputSidePackets().HasTag("") > 0) {
cc->InputSidePackets().Tag("").Set<ProcessFunction>();
}
for (std::string tag : {"OPEN", "PROCESS", "CLOSE"}) {
if (cc->InputSidePackets().HasTag(tag)) {
cc->InputSidePackets().Tag(tag).Set<CalculatorContextFunction>();
}
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status Open(CalculatorContext* cc) final { ::mediapipe::Status Open(CalculatorContext* cc) final {
callback_ = cc->InputSidePackets().Index(0).Get<ProcessFunction>(); if (cc->InputSidePackets().HasTag("OPEN")) {
return GetContextFn(cc, "OPEN")(cc);
}
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
::mediapipe::Status Process(CalculatorContext* cc) final { ::mediapipe::Status Process(CalculatorContext* cc) final {
return callback_(cc->Inputs(), &(cc->Outputs())); if (cc->InputSidePackets().HasTag("PROCESS")) {
return GetContextFn(cc, "PROCESS")(cc);
}
if (cc->InputSidePackets().HasTag("") > 0) {
return GetProcessFn(cc, "")(cc->Inputs(), &cc->Outputs());
}
return ::mediapipe::OkStatus();
}
::mediapipe::Status Close(CalculatorContext* cc) final {
if (cc->InputSidePackets().HasTag("CLOSE")) {
return GetContextFn(cc, "CLOSE")(cc);
}
return ::mediapipe::OkStatus();
} }
private: private:
ProcessFunction callback_; ProcessFunction GetProcessFn(CalculatorContext* cc, std::string tag) {
return cc->InputSidePackets().Tag(tag).Get<ProcessFunction>();
}
CalculatorContextFunction GetContextFn(CalculatorContext* cc,
std::string tag) {
return cc->InputSidePackets().Tag(tag).Get<CalculatorContextFunction>();
}
}; };
REGISTER_CALCULATOR(LambdaCalculator); REGISTER_CALCULATOR(LambdaCalculator);

View File

@ -55,6 +55,14 @@ proto_library(
deps = ["@com_google_protobuf//:any_proto"], deps = ["@com_google_protobuf//:any_proto"],
) )
mediapipe_cc_proto_library(
name = "zoo_mutator_cc_proto",
srcs = ["zoo_mutator.proto"],
cc_deps = ["@com_google_protobuf//:cc_wkt_protos"],
visibility = ["//mediapipe:__subpackages__"],
deps = [":zoo_mutator_proto"],
)
proto_library( proto_library(
name = "zoo_mutation_calculator_proto", name = "zoo_mutation_calculator_proto",
srcs = ["zoo_mutation_calculator.proto"], srcs = ["zoo_mutation_calculator.proto"],

View File

@ -237,9 +237,9 @@ static ::mediapipe::Status PrefixNames(int subgraph_index,
for (auto it = subgraph_nodes_start; it != nodes->end(); ++it) { for (auto it = subgraph_nodes_start; it != nodes->end(); ++it) {
const auto& node = *it; const auto& node = *it;
MP_RETURN_IF_ERROR(ValidateSubgraphFields(node)); MP_RETURN_IF_ERROR(ValidateSubgraphFields(node));
ASSIGN_OR_RETURN(auto subgraph, graph_registry->CreateByName( ASSIGN_OR_RETURN(auto subgraph,
config->package(), node.calculator(), graph_registry->CreateByName(config->package(),
&node.options())); node.calculator(), &node));
MP_RETURN_IF_ERROR(PrefixNames(subgraph_counter++, &subgraph)); MP_RETURN_IF_ERROR(PrefixNames(subgraph_counter++, &subgraph));
MP_RETURN_IF_ERROR(ConnectSubgraphStreams(node, &subgraph)); MP_RETURN_IF_ERROR(ConnectSubgraphStreams(node, &subgraph));
subgraphs.push_back(subgraph); subgraphs.push_back(subgraph);

View File

@ -128,8 +128,8 @@ class NodeChainSubgraph : public Subgraph {
public: public:
::mediapipe::StatusOr<CalculatorGraphConfig> GetConfig( ::mediapipe::StatusOr<CalculatorGraphConfig> GetConfig(
const SubgraphOptions& options) override { const SubgraphOptions& options) override {
const mediapipe::NodeChainSubgraphOptions& opts = auto opts =
options.GetExtension(mediapipe::NodeChainSubgraphOptions::ext); Subgraph::GetOptions<mediapipe::NodeChainSubgraphOptions>(options);
const ProtoString& node_type = opts.node_type(); const ProtoString& node_type = opts.node_type();
int chain_length = opts.chain_length(); int chain_length = opts.chain_length();
RET_CHECK(!node_type.empty()); RET_CHECK(!node_type.empty());

View File

@ -80,8 +80,7 @@ GL_BASE_LINK_OPTS_OSS = GL_BASE_LINK_OPTS + select({
"-lEGL", "-lEGL",
], ],
"//mediapipe:android": [], "//mediapipe:android": [],
"//mediapipe:apple": [], "//mediapipe:ios": [],
"//mediapipe:macos": [],
":disable_gpu": [], ":disable_gpu": [],
}) })
@ -289,6 +288,25 @@ objc_library(
], ],
) )
objc_library(
name = "MPPMetalUtil",
srcs = ["MPPMetalUtil.mm"],
hdrs = ["MPPMetalUtil.h"],
copts = [
"-x objective-c++",
"-Wno-shorten-64-to-32",
],
sdk_frameworks = [
"CoreVideo",
"Metal",
],
visibility = ["//visibility:public"],
deps = [
"//mediapipe/objc:mediapipe_framework_ios",
"@google_toolbox_for_mac//:GTM_Defines",
],
)
proto_library( proto_library(
name = "gl_context_options_proto", name = "gl_context_options_proto",
srcs = ["gl_context_options.proto"], srcs = ["gl_context_options.proto"],
@ -499,6 +517,7 @@ cc_library(
":gl_base", ":gl_base",
":gl_context", ":gl_context",
":gpu_buffer", ":gpu_buffer",
":gpu_buffer_format",
":gpu_buffer_multi_pool", ":gpu_buffer_multi_pool",
":gpu_shared_data_internal", ":gpu_shared_data_internal",
":gpu_service", ":gpu_service",

View File

@ -0,0 +1,49 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_GPU_MPP_METAL_UTIL_H_
#define MEDIAPIPE_GPU_MPP_METAL_UTIL_H_
#import <CoreVideo/CVMetalTextureCache.h>
#import <CoreVideo/CoreVideo.h>
#import <Metal/Metal.h>
NS_ASSUME_NONNULL_BEGIN
@interface MPPMetalUtil : NSObject {
}
/// Copies a Metal Buffer from source to destination.
/// Uses blitCommandEncoder and assumes offset of 0.
+ (void)blitMetalBufferTo:(id<MTLBuffer>)destination
from:(id<MTLBuffer>)source
blocking:(bool)blocking
commandBuffer:(id<MTLCommandBuffer>)commandBuffer;
/// Copies a Metal Buffer from source to destination.
/// Simple wrapper for blitCommandEncoder.
/// Optionally block until operation is completed.
+ (void)blitMetalBufferTo:(id<MTLBuffer>)destination
destinationOffset:(int)destinationOffset
from:(id<MTLBuffer>)source
sourceOffset:(int)sourceOffset
bytes:(size_t)bytes
blocking:(bool)blocking
commandBuffer:(id<MTLCommandBuffer>)commandBuffer;
@end
NS_ASSUME_NONNULL_END
#endif // MEDIAPIPE_GPU_MPP_METAL_UTIL_H_

View File

@ -0,0 +1,51 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "mediapipe/gpu/MPPMetalUtil.h"
@implementation MPPMetalUtil
+ (void)blitMetalBufferTo:(id<MTLBuffer>)destination
from:(id<MTLBuffer>)source
blocking:(bool)blocking
commandBuffer:(id<MTLCommandBuffer>)commandBuffer {
size_t bytes = MIN(destination.length, source.length);
[self blitMetalBufferTo:destination
destinationOffset:0
from:source
sourceOffset:0
bytes:bytes
blocking:blocking
commandBuffer:commandBuffer];
}
+ (void)blitMetalBufferTo:(id<MTLBuffer>)destination
destinationOffset:(int)destinationOffset
from:(id<MTLBuffer>)source
sourceOffset:(int)sourceOffset
bytes:(size_t)bytes
blocking:(bool)blocking
commandBuffer:(id<MTLCommandBuffer>)commandBuffer {
id<MTLBlitCommandEncoder> blit_command = [commandBuffer blitCommandEncoder];
[blit_command copyFromBuffer:source
sourceOffset:sourceOffset
toBuffer:destination
destinationOffset:destinationOffset
size:bytes];
[blit_command endEncoding];
[commandBuffer commit];
if (blocking) [commandBuffer waitUntilCompleted];
}
@end

View File

@ -73,7 +73,7 @@ class GlCalculatorHelperImpl {
#endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER #endif // !MEDIAPIPE_GPU_BUFFER_USE_CV_PIXEL_BUFFER
// Sets default texture filtering parameters. // Sets default texture filtering parameters.
void SetStandardTextureParams(GLenum target); void SetStandardTextureParams(GLenum target, GLint internal_format);
// Create the framebuffer for rendering. // Create the framebuffer for rendering.
void CreateFramebuffer(); void CreateFramebuffer();

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "mediapipe/gpu/gl_calculator_helper_impl.h" #include "mediapipe/gpu/gl_calculator_helper_impl.h"
#include "mediapipe/gpu/gpu_buffer_format.h"
#include "mediapipe/gpu/gpu_shared_data_internal.h" #include "mediapipe/gpu/gpu_shared_data_internal.h"
namespace mediapipe { namespace mediapipe {
@ -86,9 +87,21 @@ void GlCalculatorHelperImpl::BindFramebuffer(const GlTexture& dst) {
#endif #endif
} }
void GlCalculatorHelperImpl::SetStandardTextureParams(GLenum target) { void GlCalculatorHelperImpl::SetStandardTextureParams(GLenum target,
glTexParameteri(target, GL_TEXTURE_MIN_FILTER, GL_LINEAR); GLint internal_format) {
glTexParameteri(target, GL_TEXTURE_MAG_FILTER, GL_LINEAR); GLint filter;
switch (internal_format) {
case GL_R32F:
case GL_RGBA32F:
// 32F (unlike 16f) textures do not support texture filtering
// (According to OpenGL ES specification [TEXTURE IMAGE SPECIFICATION])
filter = GL_NEAREST;
break;
default:
filter = GL_LINEAR;
}
glTexParameteri(target, GL_TEXTURE_MIN_FILTER, filter);
glTexParameteri(target, GL_TEXTURE_MAG_FILTER, filter);
glTexParameteri(target, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE); glTexParameteri(target, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
glTexParameteri(target, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE); glTexParameteri(target, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
} }
@ -136,7 +149,9 @@ GlTexture GlCalculatorHelperImpl::MapGlTextureBuffer(
// TODO: do the params need to be reset here?? // TODO: do the params need to be reset here??
glBindTexture(texture.target(), texture.name()); glBindTexture(texture.target(), texture.name());
SetStandardTextureParams(texture.target()); GlTextureInfo info =
GlTextureInfoForGpuBufferFormat(texture_buffer->format(), texture.plane_);
SetStandardTextureParams(texture.target(), info.gl_internal_format);
glBindTexture(texture.target(), 0); glBindTexture(texture.target(), 0);
return texture; return texture;
@ -150,7 +165,9 @@ GlTextureBufferSharedPtr GlCalculatorHelperImpl::MakeGlTextureBuffer(
GpuBufferFormatForImageFormat(image_frame.Format()), GpuBufferFormatForImageFormat(image_frame.Format()),
image_frame.PixelData()); image_frame.PixelData());
glBindTexture(GL_TEXTURE_2D, buffer->name_); glBindTexture(GL_TEXTURE_2D, buffer->name_);
SetStandardTextureParams(buffer->target_); GlTextureInfo info =
GlTextureInfoForGpuBufferFormat(buffer->format_, /*plane=*/0);
SetStandardTextureParams(buffer->target_, info.gl_internal_format);
glBindTexture(GL_TEXTURE_2D, 0); glBindTexture(GL_TEXTURE_2D, 0);
return buffer; return buffer;

View File

@ -54,7 +54,7 @@ GlTexture GlCalculatorHelperImpl::CreateSourceTexture(
glTexImage2D(GL_TEXTURE_2D, 0, info.gl_internal_format, texture.width_, glTexImage2D(GL_TEXTURE_2D, 0, info.gl_internal_format, texture.width_,
texture.height_, 0, info.gl_format, info.gl_type, texture.height_, 0, info.gl_format, info.gl_type,
image_frame.PixelData()); image_frame.PixelData());
SetStandardTextureParams(GL_TEXTURE_2D); SetStandardTextureParams(GL_TEXTURE_2D, info.gl_internal_format);
return texture; return texture;
} }
@ -107,7 +107,7 @@ GlTexture GlCalculatorHelperImpl::MapGpuBuffer(
#endif // TARGET_OS_OSX #endif // TARGET_OS_OSX
glBindTexture(texture.target(), texture.name()); glBindTexture(texture.target(), texture.name());
SetStandardTextureParams(texture.target()); SetStandardTextureParams(texture.target(), info.gl_internal_format);
return texture; return texture;
} }

View File

@ -110,7 +110,13 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context,
eglChooseConfig(display_, config_attr, &config_, 1, &num_configs); eglChooseConfig(display_, config_attr, &config_, 1, &num_configs);
if (!success) { if (!success) {
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC) return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
<< "eglChooseConfig() returned error " << eglGetError(); << "eglChooseConfig() returned error " << std::showbase << std::hex
<< eglGetError();
}
if (!num_configs) {
return ::mediapipe::UnknownErrorBuilder(MEDIAPIPE_LOC)
<< "eglChooseConfig() returned no matching EGL configuration for "
<< "RGBA8888 D16 ES" << gl_version << " request. ";
} }
const EGLint context_attr[] = { const EGLint context_attr[] = {
@ -125,7 +131,8 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context,
int error = eglGetError(); int error = eglGetError();
RET_CHECK(context_ != EGL_NO_CONTEXT) RET_CHECK(context_ != EGL_NO_CONTEXT)
<< "Could not create GLES " << gl_version << " context; " << "Could not create GLES " << gl_version << " context; "
<< "eglCreateContext() returned error " << error << "eglCreateContext() returned error " << std::showbase << std::hex
<< error
<< (error == EGL_BAD_CONTEXT << (error == EGL_BAD_CONTEXT
? ": external context uses a different version of OpenGL" ? ": external context uses a different version of OpenGL"
: ""); : "");
@ -143,7 +150,8 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context,
display_ = eglGetDisplay(EGL_DEFAULT_DISPLAY); display_ = eglGetDisplay(EGL_DEFAULT_DISPLAY);
RET_CHECK(display_ != EGL_NO_DISPLAY) RET_CHECK(display_ != EGL_NO_DISPLAY)
<< "eglGetDisplay() returned error " << eglGetError(); << "eglGetDisplay() returned error " << std::showbase << std::hex
<< eglGetError();
EGLBoolean success = eglInitialize(display_, &major, &minor); EGLBoolean success = eglInitialize(display_, &major, &minor);
RET_CHECK(success) << "Unable to initialize EGL"; RET_CHECK(success) << "Unable to initialize EGL";
@ -162,7 +170,8 @@ GlContext::StatusOrGlContext GlContext::Create(EGLContext share_context,
surface_ = eglCreatePbufferSurface(display_, config_, pbuffer_attr); surface_ = eglCreatePbufferSurface(display_, config_, pbuffer_attr);
RET_CHECK(surface_ != EGL_NO_SURFACE) RET_CHECK(surface_ != EGL_NO_SURFACE)
<< "eglCreatePbufferSurface() returned error " << eglGetError(); << "eglCreatePbufferSurface() returned error " << std::showbase
<< std::hex << eglGetError();
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }
@ -186,17 +195,21 @@ void GlContext::DestroyContext() {
if (IsCurrent()) { if (IsCurrent()) {
if (!eglMakeCurrent(display_, EGL_NO_SURFACE, EGL_NO_SURFACE, if (!eglMakeCurrent(display_, EGL_NO_SURFACE, EGL_NO_SURFACE,
EGL_NO_CONTEXT)) { EGL_NO_CONTEXT)) {
LOG(ERROR) << "eglMakeCurrent() returned error " << eglGetError(); LOG(ERROR) << "eglMakeCurrent() returned error " << std::showbase
<< std::hex << eglGetError();
} }
} }
if (surface_ != EGL_NO_SURFACE) { if (surface_ != EGL_NO_SURFACE) {
if (!eglDestroySurface(display_, surface_)) { if (!eglDestroySurface(display_, surface_)) {
LOG(ERROR) << "eglDestroySurface() returned error " << eglGetError(); LOG(ERROR) << "eglDestroySurface() returned error " << std::showbase
<< std::hex << eglGetError();
} }
surface_ = EGL_NO_SURFACE;
} }
if (context_ != EGL_NO_CONTEXT) { if (context_ != EGL_NO_CONTEXT) {
if (!eglDestroyContext(display_, context_)) { if (!eglDestroyContext(display_, context_)) {
LOG(ERROR) << "eglDestroyContext() returned error " << eglGetError(); LOG(ERROR) << "eglDestroyContext() returned error " << std::showbase
<< std::hex << eglGetError();
} }
context_ = EGL_NO_CONTEXT; context_ = EGL_NO_CONTEXT;
} }
@ -245,7 +258,8 @@ void GlContext::GetCurrentContextBinding(GlContext::ContextBinding* binding) {
EGLBoolean success = EGLBoolean success =
eglMakeCurrent(display, new_binding.draw_surface, eglMakeCurrent(display, new_binding.draw_surface,
new_binding.read_surface, new_binding.context); new_binding.read_surface, new_binding.context);
RET_CHECK(success) << "eglMakeCurrent() returned error " << eglGetError(); RET_CHECK(success) << "eglMakeCurrent() returned error " << std::showbase
<< std::hex << eglGetError();
return ::mediapipe::OkStatus(); return ::mediapipe::OkStatus();
} }

View File

@ -77,7 +77,8 @@ bool GlTextureBuffer::CreateInternal(const void* data) {
// TODO: maybe we do not actually have to wait for the // TODO: maybe we do not actually have to wait for the
// consumer sync here. Check docs. // consumer sync here. Check docs.
sync_token->WaitOnGpu(); sync_token->WaitOnGpu();
DCHECK(glIsTexture(name_to_delete)); DLOG_IF(ERROR, !glIsTexture(name_to_delete))
<< "Deleting invalid texture id: " << name_to_delete;
glDeleteTextures(1, &name_to_delete); glDeleteTextures(1, &name_to_delete);
}); });
}; };

View File

@ -0,0 +1,50 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mediapipe/gpu/gpu_buffer.h"
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/gpu/gpu_test_base.h"
namespace mediapipe {
namespace {
class GpuBufferTest : public GpuTestBase {};
TEST_F(GpuBufferTest, BasicTest) {
RunInGlContext([this] {
GpuBuffer buffer = gpu_shared_.gpu_buffer_pool.GetBuffer(300, 200);
EXPECT_EQ(buffer.width(), 300);
EXPECT_EQ(buffer.height(), 200);
EXPECT_TRUE(buffer);
EXPECT_FALSE(buffer == nullptr);
GpuBuffer no_buffer;
EXPECT_FALSE(no_buffer);
EXPECT_TRUE(no_buffer == nullptr);
GpuBuffer buffer2 = buffer;
EXPECT_EQ(buffer, buffer);
EXPECT_EQ(buffer, buffer2);
EXPECT_NE(buffer, no_buffer);
buffer = nullptr;
EXPECT_TRUE(buffer == nullptr);
EXPECT_TRUE(buffer == no_buffer);
});
}
} // anonymous namespace
} // namespace mediapipe

View File

@ -123,7 +123,7 @@ struct GpuSharedData {
PlatformGlContext external_context) { PlatformGlContext external_context) {
auto status_or_resources = GpuResources::Create(external_context); auto status_or_resources = GpuResources::Create(external_context);
MEDIAPIPE_CHECK_OK(status_or_resources.status()) MEDIAPIPE_CHECK_OK(status_or_resources.status())
<< "could not create GpuResources"; << ": could not create GpuResources";
return std::move(status_or_resources).ValueOrDie(); return std::move(status_or_resources).ValueOrDie();
} }
}; };

View File

@ -0,0 +1,39 @@
// Copyright 2019 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MEDIAPIPE_GPU_GPU_TEST_BASE_H_
#define MEDIAPIPE_GPU_GPU_TEST_BASE_H_
#include "mediapipe/framework/port/gmock.h"
#include "mediapipe/framework/port/gtest.h"
#include "mediapipe/gpu/gl_calculator_helper.h"
#include "mediapipe/gpu/gpu_shared_data_internal.h"
namespace mediapipe {
class GpuTestBase : public ::testing::Test {
protected:
GpuTestBase() { helper_.InitializeForTest(&gpu_shared_); }
void RunInGlContext(std::function<void(void)> gl_func) {
helper_.RunInGlContext(std::move(gl_func));
}
GpuSharedData gpu_shared_;
GlCalculatorHelper helper_;
};
} // namespace mediapipe
#endif // MEDIAPIPE_GPU_GPU_TEST_BASE_H_

View File

@ -35,6 +35,23 @@ cc_library(
], ],
) )
cc_library(
name = "desktop_tflite_calculators",
deps = [
"//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/tflite:ssd_anchors_calculator",
"//mediapipe/calculators/tflite:tflite_converter_calculator",
"//mediapipe/calculators/tflite:tflite_inference_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator",
"//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
"//mediapipe/calculators/util:detection_letterbox_removal_calculator",
"//mediapipe/calculators/util:detections_to_render_data_calculator",
"//mediapipe/calculators/util:non_max_suppression_calculator",
],
)
load( load(
"//mediapipe/framework/tool:mediapipe_graph.bzl", "//mediapipe/framework/tool:mediapipe_graph.bzl",
"mediapipe_binary_graph", "mediapipe_binary_graph",

View File

@ -0,0 +1,184 @@
# MediaPipe graph that performs face detection with TensorFlow Lite on CPU.
# Used in the examples in
# mediapipie/examples/desktop/face_detection:face_detection_cpu.
# Images on GPU coming into and out of the graph.
input_stream: "input_video"
output_stream: "output_video"
# Throttles the images flowing downstream for flow control. It passes through
# the very first incoming image unaltered, and waits for
# TfLiteTensorsToDetectionsCalculator downstream in the graph to finish
# generating the corresponding detections before it passes through another
# image. All images that come in while waiting are dropped, limiting the number
# of in-flight images between this calculator and
# TfLiteTensorsToDetectionsCalculator to 1. This prevents the nodes in between
# from queuing up incoming images and data excessively, which leads to increased
# latency and memory usage, unwanted in real-time mobile applications. It also
# eliminates unnecessarily computation, e.g., a transformed image produced by
# ImageTransformationCalculator may get dropped downstream if the subsequent
# TfLiteConverterCalculator or TfLiteInferenceCalculator is still busy
# processing previous inputs.
node {
calculator: "FlowLimiterCalculator"
input_stream: "input_video"
input_stream: "FINISHED:detections"
input_stream_info: {
tag_index: "FINISHED"
back_edge: true
}
output_stream: "throttled_input_video"
}
# Transforms the input image on CPU to a 128x128 image. To scale the input
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
# resulting in potential letterboxing in the transformed image.
node: {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE:throttled_input_video"
output_stream: "IMAGE:transformed_input_video_cpu"
output_stream: "LETTERBOX_PADDING:letterbox_padding"
node_options: {
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
output_width: 128
output_height: 128
scale_mode: FIT
}
}
}
# Converts the transformed input image on CPU into an image tensor stored as a
# TfLiteTensor.
node {
calculator: "TfLiteConverterCalculator"
input_stream: "IMAGE:transformed_input_video_cpu"
output_stream: "TENSORS:image_tensor"
}
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and
# scores.
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS:image_tensor"
output_stream: "TENSORS:detection_tensors"
node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "mediapipe/models/face_detection_front.tflite"
}
}
}
# Generates a single side packet containing a vector of SSD anchors based on
# the specification in the options.
node {
calculator: "SsdAnchorsCalculator"
output_side_packet: "anchors"
node_options: {
[type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] {
num_layers: 4
min_scale: 0.1484375
max_scale: 0.75
input_size_height: 128
input_size_width: 128
anchor_offset_x: 0.5
anchor_offset_y: 0.5
strides: 8
strides: 16
strides: 16
strides: 16
aspect_ratios: 1.0
fixed_anchor_size: true
}
}
}
# Decodes the detection tensors generated by the TensorFlow Lite model, based on
# the SSD anchors and the specification in the options, into a vector of
# detections. Each detection describes a detected object.
node {
calculator: "TfLiteTensorsToDetectionsCalculator"
input_stream: "TENSORS:detection_tensors"
input_side_packet: "ANCHORS:anchors"
output_stream: "DETECTIONS:detections"
node_options: {
[type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] {
num_classes: 1
num_boxes: 896
num_coords: 16
box_coord_offset: 0
keypoint_coord_offset: 4
num_keypoints: 6
num_values_per_keypoint: 2
sigmoid_score: true
score_clipping_thresh: 100.0
reverse_output_order: true
x_scale: 128.0
y_scale: 128.0
h_scale: 128.0
w_scale: 128.0
min_score_thresh: 0.75
}
}
}
# Performs non-max suppression to remove excessive detections.
node {
calculator: "NonMaxSuppressionCalculator"
input_stream: "detections"
output_stream: "filtered_detections"
node_options: {
[type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] {
min_suppression_threshold: 0.3
overlap_type: INTERSECTION_OVER_UNION
algorithm: WEIGHTED
return_empty_detections: true
}
}
}
# Maps detection label IDs to the corresponding label text ("Face"). The label
# map is provided in the label_map_path option.
node {
calculator: "DetectionLabelIdToTextCalculator"
input_stream: "filtered_detections"
output_stream: "labeled_detections"
node_options: {
[type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] {
label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
}
}
}
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
# letterboxed image (after image transformation with the FIT scale mode) to the
# corresponding locations on the same image with the letterbox removed (the
# input image to the graph before image transformation).
node {
calculator: "DetectionLetterboxRemovalCalculator"
input_stream: "DETECTIONS:labeled_detections"
input_stream: "LETTERBOX_PADDING:letterbox_padding"
output_stream: "DETECTIONS:output_detections"
}
# Converts the detections to drawing primitives for annotation overlay.
node {
calculator: "DetectionsToRenderDataCalculator"
input_stream: "DETECTIONS:output_detections"
output_stream: "RENDER_DATA:render_data"
node_options: {
[type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] {
thickness: 4.0
color { r: 255 g: 0 b: 0 }
}
}
}
# Draws annotations and overlays them on top of the input images.
node {
calculator: "AnnotationOverlayCalculator"
input_stream: "INPUT_FRAME:throttled_input_video"
input_stream: "render_data"
output_stream: "OUTPUT_FRAME:output_video"
}

View File

@ -41,7 +41,7 @@ node: {
output_stream: "input_video_cpu" output_stream: "input_video_cpu"
} }
# Transforms the input image on GPU to a 128x128 image. To scale the input # Transforms the input image on CPU to a 128x128 image. To scale the input
# image, the scale_mode option is set to FIT to preserve the aspect ratio, # image, the scale_mode option is set to FIT to preserve the aspect ratio,
# resulting in potential letterboxing in the transformed image. # resulting in potential letterboxing in the transformed image.
node: { node: {
@ -75,7 +75,7 @@ node {
output_stream: "TENSORS:detection_tensors" output_stream: "TENSORS:detection_tensors"
node_options: { node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "face_detection_front.tflite" model_path: "mediapipe/models/face_detection_front.tflite"
} }
} }
} }
@ -156,7 +156,7 @@ node {
output_stream: "labeled_detections" output_stream: "labeled_detections"
node_options: { node_options: {
[type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] {
label_map_path: "face_detection_front_labelmap.txt" label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
} }
} }
} }
@ -179,7 +179,7 @@ node {
output_stream: "RENDER_DATA:render_data" output_stream: "RENDER_DATA:render_data"
node_options: { node_options: {
[type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] {
thickness: 10.0 thickness: 4.0
color { r: 255 g: 0 b: 0 } color { r: 255 g: 0 b: 0 }
} }
} }

View File

@ -62,10 +62,10 @@ node {
node { node {
calculator: "TfLiteInferenceCalculator" calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS_GPU:image_tensor" input_stream: "TENSORS_GPU:image_tensor"
output_stream: "TENSORS:detection_tensors" output_stream: "TENSORS_GPU:detection_tensors"
node_options: { node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "face_detection_front.tflite" model_path: "mediapipe/models/face_detection_front.tflite"
} }
} }
} }
@ -99,7 +99,7 @@ node {
# detections. Each detection describes a detected object. # detections. Each detection describes a detected object.
node { node {
calculator: "TfLiteTensorsToDetectionsCalculator" calculator: "TfLiteTensorsToDetectionsCalculator"
input_stream: "TENSORS:detection_tensors" input_stream: "TENSORS_GPU:detection_tensors"
input_side_packet: "ANCHORS:anchors" input_side_packet: "ANCHORS:anchors"
output_stream: "DETECTIONS:detections" output_stream: "DETECTIONS:detections"
node_options: { node_options: {
@ -146,7 +146,7 @@ node {
output_stream: "labeled_detections" output_stream: "labeled_detections"
node_options: { node_options: {
[type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] {
label_map_path: "face_detection_front_labelmap.txt" label_map_path: "mediapipe/models/face_detection_front_labelmap.txt"
} }
} }
} }
@ -169,7 +169,7 @@ node {
output_stream: "RENDER_DATA:render_data" output_stream: "RENDER_DATA:render_data"
node_options: { node_options: {
[type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] {
thickness: 10.0 thickness: 4.0
color { r: 255 g: 0 b: 0 } color { r: 255 g: 0 b: 0 }
} }
} }

View File

@ -111,7 +111,7 @@ node {
input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver" input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver"
node_options: { node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "hair_segmentation.tflite" model_path: "mediapipe/models/hair_segmentation.tflite"
use_gpu: true use_gpu: true
} }
} }

View File

@ -19,73 +19,35 @@ package(default_visibility = ["//visibility:public"])
load( load(
"//mediapipe/framework/tool:mediapipe_graph.bzl", "//mediapipe/framework/tool:mediapipe_graph.bzl",
"mediapipe_binary_graph", "mediapipe_binary_graph",
"mediapipe_simple_subgraph",
) )
mediapipe_simple_subgraph( cc_library(
name = "hand_detection_gpu", name = "desktop_tflite_calculators",
graph = "hand_detection_gpu.pbtxt",
register_as = "HandDetectionSubgraph",
deps = [ deps = [
"//mediapipe/calculators/image:image_properties_calculator", "//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/calculators/image:image_transformation_calculator", "//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/tflite:ssd_anchors_calculator", "//mediapipe/calculators/core:immediate_mux_calculator",
"//mediapipe/calculators/tflite:tflite_converter_calculator", "//mediapipe/calculators/core:merge_calculator",
"//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator", "//mediapipe/calculators/core:packet_inner_join_calculator",
"//mediapipe/calculators/tflite:tflite_inference_calculator", "//mediapipe/calculators/core:previous_loopback_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator", "//mediapipe/calculators/video:opencv_video_decoder_calculator",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator", "//mediapipe/calculators/video:opencv_video_encoder_calculator",
"//mediapipe/calculators/util:detection_letterbox_removal_calculator", "//mediapipe/graphs/hand_tracking/subgraphs:hand_detection_cpu",
"//mediapipe/calculators/util:detections_to_rects_calculator", "//mediapipe/graphs/hand_tracking/subgraphs:hand_landmark_cpu",
"//mediapipe/calculators/util:non_max_suppression_calculator", "//mediapipe/graphs/hand_tracking/subgraphs:renderer_cpu",
"//mediapipe/calculators/util:rect_transformation_calculator",
],
)
mediapipe_simple_subgraph(
name = "hand_landmark_gpu",
graph = "hand_landmark_gpu.pbtxt",
register_as = "HandLandmarkSubgraph",
deps = [
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/image:image_cropping_calculator",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/tflite:tflite_converter_calculator",
"//mediapipe/calculators/tflite:tflite_inference_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_floats_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_landmarks_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:landmark_letterbox_removal_calculator",
"//mediapipe/calculators/util:landmark_projection_calculator",
"//mediapipe/calculators/util:landmarks_to_detection_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/calculators/util:thresholding_calculator",
],
)
mediapipe_simple_subgraph(
name = "renderer_gpu",
graph = "renderer_gpu.pbtxt",
register_as = "RendererSubgraph",
deps = [
"//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:detections_to_render_data_calculator",
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
"//mediapipe/calculators/util:rect_to_render_data_calculator",
], ],
) )
cc_library( cc_library(
name = "mobile_calculators", name = "mobile_calculators",
deps = [ deps = [
":hand_detection_gpu",
":hand_landmark_gpu",
":renderer_gpu",
"//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/calculators/core:gate_calculator", "//mediapipe/calculators/core:gate_calculator",
"//mediapipe/calculators/core:merge_calculator", "//mediapipe/calculators/core:merge_calculator",
"//mediapipe/calculators/core:previous_loopback_calculator", "//mediapipe/calculators/core:previous_loopback_calculator",
"//mediapipe/graphs/hand_tracking/subgraphs:hand_detection_gpu",
"//mediapipe/graphs/hand_tracking/subgraphs:hand_landmark_gpu",
"//mediapipe/graphs/hand_tracking/subgraphs:renderer_gpu",
], ],
) )
@ -99,9 +61,9 @@ mediapipe_binary_graph(
cc_library( cc_library(
name = "detection_mobile_calculators", name = "detection_mobile_calculators",
deps = [ deps = [
":hand_detection_gpu",
":renderer_gpu",
"//mediapipe/calculators/core:flow_limiter_calculator", "//mediapipe/calculators/core:flow_limiter_calculator",
"//mediapipe/graphs/hand_tracking/subgraphs:hand_detection_gpu",
"//mediapipe/graphs/hand_tracking/subgraphs:renderer_gpu",
], ],
) )

View File

@ -0,0 +1,62 @@
# MediaPipe graph that performs hand detection on desktop with TensorFlow Lite
# on CPU.
# Used in the example in
# mediapipie/examples/desktop/hand_tracking:hand_detection_tflite.
# max_queue_size limits the number of packets enqueued on any input stream
# by throttling inputs to the graph. This makes the graph only process one
# frame per time.
max_queue_size: 1
# Decodes an input video file into images and a video header.
node {
calculator: "OpenCvVideoDecoderCalculator"
input_side_packet: "INPUT_FILE_PATH:input_video_path"
output_stream: "VIDEO:input_video"
output_stream: "VIDEO_PRESTREAM:input_video_header"
}
# Performs hand detection model on the input frames. See
# hand_detection_cpu.pbtxt for the detail of the sub-graph.
node {
calculator: "HandDetectionSubgraph"
input_stream: "input_video"
output_stream: "DETECTIONS:output_detections"
}
# Converts the detections to drawing primitives for annotation overlay.
node {
calculator: "DetectionsToRenderDataCalculator"
input_stream: "DETECTIONS:output_detections"
output_stream: "RENDER_DATA:render_data"
node_options: {
[type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] {
thickness: 4.0
color { r: 0 g: 255 b: 0 }
}
}
}
# Draws annotations and overlays them on top of the original image coming into
# the graph.
node {
calculator: "AnnotationOverlayCalculator"
input_stream: "INPUT_FRAME:input_video"
input_stream: "render_data"
output_stream: "OUTPUT_FRAME:output_video"
}
# Encodes the annotated images into a video file, adopting properties specified
# in the input video header, e.g., video framerate.
node {
calculator: "OpenCvVideoEncoderCalculator"
input_stream: "VIDEO:output_video"
input_stream: "VIDEO_PRESTREAM:input_video_header"
input_side_packet: "OUTPUT_FILE_PATH:output_video_path"
node_options: {
[type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: {
codec: "avc1"
video_format: "mp4"
}
}
}

View File

@ -0,0 +1,38 @@
# MediaPipe graph that performs hand detection on desktop with TensorFlow Lite
# on CPU.
# Used in the example in
# mediapipie/examples/desktop/hand_tracking:hand_detection_cpu.
# Images coming into and out of the graph.
input_stream: "input_video"
output_stream: "output_video"
# Performs hand detection model on the input frames. See
# hand_detection_cpu.pbtxt for the detail of the sub-graph.
node {
calculator: "HandDetectionSubgraph"
input_stream: "input_video"
output_stream: "DETECTIONS:output_detections"
}
# Converts the detections to drawing primitives for annotation overlay.
node {
calculator: "DetectionsToRenderDataCalculator"
input_stream: "DETECTIONS:output_detections"
output_stream: "RENDER_DATA:render_data"
node_options: {
[type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] {
thickness: 4.0
color { r: 0 g: 255 b: 0 }
}
}
}
# Draws annotations and overlays them on top of the original image coming into
# the graph.
node {
calculator: "AnnotationOverlayCalculator"
input_stream: "INPUT_FRAME:input_video"
input_stream: "render_data"
output_stream: "OUTPUT_FRAME:output_video"
}

View File

@ -0,0 +1,126 @@
# MediaPipe graph that performs hand tracking on desktop with TensorFlow Lite
# on CPU.
# Used in the example in
# mediapipie/examples/desktop/hand_tracking:hand_tracking_tflite.
# max_queue_size limits the number of packets enqueued on any input stream
# by throttling inputs to the graph. This makes the graph only process one
# frame per time.
max_queue_size: 1
# Decodes an input video file into images and a video header.
node {
calculator: "OpenCvVideoDecoderCalculator"
input_side_packet: "INPUT_FILE_PATH:input_video_path"
output_stream: "VIDEO:input_video"
output_stream: "VIDEO_PRESTREAM:input_video_header"
}
# Caches a hand-presence decision fed back from HandLandmarkSubgraph, and upon
# the arrival of the next input image sends out the cached decision with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous hand-presence decision. Note that upon the arrival
# of the very first input image, an empty packet is sent out to jump start the
# feedback loop.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:input_video"
input_stream: "LOOP:hand_presence"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:prev_hand_presence"
}
# Drops the incoming image if HandLandmarkSubgraph was able to identify hand
# presence in the previous image. Otherwise, passes the incoming image through
# to trigger a new round of hand detection in HandDetectionSubgraph.
node {
calculator: "GateCalculator"
input_stream: "input_video"
input_stream: "DISALLOW:prev_hand_presence"
output_stream: "hand_detection_input_video"
node_options: {
[type.googleapis.com/mediapipe.GateCalculatorOptions] {
empty_packets_as_allow: true
}
}
}
# Subgraph that detections hands (see hand_detection_cpu.pbtxt).
node {
calculator: "HandDetectionSubgraph"
input_stream: "hand_detection_input_video"
output_stream: "DETECTIONS:palm_detections"
output_stream: "NORM_RECT:hand_rect_from_palm_detections"
}
# Subgraph that localizes hand landmarks (see hand_landmark_cpu.pbtxt).
node {
calculator: "HandLandmarkSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "NORM_RECT:hand_rect"
output_stream: "LANDMARKS:hand_landmarks"
output_stream: "NORM_RECT:hand_rect_from_landmarks"
output_stream: "PRESENCE:hand_presence"
}
# Caches a hand rectangle fed back from HandLandmarkSubgraph, and upon the
# arrival of the next input image sends out the cached rectangle with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous hand rectangle. Note that upon the arrival of the
# very first input image, an empty packet is sent out to jump start the
# feedback loop.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:input_video"
input_stream: "LOOP:hand_rect_from_landmarks"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:prev_hand_rect_from_landmarks"
}
# Merges a stream of hand rectangles generated by HandDetectionSubgraph and that
# generated by HandLandmarkSubgraph into a single output stream by selecting
# between one of the two streams. The former is selected if the incoming packet
# is not empty, i.e., hand detection is performed on the current image by
# HandDetectionSubgraph (because HandLandmarkSubgraph could not identify hand
# presence in the previous image). Otherwise, the latter is selected, which is
# never empty because HandLandmarkSubgraphs processes all images (that went
# through FlowLimiterCaculator).
node {
calculator: "MergeCalculator"
input_stream: "hand_rect_from_palm_detections"
input_stream: "prev_hand_rect_from_landmarks"
output_stream: "hand_rect"
}
# Subgraph that renders annotations and overlays them on top of the input
# images (see renderer_cpu.pbtxt).
node {
calculator: "RendererSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "LANDMARKS:hand_landmarks"
input_stream: "NORM_RECT:hand_rect"
input_stream: "DETECTIONS:palm_detections"
output_stream: "IMAGE:output_video"
}
# Encodes the annotated images into a video file, adopting properties specified
# in the input video header, e.g., video framerate.
node {
calculator: "OpenCvVideoEncoderCalculator"
input_stream: "VIDEO:output_video"
input_stream: "VIDEO_PRESTREAM:input_video_header"
input_side_packet: "OUTPUT_FILE_PATH:output_video_path"
node_options: {
[type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: {
codec: "avc1"
video_format: "mp4"
}
}
}

View File

@ -0,0 +1,103 @@
# MediaPipe graph that performs hand tracking on desktop with TensorFlow Lite
# on CPU.
# Used in the example in
# mediapipie/examples/desktop/hand_tracking:hand_tracking_cpu.
# Images coming into and out of the graph.
input_stream: "input_video"
output_stream: "output_video"
# Caches a hand-presence decision fed back from HandLandmarkSubgraph, and upon
# the arrival of the next input image sends out the cached decision with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous hand-presence decision. Note that upon the arrival
# of the very first input image, an empty packet is sent out to jump start the
# feedback loop.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:input_video"
input_stream: "LOOP:hand_presence"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:prev_hand_presence"
}
# Drops the incoming image if HandLandmarkSubgraph was able to identify hand
# presence in the previous image. Otherwise, passes the incoming image through
# to trigger a new round of hand detection in HandDetectionSubgraph.
node {
calculator: "GateCalculator"
input_stream: "input_video"
input_stream: "DISALLOW:prev_hand_presence"
output_stream: "hand_detection_input_video"
node_options: {
[type.googleapis.com/mediapipe.GateCalculatorOptions] {
empty_packets_as_allow: true
}
}
}
# Subgraph that detections hands (see hand_detection_cpu.pbtxt).
node {
calculator: "HandDetectionSubgraph"
input_stream: "hand_detection_input_video"
output_stream: "DETECTIONS:palm_detections"
output_stream: "NORM_RECT:hand_rect_from_palm_detections"
}
# Subgraph that localizes hand landmarks (see hand_landmark_cpu.pbtxt).
node {
calculator: "HandLandmarkSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "NORM_RECT:hand_rect"
output_stream: "LANDMARKS:hand_landmarks"
output_stream: "NORM_RECT:hand_rect_from_landmarks"
output_stream: "PRESENCE:hand_presence"
}
# Caches a hand rectangle fed back from HandLandmarkSubgraph, and upon the
# arrival of the next input image sends out the cached rectangle with the
# timestamp replaced by that of the input image, essentially generating a packet
# that carries the previous hand rectangle. Note that upon the arrival of the
# very first input image, an empty packet is sent out to jump start the
# feedback loop.
node {
calculator: "PreviousLoopbackCalculator"
input_stream: "MAIN:input_video"
input_stream: "LOOP:hand_rect_from_landmarks"
input_stream_info: {
tag_index: "LOOP"
back_edge: true
}
output_stream: "PREV_LOOP:prev_hand_rect_from_landmarks"
}
# Merges a stream of hand rectangles generated by HandDetectionSubgraph and that
# generated by HandLandmarkSubgraph into a single output stream by selecting
# between one of the two streams. The former is selected if the incoming packet
# is not empty, i.e., hand detection is performed on the current image by
# HandDetectionSubgraph (because HandLandmarkSubgraph could not identify hand
# presence in the previous image). Otherwise, the latter is selected, which is
# never empty because HandLandmarkSubgraphs processes all images (that went
# through FlowLimiterCaculator).
node {
calculator: "MergeCalculator"
input_stream: "hand_rect_from_palm_detections"
input_stream: "prev_hand_rect_from_landmarks"
output_stream: "hand_rect"
}
# Subgraph that renders annotations and overlays them on top of the input
# images (see renderer_cpu.pbtxt).
node {
calculator: "RendererSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "LANDMARKS:hand_landmarks"
input_stream: "NORM_RECT:hand_rect"
input_stream: "DETECTIONS:palm_detections"
output_stream: "IMAGE:output_video"
}

View File

@ -0,0 +1,132 @@
# Copyright 2019 The MediaPipe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])
load(
"//mediapipe/framework/tool:mediapipe_graph.bzl",
"mediapipe_simple_subgraph",
)
mediapipe_simple_subgraph(
name = "hand_detection_cpu",
graph = "hand_detection_cpu.pbtxt",
register_as = "HandDetectionSubgraph",
deps = [
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/tflite:ssd_anchors_calculator",
"//mediapipe/calculators/tflite:tflite_converter_calculator",
"//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator",
"//mediapipe/calculators/tflite:tflite_inference_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
"//mediapipe/calculators/util:detection_letterbox_removal_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:detections_to_render_data_calculator",
"//mediapipe/calculators/util:non_max_suppression_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
],
)
mediapipe_simple_subgraph(
name = "hand_landmark_cpu",
graph = "hand_landmark_cpu.pbtxt",
register_as = "HandLandmarkSubgraph",
deps = [
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/image:image_cropping_calculator",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/tflite:tflite_converter_calculator",
"//mediapipe/calculators/tflite:tflite_inference_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_floats_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_landmarks_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:landmark_letterbox_removal_calculator",
"//mediapipe/calculators/util:landmark_projection_calculator",
"//mediapipe/calculators/util:landmarks_to_detection_calculator",
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/calculators/util:thresholding_calculator",
],
)
mediapipe_simple_subgraph(
name = "renderer_cpu",
graph = "renderer_cpu.pbtxt",
register_as = "RendererSubgraph",
deps = [
"//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:detections_to_render_data_calculator",
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
"//mediapipe/calculators/util:rect_to_render_data_calculator",
],
)
mediapipe_simple_subgraph(
name = "hand_detection_gpu",
graph = "hand_detection_gpu.pbtxt",
register_as = "HandDetectionSubgraph",
deps = [
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/tflite:ssd_anchors_calculator",
"//mediapipe/calculators/tflite:tflite_converter_calculator",
"//mediapipe/calculators/tflite:tflite_custom_op_resolver_calculator",
"//mediapipe/calculators/tflite:tflite_inference_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_detections_calculator",
"//mediapipe/calculators/util:detection_label_id_to_text_calculator",
"//mediapipe/calculators/util:detection_letterbox_removal_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:non_max_suppression_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
],
)
mediapipe_simple_subgraph(
name = "hand_landmark_gpu",
graph = "hand_landmark_gpu.pbtxt",
register_as = "HandLandmarkSubgraph",
deps = [
"//mediapipe/calculators/core:split_vector_calculator",
"//mediapipe/calculators/image:image_cropping_calculator",
"//mediapipe/calculators/image:image_properties_calculator",
"//mediapipe/calculators/image:image_transformation_calculator",
"//mediapipe/calculators/tflite:tflite_converter_calculator",
"//mediapipe/calculators/tflite:tflite_inference_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_floats_calculator",
"//mediapipe/calculators/tflite:tflite_tensors_to_landmarks_calculator",
"//mediapipe/calculators/util:detections_to_rects_calculator",
"//mediapipe/calculators/util:landmark_letterbox_removal_calculator",
"//mediapipe/calculators/util:landmark_projection_calculator",
"//mediapipe/calculators/util:landmarks_to_detection_calculator",
"//mediapipe/calculators/util:rect_transformation_calculator",
"//mediapipe/calculators/util:thresholding_calculator",
],
)
mediapipe_simple_subgraph(
name = "renderer_gpu",
graph = "renderer_gpu.pbtxt",
register_as = "RendererSubgraph",
deps = [
"//mediapipe/calculators/util:annotation_overlay_calculator",
"//mediapipe/calculators/util:detections_to_render_data_calculator",
"//mediapipe/calculators/util:landmarks_to_render_data_calculator",
"//mediapipe/calculators/util:rect_to_render_data_calculator",
],
)

View File

@ -0,0 +1,193 @@
# MediaPipe hand detection subgraph.
type: "HandDetectionSubgraph"
input_stream: "input_video"
output_stream: "DETECTIONS:palm_detections"
output_stream: "NORM_RECT:hand_rect_from_palm_detections"
# Transforms the input image on CPU to a 256x256 image. To scale the input
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
# resulting in potential letterboxing in the transformed image.
node: {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE:input_video"
output_stream: "IMAGE:transformed_input_video"
output_stream: "LETTERBOX_PADDING:letterbox_padding"
node_options: {
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
output_width: 256
output_height: 256
scale_mode: FIT
}
}
}
# Generates a single side packet containing a TensorFlow Lite op resolver that
# supports custom ops needed by the model used in this graph.
node {
calculator: "TfLiteCustomOpResolverCalculator"
output_side_packet: "op_resolver"
}
# Converts the transformed input image on CPU into an image tensor as a
# TfLiteTensor. The zero_center option is set to true to normalize the
# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f].
node {
calculator: "TfLiteConverterCalculator"
input_stream: "IMAGE:transformed_input_video"
output_stream: "TENSORS:image_tensor"
}
# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and
# scores.
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS:image_tensor"
output_stream: "TENSORS:detection_tensors"
input_side_packet: "CUSTOM_OP_RESOLVER:op_resolver"
node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "mediapipe/models/palm_detection.tflite"
}
}
}
# Generates a single side packet containing a vector of SSD anchors based on
# the specification in the options.
node {
calculator: "SsdAnchorsCalculator"
output_side_packet: "anchors"
node_options: {
[type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] {
num_layers: 5
min_scale: 0.1171875
max_scale: 0.75
input_size_height: 256
input_size_width: 256
anchor_offset_x: 0.5
anchor_offset_y: 0.5
strides: 8
strides: 16
strides: 32
strides: 32
strides: 32
aspect_ratios: 1.0
fixed_anchor_size: true
}
}
}
# Decodes the detection tensors generated by the TensorFlow Lite model, based on
# the SSD anchors and the specification in the options, into a vector of
# detections. Each detection describes a detected object.
node {
calculator: "TfLiteTensorsToDetectionsCalculator"
input_stream: "TENSORS:detection_tensors"
input_side_packet: "ANCHORS:anchors"
output_stream: "DETECTIONS:detections"
node_options: {
[type.googleapis.com/mediapipe.TfLiteTensorsToDetectionsCalculatorOptions] {
num_classes: 1
num_boxes: 2944
num_coords: 18
box_coord_offset: 0
keypoint_coord_offset: 4
num_keypoints: 7
num_values_per_keypoint: 2
sigmoid_score: true
score_clipping_thresh: 100.0
reverse_output_order: true
x_scale: 256.0
y_scale: 256.0
h_scale: 256.0
w_scale: 256.0
min_score_thresh: 0.5
}
}
}
# Performs non-max suppression to remove excessive detections.
node {
calculator: "NonMaxSuppressionCalculator"
input_stream: "detections"
output_stream: "filtered_detections"
node_options: {
[type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] {
min_suppression_threshold: 0.3
min_score_threshold: 0.5
overlap_type: INTERSECTION_OVER_UNION
algorithm: WEIGHTED
return_empty_detections: true
}
}
}
# Maps detection label IDs to the corresponding label text. The label map is
# provided in the label_map_path option.
node {
calculator: "DetectionLabelIdToTextCalculator"
input_stream: "filtered_detections"
output_stream: "labeled_detections"
node_options: {
[type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] {
label_map_path: "mediapipe/models/palm_detection_labelmap.txt"
}
}
}
# Adjusts detection locations (already normalized to [0.f, 1.f]) on the
# letterboxed image (after image transformation with the FIT scale mode) to the
# corresponding locations on the same image with the letterbox removed (the
# input image to the graph before image transformation).
node {
calculator: "DetectionLetterboxRemovalCalculator"
input_stream: "DETECTIONS:labeled_detections"
input_stream: "LETTERBOX_PADDING:letterbox_padding"
output_stream: "DETECTIONS:palm_detections"
}
# Extracts image size from the input images.
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE:input_video"
output_stream: "SIZE:image_size"
}
# Converts results of palm detection into a rectangle (normalized by image size)
# that encloses the palm and is rotated such that the line connecting center of
# the wrist and MCP of the middle finger is aligned with the Y-axis of the
# rectangle.
node {
calculator: "DetectionsToRectsCalculator"
input_stream: "DETECTIONS:palm_detections"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "NORM_RECT:palm_rect"
node_options: {
[type.googleapis.com/mediapipe.DetectionsToRectsCalculatorOptions] {
rotation_vector_start_keypoint_index: 0 # Center of wrist.
rotation_vector_end_keypoint_index: 2 # MCP of middle finger.
rotation_vector_target_angle_degrees: 90
output_zero_rect_for_empty_detections: true
}
}
}
# Expands and shifts the rectangle that contains the palm so that it's likely
# to cover the entire hand.
node {
calculator: "RectTransformationCalculator"
input_stream: "NORM_RECT:palm_rect"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "hand_rect_from_palm_detections"
node_options: {
[type.googleapis.com/mediapipe.RectTransformationCalculatorOptions] {
scale_x: 2.6
scale_y: 2.6
shift_y: -0.5
square_long: true
}
}
}

View File

@ -49,11 +49,11 @@ node {
node { node {
calculator: "TfLiteInferenceCalculator" calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS_GPU:image_tensor" input_stream: "TENSORS_GPU:image_tensor"
output_stream: "TENSORS:detection_tensors" output_stream: "TENSORS_GPU:detection_tensors"
input_side_packet: "CUSTOM_OP_RESOLVER:opresolver" input_side_packet: "CUSTOM_OP_RESOLVER:opresolver"
node_options: { node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] { [type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "palm_detection.tflite" model_path: "mediapipe/models/palm_detection.tflite"
use_gpu: true use_gpu: true
} }
} }
@ -89,7 +89,7 @@ node {
# detections. Each detection describes a detected object. # detections. Each detection describes a detected object.
node { node {
calculator: "TfLiteTensorsToDetectionsCalculator" calculator: "TfLiteTensorsToDetectionsCalculator"
input_stream: "TENSORS:detection_tensors" input_stream: "TENSORS_GPU:detection_tensors"
input_side_packet: "ANCHORS:anchors" input_side_packet: "ANCHORS:anchors"
output_stream: "DETECTIONS:detections" output_stream: "DETECTIONS:detections"
node_options: { node_options: {
@ -137,7 +137,7 @@ node {
output_stream: "labeled_detections" output_stream: "labeled_detections"
node_options: { node_options: {
[type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] {
label_map_path: "palm_detection_labelmap.txt" label_map_path: "mediapipe/models/palm_detection_labelmap.txt"
} }
} }
} }

View File

@ -0,0 +1,185 @@
# MediaPipe hand landmark localization subgraph.
type: "HandLandmarkSubgraph"
input_stream: "IMAGE:input_video"
input_stream: "NORM_RECT:hand_rect"
output_stream: "LANDMARKS:hand_landmarks"
output_stream: "NORM_RECT:hand_rect_for_next_frame"
output_stream: "PRESENCE:hand_presence"
# Crops the rectangle that contains a hand from the input image.
node {
calculator: "ImageCroppingCalculator"
input_stream: "IMAGE:input_video"
input_stream: "NORM_RECT:hand_rect"
output_stream: "IMAGE:hand_image"
}
# Transforms the input image on CPU to a 256x256 image. To scale the input
# image, the scale_mode option is set to FIT to preserve the aspect ratio,
# resulting in potential letterboxing in the transformed image.
node: {
calculator: "ImageTransformationCalculator"
input_stream: "IMAGE:hand_image"
output_stream: "IMAGE:transformed_input_video"
output_stream: "LETTERBOX_PADDING:letterbox_padding"
node_options: {
[type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] {
output_width: 256
output_height: 256
scale_mode: FIT
}
}
}
# Converts the transformed input image on GPU into an image tensor stored in
# tflite::gpu::GlBuffer. The zero_center option is set to true to normalize the
# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. The flip_vertically
# option is set to true to account for the descrepancy between the
# representation of the input image (origin at the bottom-left corner, the
# OpenGL convention) and what the model used in this graph is expecting (origin
# at the top-left corner).
node {
calculator: "TfLiteConverterCalculator"
input_stream: "IMAGE:transformed_input_video"
output_stream: "TENSORS:image_tensor"
node_options: {
[type.googleapis.com/mediapipe.TfLiteConverterCalculatorOptions] {
zero_center: false
}
}
}
# Runs a TensorFlow Lite model on GPU that takes an image tensor and outputs a
# vector of tensors representing, for instance, detection boxes/keypoints and
# scores.
node {
calculator: "TfLiteInferenceCalculator"
input_stream: "TENSORS:image_tensor"
output_stream: "TENSORS:output_tensors"
node_options: {
[type.googleapis.com/mediapipe.TfLiteInferenceCalculatorOptions] {
model_path: "mediapipe/models/hand_landmark.tflite"
}
}
}
# Splits a vector of TFLite tensors to multiple vectors according to the ranges
# specified in option.
node {
calculator: "SplitTfLiteTensorVectorCalculator"
input_stream: "output_tensors"
output_stream: "landmark_tensors"
output_stream: "hand_flag_tensor"
node_options: {
[type.googleapis.com/mediapipe.SplitVectorCalculatorOptions] {
ranges: { begin: 0 end: 1 }
ranges: { begin: 1 end: 2 }
}
}
}
# Converts the hand-flag tensor into a float that represents the confidence
# score of hand presence.
node {
calculator: "TfLiteTensorsToFloatsCalculator"
input_stream: "TENSORS:hand_flag_tensor"
output_stream: "FLOAT:hand_presence_score"
}
# Applies a threshold to the confidence score to determine whether a hand is
# present.
node {
calculator: "ThresholdingCalculator"
input_stream: "FLOAT:hand_presence_score"
output_stream: "FLAG:hand_presence"
node_options: {
[type.googleapis.com/mediapipe.ThresholdingCalculatorOptions] {
threshold: 0.1
}
}
}
# Decodes the landmark tensors into a vector of lanmarks, where the landmark
# coordinates are normalized by the size of the input image to the model.
node {
calculator: "TfLiteTensorsToLandmarksCalculator"
input_stream: "TENSORS:landmark_tensors"
output_stream: "NORM_LANDMARKS:landmarks"
node_options: {
[type.googleapis.com/mediapipe.TfLiteTensorsToLandmarksCalculatorOptions] {
num_landmarks: 21
input_image_width: 256
input_image_height: 256
}
}
}
# Adjusts landmarks (already normalized to [0.f, 1.f]) on the letterboxed hand
# image (after image transformation with the FIT scale mode) to the
# corresponding locations on the same image with the letterbox removed (hand
# image before image transformation).
node {
calculator: "LandmarkLetterboxRemovalCalculator"
input_stream: "LANDMARKS:landmarks"
input_stream: "LETTERBOX_PADDING:letterbox_padding"
output_stream: "LANDMARKS:scaled_landmarks"
}
# Projects the landmarks from the cropped hand image to the corresponding
# locations on the full image before cropping (input to the graph).
node {
calculator: "LandmarkProjectionCalculator"
input_stream: "NORM_LANDMARKS:scaled_landmarks"
input_stream: "NORM_RECT:hand_rect"
output_stream: "NORM_LANDMARKS:hand_landmarks"
}
# Extracts image size from the input images.
node {
calculator: "ImagePropertiesCalculator"
input_stream: "IMAGE:input_video"
output_stream: "SIZE:image_size"
}
# Converts hand landmarks to a detection that tightly encloses all landmarks.
node {
calculator: "LandmarksToDetectionCalculator"
input_stream: "NORM_LANDMARKS:hand_landmarks"
output_stream: "DETECTION:hand_detection"
}
# Converts the hand detection into a rectangle (normalized by image size)
# that encloses the hand and is rotated such that the line connecting center of
# the wrist and MCP of the middle finger is aligned with the Y-axis of the
# rectangle.
node {
calculator: "DetectionsToRectsCalculator"
input_stream: "DETECTION:hand_detection"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "NORM_RECT:hand_rect_from_landmarks"
node_options: {
[type.googleapis.com/mediapipe.DetectionsToRectsCalculatorOptions] {
rotation_vector_start_keypoint_index: 0 # Center of wrist.
rotation_vector_end_keypoint_index: 9 # MCP of middle finger.
rotation_vector_target_angle_degrees: 90
}
}
}
# Expands the hand rectangle so that in the next video frame it's likely to
# still contain the hand even with some motion.
node {
calculator: "RectTransformationCalculator"
input_stream: "NORM_RECT:hand_rect_from_landmarks"
input_stream: "IMAGE_SIZE:image_size"
output_stream: "hand_rect_for_next_frame"
node_options: {
[type.googleapis.com/mediapipe.RectTransformationCalculatorOptions] {
scale_x: 1.6
scale_y: 1.6
square_long: true
}
}
}

Some files were not shown because too many files have changed in this diff Show More